规范化处理#

import pytest

import tvm
import tvm.testing
from tvm import relax
from tvm import tir
from tvm.ir.base import assert_structural_equal

import tvm.script
from tvm.script import tir as T, relax as R

测试基本函数的规范化处理#

Normalize 变换将嵌套的算子分解为单独的绑定语句,使 IR 更加扁平化

  • 输入: 包含嵌套 addmultiply 算子的函数

  • 预期输出: 分解后的 ANF(Administrative Normal Form) 形式,每个算子都有独立的绑定变量

  • 测试重点: 验证函数体的规范化处理

  • ANF 形式特点: 每个复杂表达式都被分解为一系列绑定到变量的简单表达式

  • 核心步骤:

    1. 手动构建带嵌套算子的函数

    2. 应 用Normalize 变换

    3. 验证转换结果是否符合 ANF 形式

  • 测试方法: 使用 assert_structural_equal 验证变换前后 IR 结构是否一致

m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = relax.Var("x", R.Tensor([m, n], "float16"))

注意: TVMScript 解析器会自动规范化用 TVMScript 编写的 IR,因此手动构造函数,这里构建包含嵌套操作的函数: multiply(add(x,x), add(x,x))

mul_add = relax.Function(
    [x],
    relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
    ret_struct_info=R.Tensor("float16", ndim=2),
)

注意: from_expr API 将私有函数(没有 global_symbol 的函数)命名为 "main"

before_mod = tvm.IRModule.from_expr(mul_add)
before_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def main(x: R.Tensor(("m", "n"), dtype="float16")) -> R.Tensor(dtype="float16", ndim=2):
        m = T.int64()
        n = T.int64()
        return R.multiply(R.add(x, x), R.add(x, x))
after_mod = relax.transform.Normalize()(before_mod)
after_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def main(x: R.Tensor(("m", "n"), dtype="float16")) -> R.Tensor(("m", "n"), dtype="float16"):
        m = T.int64()
        n = T.int64()
        gv: R.Tensor((m, n), dtype="float16") = R.add(x, x)
        gv1: R.Tensor((m, n), dtype="float16") = R.add(x, x)
        gv2: R.Tensor((m, n), dtype="float16") = R.multiply(gv, gv1)
        return gv2

测试条件语句(If节点)的规范化处理#

  • 输入: 包含 If 节点的函数,if 和 else 分支中包含嵌套操作

  • 预期输出: 规范化后的函数,其中 If 节点的分支被变换为 seq exprs,每个算子都有独立绑定

  • 测试重点: 验证条件分支内的操作是否被正确规范化

  • 变换机制: Normalize 变换会确保 If 节点的两个分支都被变换为规范化的表达式序列

cond = relax.Var("cond", R.Tensor([], "bool"))
x = relax.Var("x", R.Tensor([1], "float32"))
# TODO(relax-team): 为IfNode添加类型和形状推断
y = relax.Var("y")

# 注意: TVMScript解析器会自动规范化用TVMScript编写的IR,因此我们手动构造函数和If节点
f = relax.Function(
    [cond, x],
    relax.SeqExpr(
        [
            relax.BindingBlock(
                [
                    relax.VarBinding(
                        y,
                        relax.If(
                            cond,
                            relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
                            relax.op.add(relax.op.multiply(x, x), relax.op.multiply(x, x)),
                        ),
                    )
                ]
            )
        ],
        y,
    ),
    ret_struct_info=R.Tensor("float32", ndim=1),
)

before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)
before_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def main(cond: R.Tensor((), dtype="bool"), x: R.Tensor((1,), dtype="float32")) -> R.Tensor(dtype="float32", ndim=1):
        if cond:
            y: R.Tensor((1,), dtype="float32") = R.multiply(R.add(x, x), R.add(x, x))
        else:
            y: R.Tensor((1,), dtype="float32") = R.add(R.multiply(x, x), R.multiply(x, x))
        return y
