FNormalize#

FNormalize 是 TVM Relax 中的特殊属性,用于标准化算子的参数格式。

测试主要验证 FNormalize 在不同场景下的行为,包括:

  1. 解析 TVMScript 时是否抑制 FNormalize 应用

  2. 在 C++ 和 Python 的 Mutator 中是否正确应用 FNormalize

  3. 标准化后的调用节点是否格式良好

  4. 未标准化的调用节点是否格式不良

  5. 针对特定算子(如 call_tircall_tir_inplacecall_tir_with_grad)的参数元组内联处理

import tvm
import tvm.testing
import tvm.relax.testing.transform

from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
def register_custom_op(define_normalization=True):
    """
    用于测试目的的自定义算子

    这个自定义算子会忽略其第二个参数。如果没有可以忽略的第二个参数,
    FNormalize 会添加额外的参数以便正确地忽略它。

    参数:
        define_normalization: 布尔值,决定是否为算子定义 FNormalize 属性

    返回:
        自定义算子实例
    """

    op_name = "custom_op.ignore_second_argument"

    def infer_struct_info(call: relax.Call, context: relax.BlockBuilder):
        """推断调用的结构信息,直接返回第一个参数的结构信息"""
        return call.args[0].struct_info

    def normalize(context: relax.BlockBuilder, call: relax.Call):
        """标准化调用参数:如果只有一个参数,则添加空元组作为第二个参数"""
        if len(call.args) == 1:
            return relax.Call(call.op, [call.args[0], relax.Tuple([])])
        else:
            return call

    def legalize(context: relax.BlockBuilder, call: relax.Call):
        """合法化调用:直接返回第一个参数"""
        return call.args[0]

    # 定义算子属性
    op_attrs = {
        "FInferStructInfo": infer_struct_info,  # 推断结构信息
        "FLegalize": legalize,  # 合法化处理
        "FPurity": True,  # 标记为纯函数
    }
    if define_normalization:
        op_attrs["FNormalize"] = normalize  # 如果需要,添加标准化函数

    # 注册算子属性
    for key, value in op_attrs.items():
        tvm.ir.register_op_attr(op_name, key, value)

    op = tvm.ir.Op.get(op_name)
    return op

custom_op = register_custom_op(define_normalization=True)

测试在解析 TVMScript 时是否抑制应用 FNormalize#

@R.function(check_well_formed=False)  # 禁用格式良好检查
def func(A: R.Tensor) -> R.Tensor:
    return relax.Call(custom_op, [A])
func.show()
# 提取函数体内的调用表达式
call_expr = func.body.blocks[0].bindings[0].value
assert isinstance(
    call_expr, relax.Call
), "测试实现错误,未提取到正确的表达式"
assert (
    len(call_expr.args) == 1
), "期望 TVMScript 抑制使用 FNormalize,按原样生成参数"
# from tvm.script import relax as R

@R.function
def func(A: R.Tensor) -> R.Tensor:
    gv: R.Tensor = custom_op.ignore_second_argument(A)
    return gv

测试在 C++ 的 ExprMutator 子类中是否应用 FNormalize#

@I.ir_module(check_well_formed=False)
class Before:
    @R.function
    def main(A: R.Tensor):
        return relax.Call(custom_op, [A])
@I.ir_module
class Expected:
    @R.function
    def main(A: R.Tensor):
        return relax.Call(custom_op, [A, R.tuple()])

# 应用空的 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)

# 验证 Before 和 After 不相等,但 After 与 Expected 相等
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)

测试在 Python 的 ExprMutator 子类中是否应用 FNormalize#

@R.function(private=True, check_well_formed=False)
def before(A: R.Tensor):
    return relax.Call(custom_op, [A])

@R.function(private=True)
def expected(A: R.Tensor):
    return relax.Call(custom_op, [A, R.tuple()])

@relax.expr_functor.mutator
class EmptyPyExprMutator(relax.PyExprMutator):
    """默认的 ExprMutator 实现"""
    ...

# 应用 Python Mutator
after = EmptyPyExprMutator().visit_expr(before)

# 验证 before 和 after 不相等,但 after 与 expected 相等
assert not tvm.ir.structural_equal(before, after)
tvm.ir.assert_structural_equal(expected, after)

测试如果 FNormalize 不应用更改,IR 是否格式良好#

@I.ir_module
class Module:
    @R.function
    def main(A: R.Tensor):
        return relax.Call(custom_op, [A, A])

# 验证模块格式良好
assert relax.analysis.well_formed(Module)

测试如果 FNormalize 应用更改,IR 是否格式不良#

@I.ir_module(check_well_formed=False)
class Module:
    @R.function
    def main(A: R.Tensor):
        return relax.Call(custom_op, [A])

for define_normalization in [1, 0]:
    for key in ["FInferStructInfo", "FLegalize", "FPurity", "FNormalize"]:
        custom_op.reset_attr(key)
    custom_op = register_custom_op(define_normalization=define_normalization)
    # 如果定义了 FNormalize,则模块格式不良;否则格式良好
    if define_normalization:
        assert not relax.analysis.well_formed(Module)
    else:
        assert relax.analysis.well_formed(Module)
