Relay Sequential pass#
import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import run_infer_type
def check_func(func, ref_func):
func = run_infer_type(func)
ref_func = run_infer_type(ref_func)
assert tvm.ir.structural_equal(func, ref_func)
def extract_var_func(mod, name):
var = mod.get_global_var(name)
func = mod[var]
return var, func
def get_rand(shape, dtype="float32"):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def get_ref_log():
ref_log = relay.Function([x], relay.log(relay.add(x, x)))
return ref_log
def get_ref_sub():
ref_sub = relay.Function([x, y], relay.subtract(relay.add(x, x), relay.add(y, y)))
return ref_sub
def get_ref_abs():
shape = (5, 10)
tp = relay.TensorType(shape, "float32")
a = relay.var("a", tp)
ref_abs = relay.Function([a], relay.abs(relay.add(a, a)))
return ref_abs
shape = (10,)
dtype = "float32"
tp = relay.TensorType(shape, dtype)
x = relay.var("x", tp)
y = relay.var("y", tp)
v_sub = relay.GlobalVar("mySub")
sub = relay.Function([x, y], relay.subtract(x, y))
z = relay.var("z", tp)
v_log = relay.GlobalVar("myLog")
log = relay.Function([z], relay.log(z))
mod = tvm.IRModule({v_sub: sub, v_log: log})
from utils.helper import OptTester
# 注册 module pass
opt_tester = OptTester(mod)
@tvm.transform.module_pass(opt_level=1)
def mod_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
# 注册 function pass.
@relay.transform.function_pass(opt_level=1)
def func_transform(expr, mod, ctx):
return opt_tester.transform(expr, ctx)
# 序列级 Pass
passes = [mod_transform, func_transform]
opt_level = 2
pass_name = "sequential"
sequential = tvm.transform.Sequential(passes=passes, opt_level=opt_level)
pass_info = sequential.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
测试 seq pass#
空白 pass:
passes = []
sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential(mod)
mod_func = ret_mod[v_sub]
check_func(sub, mod_func)
模块级 pass:
passes = [mod_transform]
sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with tvm.transform.PassContext(required_pass=["mod_transform"]):
ret_mod = sequential(mod)
# Check the subtract function.
sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, sub)
带作用域的 pass#
shape = (1, 2, 3)
c_data = np.array(shape).astype("float32")
tp = relay.TensorType(shape, "float32")
def before():
c = relay.const(c_data)
x = relay.var("x", tp)
y = relay.add(c, c)
y = relay.multiply(y, relay.const(2, "float32"))
y = relay.add(x, y)
z = relay.add(y, c)
z1 = relay.add(y, c)
z2 = relay.add(z, z1)
return relay.Function([x], z2)
def expected():
x = relay.var("x", tp)
c_folded = (c_data + c_data) * 2
y = relay.add(x, relay.const(c_folded))
z = relay.add(y, relay.const(c_data))
z1 = relay.add(z, z)
return relay.Function([x], z1)
seq = tvm.transform.Sequential(
[
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
relay.transform.AlterOpLayout(),
]
)
mod = tvm.IRModule({"main": before()})
with tvm.transform.PassContext(opt_level=3):
with tvm.target.Target("llvm"):
mod = seq(mod)
zz = mod["main"]
zexpected = run_infer_type(expected())
assert tvm.ir.structural_equal(zz, zexpected)
/media/pc/data/4tb/lxw/libs/anaconda3/envs/py38/lib/python3.8/site-packages/tvm/driver/build_module.py:267: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
warnings.warn(
嵌套型 pass#
def before():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
y = relay.nn.conv2d(x, w, padding=(1, 1))
y = relay.reshape(y, newshape=(1, 16, -1))
y = relay.reshape(y, newshape=(4, 8, -1, 16))
y = relay.reverse_reshape(y, newshape=(32, 0, -1))
return tvm.IRModule.from_expr(y)
def expected():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
y = relay.nn.conv2d(x, w, padding=(1, 1))
y = relay.reshape(y, newshape=(32, 16, 16))
return tvm.IRModule.from_expr(y)
z = before()
passes = [
tvm.transform.Sequential([relay.transform.SimplifyExpr()]),
]
with tvm.transform.PassContext(opt_level=1):
zz = tvm.transform.Sequential(passes)(z)
expected = relay.transform.InferType()(expected())
assert tvm.ir.structural_equal(zz, expected)