after_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def main(cond: R.Tensor((), dtype="bool"), x: R.Tensor((1,), dtype="float32")) -> R.Tensor((1,), dtype="float32"):
        if cond:
            gv: R.Tensor((1,), dtype="float32") = R.add(x, x)
            gv1: R.Tensor((1,), dtype="float32") = R.add(x, x)
            gv2: R.Tensor((1,), dtype="float32") = R.multiply(gv, gv1)
            y: R.Tensor((1,), dtype="float32") = gv2
        else:
            gv3: R.Tensor((1,), dtype="float32") = R.multiply(x, x)
            gv4: R.Tensor((1,), dtype="float32") = R.multiply(x, x)
            gv5: R.Tensor((1,), dtype="float32") = R.add(gv3, gv4)
            y: R.Tensor((1,), dtype="float32") = gv5
        return y

测试已经是ANF形式的IR的规范化处理#

  • 输入: 已经符合ANF形式的IR模块

  • 预期输出: 保持不变,Normalize变换对其不产生影响

  • 测试重点: 验证Normalize变换对已经规范化的IR是幂等的

  • 幂等性说明: 多次应用同一个变换应产生相同的结果,不会改变已经符合要求的IR

# normalize pass对ANF形式的IR应该是无操作的
@tvm.script.ir_module
class ANFMod1:
    @R.function
    def f(x: R.Tensor(dtype="float32")):
        gv = R.add(x, x)
        gv1 = R.add(gv, gv)
        gv2 = R.add(gv, gv1)
        return (gv, gv2)

before_mod = ANFMod1
after_mod = relax.transform.Normalize()(before_mod)
assert_structural_equal(before_mod, after_mod, map_free_vars=True)

# 测试dataflow块的情况
@tvm.script.ir_module
class ANFMod2:
    @R.function
    def foo(x: R.Tensor(("m", "n"), "float32")):
        m, n = T.int64(), T.int64()
        with R.dataflow():
            lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32"))
            gv0 = R.call_dps_packed(
                "test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")
            )
            R.output(gv0)
        return gv0

mod = ANFMod2
mod_post = relax.transform.Normalize()(mod)

assert_structural_equal(mod, mod_post)

测试序列表达式(SeqExpr)中非叶节点体的规范化处理#

  • 输入: 序列表达式的体(body)不是叶节点的情况

  • 预期输出: 将非叶节点体绑定到一个变量

  • 测试重点: 验证seq expr中非叶节点的处理逻辑

  • 技术细节: 规范化过程会将复杂的表达式绑定到变量,使表达式结构更加扁平

# 一个带有非叶节点体的序列表达式也应该将该体绑定到一个变量
x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
seq = relax.SeqExpr([], relax.op.add(x, y))
f = relax.Function(
    [x, y],
    seq,
    ret_struct_info=R.Tensor([], "int32"),
)

before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)

@R.function(private=True)
def expected(
    x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
) -> R.Tensor(ndim=0, dtype="int32"):
    # 规范化插入了这样的绑定
    z = R.add(x, y)
    return z

assert_structural_equal(after_mod["main"], expected)
# 一个体不是序列表达式的函数应该将其包装在序列表达式中
x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
f = relax.Function(
    [x, y],
    relax.op.add(x, y),
    ret_struct_info=R.Tensor([], "int32"),
)

before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)

@R.function(private=True)
def expected(
    x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32")
) -> R.Tensor(ndim=0, dtype="int32"):
    # 结果将是一个序列表达式,其中body是一个变量
    z = R.add(x, y)
    return z

assert_structural_equal(after_mod["main"], expected)

测试If节点分支的规范化处理#

  • 输入: If节点的分支不是序列表达式的情况

  • 预期输出: 将If节点的分支转换为序列表达式

  • 测试重点: 验证If节点内部结构的规范化

  • 技术要点: 规范化过程会确保If节点的两个分支都符合ANF形式

# if节点的分支必须是序列表达式
x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
# TODO(@relax-team): z具有()形状和TensorType(ndim=0)类型,
# 但规范化未能推断出这些,尽管它应该能推断
z = relax.Var("z")
cond = relax.Var("cond", R.Tensor([], "bool"))
plus = relax.op.add(x, y)
mult = relax.op.multiply(x, y)
if_node = relax.If(cond, plus, mult)
seq = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(z, if_node)])], z)
f = relax.Function(
    [cond, x, y],
    seq,
    ret_struct_info=R.Tensor([], "int32"),
)

before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)