[18:02:22] /media/pc/data/lxw/ai/tvm/src/relax/analysis/well_formed.cc:134: Warning: This IR is not well formed: If an operator defines an operator-specific normalization function (FNormalize), calls to that operator must be normalized with it.  However, normalization of custom_op.ignore_second_argument(A) resulted in custom_op.ignore_second_argument(A, R.tuple())

测试 FNormalize 是否为 R.call_tir 内联参数元组#

@I.ir_module(check_well_formed=False)
class Before:
    @R.function
    def main(A: R.Tensor([16], "float32")):
        cls = Before
        args = (A,)
        return relax.Call(
            tvm.ir.Op.get("relax.call_tir"),
            [cls.multiply_by_two, args],
            sinfo_args=[A.struct_info],
        )

    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
        for i in range(16):
            B[i] = A[i] * 2.0

@I.ir_module
class Expected:
    @R.function
    def main(A: R.Tensor([16], "float32")):
        cls = Expected
        args = (A,)
        return relax.Call(
            tvm.ir.Op.get("relax.call_tir"),
            [cls.multiply_by_two, relax.Tuple([A])],  # 内联的元组
            sinfo_args=[A.struct_info],
        )

    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
        for i in range(16):
            B[i] = A[i] * 2.0

# 应用 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)

# 验证结果
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)

测试 FNormalize 是否为 R.call_tir 内联参数元组#

@I.ir_module(check_well_formed=False)
class Before:
    @R.function
    def main(args: R.Tuple([R.Tensor([16], "float32")])):
        cls = Before
        return relax.Call(
            tvm.ir.Op.get("relax.call_tir"),
            [cls.multiply_by_two, args],
            sinfo_args=[args[0].struct_info],
        )

    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
        for i in range(16):
            B[i] = A[i] * 2.0

@I.ir_module
class Expected:
    @R.function
    def main(args: R.Tuple([R.Tensor([16], "float32")])):
        cls = Expected
        return relax.Call(
            tvm.ir.Op.get("relax.call_tir"),
            [cls.multiply_by_two, relax.Tuple([args[0]])],  # 从函数参数构建的内联元组
            sinfo_args=[args[0].struct_info],
        )

    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
        for i in range(16):
            B[i] = A[i] * 2.0

# 应用 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)

# 验证结果
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)

测试 FNormalize 是否为 R.call_tir_inplace 内联参数元组#

# CallTIRInplaceAttrs 在 Python API 中构造困难,因此先声明期望的模块并重用其属性
@I.ir_module
class Expected:
    @R.function
    def main(A: R.Tensor([16], "float32")):
        cls = Expected
        args = (A,)
        return R.call_tir_inplace(
            cls.multiply_by_two,
            A,
            inplace_indices=[0],
            out_sinfo=[A.struct_info],
        )

    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer(16, "float32")):
        for i in range(16):
            A[i] = A[i] * 2.0

# 提取 inplace_attrs
inplace_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs

@I.ir_module(check_well_formed=False)
class Before:
    @R.function
    def main(A: R.Tensor([16], "float32")):
        cls = Before
        args = (A,)
        return relax.Call(
            tvm.ir.Op.get("relax.call_tir_inplace"),
            [cls.multiply_by_two, args],
            attrs=inplace_attrs,
            sinfo_args=[A.struct_info],
        )

    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer(16, "float32")):
        for i in range(16):
            A[i] = A[i] * 2.0

# 应用 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)

# 验证结果
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)

测试 FNormalize 是否为 R.call_tir_with_grad 内联参数元组#

# CallTIRWithGradAttrs 在 Python API 中构造困难,因此先声明期望的模块并重用其属性
@I.ir_module
class Expected:
    @R.function
    def main(A: R.Tensor([16], "float32")):
        cls = Expected
        args = (A,)
        return R.call_tir_with_grad(
            cls.multiply_by_two,
            A,
            out_sinfo=[A.struct_info],
            te_grad_name="f_grad",
        )

    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
        for i in range(16):
            B[i] = A[i] * 2.0

    @T.prim_func(private=True)
    def f_grad(
        A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32")
    ):
        for i in range(16):
            Grad[i] = 2.0

# 提取 with_grad_attrs
with_grad_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs

@I.ir_module(check_well_formed=False)
class Before:
    @R.function
    def main(A: R.Tensor([16], "float32")):
        cls = Before
        args = (A,)
        return relax.Call(
            tvm.ir.Op.get("relax.call_tir_with_grad"),
            [cls.multiply_by_two, args],
            attrs=with_grad_attrs,
            sinfo_args=[A.struct_info],
        )

    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
        for i in range(16):
            B[i] = A[i] * 2.0

    @T.prim_func(private=True)
    def f_grad(
        A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32")
    ):
        for i in range(16):
            Grad[i] = 2.0

# 应用 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)

# 验证结果
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)