Relay 函数级 Pass#
import tvm
from tvm import relay
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
"""简单的测试函数,将一个参数替换为另一个参数。"""
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
innerstr1 = "="*40
innerstr2 = "*"*40
des = f"func:\n{innerstr1}\n{func}\n{innerstr2}\n"
des += f"mod\n{innerstr1}:\n{mod}\n{innerstr2}\n"
des += f"ctx:\n{innerstr1}\n{ctx}\n"
print(des)
return self.new_func
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
fpass = TestReplaceFunc(f1)
assert fpass.info.opt_level == 1
assert fpass.info.name == "TestReplaceFunc"
mod = tvm.IRModule.from_expr(f2)
mod = fpass(mod)
# wrap in expr
mod2 = tvm.IRModule.from_expr(f1)
mod2 = tvm.relay.transform.InferType()(mod2)
assert tvm.ir.structural_equal(mod["main"], mod2["main"])
func:
========================================
fn (%x: Tensor[(10, 20), float32]) {
log(%x)
}
****************************************
mod
========================================:
def @main(%x: Tensor[(10, 20), float32]) {
log(%x)
}
****************************************
ctx:
========================================
Pass context information:
opt_level: 2
required passes: []
disabled passes: []
config: {}
也可以直接装饰函数:
@relay.transform.function_pass(opt_level=1)
def transform(expr, mod, ctx):
...