@R.function(private=True)
def expected(
    cond: R.Tensor((), dtype="bool"),
    x: R.Tensor((), dtype="int32"),
    y: R.Tensor((), dtype="int32"),
) -> R.Tensor(ndim=0, dtype="int32"):
    # 分支的body将是带有绑定的序列表达式
    if cond:
        w = R.add(x, y)
        z = w
    else:
        w = R.multiply(x, y)
        z = w
    return z

assert_structural_equal(after_mod["main"], expected)

测试If节点条件的规范化处理#

  • 输入: If节点的条件是复杂表达式的情况

  • 预期输出: 将复杂条件表达式分解为单独的绑定语句

  • 测试重点: 验证If条件表达式的规范化处理

  • 技术细节: 即使是if条件也会被规范化,确保所有复杂表达式都被分解

cond = relax.Var("cond", R.Tensor([], "bool"))
x = relax.Var("x", R.Tensor([1], "float32"))
# TODO(relax-team): 为IfNode添加类型和形状推断
y = relax.Var("y")

# 条件被包装在元组中然后被索引
f = relax.Function(
    [cond, x],
    relax.SeqExpr(
        [
            relax.BindingBlock(
                [
                    relax.VarBinding(
                        y,
                        relax.If(
                            relax.TupleGetItem(relax.Tuple([cond]), 0),
                            relax.op.add(x, x),
                            relax.op.multiply(x, x),
                        ),
                    )
                ]
            )
        ],
        y,
    ),
    ret_struct_info=R.Tensor("float32", ndim=1),
)

before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)

after_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def main(cond: R.Tensor((), dtype="bool"), x: R.Tensor((1,), dtype="float32")) -> R.Tensor((1,), dtype="float32"):
        gv2: R.Tensor((), dtype="bool") = (cond,)[0]
        if gv2:
            gv: R.Tensor((1,), dtype="float32") = R.add(x, x)
            y: R.Tensor((1,), dtype="float32") = gv
        else:
            gv1: R.Tensor((1,), dtype="float32") = R.multiply(x, x)
            y: R.Tensor((1,), dtype="float32") = gv1
        return y

测试元组元素获取(TupleGetItem)的规范化处理#

  • 输入: 嵌套的TupleGetItem操作

  • 预期输出: 将嵌套的元组索引操作分解为单独的绑定语句

  • 测试重点: 验证复杂元组操作的规范化

  • 技术细节: 多层嵌套的元组索引会被分解为多个简单的索引操作

x = relax.Var("x", R.Tensor([], "int32"))
f = relax.Function(
    [x],
    relax.TupleGetItem(
        relax.TupleGetItem(
            relax.Tuple([relax.Tuple([x])]),
            0,
        ),
        0,
    ),
    ret_struct_info=R.Tensor([], "int32"),
)

before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)

# TODO: 在我们规范化SeqExprs后重新审视(作为规范化的一部分?)
# 这次不使用解析器,因为正确写出它会导致
# *一个*绑定块,而规范化版本有*两个*
idx_var = relax.Var("idx_var", R.Tuple([R.Tensor([], "int32")]))
ret_var = relax.Var("ret", R.Tensor([], "int32"))
expected_f = relax.Function(
    [x],
    relax.SeqExpr(
        [
            relax.BindingBlock(
                [
                    relax.VarBinding(
                        idx_var, relax.TupleGetItem(relax.Tuple([relax.Tuple([x])]), 0)
                    )
                ]
            ),
            relax.BindingBlock([relax.VarBinding(ret_var, relax.TupleGetItem(idx_var, 0))]),
        ],
        ret_var,
    ),
    ret_struct_info=R.Tensor([], "int32"),
)
expected_mod = tvm.IRModule.from_expr(expected_f)
# 应用规范化以填充类型和形状注解(否则很繁琐)
final_mod = relax.transform.Normalize()(expected_mod)

assert_structural_equal(after_mod, final_mod)

测试相邻块合并的规范化处理#

  • 输入: 包含多个相邻数据块和绑定块的函数

  • 预期输出: 将相邻的同类块合并,并规范化变量引用

  • 测试重点: 验证块合并优化逻辑

  • 优化策略: Normalize转换会合并相邻的相同类型块,减少IR中的块数量

