测试自定义算子合法化变换功能#

本测试验证用户是否可以为特定算子定义自定义的合法化变换逻辑。

import tvm
from tvm import relax
from tvm.relax.transform import LegalizeOps
from tvm.relax.transform.legalize_ops.common import register_legalize
from tvm.script import relax as R, tir as T, ir as I

自定义 legalize#

R.add 算子通过自定义函数变换为使用 topi.add 的 TE 调用并验证变换后的 IR 是否与预期结果一致。

@tvm.script.ir_module
class Add:
    @R.function
    def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
        gv: R.Tensor((4, 3, 2, 3), "float32") = R.add(x, y)
        return gv

定义自定义的合法化变换函数,将 xy 的顺序调换后调用 topi.add

def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call):
    from tvm import topi  # pylint: disable=import-outside-toplevel
    return bb.call_te(topi.add, call.args[1], call.args[0])

应用 LegalizeOps 变换,传入自定义的合法化函数

mod = LegalizeOps({"relax.add": customize_legalize_add})(Add)
mod.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(y: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), x: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(y[v_ax0, v_ax1, v_ax2, T.int64(0)], x[T.int64(0), v_ax2, v_ax3])
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3] = y[v_ax0, v_ax1, v_ax2, T.int64(0)] + x[T.int64(0), v_ax2, v_ax3]

    @R.function
    def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3, 2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"):
        cls = Module
        gv = R.call_tir(cls.add, (y, x), out_sinfo=R.Tensor((4, 3, 2, 3), dtype="float32"))
        return gv

测试不同类型的调用是否都能被正确合法化#

定义包含多种调用类型的测试模块

@tvm.script.ir_module
class Before:
    @R.function
    def mul2(x: R.Tensor((3, 3), "float32")):
        gv = R.multiply(x, R.const(2.0, "float32"))
        return gv

    @T.prim_func(private=True)
    def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")):
        for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(rxplaceholder[v_ax0, v_ax1])
                T.writes(T_id[v_ax0, v_ax1])
                T_id[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1]

    @R.function
    def main(x: R.Tensor((3, 3), "float32")):
        cls = Before
        gv: R.Tensor((3, 3), "float32") = cls.mul2(x)
        gv1 = R.call_tir(cls.identity, gv, R.Tensor((3, 3), dtype="float32"))
        gv2 = R.multiply(gv1, R.const(2.0, "float32"))
        return gv2
