注册算子属性#
from tvm.ir import Op
import tvm
测试算子属性#
log_op = Op.get("relax.log")
@tvm.ir.register_op_attr("exp", "ftest")
def test(x):
return x + 1
assert log_op.num_inputs == 1
assert log_op.get_attr("ftest") is None
assert Op.get("exp").get_attr("ftest")(1) == 2
重置算子属性#
def add1(x):
return x + 1
def add2(x):
return x + 2
# Register fadd1 and fadd2 attributes.
tvm.ir.register_op_attr("exp", "fadd1", add1)
tvm.ir.register_op_attr("log", "fadd1", add1)
tvm.ir.register_op_attr("log", "fadd2", add2)
# Reset log fadd1 attr.
log_op = Op.get("log")
log_op.reset_attr("fadd1")
# Check that fadd1 attr is resetted.
assert log_op.get_attr("fadd1") is None
# Check that fadd1 attr of other ops are intact.
assert Op.get("exp").get_attr("fadd1")(1) == 2
# Check that other attrs of the log op are intact.
assert Op.get("log").get_attr("fadd2")(1) == 3
缓存算子属性#
from dataclasses import dataclass
@dataclass
class TempOpAttr:
"""Temporarily changes the attr of an op. Saves the required info for RAII pattern usage.
Examples
--------
.. code-block:: python
# Temporarily update FTVMAlterOpLayout to a user-defined packed function.
# After the test is finished, the attr value will be set back to the original value.
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
my_mod = relay.transform.AlterOpLayout()(my_mod)
"""
op_name : str # The op name.
attr_key : str # The attribute name.
attr_value : object # The attribute value.
def __post_init__(self):
self.op = Op.get(self.op_name)
def __enter__(self):
self.older_attr = self.op.get_attr(self.attr_key)
self.op.reset_attr(self.attr_key)
self.op.set_attr(self.attr_key, self.attr_value)
return self
def __exit__(self, ptype, value, trace):
self.op.reset_attr(self.attr_key)
if self.older_attr:
self.op.set_attr(self.attr_key, self.older_attr)
def add1(x):
return x + 1
def add2(x):
return x + 2
# Set original attr value is add1.
tvm.ir.register_op_attr("sqrt", "ftest", add1)
with TempOpAttr("sqrt", "ftest", add2):
# Check that the attr value is updated to add2.
assert Op.get("sqrt").get_attr("ftest")(1) == 3
# Check that the attr value is recovered to add1.
assert Op.get("sqrt").get_attr("ftest")(1) == 2
算子级别#
from tvm import relax
x = relax.Var("x")
for op_name in ["log", "exp", "sqrt", "rsqrt", "tanh"]:
y = getattr(relax.op, op_name)(x)
assert y.op.name == f"relax.{op_name}"
assert y.op.support_level == 10
assert y.args[0] == x