x = relax.Var("x", R.Tensor([], "int32"))
v0 = relax.Var("v0", R.Tensor([], "int32"))
v1 = relax.Var("v1", R.Tensor([], "int32"))
v2 = relax.Var("v2", R.Tensor([], "int32"))
v3 = relax.Var("v3", R.Tensor([], "int32"))
f = relax.Function(
    [x],
    relax.SeqExpr(
        [
            relax.DataflowBlock([relax.VarBinding(v0, x)]),
            relax.DataflowBlock([relax.VarBinding(v1, v0)]),
            relax.BindingBlock([relax.VarBinding(v2, v1)]),
            relax.BindingBlock([relax.VarBinding(v3, v2)]),
        ],
        v3,
    ),
    ret_struct_info=R.Tensor([], "int32"),
)

after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))
after_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
        with R.dataflow():
            v0: R.Tensor((), dtype="int32") = x
            v1: R.Tensor((), dtype="int32") = v0
            R.output(v0, v1)
        v2: R.Tensor((), dtype="int32") = v1
        v3: R.Tensor((), dtype="int32") = v2
        return v3

测试嵌套序列表达式的规范化处理#

  • 输入: 包含嵌套SeqExpr的函数

  • 预期输出: 展平嵌套结构,将所有绑定提升到顶层

  • 测试重点: 验证嵌套序列的展平逻辑

  • 转换机制: Normalize转换会递归处理嵌套的序列表达式,将它们展平为单一层次的绑定

x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
z = relax.Var("z", R.Tensor([], "int32"))
seq = relax.SeqExpr(
    [
        relax.BindingBlock(
            [
                relax.VarBinding(x, relax.const(1)),
                relax.VarBinding(
                    y,
                    relax.SeqExpr(
                        [relax.BindingBlock([relax.VarBinding(z, relax.const(2))])],
                        z,
                    ),
                ),
            ]
        )
    ],
    y,
)

f = relax.Function(
    [],
    seq,
    ret_struct_info=R.Tensor([], "int32"),
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))

@R.function(private=True)
def expected():
    x = relax.const(1)
    z = relax.const(2)
    y = z
    return y

assert_structural_equal(after_mod["main"], expected)

测试包含数据流块的嵌套序列表达式的规范化处理#

  • 输入: 包含DataflowBlock的嵌套SeqExpr

  • 预期输出: 展平嵌套结构,同时保留数据流块的特性

  • 测试重点: 验证嵌套数据流块的处理逻辑

  • 技术挑战: 需要在展平嵌套结构的同时,保持数据流块的语义完整性

x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
z = relax.Var("z", R.Tensor([], "int32"))
q = relax.Var("u", R.Tensor([], "int32"))
w = relax.DataflowVar("w", R.Tensor([], "int32"))
u = relax.Var("u", R.Tensor([], "int32"))
seq = relax.SeqExpr(
    [
        relax.BindingBlock(
            [
                relax.VarBinding(x, relax.const(1)),
                relax.VarBinding(
                    y,
                    relax.SeqExpr(
                        [
                            relax.BindingBlock([relax.VarBinding(q, relax.const(2))]),
                            relax.DataflowBlock(
                                [
                                    relax.VarBinding(w, q),
                                    relax.VarBinding(u, w),
                                ]
                            ),
                            relax.BindingBlock([relax.VarBinding(z, u)]),
                        ],
                        z,
                    ),
                ),
            ]
        )
    ],
    y,
)

f = relax.Function(
    [],
    seq,
    ret_struct_info=R.Tensor([], "int32"),
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))

@R.function(private=True)
def expected():
    x = relax.const(1)
    q = relax.const(2)
    with R.dataflow():
        w = q
        u = w
        R.output(u)
    z = u
    y = z
    return y

assert_structural_equal(after_mod["main"], expected)

测试深层嵌套序列表达式的规范化处理#

  • 输入: 多层嵌套的SeqExpr

  • 预期输出: 完全展平深层嵌套结构

  • 测试重点: 验证深层嵌套的处理能力

  • 边界测试: 确保规范化转换能够处理任意深度的嵌套结构

