测试 LambdaLift#
# 导入必要的TVM模块和测试工具
import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R, tir as T, ir as I
from tvm.relax import transform
from tvm.ir.base import assert_structural_equal
# 辅助函数:检查两个IR结构是否相等
def _check_equal(x, y):
tvm.ir.assert_structural_equal(x, y)
tvm.ir.assert_structural_equal(y, x)
xhash = tvm.ir.structural_hash(x, map_free_vars=True)
yhash = tvm.ir.structural_hash(y, map_free_vars=True)
assert xhash == yhash
# 辅助函数:检查IR结构是否可以正确地序列化和反序列化
def _check_save_roundtrip(x):
y = tvm.ir.load_json(tvm.ir.save_json(x))
_check_equal(x, y)
测试 LambdaLift 能否将局部绑定的函数提升到 IRModule 顶层#
变换前的 IRModule:内部函数定义在 main 函数内部
@I.ir_module
class Before:
@R.function
def main(
x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
# 内部定义的函数
@R.function
def inner(
x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
return s
gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
return gv1
# 执行 LambdaLift 变换
after = transform.LambdaLift()(Before)
# 验证变换后的模块包含两个函数
assert len(after.functions) == 2
main_inner 是被提升的内部函数:
after.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main_inner(x2: R.Tensor((10, 5), dtype="float32"), y2: R.Tensor((10, 5), dtype="float32")) -> R.Tensor((10, 5), dtype="float32"):
s: R.Tensor((10, 5), dtype="float32") = R.add(x2, y2)
return s
@R.function
def main(x1: R.Tensor((10, 5), dtype="float32"), y1: R.Tensor((10, 5), dtype="float32")) -> R.Tensor((10, 5), dtype="float32"):
cls = Module
gv1: R.Tensor((10, 5), dtype="float32") = cls.main_inner(x1, y1)
return gv1
测试变换不会修改输入模块#
如果输出需要新的 StructInfo,必须创建新的 relax 变量。不能更新现有 relax 变量的 struct info,因为该变量可能被其他 IRModule 使用。
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32"):
@R.function
def outer_func(
c1: R.Tensor((2, 3), "float32")
) -> R.Callable((R.Tensor((2, 3), "float32"),), R.Tensor((2, 3), "float32")):
@R.function
def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
s: R.Tensor((2, 3), "float32") = R.add(x1, c1)
return s
return inner_func
in_call = outer_func(x)
res = in_call(y)
return res
before = Before
# 保存原始模块的副本用于比较
copy_of_before = tvm.ir.load_json(tvm.ir.save_json(before))
# 执行LambdaLift转换
transform.LambdaLift()(before)
# 验证原始模块没有被修改
tvm.ir.assert_structural_equal(before, copy_of_before)
测试闭包变换#
# 执行LambdaLift转换前的IRModule
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32"):
@R.function
def outer_func(
c1: R.Tensor((2, 3), "float32")
) -> R.Callable((R.Tensor((2, 3), "float32"),), R.Tensor((2, 3), "float32")):
@R.function
def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
# inner_func引用了外部作用域的c1变量,形成闭包
s: R.Tensor((2, 3), "float32") = R.add(x1, c1)
return s
return inner_func
in_call = outer_func(x)
res = in_call(y)
return res
after = transform.LambdaLift()(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main_inner_func(x1: R.Tensor((2, 3), dtype="float32"), c1: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
s: R.Tensor((2, 3), dtype="float32") = R.add(x1, c1)
return s
@R.function(private=True)
def main_outer_func(c1: R.Tensor((2, 3), dtype="float32")) -> R.Object:
cls = Module
inner_func: R.Object = R.make_closure(cls.main_inner_func, (c1,))
return inner_func
@R.function
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
cls = Module
in_call: R.Object = cls.main_outer_func(x)
res: R.Tensor((2, 3), dtype="float32") = R.invoke_pure_closure(in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32"),))
return res
# 验证序列化和反序列化功能
_check_save_roundtrip(after)
测试递归函数提升#
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
@R.function
def while_loop(
i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32"):
cond: R.Tensor((), "bool") = R.call_pure_packed(
"test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool"))
)
c: R.Tensor((), "int32") = R.const(1, dtype="int32")
if cond:
new_i: R.Tensor((), "int32") = R.add(i, c)
new_s: R.Tensor((2, 3), "float32") = R.add(s, x)
# 递归调用自身
r: R.Tensor((2, 3), "float32") = while_loop(new_i, new_s)
else:
r: R.Tensor((2, 3), "float32") = s
return r
gv: R.Tensor((2, 3), "float32") = while_loop(R.const(0), x)
return gv
before = Before
# 检查递归调用的格式是否正确
assert relax.analysis.well_formed(before)
# 执行LambdaLift转换
after = transform.LambdaLift()(before)
# 验证转换后包含两个函数
assert len(after.functions) == 2
# 验证序列化和反序列化功能
_check_save_roundtrip(after)
测试多个顶级函数的提升#
IRModule 中的 GlobalVar 名称去重是通过附加它们被提升的函数名称来实现的。
# 转换前的IRModule:两个顶级函数都包含同名的内部函数
@I.ir_module
class Before:
@R.function
def glob_func_1(
x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
@R.function
def inner(
x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
return s
gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
return gv1
@R.function
def glob_func_2(
x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
@R.function
def inner(
x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
return s
gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
return gv1
before = Before
# 执行LambdaLift转换
after = transform.LambdaLift()(before)
# 验证转换后包含4个函数
assert len(after.functions) == 4
# 验证序列化和反序列化功能
_check_save_roundtrip(after)
测试无局部函数的情况#
# 没有局部函数的IRModule
@I.ir_module
class Before:
@T.prim_func
def sub(
A: T.Buffer((16, 16), "float32"),
B: T.Buffer((16, 16), "float32"),
C: T.Buffer((16, 16), "float32"),
) -> None:
for i, j in T.grid(16, 16):
with T.block("sub"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] - B[vi, vj]
@R.function
def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor(dtype="float32", ndim=2)):
s = R.call_tir(Before.sub, (c0, x), R.Tensor((16, 16), dtype="float32"))
return s
before = Before
# 执行LambdaLift转换
after = transform.LambdaLift()(before)
# 验证没有局部函数被提升,模块保持不变
assert_structural_equal(after, before, map_free_vars=True)
# 验证序列化和反序列化功能
_check_save_roundtrip(after)
测试非纯函数的提升#
# 转换前的IRModule:包含非纯内部函数
@I.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
# 非纯内部函数,使用R.print产生副作用
@R.function(pure=False)
def inner() -> R.Tuple:
y = R.print(format="Wow!")
return y
gv1 = inner()
return x
before = Before
# 执行LambdaLift转换
after = transform.LambdaLift()(before)
# 验证转换后包含两个函数
assert len(after.functions) == 2
# 验证序列化和反序列化功能
_check_save_roundtrip(after)
after.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(pure=False, private=True)
def main_inner() -> R.Tuple:
R.print(format=R.str("Wow!"))
return R.tuple()
@R.function(pure=False)
def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
cls = Module
cls.main_inner()
return x
测试与全局函数同名的 lambda 函数提升#
测试提升的 lambda 名称可能不会与先前的名称冲突。模块已有名为main_inner的函数,该名称与LambdaLift为提升函数选择的第一个名称相同。
# 转换前的IRModule:已包含一个名为main_inner的全局函数
@I.ir_module
class Before:
@R.function
def main(
x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
@R.function
def inner(
x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
s: R.Tensor((10, 5), "float32") = R.add(x2, y2)
return s
gv1: R.Tensor((10, 5), "float32") = inner(x1, y1)
return gv1
# 已存在的全局函数,名称为main_inner
@R.function
def main_inner():
return R.tuple()
# 执行LambdaLift转换
after = transform.LambdaLift()(Before)
after.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main_inner() -> R.Tuple:
return R.tuple()
@R.function(private=True)
def main_inner_0(x2: R.Tensor((10, 5), dtype="float32"), y2: R.Tensor((10, 5), dtype="float32")) -> R.Tensor((10, 5), dtype="float32"):
s: R.Tensor((10, 5), dtype="float32") = R.add(x2, y2)
return s
@R.function
def main(x1: R.Tensor((10, 5), dtype="float32"), y1: R.Tensor((10, 5), dtype="float32")) -> R.Tensor((10, 5), dtype="float32"):
cls = Module
gv1: R.Tensor((10, 5), dtype="float32") = cls.main_inner_0(x1, y1)
return gv1
测试由内部函数定义的符号变量#
# 转换前的IRModule:内部函数使用符号变量定义张量形状
@I.ir_module
class Before:
@R.function
def main(
x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32")
) -> R.Tensor((10, 5), "float32"):
@R.function
def inner(x2: R.Tensor(("n", "m"), "float32"), y2: R.Tensor(("n", "m"), "float32")):
sum_inner = R.add(x2, y2)
return sum_inner
sum_main = inner(x1, y1)
return sum_main
# 执行LambdaLift转换
After = transform.LambdaLift()(Before)
After.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main_inner(x2: R.Tensor(("n", "m"), dtype="float32"), y2: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(("n", "m"), dtype="float32"):
n = T.int64()
m = T.int64()
sum_inner: R.Tensor((n, m), dtype="float32") = R.add(x2, y2)
return sum_inner
@R.function
def main(x1: R.Tensor((10, 5), dtype="float32"), y1: R.Tensor((10, 5), dtype="float32")) -> R.Tensor((10, 5), dtype="float32"):
cls = Module
sum_main: R.Tensor((10, 5), dtype="float32") = cls.main_inner(x1, y1)
return sum_main
测试由外部函数定义的符号变量#
# 转换前的IRModule:内部函数使用外部函数定义的符号变量
@I.ir_module
class Before:
@R.function
def main(
x1: R.Tensor(("n", "m"), "float32"), y1: R.Tensor(("n", "m"), "float32")
) -> R.Tensor(("n", "m"), "float32"):
# 在外部函数中定义符号变量
n = T.int64()
m = T.int64()
@R.function
def inner(x2: R.Tensor((n, m), "float32"), y2: R.Tensor((n, m), "float32")):
sum_inner = R.add(x2, y2)
return sum_inner
sum_main = inner(x1, y1)
return sum_main
# 预期的IRModule结构:提升后的函数正确处理外部定义的符号变量
# 执行LambdaLift转换
After = transform.LambdaLift()(Before)
After.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main_inner(x2: R.Tensor(("n", "m"), dtype="float32"), y2: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(("n", "m"), dtype="float32"):
n = T.int64()
m = T.int64()
sum_inner: R.Tensor((n, m), dtype="float32") = R.add(x2, y2)
return sum_inner
@R.function
def main(x1: R.Tensor(("n", "m"), dtype="float32"), y1: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(("n", "m"), dtype="float32"):
n = T.int64()
m = T.int64()
cls = Module
sum_main: R.Tensor((n, m), dtype="float32") = cls.main_inner(x1, y1)
return sum_main