测试自定义算子合法化变换功能#
本测试验证用户是否可以为特定算子定义自定义的合法化变换逻辑。
import tvm
from tvm import relax
from tvm.relax.transform import LegalizeOps
from tvm.relax.transform.legalize_ops.common import register_legalize
from tvm.script import relax as R, tir as T, ir as I
自定义 legalize#
将 R.add 算子通过自定义函数变换为使用 topi.add 的 TE 调用并验证变换后的 IR 是否与预期结果一致。
@tvm.script.ir_module
class Add:
@R.function
def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
gv: R.Tensor((4, 3, 2, 3), "float32") = R.add(x, y)
return gv
定义自定义的合法化变换函数,将 x 和 y 的顺序调换后调用 topi.add:
def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call):
from tvm import topi # pylint: disable=import-outside-toplevel
return bb.call_te(topi.add, call.args[1], call.args[0])
应用 LegalizeOps 变换,传入自定义的合法化函数
mod = LegalizeOps({"relax.add": customize_legalize_add})(Add)
mod.show()
测试不同类型的调用是否都能被正确合法化#
定义包含多种调用类型的测试模块
@tvm.script.ir_module
class Before:
@R.function
def mul2(x: R.Tensor((3, 3), "float32")):
gv = R.multiply(x, R.const(2.0, "float32"))
return gv
@T.prim_func(private=True)
def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")):
for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1])
T.writes(T_id[v_ax0, v_ax1])
T_id[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1]
@R.function
def main(x: R.Tensor((3, 3), "float32")):
cls = Before
gv: R.Tensor((3, 3), "float32") = cls.mul2(x)
gv1 = R.call_tir(cls.identity, gv, R.Tensor((3, 3), dtype="float32"))
gv2 = R.multiply(gv1, R.const(2.0, "float32"))
return gv2
# 应用LegalizeOps转换
After = LegalizeOps()(Before)
After.show()
测试无法进行合法化转变换的情况#
本测试验证当算子没有对应的合法化函数或缺少必要的形状信息时,变换行为是否符合预期
情况1:算子没有对应的合法化函数
add_legalize = tvm.ir.Op.get("relax.add").get_attr("FLegalize")
# 重置属性用于测试
tvm.ir.Op.get("relax.add").reset_attr("FLegalize")
定义简单的包含 add 算子的模块
@tvm.script.ir_module
class Before0:
@R.function
def main(x: R.Tensor((3, 3), "float32")):
gv: R.Tensor((3, 3), "float32") = R.add(x, x)
return gv
# 应用转换(此时没有add操作的合法化函数)
After0 = LegalizeOps()(Before0)
# 验证模块是否保持不变
tvm.ir.assert_structural_equal(After0, Before0)
恢复原有的合法化函数
register_legalize("relax.add", add_legalize)
ffi.Function(0x55de62489d30)
情况2:无法确定所有形状信息
s = relax.Var("s", relax.ShapeStructInfo((3, 3)))
x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
y = relax.Var("y", relax.TensorStructInfo(s, "float32"))
bb = relax.BlockBuilder()
with bb.function("main", [x, y]):
with bb.dataflow():
gv = bb.emit_output(R.add(x, y))
bb.emit_func_output(gv)
Before1 = bb.get()
# 应用转换(此时无法确定y的完整形状信息)
After1 = LegalizeOps()(Before1)
# 验证模块是否保持不变
tvm.ir.assert_structural_equal(After1, Before1)
测试不同数据类型的标量算子合法化变换#
本测试验证 LegalizeOps 在处理 float16、uint8 和 bool 等不同数据类型时能够正确保留类型信息
# 定义float16数据类型的测试模块
@tvm.script.ir_module
class Before0:
@R.function
def main(x: R.Tensor((3, 3), "float16")):
gv: R.Tensor((3, 3), "float16") = R.multiply(x, R.const(1.14514, "float16"))
return gv
# 定义uint8数据类型的测试模块
@tvm.script.ir_module
class Before1:
@R.function
def main(x: R.Tensor((3, 3), "uint8")):
gv: R.Tensor((3, 3), "uint8") = R.multiply(x, R.const(2, "uint8"))
return gv
# 定义bool数据类型的测试模块
@tvm.script.ir_module
class Before2:
@R.function
def main(x: R.Tensor((3, 3), "bool")):
gv: R.Tensor((3, 3), "bool") = R.equal(x, R.const(True, "bool"))
return gv
应用转换并验证结果
After0 = LegalizeOps()(Before0)
After1 = LegalizeOps()(Before1)
After2 = LegalizeOps()(Before2)
After0.show()
After1.show()
After2.show()
测试矩阵乘法算子合法化要求已知数据类型#
本测试验证当 matmul 算子缺少明确数据类型时,合法化转换应抛出适当的错误
import pytest
@I.ir_module
class ArbitraryDtype:
@R.function
def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16, 8]):
return R.matmul(A, B)
# 验证转换时是否抛出预期的错误
with pytest.raises(AssertionError) as err:
LegalizeOps()(ArbitraryDtype)
# 错误应该是在尝试合法化R.matmul时捕获的,并提供友好的错误消息
# 而不是等到`BlockBuilder.call_te`实现时,尝试创建kHandle类型的数值常量时才抛出错误
err_message = err.value.args[0]
assert err_message.startswith("To legalize R.matmul")
[13:44:41] /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!
测试带 vdevice 的合法化变换#
本测试验证当参数类型仅在 vdevice 上不同时,LegalizeOps 能够为不同目标生成不同的内核。
这是回归测试,之前的实现中,具有不同 vdevice 的参数类型会被合法化为使用相同的 PrimFunc。
@I.ir_module
class Before:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})
@R.function
def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
C = R.add(A, B)
return C
@R.function
def func_llvm(
A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
):
C = R.add(A, B)
return C
# 在CUDA目标下应用转换
with tvm.target.Target("cuda"):
After = tvm.relax.transform.LegalizeOps()(Before)
After.show()
自定义算子,测试递归合法化功能#
定义测试参数,测试不同的合法化变换返回方式
import tvm.testing
def register_custom_op(emit_legalization_through_builder):
op_name = "custom_op.matmul_bias_add"
# 定义结构信息推断函数
def infer_struct_info(call: relax.Call, context):
activations, weight, bias = call.args
matmul_call = relax.op.matmul(activations, weight)
matmul_sinfo = tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
matmul_call, context
)
matmul_var = relax.Var("dummy_var", matmul_sinfo)
add_call = matmul_var + bias
add_sinfo = tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)
return add_sinfo
# 定义合法化函数
def legalize(bb: relax.BlockBuilder, call: relax.Call):
activations, weight, bias = call.args
legalized = relax.op.matmul(activations, weight) + bias
if emit_legalization_through_builder:
legalized = bb.emit(legalized)
return legalized
# 注册操作的属性
op_attrs = {
"FInferStructInfo": infer_struct_info,
"FLegalize": legalize,
"FPurity": True,
}
for key, value in op_attrs.items():
tvm.ir.register_op_attr(op_name, key, value)
op = tvm.ir.Op.get(op_name)
return op
# yield op
# # 清理:重置属性
# for key in op_attrs:
# op.reset_attr(key)
custom_op = register_custom_op(emit_legalization_through_builder={
"return_relax_expr": False,
"return_relax_var": True,
})
本测试验证算子的合法化变换可能生成新的需要合法化的算子
@I.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16, 32, 64], "float32"),
Weight: R.Tensor([64, 128], "float32"),
Bias: R.Tensor([16, 32, 128], "float32"),
):
return relax.Call(custom_op, [A, Weight, Bias])
# 应用一次LegalizeOps转换
AfterFirstIter = LegalizeOps()(Before)
# 再次应用LegalizeOps转换
AfterSecondIter = LegalizeOps()(AfterFirstIter)
# LegalizeOps后,自定义操作应被替换为`R.matmul`和`R.add`,
# 这些操作又应被替换为TIR实现。因此,第二次应用LegalizeOps()应该是无效操作。
tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)