# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Primitive operators in the TVM IR."""
import tvm._ffi
from . import _ffi_api
from .expr import RelayExpr
[文档]
@tvm._ffi.register_object("Op")
class Op(RelayExpr):
"""Primitive operator in the IR."""
def __init__(self):
raise RuntimeError("Cannot create op, use get instead")
[文档]
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
Parameters
----------
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[Object->str]
Optionally annotate function to provide additional
information in the comment block.
Returns
-------
text : str
The text format of the expression.
Notes
-----
The meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big (e.g constant weights),
so it can be helpful to skip printing the meta data section.
"""
from tvm.relay import astext # pylint: disable=import-outside-toplevel
return astext(self, show_meta_data, annotate)
[文档]
@staticmethod
def get(op_name):
"""Get the Op for a given name
Parameters
----------
op_name : str
The operator name
Returns
-------
op : Op
The op of the corresponding name
"""
return _ffi_api.GetOp(op_name)
[文档]
def get_attr(self, attr_name):
"""Get additional attribute about the operator.
Parameters
----------
attr_name : str
The attribute name.
Returns
-------
value : object
The attribute value
"""
return _ffi_api.OpGetAttr(self, attr_name)
[文档]
def has_attr(self, attr_name):
"""Check whether the operator has additional attribute.
Parameters
----------
attr_name : str
The attribute name.
Returns
-------
value : bool
Whether the operator has additional attribute
"""
return _ffi_api.OpHasAttr(self, attr_name)
[文档]
def set_attr(self, attr_name, value, plevel=10):
"""Set attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
value : object
The attribute value
plevel : int
The priority level
"""
_ffi_api.OpSetAttr(self, attr_name, value, plevel)
[文档]
def reset_attr(self, attr_name):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
"""
_ffi_api.OpResetAttr(self, attr_name)
[文档]
def add_type_rel(self, rel_name, type_rel_func=None):
"""Attach the type function corresponding to the return type.
Parameters
----------
rel_name : str
The type relation name to register.
type_rel_func : Optional[function (args: List[Type], attrs: Attrs) -> Type]
The backing relation function which can solve an arbitrary relation on variables.
Differences with type_rel_func in C++:
1) When type_rel_func is not None
a) OpAddTypeRel on C++ side will adjust type_rel_func with TypeReporter to
calling convention of relay type system.
b) type_rel_func returns output argument's type, return None means can't
infer output's type.
c) only support single output operators for now, the last argument is output tensor.
2) when type_rel_func is None, will call predefined type_rel_funcs in relay
according to ``tvm.relay.type_relation.`` + rel_name.
"""
_ffi_api.OpAddTypeRel(self, rel_name, type_rel_func)
[文档]
def add_argument(self, name, type, description): # pylint: disable=redefined-builtin
"""Add arguments information to the function.
Parameters
----------
name : str
The argument name.
type : str
The argument type.
description : str
The argument description.
"""
_ffi_api.OpAddArgument(self, name, type, description)
[文档]
def set_support_level(self, level):
"""Set the support level of op.
Parameters
----------
level : int
The support level.
"""
_ffi_api.OpSetSupportLevel(self, level)
[文档]
def set_attrs_type_key(self, key):
"""Set the attribute type key of op.
Parameters
----------
key : str
The type key.
"""
_ffi_api.OpSetAttrsTypeKey(self, key)
[文档]
@staticmethod
def list_op_names():
"""List all the op names in the op registry.
Returns
-------
value : List[str]
The registered op names
"""
return _ffi_api.ListOpNames()
[文档]
def register_op_attr(op_name, attr_key, value=None, level=10):
"""Register an operator property of an operator by name.
Parameters
----------
op_name : str
The name of operator
attr_key : str
The attribute name.
value : object, optional
The value to set
level : int, optional
The priority level
Returns
-------
fregister : function
Register function if value is not specified.
"""
def _register(v):
"""internal register function"""
_ffi_api.RegisterOpAttr(op_name, attr_key, v, level)
return v
return _register(value) if value is not None else _register
[文档]
def register_intrin_lowering(
op_name,
target,
*,
f=None,
level=10,
):
"""Register Op lowering function
Parameters
----------
op_name : str
The op name
target : str
The target string for given intrinsic lowering function
f : function, optional
The function to be registered.
level : int
The priority level
Returns
-------
fregister : function
Register op lowering function if f is not specified.
"""
def _register(f):
"""internal register function"""
_ffi_api.RegisterOpLowerIntrinsic(op_name, f, target, level)
return f
return _register(f) if f is not None else _register