规范化处理#
import pytest
import tvm
import tvm.testing
from tvm import relax
from tvm import tir
from tvm.ir.base import assert_structural_equal
import tvm.script
from tvm.script import tir as T, relax as R
测试基本函数的规范化处理#
Normalize 变换将嵌套的算子分解为单独的绑定语句,使 IR 更加扁平化
输入: 包含嵌套
add和multiply算子的函数预期输出: 分解后的 ANF(Administrative Normal Form) 形式,每个算子都有独立的绑定变量
测试重点: 验证函数体的规范化处理
ANF 形式特点: 每个复杂表达式都被分解为一系列绑定到变量的简单表达式
核心步骤:
手动构建带嵌套算子的函数
应 用Normalize 变换
验证转换结果是否符合 ANF 形式
测试方法: 使用
assert_structural_equal验证变换前后 IR 结构是否一致
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = relax.Var("x", R.Tensor([m, n], "float16"))
注意: TVMScript 解析器会自动规范化用 TVMScript 编写的 IR,因此手动构造函数,这里构建包含嵌套操作的函数: multiply(add(x,x), add(x,x))
mul_add = relax.Function(
[x],
relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
ret_struct_info=R.Tensor("float16", ndim=2),
)
注意: from_expr API 将私有函数(没有 global_symbol 的函数)命名为 "main":
before_mod = tvm.IRModule.from_expr(mul_add)
before_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main(x: R.Tensor(("m", "n"), dtype="float16")) -> R.Tensor(dtype="float16", ndim=2):
m = T.int64()
n = T.int64()
return R.multiply(R.add(x, x), R.add(x, x))
after_mod = relax.transform.Normalize()(before_mod)
after_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main(x: R.Tensor(("m", "n"), dtype="float16")) -> R.Tensor(("m", "n"), dtype="float16"):
m = T.int64()
n = T.int64()
gv: R.Tensor((m, n), dtype="float16") = R.add(x, x)
gv1: R.Tensor((m, n), dtype="float16") = R.add(x, x)
gv2: R.Tensor((m, n), dtype="float16") = R.multiply(gv, gv1)
return gv2
测试条件语句(If节点)的规范化处理#
输入: 包含 If 节点的函数,if 和 else 分支中包含嵌套操作
预期输出: 规范化后的函数,其中 If 节点的分支被变换为 seq exprs,每个算子都有独立绑定
测试重点: 验证条件分支内的操作是否被正确规范化
变换机制: Normalize 变换会确保 If 节点的两个分支都被变换为规范化的表达式序列
cond = relax.Var("cond", R.Tensor([], "bool"))
x = relax.Var("x", R.Tensor([1], "float32"))
# TODO(relax-team): 为IfNode添加类型和形状推断
y = relax.Var("y")
# 注意: TVMScript解析器会自动规范化用TVMScript编写的IR,因此我们手动构造函数和If节点
f = relax.Function(
[cond, x],
relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(
y,
relax.If(
cond,
relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
relax.op.add(relax.op.multiply(x, x), relax.op.multiply(x, x)),
),
)
]
)
],
y,
),
ret_struct_info=R.Tensor("float32", ndim=1),
)
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
before_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main(cond: R.Tensor((), dtype="bool"), x: R.Tensor((1,), dtype="float32")) -> R.Tensor(dtype="float32", ndim=1):
if cond:
y: R.Tensor((1,), dtype="float32") = R.multiply(R.add(x, x), R.add(x, x))
else:
y: R.Tensor((1,), dtype="float32") = R.add(R.multiply(x, x), R.multiply(x, x))
return y
after_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main(cond: R.Tensor((), dtype="bool"), x: R.Tensor((1,), dtype="float32")) -> R.Tensor((1,), dtype="float32"):
if cond:
gv: R.Tensor((1,), dtype="float32") = R.add(x, x)
gv1: R.Tensor((1,), dtype="float32") = R.add(x, x)
gv2: R.Tensor((1,), dtype="float32") = R.multiply(gv, gv1)
y: R.Tensor((1,), dtype="float32") = gv2
else:
gv3: R.Tensor((1,), dtype="float32") = R.multiply(x, x)
gv4: R.Tensor((1,), dtype="float32") = R.multiply(x, x)
gv5: R.Tensor((1,), dtype="float32") = R.add(gv3, gv4)
y: R.Tensor((1,), dtype="float32") = gv5
return y
测试已经是ANF形式的IR的规范化处理#
输入: 已经符合ANF形式的IR模块
预期输出: 保持不变,Normalize变换对其不产生影响
测试重点: 验证Normalize变换对已经规范化的IR是幂等的
幂等性说明: 多次应用同一个变换应产生相同的结果,不会改变已经符合要求的IR
# normalize pass对ANF形式的IR应该是无操作的
@tvm.script.ir_module
class ANFMod1:
@R.function
def f(x: R.Tensor(dtype="float32")):
gv = R.add(x, x)
gv1 = R.add(gv, gv)
gv2 = R.add(gv, gv1)
return (gv, gv2)
before_mod = ANFMod1
after_mod = relax.transform.Normalize()(before_mod)
assert_structural_equal(before_mod, after_mod, map_free_vars=True)
# 测试dataflow块的情况
@tvm.script.ir_module
class ANFMod2:
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
m, n = T.int64(), T.int64()
with R.dataflow():
lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32"))
gv0 = R.call_dps_packed(
"test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")
)
R.output(gv0)
return gv0
mod = ANFMod2
mod_post = relax.transform.Normalize()(mod)
assert_structural_equal(mod, mod_post)
测试序列表达式(SeqExpr)中非叶节点体的规范化处理#
输入: 序列表达式的体(body)不是叶节点的情况
预期输出: 将非叶节点体绑定到一个变量
测试重点: 验证seq expr中非叶节点的处理逻辑
技术细节: 规范化过程会将复杂的表达式绑定到变量,使表达式结构更加扁平
# 一个带有非叶节点体的序列表达式也应该将该体绑定到一个变量
x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
seq = relax.SeqExpr([], relax.op.add(x, y))
f = relax.Function(
[x, y],
seq,
ret_struct_info=R.Tensor([], "int32"),
)
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
@R.function(private=True)
def expected(
x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
) -> R.Tensor(ndim=0, dtype="int32"):
# 规范化插入了这样的绑定
z = R.add(x, y)
return z
assert_structural_equal(after_mod["main"], expected)
# 一个体不是序列表达式的函数应该将其包装在序列表达式中
x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
f = relax.Function(
[x, y],
relax.op.add(x, y),
ret_struct_info=R.Tensor([], "int32"),
)
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
@R.function(private=True)
def expected(
x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
) -> R.Tensor(ndim=0, dtype="int32"):
# 结果将是一个序列表达式,其中body是一个变量
z = R.add(x, y)
return z
assert_structural_equal(after_mod["main"], expected)
测试If节点分支的规范化处理#
输入: If节点的分支不是序列表达式的情况
预期输出: 将If节点的分支转换为序列表达式
测试重点: 验证If节点内部结构的规范化
技术要点: 规范化过程会确保If节点的两个分支都符合ANF形式
# if节点的分支必须是序列表达式
x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
# TODO(@relax-team): z具有()形状和TensorType(ndim=0)类型,
# 但规范化未能推断出这些,尽管它应该能推断
z = relax.Var("z")
cond = relax.Var("cond", R.Tensor([], "bool"))
plus = relax.op.add(x, y)
mult = relax.op.multiply(x, y)
if_node = relax.If(cond, plus, mult)
seq = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(z, if_node)])], z)
f = relax.Function(
[cond, x, y],
seq,
ret_struct_info=R.Tensor([], "int32"),
)
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
@R.function(private=True)
def expected(
cond: R.Tensor((), dtype="bool"),
x: R.Tensor((), dtype="int32"),
y: R.Tensor((), dtype="int32"),
) -> R.Tensor(ndim=0, dtype="int32"):
# 分支的body将是带有绑定的序列表达式
if cond:
w = R.add(x, y)
z = w
else:
w = R.multiply(x, y)
z = w
return z
assert_structural_equal(after_mod["main"], expected)
测试If节点条件的规范化处理#
输入: If节点的条件是复杂表达式的情况
预期输出: 将复杂条件表达式分解为单独的绑定语句
测试重点: 验证If条件表达式的规范化处理
技术细节: 即使是if条件也会被规范化,确保所有复杂表达式都被分解
cond = relax.Var("cond", R.Tensor([], "bool"))
x = relax.Var("x", R.Tensor([1], "float32"))
# TODO(relax-team): 为IfNode添加类型和形状推断
y = relax.Var("y")
# 条件被包装在元组中然后被索引
f = relax.Function(
[cond, x],
relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(
y,
relax.If(
relax.TupleGetItem(relax.Tuple([cond]), 0),
relax.op.add(x, x),
relax.op.multiply(x, x),
),
)
]
)
],
y,
),
ret_struct_info=R.Tensor("float32", ndim=1),
)
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
after_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main(cond: R.Tensor((), dtype="bool"), x: R.Tensor((1,), dtype="float32")) -> R.Tensor((1,), dtype="float32"):
gv2: R.Tensor((), dtype="bool") = (cond,)[0]
if gv2:
gv: R.Tensor((1,), dtype="float32") = R.add(x, x)
y: R.Tensor((1,), dtype="float32") = gv
else:
gv1: R.Tensor((1,), dtype="float32") = R.multiply(x, x)
y: R.Tensor((1,), dtype="float32") = gv1
return y
测试元组元素获取(TupleGetItem)的规范化处理#
输入: 嵌套的TupleGetItem操作
预期输出: 将嵌套的元组索引操作分解为单独的绑定语句
测试重点: 验证复杂元组操作的规范化
技术细节: 多层嵌套的元组索引会被分解为多个简单的索引操作
x = relax.Var("x", R.Tensor([], "int32"))
f = relax.Function(
[x],
relax.TupleGetItem(
relax.TupleGetItem(
relax.Tuple([relax.Tuple([x])]),
0,
),
0,
),
ret_struct_info=R.Tensor([], "int32"),
)
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
# TODO: 在我们规范化SeqExprs后重新审视(作为规范化的一部分?)
# 这次不使用解析器,因为正确写出它会导致
# *一个*绑定块,而规范化版本有*两个*
idx_var = relax.Var("idx_var", R.Tuple([R.Tensor([], "int32")]))
ret_var = relax.Var("ret", R.Tensor([], "int32"))
expected_f = relax.Function(
[x],
relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(
idx_var, relax.TupleGetItem(relax.Tuple([relax.Tuple([x])]), 0)
)
]
),
relax.BindingBlock([relax.VarBinding(ret_var, relax.TupleGetItem(idx_var, 0))]),
],
ret_var,
),
ret_struct_info=R.Tensor([], "int32"),
)
expected_mod = tvm.IRModule.from_expr(expected_f)
# 应用规范化以填充类型和形状注解(否则很繁琐)
final_mod = relax.transform.Normalize()(expected_mod)
assert_structural_equal(after_mod, final_mod)
测试相邻块合并的规范化处理#
输入: 包含多个相邻数据块和绑定块的函数
预期输出: 将相邻的同类块合并,并规范化变量引用
测试重点: 验证块合并优化逻辑
优化策略: Normalize转换会合并相邻的相同类型块,减少IR中的块数量
x = relax.Var("x", R.Tensor([], "int32"))
v0 = relax.Var("v0", R.Tensor([], "int32"))
v1 = relax.Var("v1", R.Tensor([], "int32"))
v2 = relax.Var("v2", R.Tensor([], "int32"))
v3 = relax.Var("v3", R.Tensor([], "int32"))
f = relax.Function(
[x],
relax.SeqExpr(
[
relax.DataflowBlock([relax.VarBinding(v0, x)]),
relax.DataflowBlock([relax.VarBinding(v1, v0)]),
relax.BindingBlock([relax.VarBinding(v2, v1)]),
relax.BindingBlock([relax.VarBinding(v3, v2)]),
],
v3,
),
ret_struct_info=R.Tensor([], "int32"),
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
after_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
with R.dataflow():
v0: R.Tensor((), dtype="int32") = x
v1: R.Tensor((), dtype="int32") = v0
R.output(v0, v1)
v2: R.Tensor((), dtype="int32") = v1
v3: R.Tensor((), dtype="int32") = v2
return v3
测试嵌套序列表达式的规范化处理#
输入: 包含嵌套SeqExpr的函数
预期输出: 展平嵌套结构,将所有绑定提升到顶层
测试重点: 验证嵌套序列的展平逻辑
转换机制: Normalize转换会递归处理嵌套的序列表达式,将它们展平为单一层次的绑定
x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
z = relax.Var("z", R.Tensor([], "int32"))
seq = relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(x, relax.const(1)),
relax.VarBinding(
y,
relax.SeqExpr(
[relax.BindingBlock([relax.VarBinding(z, relax.const(2))])],
z,
),
),
]
)
],
y,
)
f = relax.Function(
[],
seq,
ret_struct_info=R.Tensor([], "int32"),
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
@R.function(private=True)
def expected():
x = relax.const(1)
z = relax.const(2)
y = z
return y
assert_structural_equal(after_mod["main"], expected)
测试包含数据流块的嵌套序列表达式的规范化处理#
输入: 包含DataflowBlock的嵌套SeqExpr
预期输出: 展平嵌套结构,同时保留数据流块的特性
测试重点: 验证嵌套数据流块的处理逻辑
技术挑战: 需要在展平嵌套结构的同时,保持数据流块的语义完整性
x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
z = relax.Var("z", R.Tensor([], "int32"))
q = relax.Var("u", R.Tensor([], "int32"))
w = relax.DataflowVar("w", R.Tensor([], "int32"))
u = relax.Var("u", R.Tensor([], "int32"))
seq = relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(x, relax.const(1)),
relax.VarBinding(
y,
relax.SeqExpr(
[
relax.BindingBlock([relax.VarBinding(q, relax.const(2))]),
relax.DataflowBlock(
[
relax.VarBinding(w, q),
relax.VarBinding(u, w),
]
),
relax.BindingBlock([relax.VarBinding(z, u)]),
],
z,
),
),
]
)
],
y,
)
f = relax.Function(
[],
seq,
ret_struct_info=R.Tensor([], "int32"),
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
@R.function(private=True)
def expected():
x = relax.const(1)
q = relax.const(2)
with R.dataflow():
w = q
u = w
R.output(u)
z = u
y = z
return y
assert_structural_equal(after_mod["main"], expected)
测试深层嵌套序列表达式的规范化处理#
输入: 多层嵌套的SeqExpr
预期输出: 完全展平深层嵌套结构
测试重点: 验证深层嵌套的处理能力
边界测试: 确保规范化转换能够处理任意深度的嵌套结构
x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
z = relax.Var("z", R.Tensor([], "int32"))
u = relax.Var("u", R.Tensor([], "int32"))
v = relax.Var("v", R.Tensor([], "int32"))
w = relax.Var("w", R.Tensor([], "int32"))
_ = relax.Var("w", R.Tensor([], "int32"))
seq = relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(x, relax.const(1)),
relax.VarBinding(
y,
relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(
z,
relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(u, relax.const(2)),
relax.MatchCast(
_, u, R.Tensor([], "int32")
),
relax.VarBinding(v, u),
relax.MatchCast(
w, v, R.Tensor([], "int32")
),
]
)
],
w,
),
)
]
)
],
z,
),
),
]
)
],
y,
)
f = relax.Function(
[],
seq,
ret_struct_info=R.Tensor([], "int32"),
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
@R.function(private=True)
def expected():
x = relax.const(1)
u = relax.const(2)
_ = R.match_cast(u, R.Tensor((), "int32"))
v = u
w = R.match_cast(v, R.Tensor((), "int32"))
z = w
y = z
return y
assert_structural_equal(after_mod["main"], expected)
测试在数据流块中嵌套非数据流块的错误情况#
标记为预期失败的测试
验证: 在数据流块中嵌套普通绑定块应该失败
@pytest.mark.xfail()
# xfail标记表示这个测试预期会失败,这是因为当前实现不支持在DataflowBlock中嵌套普通BindingBlock
# 这个测试用例验证了IR的结构约束
def test_nesting_non_dataflow_in_dataflow_error():
x = relax.DataflowVar("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
z = relax.Var("z", R.Tensor([], "int32"))
seq = relax.SeqExpr(
[
relax.DataflowBlock(
[
relax.VarBinding(x, relax.const(1)),
relax.VarBinding(
y,
relax.SeqExpr(
[relax.BindingBlock([relax.VarBinding(z, relax.const(2))])],
z,
),
),
]
)
],
y,
)
f = relax.Function(
[],
seq,
ret_struct_info=R.Tensor([], "int32"),
)
relax.transform.Normalize()(tvm.IRModule.from_expr(f))
# 应该失败,因为在dataflowblock内部有一个普通的binding block
测试移除void类型变量的使用#
验证: 所有空元组都应该内联构造
技术细节: 为了可读性,TVMScript隐藏了void类型变量的绑定,但在Relax中使用空元组表示void
优化处理: 通过规范化将void类型变量的使用替换为内联的R.tuple()
def test_remove_usage_of_void_type_variables():
"""所有空元组都应该内联构造
为了可读性,TVMScript隐藏了类型为void的变量的绑定。例如,`R.assert_op(condition)`
而不是`void_var: R.Tuple([]) = R.assert_op(condition)`。
然而,Relax遵循函数式语言的标准约定,使用空元组表示void。由于空元组
可能在函数后面被合法使用,`void_var`可能需要一个绑定。
通过使用内联的`R.tuple()`规范化所有void类型变量的使用,可以避免这种情况。
"""
x = relax.Var("x", R.Tuple([]))
bindings = [
relax.VarBinding(x, R.assert_op(R.const(True, "bool"))),
]
seq = relax.SeqExpr([relax.BindingBlock(bindings)], x)
before = relax.Function([], seq, ret_struct_info=R.Tuple([]), is_pure=False)
after = relax.transform.Normalize()(tvm.IRModule({"main": before}))["main"]
@R.function(private=True, pure=False)
def expected():
x = R.assert_op(R.const(True, "bool"))
return R.tuple()