x = relax.Var("x", R.Tensor([], "int32"))
y = relax.Var("y", R.Tensor([], "int32"))
z = relax.Var("z", R.Tensor([], "int32"))
u = relax.Var("u", R.Tensor([], "int32"))
v = relax.Var("v", R.Tensor([], "int32"))
w = relax.Var("w", R.Tensor([], "int32"))
_ = relax.Var("w", R.Tensor([], "int32"))
seq = relax.SeqExpr(
    [
        relax.BindingBlock(
            [
                relax.VarBinding(x, relax.const(1)),
                relax.VarBinding(
                    y,
                    relax.SeqExpr(
                        [
                            relax.BindingBlock(
                                [
                                    relax.VarBinding(
                                        z,
                                        relax.SeqExpr(
                                            [
                                                relax.BindingBlock(
                                                    [
                                                        relax.VarBinding(u, relax.const(2)),
                                                        relax.MatchCast(
                                                            _, u, R.Tensor([], "int32")
                                                        ),
                                                        relax.VarBinding(v, u),
                                                        relax.MatchCast(
                                                            w, v, R.Tensor([], "int32")
                                                        ),
                                                    ]
                                                )
                                            ],
                                            w,
                                        ),
                                    )
                                ]
                            )
                        ],
                        z,
                    ),
                ),
            ]
        )
    ],
    y,
)

f = relax.Function(
    [],
    seq,
    ret_struct_info=R.Tensor([], "int32"),
)
after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f))

@R.function(private=True)
def expected():
    x = relax.const(1)
    u = relax.const(2)
    _ = R.match_cast(u, R.Tensor((), "int32"))
    v = u
    w = R.match_cast(v, R.Tensor((), "int32"))
    z = w
    y = z
    return y

assert_structural_equal(after_mod["main"], expected)

测试在数据流块中嵌套非数据流块的错误情况#

  • 标记为预期失败的测试

  • 验证: 在数据流块中嵌套普通绑定块应该失败

@pytest.mark.xfail()
# xfail标记表示这个测试预期会失败,这是因为当前实现不支持在DataflowBlock中嵌套普通BindingBlock
# 这个测试用例验证了IR的结构约束
def test_nesting_non_dataflow_in_dataflow_error():
    x = relax.DataflowVar("x", R.Tensor([], "int32"))
    y = relax.Var("y", R.Tensor([], "int32"))
    z = relax.Var("z", R.Tensor([], "int32"))
    seq = relax.SeqExpr(
        [
            relax.DataflowBlock(
                [
                    relax.VarBinding(x, relax.const(1)),
                    relax.VarBinding(
                        y,
                        relax.SeqExpr(
                            [relax.BindingBlock([relax.VarBinding(z, relax.const(2))])],
                            z,
                        ),
                    ),
                ]
            )
        ],
        y,
    )
    f = relax.Function(
        [],
        seq,
        ret_struct_info=R.Tensor([], "int32"),
    )
    relax.transform.Normalize()(tvm.IRModule.from_expr(f))
    # 应该失败,因为在dataflowblock内部有一个普通的binding block

测试移除void类型变量的使用#

  • 验证: 所有空元组都应该内联构造

  • 技术细节: 为了可读性,TVMScript隐藏了void类型变量的绑定,但在Relax中使用空元组表示void

  • 优化处理: 通过规范化将void类型变量的使用替换为内联的R.tuple()

def test_remove_usage_of_void_type_variables():
    """所有空元组都应该内联构造

    为了可读性,TVMScript隐藏了类型为void的变量的绑定。例如,`R.assert_op(condition)`
    而不是`void_var: R.Tuple([]) = R.assert_op(condition)`。
    然而,Relax遵循函数式语言的标准约定,使用空元组表示void。由于空元组
    可能在函数后面被合法使用,`void_var`可能需要一个绑定。

    通过使用内联的`R.tuple()`规范化所有void类型变量的使用,可以避免这种情况。
    """
    x = relax.Var("x", R.Tuple([]))
    bindings = [
        relax.VarBinding(x, R.assert_op(R.const(True, "bool"))),
    ]
    seq = relax.SeqExpr([relax.BindingBlock(bindings)], x)
    before = relax.Function([], seq, ret_struct_info=R.Tuple([]), is_pure=False)

    after = relax.transform.Normalize()(tvm.IRModule({"main": before}))["main"]

    @R.function(private=True, pure=False)
    def expected():
        x = R.assert_op(R.const(True, "bool"))
        return R.tuple()