测试 Pass#
split args#
from tvm.ir.transform import Pass
from tvm.ir import IRModule
from tvm.relay import transform
from tvm import relay
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, Pass)
mod = IRModule.from_expr(expr)
mod = relay.transform.InferType()(mod)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
target = tvm.target.Target("metal")
shape = (1, 1, 1, 3)
dtype = "float32"
axis = 1
inputs = []
for i in range(100):
inputs.append(relay.var(f"p{i}", shape=shape, dtype=dtype))
def before():
inp = relay.Tuple(inputs)
return relay.op.concatenate(inp, axis)
res = run_opt_pass(before(), transform.SplitArgs(target.max_function_args))
limit = target.max_function_args - 1 # one buffer with output
splitNum = int(len(inputs) / limit)
if len(inputs) % limit > 0:
splitNum += 1
splitted = []
for i in range(splitNum):
startIdx = i * limit
argsCount = min(limit, len(inputs) - startIdx)
args = []
for j in range(argsCount):
args.append(inputs[j + startIdx])
t = relay.Tuple(args)
concat = relay.op.concatenate(t, axis)
splitted.append(relay.annotation.stop_fusion(concat))
inp = relay.Tuple(splitted)
expr = relay.op.concatenate(inp, axis)
tvm.ir.structural_equal(res, expr)
Pass Instrument#
import tvm
from tvm.ir import IRModule
from tvm import relay
from tvm.relay import op
from tvm import transform
from tvm.ir.instrument import PassTimingInstrument, pass_instrument
from tvm.ir.transform import PassContext
def get_test_model():
x, y, z = [relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
e1 = op.add(x, y)
e2 = op.subtract(x, z)
e3 = op.multiply(e1, e1 / e2)
return IRModule.from_expr(e3 + e2)
pass_timing = PassTimingInstrument()
seq = transform.Sequential([relay.transform.AnnotateSpans(),
relay.transform.ToANormalForm(),
relay.transform.InferType()])
# 覆盖当前 PassContext 的 instruments
PassContext.current().override_instruments([pass_timing])
mod = get_test_model()
mod = seq(mod)
profiles = pass_timing.render()
assert "AnnotateSpans" in profiles
assert "ToANormalForm" in profiles
assert "InferType" in profiles
# 重置当前 PassContext 的 instruments 为 None
PassContext.current().override_instruments(None)
mod = get_test_model()
mod = seq(mod)
profiles = pass_timing.render()
assert profiles == ""
instrument_definition_type = tvm.testing.parameter("decorator", "subclass")
def test_custom_instrument(instrument_definition_type):
class BaseTest:
def __init__(self):
self.events = []
def enter_pass_ctx(self):
self.events.append("enter ctx")
def exit_pass_ctx(self):
self.events.append("exit ctx")
def run_before_pass(self, mod, info):
self.events.append("run before " + info.name)
def run_after_pass(self, mod, info):
self.events.append("run after " + info.name)
if instrument_definition_type == "decorator":
MyTest = pass_instrument(BaseTest)
elif instrument_definition_type == "subclass":
class MyTest(BaseTest, tvm.ir.instrument.PassInstrument):
def __init__(self):
BaseTest.__init__(self)
tvm.ir.instrument.PassInstrument.__init__(self)
mod = get_test_model()
my_test = MyTest()
with tvm.transform.PassContext(instruments=[my_test]):
mod = tvm.relay.transform.InferType()(mod)
assert (
"enter ctx"
"run before InferType"
"run after InferType"
"exit ctx" == "".join(my_test.events)
)
禁用 pass#
@pass_instrument
class CustomPI:
def __init__(self):
self.events = []
def should_run(self, mod, info):
# Only run pass name contains "InferType"
if "InferType" not in info.name:
return False
return True
def run_before_pass(self, mod, info):
self.events.append(info.name)
mod = get_test_model()
custom_pi = CustomPI()
# seq = transform.Sequential([relay.transform.AnnotateSpans(),
# relay.transform.ToANormalForm(),
# relay.transform.InferType()])
with PassContext(instruments=[custom_pi]):
# mod = seq(mod)
mod = tvm.relay.transform.AnnotateSpans()(mod)
mod = tvm.relay.transform.ToANormalForm()(mod)
mod = tvm.relay.transform.InferType()(mod)
assert "InferType" == "".join(custom_pi.events)
@pass_instrument
class SkipPass:
def __init__(self, skip_pass_name):
self.skip_pass_name = skip_pass_name
def should_run(self, mod, info):
if self.skip_pass_name in info.name:
return False
return True
skip_annotate = SkipPass("AnnotateSpans")
skip_anf = SkipPass("ToANormalForm")
@pass_instrument
class PrintPassName:
def __init__(self):
self.events = []
def run_before_pass(self, mod, info):
self.events.append(info.name)
mod = get_test_model()
print_pass_name = PrintPassName()
with tvm.transform.PassContext(instruments=[skip_annotate, skip_anf, print_pass_name]):
mod = tvm.relay.transform.AnnotateSpans()(mod)
mod = tvm.relay.transform.ToANormalForm()(mod)
mod = tvm.relay.transform.InferType()(mod)
assert "InferType" == "".join(print_pass_name.events)
@pass_instrument
class PassesCounter:
def __init__(self):
self.run_before_count = 0
self.run_after_count = 0
def __clear(self):
self.run_before_count = 0
self.run_after_count = 0
def enter_pass_ctx(self):
self.__clear()
def exit_pass_ctx(self):
self.__clear()
def run_before_pass(self, mod, info):
self.run_before_count = self.run_before_count + 1
def run_after_pass(self, mod, info):
self.run_after_count = self.run_after_count + 1
mod = get_test_model()
passes_counter = PassesCounter()
with tvm.transform.PassContext(instruments=[passes_counter]):
tvm.relay.build(mod, "llvm")
assert passes_counter.run_after_count != 0
assert passes_counter.run_after_count == passes_counter.run_before_count
# Out of pass context scope, should be reset
assert passes_counter.run_before_count == 0
assert passes_counter.run_after_count == 0
configs = PassContext.list_configs()
assert len(configs) > 0
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"]["type"] == "IntImm"
events = []
@pass_instrument
class PI:
def __init__(self, id):
self.id = id
def enter_pass_ctx(self):
events.append(self.id + " enter_pass_ctx")
def exit_pass_ctx(self):
events.append(self.id + " exit_pass_ctx")
def should_run(self, mod, info):
events.append(" " + self.id + " should_run")
return True
def run_before_pass(self, mod, info):
events.append(" " + self.id + " run_before_pass")
def run_after_pass(self, mod, info):
events.append(" " + self.id + " run_after_pass")
@tvm.transform.module_pass(opt_level=2)
def transform1(mod, ctx):
events.append(" transform1 pass")
return mod
@tvm.transform.module_pass(opt_level=2)
def transform2(mod, ctx):
events.append(" transform2 pass")
return mod
mod = get_test_model()
with PassContext(instruments=[PI("%1"), PI("%2")]):
mod = transform1(mod)
mod = transform2(mod)
assert (
"%1 enter_pass_ctx"
"%2 enter_pass_ctx"
" %1 should_run"
" %2 should_run"
" %1 run_before_pass"
" %2 run_before_pass"
" transform1 pass"
" %1 run_after_pass"
" %2 run_after_pass"
" %1 should_run"
" %2 should_run"
" %1 run_before_pass"
" %2 run_before_pass"
" transform2 pass"
" %1 run_after_pass"
" %2 run_after_pass"
"%1 exit_pass_ctx"
"%2 exit_pass_ctx" == "".join(events)
)
Pass 去函数化#
from tvm.relay.backend.interpreter import ConstructorValue
from tvm.relay import transform, ExprVisitor, TypeVisitor
from tvm.relay.testing import Prelude
def has_func_type(t):
"""确定类型 t 是 FuncType 还是嵌套的 FuncType"""
class FuncTypeVisitor(TypeVisitor):
def __init__(self):
super().__init__()
self.has_func = False
def visit_func_type(self, ftt):
self.has_func = True
ftvisitor = FuncTypeVisitor()
ftvisitor.visit(t)
return ftvisitor.has_func
确定程序是否有高阶函数,高阶函数定义为:
具有函数类型参数
返回函数
def assert_no_higher_order_functions(expr, mod):
class CheckFirstOrderVisitor(ExprVisitor):
def __init__(self, mod):
super().__init__()
self.mod = mod
self.hof = []
self.visited_gv = set()
def visit_call(self, call):
is_higher_order = False
# check return type
if has_func_type(call.checked_type):
is_higher_order = True
# check argument types
for a in call.args:
if has_func_type(a.checked_type):
is_higher_order = True
# if it is higher order, save it for debugging later
if is_higher_order:
self.hof.append(call)
super().visit_call(call)
def visit_global_var(self, gv):
# visit global vars to visit entire program
if gv not in self.visited_gv:
self.visited_gv.add(gv)
self.visit(self.mod[gv])
mod = transform.InferType()(mod)
check_fo_visitor = CheckFirstOrderVisitor(mod)
check_fo_visitor.visit(expr)
nl = "\n--------\n"
errmsg = f"""found {len(check_fo_visitor.hof)} higher order functions:
{nl.join(expr.astext() for expr in check_fo_visitor.hof)}"""
assert len(check_fo_visitor.hof) == 0, errmsg
断言程序是去函数化的,并返回去函数化的模块,假设程序从 mod['main']
开始:
def defunctionalized(mod):
mod = transform.InferType()(mod)
mod["main"] = transform.Defunctionalization(mod["main"], mod)
mod = transform.InferType()(mod)
assert_no_higher_order_functions(mod["main"], mod)
return mod
# adt list to python list
def to_list(mod, l):
list = mod.get_global_type_var("List")
list_adt = mod[list]
cons = list_adt.constructors[0]
nil = list_adt.constructors[1]
assert isinstance(l, ConstructorValue)
val = l
ret = []
while True:
if val.tag == cons.tag:
ret.append(val.fields[0].numpy())
val = val.fields[1]
else:
assert val.tag == nil.tag
break
return ret
# list to adt list
def to_adt_list(mod, arr):
expr = mod["main"]
l = mod.get_global_type_var("List")
list_adt = mod[l]
cons = list_adt.constructors[0]
nil = list_adt.constructors[1]
li = nil()
for a in arr:
li = cons(relay.const(a), li)
adt = relay.create_executor(mod=mod).evaluate(li)
mod["main"] = expr
return adt
import tvm
from tvm import relay
import numpy as np
code = """
#[version = "0.0.5"]
def @simple[A, B](%f: fn(A) -> B, %xs: A) -> B {
%f(%xs)
}
def @main(%l: Tensor[(5, 5), float32]) -> Tensor[(5, 5), float32] {
%0 = fn[A](%x: A) -> A {
%x
};
@simple(%0, %l)
}
"""
mod = tvm.parser.fromtext(code)
defunc_mod = defunctionalized(mod)
input = np.random.rand(5, 5).astype("float32")
out = relay.create_executor("debug", mod=mod).evaluate()(input)
defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()(input)
np.testing.assert_equal(out.numpy(), defunc_out.numpy())
code = """
#[version = "0.0.5"]
type List[A] {
Cons(A, List[A]),
Nil,
}
def @id[A](%x: A) -> A {
%x
}
def @map[A, B](%f: fn(A) -> B, %xs: List[A]) -> List[B] {
match (%xs) {
Cons(%x, %rest) => Cons(%f(%x), @map(%f, %rest)),
Nil => Nil,
}
}
def @main(%l: List[float32]) -> List[float32] {
@map(@id, %l)
}
"""
mod = tvm.parser.fromtext(code)
defunc_mod = defunctionalized(mod)
input = np.random.rand(10).astype("float32")
out = relay.create_executor("debug", mod=mod).evaluate(mod["main"])(to_adt_list(mod, input))
defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()(
to_adt_list(defunc_mod, input)
)
np.testing.assert_array_equal(to_list(mod, out), to_list(defunc_mod, defunc_out))
def test_recursive_datatype():
# CPS will create recursive datatype
code = """
#[version = "0.0.5"]
type List[A] {
Cons(A, List[A]),
Nil,
}
def @sum(%f: fn(int32) -> int32, %k: List[int32]) -> int32 {
match (%k) {
Cons(%x, %rest) => %0 = fn(%n) {
%x + %f(%n)
};
@sum(%0, %rest),
Nil => %f(0),
}
}
def @id[A](%x: A) -> A {
%x
}
def @main(%l: List[int32]) -> int32 {
@sum(@id, %l)
}
"""
mod = tvm.parser.fromtext(code)
defunc_mod = defunctionalized(mod)
input = np.random.randint(1, 100, 10)
out = relay.create_executor("debug", mod=mod).evaluate(mod["main"])(to_adt_list(mod, input))
defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()(
to_adt_list(defunc_mod, input)
)
tvm.testing.assert_allclose(out.numpy(), defunc_out.numpy())