FNormalize#
FNormalize 是 TVM Relax 中的特殊属性,用于标准化算子的参数格式。
测试主要验证 FNormalize 在不同场景下的行为,包括:
解析 TVMScript 时是否抑制 FNormalize 应用
在 C++ 和 Python 的 Mutator 中是否正确应用 FNormalize
标准化后的调用节点是否格式良好
未标准化的调用节点是否格式不良
针对特定算子(如
call_tir、call_tir_inplace、call_tir_with_grad)的参数元组内联处理
import tvm
import tvm.testing
import tvm.relax.testing.transform
from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
def register_custom_op(define_normalization=True):
"""
用于测试目的的自定义算子
这个自定义算子会忽略其第二个参数。如果没有可以忽略的第二个参数,
FNormalize 会添加额外的参数以便正确地忽略它。
参数:
define_normalization: 布尔值,决定是否为算子定义 FNormalize 属性
返回:
自定义算子实例
"""
op_name = "custom_op.ignore_second_argument"
def infer_struct_info(call: relax.Call, context: relax.BlockBuilder):
"""推断调用的结构信息,直接返回第一个参数的结构信息"""
return call.args[0].struct_info
def normalize(context: relax.BlockBuilder, call: relax.Call):
"""标准化调用参数:如果只有一个参数,则添加空元组作为第二个参数"""
if len(call.args) == 1:
return relax.Call(call.op, [call.args[0], relax.Tuple([])])
else:
return call
def legalize(context: relax.BlockBuilder, call: relax.Call):
"""合法化调用:直接返回第一个参数"""
return call.args[0]
# 定义算子属性
op_attrs = {
"FInferStructInfo": infer_struct_info, # 推断结构信息
"FLegalize": legalize, # 合法化处理
"FPurity": True, # 标记为纯函数
}
if define_normalization:
op_attrs["FNormalize"] = normalize # 如果需要,添加标准化函数
# 注册算子属性
for key, value in op_attrs.items():
tvm.ir.register_op_attr(op_name, key, value)
op = tvm.ir.Op.get(op_name)
return op
custom_op = register_custom_op(define_normalization=True)
测试在解析 TVMScript 时是否抑制应用 FNormalize#
@R.function(check_well_formed=False) # 禁用格式良好检查
def func(A: R.Tensor) -> R.Tensor:
return relax.Call(custom_op, [A])
func.show()
# 提取函数体内的调用表达式
call_expr = func.body.blocks[0].bindings[0].value
assert isinstance(
call_expr, relax.Call
), "测试实现错误,未提取到正确的表达式"
assert (
len(call_expr.args) == 1
), "期望 TVMScript 抑制使用 FNormalize,按原样生成参数"
# from tvm.script import relax as R
@R.function
def func(A: R.Tensor) -> R.Tensor:
gv: R.Tensor = custom_op.ignore_second_argument(A)
return gv
测试在 C++ 的 ExprMutator 子类中是否应用 FNormalize#
@I.ir_module(check_well_formed=False)
class Before:
@R.function
def main(A: R.Tensor):
return relax.Call(custom_op, [A])
@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor):
return relax.Call(custom_op, [A, R.tuple()])
# 应用空的 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
# 验证 Before 和 After 不相等,但 After 与 Expected 相等
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)
测试在 Python 的 ExprMutator 子类中是否应用 FNormalize#
@R.function(private=True, check_well_formed=False)
def before(A: R.Tensor):
return relax.Call(custom_op, [A])
@R.function(private=True)
def expected(A: R.Tensor):
return relax.Call(custom_op, [A, R.tuple()])
@relax.expr_functor.mutator
class EmptyPyExprMutator(relax.PyExprMutator):
"""默认的 ExprMutator 实现"""
...
# 应用 Python Mutator
after = EmptyPyExprMutator().visit_expr(before)
# 验证 before 和 after 不相等,但 after 与 expected 相等
assert not tvm.ir.structural_equal(before, after)
tvm.ir.assert_structural_equal(expected, after)
测试如果 FNormalize 不应用更改,IR 是否格式良好#
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor):
return relax.Call(custom_op, [A, A])
# 验证模块格式良好
assert relax.analysis.well_formed(Module)
测试如果 FNormalize 应用更改,IR 是否格式不良#
@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(A: R.Tensor):
return relax.Call(custom_op, [A])
for define_normalization in [1, 0]:
for key in ["FInferStructInfo", "FLegalize", "FPurity", "FNormalize"]:
custom_op.reset_attr(key)
custom_op = register_custom_op(define_normalization=define_normalization)
# 如果定义了 FNormalize,则模块格式不良;否则格式良好
if define_normalization:
assert not relax.analysis.well_formed(Module)
else:
assert relax.analysis.well_formed(Module)
[18:02:22] /media/pc/data/lxw/ai/tvm/src/relax/analysis/well_formed.cc:134: Warning: This IR is not well formed: If an operator defines an operator-specific normalization function (FNormalize), calls to that operator must be normalized with it. However, normalization of custom_op.ignore_second_argument(A) resulted in custom_op.ignore_second_argument(A, R.tuple())
测试 FNormalize 是否为 R.call_tir 内联参数元组#
@I.ir_module(check_well_formed=False)
class Before:
@R.function
def main(A: R.Tensor([16], "float32")):
cls = Before
args = (A,)
return relax.Call(
tvm.ir.Op.get("relax.call_tir"),
[cls.multiply_by_two, args],
sinfo_args=[A.struct_info],
)
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
for i in range(16):
B[i] = A[i] * 2.0
@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor([16], "float32")):
cls = Expected
args = (A,)
return relax.Call(
tvm.ir.Op.get("relax.call_tir"),
[cls.multiply_by_two, relax.Tuple([A])], # 内联的元组
sinfo_args=[A.struct_info],
)
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
for i in range(16):
B[i] = A[i] * 2.0
# 应用 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
# 验证结果
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)
测试 FNormalize 是否为 R.call_tir 内联参数元组#
@I.ir_module(check_well_formed=False)
class Before:
@R.function
def main(args: R.Tuple([R.Tensor([16], "float32")])):
cls = Before
return relax.Call(
tvm.ir.Op.get("relax.call_tir"),
[cls.multiply_by_two, args],
sinfo_args=[args[0].struct_info],
)
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
for i in range(16):
B[i] = A[i] * 2.0
@I.ir_module
class Expected:
@R.function
def main(args: R.Tuple([R.Tensor([16], "float32")])):
cls = Expected
return relax.Call(
tvm.ir.Op.get("relax.call_tir"),
[cls.multiply_by_two, relax.Tuple([args[0]])], # 从函数参数构建的内联元组
sinfo_args=[args[0].struct_info],
)
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
for i in range(16):
B[i] = A[i] * 2.0
# 应用 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
# 验证结果
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)
测试 FNormalize 是否为 R.call_tir_inplace 内联参数元组#
# CallTIRInplaceAttrs 在 Python API 中构造困难,因此先声明期望的模块并重用其属性
@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor([16], "float32")):
cls = Expected
args = (A,)
return R.call_tir_inplace(
cls.multiply_by_two,
A,
inplace_indices=[0],
out_sinfo=[A.struct_info],
)
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer(16, "float32")):
for i in range(16):
A[i] = A[i] * 2.0
# 提取 inplace_attrs
inplace_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs
@I.ir_module(check_well_formed=False)
class Before:
@R.function
def main(A: R.Tensor([16], "float32")):
cls = Before
args = (A,)
return relax.Call(
tvm.ir.Op.get("relax.call_tir_inplace"),
[cls.multiply_by_two, args],
attrs=inplace_attrs,
sinfo_args=[A.struct_info],
)
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer(16, "float32")):
for i in range(16):
A[i] = A[i] * 2.0
# 应用 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
# 验证结果
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)
测试 FNormalize 是否为 R.call_tir_with_grad 内联参数元组#
# CallTIRWithGradAttrs 在 Python API 中构造困难,因此先声明期望的模块并重用其属性
@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor([16], "float32")):
cls = Expected
args = (A,)
return R.call_tir_with_grad(
cls.multiply_by_two,
A,
out_sinfo=[A.struct_info],
te_grad_name="f_grad",
)
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
for i in range(16):
B[i] = A[i] * 2.0
@T.prim_func(private=True)
def f_grad(
A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32")
):
for i in range(16):
Grad[i] = 2.0
# 提取 with_grad_attrs
with_grad_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs
@I.ir_module(check_well_formed=False)
class Before:
@R.function
def main(A: R.Tensor([16], "float32")):
cls = Before
args = (A,)
return relax.Call(
tvm.ir.Op.get("relax.call_tir_with_grad"),
[cls.multiply_by_two, args],
attrs=with_grad_attrs,
sinfo_args=[A.struct_info],
)
@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
for i in range(16):
B[i] = A[i] * 2.0
@T.prim_func(private=True)
def f_grad(
A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: T.Buffer(16, "float32")
):
for i in range(16):
Grad[i] = 2.0
# 应用 C++ Mutator
After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
# 验证结果
assert not tvm.ir.structural_equal(Before, After)
tvm.ir.assert_structural_equal(Expected, After)