# 应用LegalizeOps转换
After = LegalizeOps()(Before)
After.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")):
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(rxplaceholder[v_ax0, v_ax1])
                T.writes(T_id[v_ax0, v_ax1])
                T_id[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def multiply(gv1: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(gv1[v_ax0, v_ax1])
                T.writes(T_multiply[v_ax0, v_ax1])
                T_multiply[v_ax0, v_ax1] = gv1[v_ax0, v_ax1] * T.float32(2.0)

    @R.function
    def mul2(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"):
        cls = Module
        gv = R.call_tir(cls.multiply, (x,), out_sinfo=R.Tensor((3, 3), dtype="float32"))
        return gv

    @R.function
    def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"):
        cls = Module
        gv: R.Tensor((3, 3), dtype="float32") = cls.mul2(x)
        gv1 = R.call_tir(cls.identity, (gv,), out_sinfo=R.Tensor((3, 3), dtype="float32"))
        gv2 = R.call_tir(cls.multiply, (gv1,), out_sinfo=R.Tensor((3, 3), dtype="float32"))
        return gv2

测试无法进行合法化转变换的情况#

本测试验证当算子没有对应的合法化函数或缺少必要的形状信息时,变换行为是否符合预期

情况1:算子没有对应的合法化函数

add_legalize = tvm.ir.Op.get("relax.add").get_attr("FLegalize")
# 重置属性用于测试
tvm.ir.Op.get("relax.add").reset_attr("FLegalize")

定义简单的包含 add 算子的模块

@tvm.script.ir_module
class Before0:
    @R.function
    def main(x: R.Tensor((3, 3), "float32")):
        gv: R.Tensor((3, 3), "float32") = R.add(x, x)
        return gv
# 应用转换(此时没有add操作的合法化函数)
After0 = LegalizeOps()(Before0)
# 验证模块是否保持不变
tvm.ir.assert_structural_equal(After0, Before0)

恢复原有的合法化函数

register_legalize("relax.add", add_legalize)
ffi.Function(0x55de62489d30)

情况2:无法确定所有形状信息

s = relax.Var("s", relax.ShapeStructInfo((3, 3)))
x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
y = relax.Var("y", relax.TensorStructInfo(s, "float32"))
bb = relax.BlockBuilder()
with bb.function("main", [x, y]):
    with bb.dataflow():
        gv = bb.emit_output(R.add(x, y))
    bb.emit_func_output(gv)
Before1 = bb.get()
# 应用转换(此时无法确定y的完整形状信息)
After1 = LegalizeOps()(Before1)
# 验证模块是否保持不变
tvm.ir.assert_structural_equal(After1, Before1)

测试不同数据类型的标量算子合法化变换#

本测试验证 LegalizeOps 在处理 float16uint8bool 等不同数据类型时能够正确保留类型信息

# 定义float16数据类型的测试模块
@tvm.script.ir_module
class Before0:
    @R.function
    def main(x: R.Tensor((3, 3), "float16")):
        gv: R.Tensor((3, 3), "float16") = R.multiply(x, R.const(1.14514, "float16"))
        return gv

# 定义uint8数据类型的测试模块
@tvm.script.ir_module
class Before1:
    @R.function
    def main(x: R.Tensor((3, 3), "uint8")):
        gv: R.Tensor((3, 3), "uint8") = R.multiply(x, R.const(2, "uint8"))
        return gv

# 定义bool数据类型的测试模块
@tvm.script.ir_module
class Before2:
    @R.function
    def main(x: R.Tensor((3, 3), "bool")):
        gv: R.Tensor((3, 3), "bool") = R.equal(x, R.const(True, "bool"))
        return gv

应用转换并验证结果

After0 = LegalizeOps()(Before0)
After1 = LegalizeOps()(Before1)
After2 = LegalizeOps()(Before2)
After0.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def multiply(x: T.Buffer((T.int64(3), T.int64(3)), "float16"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float16")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(x[v_ax0, v_ax1])
                T.writes(T_multiply[v_ax0, v_ax1])
                T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T.float16(1.1455078125)

    @R.function
    def main(x: R.Tensor((3, 3), dtype="float16")) -> R.Tensor((3, 3), dtype="float16"):
        cls = Module
        gv = R.call_tir(cls.multiply, (x,), out_sinfo=R.Tensor((3, 3), dtype="float16"))
        return gv
After1.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def multiply(x: T.Buffer((T.int64(3), T.int64(3)), "uint8"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "uint8")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(x[v_ax0, v_ax1])
                T.writes(T_multiply[v_ax0, v_ax1])
                T_multiply[v_ax0, v_ax1] = x[v_ax0, v_ax1] * T.uint8(2)

    @R.function
    def main(x: R.Tensor((3, 3), dtype="uint8")) -> R.Tensor((3, 3), dtype="uint8"):
        cls = Module
        gv = R.call_tir(cls.multiply, (x,), out_sinfo=R.Tensor((3, 3), dtype="uint8"))
        return gv
After2.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def equal(x: T.Buffer((T.int64(3), T.int64(3)), "bool"), T_equal: T.Buffer((T.int64(3), T.int64(3)), "bool")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
            with T.block("T_equal"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(x[v_ax0, v_ax1])
                T.writes(T_equal[v_ax0, v_ax1])
                T_equal[v_ax0, v_ax1] = x[v_ax0, v_ax1] == T.bool(True)

    @R.function
    def main(x: R.Tensor((3, 3), dtype="bool")) -> R.Tensor((3, 3), dtype="bool"):
        cls = Module
        gv = R.call_tir(cls.equal, (x,), out_sinfo=R.Tensor((3, 3), dtype="bool"))
        return gv

测试矩阵乘法算子合法化要求已知数据类型#

本测试验证当 matmul 算子缺少明确数据类型时,合法化转换应抛出适当的错误

import pytest
@I.ir_module
class ArbitraryDtype:
    @R.function
    def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16, 8]):
        return R.matmul(A, B)

# 验证转换时是否抛出预期的错误
with pytest.raises(AssertionError) as err:
    LegalizeOps()(ArbitraryDtype)

# 错误应该是在尝试合法化R.matmul时捕获的,并提供友好的错误消息
# 而不是等到`BlockBuilder.call_te`实现时,尝试创建kHandle类型的数值常量时才抛出错误
err_message = err.value.args[0]
assert err_message.startswith("To legalize R.matmul")
[13:44:41] /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!

测试带 vdevice 的合法化变换#

本测试验证当参数类型仅在 vdevice 上不同时,LegalizeOps 能够为不同目标生成不同的内核。

这是回归测试,之前的实现中,具有不同 vdevice 的参数类型会被合法化为使用相同的 PrimFunc

@I.ir_module
class Before:
    I.module_global_infos({"vdevice": [I.vdevice("llvm")]})

    @R.function
    def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
        C = R.add(A, B)
        return C

    @R.function
    def func_llvm(
        A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
    ):
        C = R.add(A, B)
        return C

# 在CUDA目标下应用转换
with tvm.target.Target("cuda"):
    After = tvm.relax.transform.LegalizeOps()(Before)
After.show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global")]})
    @T.prim_func(private=True)
    def add(A: T.Buffer((T.int64(32), T.int64(32)), "float32"), B: T.Buffer((T.int64(32), T.int64(32)), "float32"), T_add: T.Buffer((T.int64(32), T.int64(32)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def add_llvm(A: T.Buffer((T.int64(32), T.int64(32)), "float32"), B: T.Buffer((T.int64(32), T.int64(32)), "float32"), T_add: T.Buffer((T.int64(32), T.int64(32)), "float32")):
        T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}), "tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]

    @R.function
    def func_cuda(A: R.Tensor((32, 32), dtype="float32"), B: R.Tensor((32, 32), dtype="float32")) -> R.Tensor((32, 32), dtype="float32"):
        cls = Module
        C = R.call_tir(cls.add, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32"))
        return C

    @R.function
    def func_llvm(A: R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"), B: R.Tensor((32, 32), dtype="float32", vdevice="llvm:0")) -> R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"):
        cls = Module
        C = R.call_tir(cls.add_llvm, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"))
        return C

自定义算子,测试递归合法化功能#

定义测试参数,测试不同的合法化变换返回方式

import tvm.testing
def register_custom_op(emit_legalization_through_builder):
    op_name = "custom_op.matmul_bias_add"

    # 定义结构信息推断函数
    def infer_struct_info(call: relax.Call, context):
        activations, weight, bias = call.args

        matmul_call = relax.op.matmul(activations, weight)
        matmul_sinfo = tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
            matmul_call, context
        )

        matmul_var = relax.Var("dummy_var", matmul_sinfo)
        add_call = matmul_var + bias
        add_sinfo = tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)

        return add_sinfo

    # 定义合法化函数
    def legalize(bb: relax.BlockBuilder, call: relax.Call):
        activations, weight, bias = call.args
        legalized = relax.op.matmul(activations, weight) + bias
        if emit_legalization_through_builder:
            legalized = bb.emit(legalized)
        return legalized

    # 注册操作的属性
    op_attrs = {
        "FInferStructInfo": infer_struct_info,
        "FLegalize": legalize,
        "FPurity": True,
    }

    for key, value in op_attrs.items():
        tvm.ir.register_op_attr(op_name, key, value)

    op = tvm.ir.Op.get(op_name)
    return op
    # yield op

    # # 清理:重置属性
    # for key in op_attrs:
    #     op.reset_attr(key)


custom_op = register_custom_op(emit_legalization_through_builder={
    "return_relax_expr": False,
    "return_relax_var": True,
})

本测试验证算子的合法化变换可能生成新的需要合法化的算子

@I.ir_module
class Before:
    @R.function
    def main(
        A: R.Tensor([16, 32, 64], "float32"),
        Weight: R.Tensor([64, 128], "float32"),
        Bias: R.Tensor([16, 32, 128], "float32"),
    ):
        return relax.Call(custom_op, [A, Weight, Bias])

# 应用一次LegalizeOps转换
AfterFirstIter = LegalizeOps()(Before)
# 再次应用LegalizeOps转换
AfterSecondIter = LegalizeOps()(AfterFirstIter)

# LegalizeOps后,自定义操作应被替换为`R.matmul`和`R.add`,
# 这些操作又应被替换为TIR实现。因此,第二次应用LegalizeOps()应该是无效操作。
tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)