Relay 函数级 Pass

Relay 函数级 Pass#

import tvm
from tvm import relay

@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
    """简单的测试函数,将一个参数替换为另一个参数。"""

    def __init__(self, new_func):
        self.new_func = new_func

    def transform_function(self, func, mod, ctx):
        innerstr1 = "="*40
        innerstr2 = "*"*40
        des = f"func:\n{innerstr1}\n{func}\n{innerstr2}\n"
        des += f"mod\n{innerstr1}:\n{mod}\n{innerstr2}\n"
        des += f"ctx:\n{innerstr1}\n{ctx}\n"
        print(des)
        return self.new_func
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
fpass = TestReplaceFunc(f1)
assert fpass.info.opt_level == 1
assert fpass.info.name == "TestReplaceFunc"
mod = tvm.IRModule.from_expr(f2)
mod = fpass(mod)
# wrap in expr
mod2 = tvm.IRModule.from_expr(f1)
mod2 = tvm.relay.transform.InferType()(mod2)
assert tvm.ir.structural_equal(mod["main"], mod2["main"])
func:
========================================
fn (%x: Tensor[(10, 20), float32]) {
  log(%x)
}
****************************************
mod
========================================:
def @main(%x: Tensor[(10, 20), float32]) {
  log(%x)
}

****************************************
ctx:
========================================
Pass context information: 
	opt_level: 2
	required passes: []
	disabled passes: []
	config: {}

也可以直接装饰函数:

@relay.transform.function_pass(opt_level=1)
def transform(expr, mod, ctx):
    ...