注册算子属性#

from tvm.ir import Op
import tvm

测试算子属性#

log_op = Op.get("relax.log")

@tvm.ir.register_op_attr("exp", "ftest")
def test(x):
    return x + 1

assert log_op.num_inputs == 1
assert log_op.get_attr("ftest") is None
assert Op.get("exp").get_attr("ftest")(1) == 2

重置算子属性#

def add1(x):
    return x + 1

def add2(x):
    return x + 2

# Register fadd1 and fadd2 attributes.
tvm.ir.register_op_attr("exp", "fadd1", add1)
tvm.ir.register_op_attr("log", "fadd1", add1)
tvm.ir.register_op_attr("log", "fadd2", add2)

# Reset log fadd1 attr.
log_op = Op.get("log")
log_op.reset_attr("fadd1")

# Check that fadd1 attr is resetted.
assert log_op.get_attr("fadd1") is None

# Check that fadd1 attr of other ops are intact.
assert Op.get("exp").get_attr("fadd1")(1) == 2

# Check that other attrs of the log op are intact.
assert Op.get("log").get_attr("fadd2")(1) == 3

缓存算子属性#

from dataclasses import dataclass

@dataclass
class TempOpAttr:
    """Temporarily changes the attr of an op. Saves the required info for RAII pattern usage.
    
    Examples
        --------
        .. code-block:: python

        # Temporarily update FTVMAlterOpLayout to a user-defined packed function.
        # After the test is finished, the attr value will be set back to the original value.

        with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
            my_mod = relay.transform.AlterOpLayout()(my_mod)
    """
    op_name : str # The op name.
    attr_key : str # The attribute name.
    attr_value : object # The attribute value.

    def __post_init__(self):
        self.op = Op.get(self.op_name)

    def __enter__(self):
        self.older_attr = self.op.get_attr(self.attr_key)
        self.op.reset_attr(self.attr_key)
        self.op.set_attr(self.attr_key, self.attr_value)
        return self

    def __exit__(self, ptype, value, trace):
        self.op.reset_attr(self.attr_key)
        if self.older_attr:
            self.op.set_attr(self.attr_key, self.older_attr)
def add1(x):
    return x + 1

def add2(x):
    return x + 2

# Set original attr value is add1.
tvm.ir.register_op_attr("sqrt", "ftest", add1)

with TempOpAttr("sqrt", "ftest", add2):
    # Check that the attr value is updated to add2.
    assert Op.get("sqrt").get_attr("ftest")(1) == 3

# Check that the attr value is recovered to add1.
assert Op.get("sqrt").get_attr("ftest")(1) == 2

算子级别#

from tvm import relax
x = relax.Var("x")

for op_name in ["log", "exp", "sqrt", "rsqrt", "tanh"]:
    y = getattr(relax.op, op_name)(x)
    assert y.op.name == f"relax.{op_name}"
    assert y.op.support_level == 10
    assert y.args[0] == x