op IR#
import testing
import tvm
from tvm import relay
from tvm.relay.testing.temp_op_attr import TempOpAttr
from tvm.relay.op import op as _op
op 属性#
属性访问:
log_op = relay.op.get("log")
assert log_op.num_inputs == 1
注册 op 属性函数:
@tvm.ir.register_op_attr("exp", "ftest")
def test(x):
return x + 1
assert log_op.get_attr("ftest") is None
assert relay.op.get("exp").get_attr("ftest")(1) == 2
重置属性函数:
def add1(x):
return x + 1
def add2(x):
return x + 2
# 注册 fadd1 和 fadd2 属性
tvm.ir.register_op_attr("exp", "fadd1", add1)
tvm.ir.register_op_attr("log", "fadd1", add1)
tvm.ir.register_op_attr("log", "fadd2", add2)
<function __main__.add2(x)>
重置 log
属性函数:
log_op = relay.op.get("log")
log_op.reset_attr("fadd1")
# 检查 fadd1 属性是否已重置。
assert log_op.get_attr("fadd1") is None
# 检查其他算子的 fadd1 属性是否完好无损。
assert relay.op.get("exp").get_attr("fadd1")(1) == 2
# 检查 log 算子的其他属性是否完好无损。
assert relay.op.get("log").get_attr("fadd2")(1) == 3
op 临时属性#
def add1(x):
return x + 1
def add2(x):
return x + 2
# 将原始 attr 值设置为add1。
tvm.ir.register_op_attr("sqrt", "ftest", add1)
with TempOpAttr("sqrt", "ftest", add2):
# 检查 attr 值是否已更新为 add2。
assert relay.op.get("sqrt").get_attr("ftest")(1) == 3
# 检查 attr 值是否已恢复为 add1。
assert relay.op.get("sqrt").get_attr("ftest")(1) == 2
op 注册#
op_name = "custom_op"
_op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
_op.get(op_name).set_num_inputs(2)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
# 调用默认关系函数
_op.get(op_name).add_type_rel("Identity")
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)
assert _op.get(op_name).name == op_name
assert _op.get(op_name).num_inputs == 2
assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE
assert _op.get(op_name).get_attr("TOpIsStateful") == False
备注
"TOpIsStateful"
为 True
表示算子是有状态的或包含内部状态。
我们总是可以通过添加额外的句柄参数并返回它来处理有状态的算子。
_op.register??
Signature: _op.register(op_name, describe='')
Source:
def register(op_name, describe=""):
"""Get the Op for a given name.
when the op_name is not registered, create a new empty op with the given name.
when the op_name has been registered, abort with an error message.
Parameters
----------
op_name : str
The operator name
describe : Optional[str]
The operator description
"""
tvm.ir._ffi_api.RegisterOp(op_name, describe)
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/op/op.py
Type: function