InferCorrectLayoutOutput#
import set_env
import tvm
from tvm import relay
from tvm.relay.testing.temp_op_attr import TempOpAttr
from tvm.relay.op import op as _op
op_name = "custom_op"
_op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
_op.get(op_name).set_num_inputs(2)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
# 调用默认关系函数
_op.get(op_name).add_type_rel("Identity")
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)
assert _op.get(op_name).name == op_name
assert _op.get(op_name).num_inputs == 2
assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE
assert _op.get(op_name).get_attr("TOpIsStateful") == False
_op.register_infer_correct_layout??
Signature: _op.register_infer_correct_layout(op_name, infer_layout=None, level=10)
Source:
def register_infer_correct_layout(op_name, infer_layout=None, level=10):
"""Register infer op layout function for an op
Parameters
----------
op_name : str
The name of the operator
infer_layout: function (attrs: Attrs, inputs: List[Layout]) -> InferCorrectLayoutOutput
The function to infer correct layout
level : int
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FInferCorrectLayout", infer_layout, level)
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/op/op.py
Type: function
register_infer_correct_layout()
用于为算子注册推断布局函数。
参数说明:
op_name
:字符串类型,表示算子的名称。infer_layout
:函数类型,接受两个参数attrs
和inputs
,返回InferCorrectLayoutOutput
类型的对象。该函数用于推断正确的布局。level
:整数类型,表示优先级级别。
from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput
from tvm.ir import Attrs
from tvm.tir.data_layout import Layout
def infer_layout(attrs: Attrs, inputs: list[Layout]):
return InferCorrectLayoutOutput()
_op.register_infer_correct_layout(op_name, infer_layout)
<function __main__.infer_layout(attrs: tvm.ir.attrs.Attrs, inputs: list[tvm.tir.data_layout.Layout])>
InferCorrectLayoutOutput?
Init signature: InferCorrectLayoutOutput(input_layouts, output_layouts, new_attrs)
Docstring: An output structure to hold results from FInferCorrectLayout calls.
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/transform/infer_layout_utils.py
Type: type
Subclasses: