InferCorrectLayoutOutput

InferCorrectLayoutOutput#

import set_env
import tvm
from tvm import relay
from tvm.relay.testing.temp_op_attr import TempOpAttr
from tvm.relay.op import op as _op
op_name = "custom_op"

_op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
_op.get(op_name).set_num_inputs(2)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
# 调用默认关系函数
_op.get(op_name).add_type_rel("Identity")
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)

assert _op.get(op_name).name == op_name
assert _op.get(op_name).num_inputs == 2
assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE
assert _op.get(op_name).get_attr("TOpIsStateful") == False
_op.register_infer_correct_layout??
Signature: _op.register_infer_correct_layout(op_name, infer_layout=None, level=10)
Source:   
def register_infer_correct_layout(op_name, infer_layout=None, level=10):
    """Register infer op layout function for an op

    Parameters
    ----------
    op_name : str
        The name of the operator

    infer_layout: function (attrs: Attrs, inputs: List[Layout]) -> InferCorrectLayoutOutput
        The function to infer correct layout

    level : int
        The priority level
    """
    return tvm.ir.register_op_attr(op_name, "FInferCorrectLayout", infer_layout, level)
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/op/op.py
Type:      function

register_infer_correct_layout() 用于为算子注册推断布局函数。

参数说明:

  • op_name:字符串类型,表示算子的名称。

  • infer_layout:函数类型,接受两个参数 attrsinputs,返回 InferCorrectLayoutOutput 类型的对象。该函数用于推断正确的布局。

  • level:整数类型,表示优先级级别。

from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput
from tvm.ir import Attrs
from tvm.tir.data_layout import Layout

def infer_layout(attrs: Attrs, inputs: list[Layout]):
    return InferCorrectLayoutOutput()
_op.register_infer_correct_layout(op_name, infer_layout)
<function __main__.infer_layout(attrs: tvm.ir.attrs.Attrs, inputs: list[tvm.tir.data_layout.Layout])>
InferCorrectLayoutOutput?
Init signature: InferCorrectLayoutOutput(input_layouts, output_layouts, new_attrs)
Docstring:      An output structure to hold results from FInferCorrectLayout calls.
File:           /media/pc/data/lxw/ai/tvm/python/tvm/relay/transform/infer_layout_utils.py
Type:           type
Subclasses: