# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Primitive operators in the TVM IR."""
import tvm_ffi
from . import _ffi_api
from .expr import RelaxExpr
[文档]
@tvm_ffi.register_object("ir.Op")
class Op(RelaxExpr):
"""Primitive operator in the IR."""
def __init__(self):
raise RuntimeError("Cannot create op, use get instead")
[文档]
@staticmethod
def get(op_name):
"""Get the Op for a given name
Parameters
----------
op_name : str
The operator name
Returns
-------
op : Op
The op of the corresponding name
"""
return _ffi_api.GetOp(op_name)
[文档]
def get_attr(self, attr_name):
"""Get additional attribute about the operator.
Parameters
----------
attr_name : str
The attribute name.
Returns
-------
value : object
The attribute value
"""
return _ffi_api.OpGetAttr(self, attr_name)
[文档]
def has_attr(self, attr_name):
"""Check whether the operator has additional attribute.
Parameters
----------
attr_name : str
The attribute name.
Returns
-------
value : bool
Whether the operator has additional attribute
"""
return _ffi_api.OpHasAttr(self, attr_name)
[文档]
def set_attr(self, attr_name, value, plevel=10):
"""Set attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
value : object
The attribute value
plevel : int
The priority level
"""
_ffi_api.OpSetAttr(self, attr_name, value, plevel)
[文档]
def reset_attr(self, attr_name):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
"""
_ffi_api.OpResetAttr(self, attr_name)
[文档]
def add_argument(self, name, type, description): # pylint: disable=redefined-builtin
"""Add arguments information to the function.
Parameters
----------
name : str
The argument name.
type : str
The argument type.
description : str
The argument description.
"""
_ffi_api.OpAddArgument(self, name, type, description)
[文档]
def set_support_level(self, level):
"""Set the support level of op.
Parameters
----------
level : int
The support level.
"""
_ffi_api.OpSetSupportLevel(self, level)
[文档]
def set_attrs_type_key(self, key):
"""Set the attribute type key of op.
Parameters
----------
key : str
The type key.
"""
_ffi_api.OpSetAttrsTypeKey(self, key)
[文档]
@staticmethod
def list_op_names():
"""List all the op names in the op registry.
Returns
-------
value : List[str]
The registered op names
"""
return _ffi_api.ListOpNames()
[文档]
def register_op_attr(op_name, attr_key, value=None, level=10):
"""Register an operator property of an operator by name.
Parameters
----------
op_name : str
The name of operator
attr_key : str
The attribute name.
value : object, optional
The value to set
level : int, optional
The priority level
Returns
-------
fregister : function
Register function if value is not specified.
"""
def _register(v):
"""internal register function"""
_ffi_api.RegisterOpAttr(op_name, attr_key, v, level)
return v
return _register(value) if value is not None else _register
[文档]
def register_intrin_lowering(
op_name,
target,
*,
f=None,
level=10,
):
"""Register Op lowering function
Parameters
----------
op_name : str
The op name
target : str
The target string for given intrinsic lowering function
f : function, optional
The function to be registered.
level : int
The priority level
Returns
-------
fregister : function
Register op lowering function if f is not specified.
"""
def _register(f):
"""internal register function"""
_ffi_api.RegisterOpLowerIntrinsic(op_name, f, target, level)
return f
return _register(f) if f is not None else _register