Relay Sequential pass

Relay Sequential pass#

import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import run_infer_type


def check_func(func, ref_func):
    func = run_infer_type(func)
    ref_func = run_infer_type(ref_func)
    assert tvm.ir.structural_equal(func, ref_func)
    
def extract_var_func(mod, name):
    var = mod.get_global_var(name)
    func = mod[var]
    return var, func
    
def get_rand(shape, dtype="float32"):
    return tvm.nd.array(np.random.rand(*shape).astype(dtype))

def get_ref_log():
    ref_log = relay.Function([x], relay.log(relay.add(x, x)))
    return ref_log

def get_ref_sub():
    ref_sub = relay.Function([x, y], relay.subtract(relay.add(x, x), relay.add(y, y)))
    return ref_sub

def get_ref_abs():
    shape = (5, 10)
    tp = relay.TensorType(shape, "float32")
    a = relay.var("a", tp)
    ref_abs = relay.Function([a], relay.abs(relay.add(a, a)))
    return ref_abs
shape = (10,)
dtype = "float32"
tp = relay.TensorType(shape, dtype)
x = relay.var("x", tp)
y = relay.var("y", tp)
v_sub = relay.GlobalVar("mySub")
sub = relay.Function([x, y], relay.subtract(x, y))

z = relay.var("z", tp)
v_log = relay.GlobalVar("myLog")
log = relay.Function([z], relay.log(z))

mod = tvm.IRModule({v_sub: sub, v_log: log})
from utils.helper import OptTester

# 注册 module pass
opt_tester = OptTester(mod)
@tvm.transform.module_pass(opt_level=1)
def mod_transform(expr, ctx):
    return opt_tester.transform(expr, ctx)
# 注册 function pass.
@relay.transform.function_pass(opt_level=1)
def func_transform(expr, mod, ctx):
    return opt_tester.transform(expr, ctx)
# 序列级 Pass
passes = [mod_transform, func_transform]
opt_level = 2
pass_name = "sequential"
sequential = tvm.transform.Sequential(passes=passes, opt_level=opt_level)
pass_info = sequential.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level

测试 seq pass#

空白 pass:

passes = []
sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential(mod)
mod_func = ret_mod[v_sub]
check_func(sub, mod_func)

模块级 pass:

passes = [mod_transform]
sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with tvm.transform.PassContext(required_pass=["mod_transform"]):
    ret_mod = sequential(mod)
# Check the subtract function.
sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, sub)

带作用域的 pass#

shape = (1, 2, 3)
c_data = np.array(shape).astype("float32")
tp = relay.TensorType(shape, "float32")

def before():
    c = relay.const(c_data)
    x = relay.var("x", tp)
    y = relay.add(c, c)
    y = relay.multiply(y, relay.const(2, "float32"))
    y = relay.add(x, y)
    z = relay.add(y, c)
    z1 = relay.add(y, c)
    z2 = relay.add(z, z1)
    return relay.Function([x], z2)

def expected():
    x = relay.var("x", tp)
    c_folded = (c_data + c_data) * 2
    y = relay.add(x, relay.const(c_folded))
    z = relay.add(y, relay.const(c_data))
    z1 = relay.add(z, z)
    return relay.Function([x], z1)
seq = tvm.transform.Sequential(
    [
        relay.transform.InferType(),
        relay.transform.FoldConstant(),
        relay.transform.EliminateCommonSubexpr(),
        relay.transform.AlterOpLayout(),
    ]
)

mod = tvm.IRModule({"main": before()})
with tvm.transform.PassContext(opt_level=3):
    with tvm.target.Target("llvm"):
        mod = seq(mod)

zz = mod["main"]
zexpected = run_infer_type(expected())
assert tvm.ir.structural_equal(zz, zexpected)
/media/pc/data/4tb/lxw/libs/anaconda3/envs/py38/lib/python3.8/site-packages/tvm/driver/build_module.py:267: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  warnings.warn(

嵌套型 pass#

def before():
    x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
    w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
    y = relay.nn.conv2d(x, w, padding=(1, 1))
    y = relay.reshape(y, newshape=(1, 16, -1))
    y = relay.reshape(y, newshape=(4, 8, -1, 16))
    y = relay.reverse_reshape(y, newshape=(32, 0, -1))
    return tvm.IRModule.from_expr(y)

def expected():
    x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
    w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
    y = relay.nn.conv2d(x, w, padding=(1, 1))
    y = relay.reshape(y, newshape=(32, 16, 16))
    return tvm.IRModule.from_expr(y)

z = before()
passes = [
    tvm.transform.Sequential([relay.transform.SimplifyExpr()]),
]
with tvm.transform.PassContext(opt_level=1):
    zz = tvm.transform.Sequential(passes)(z)

expected = relay.transform.InferType()(expected())
assert tvm.ir.structural_equal(zz, expected)