AnnotateTIROpPattern#

import pytest
import tvm
import tvm.script
import tvm.testing
from tvm import relax
from tvm.script import tir as T
from tvm_book.op.attr_types import OpPatternKind

kOutEWiseFusable 模式#

测试复杂运算算子模式的注解:

  • 验证矩阵乘法等复杂算子是否被正确识别为 kOutEWiseFusable 模式。

  • 这类算子可以将元素级算子融合到其输出中,但不能链接另一个复杂算子。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
        T.func_attr({"global_symbol": "tir_matmul"})
        m = T.int32()
        n = T.int32()
        k = T.int32()
        A = T.match_buffer(x, (m, n))
        B = T.match_buffer(y, (n, k))
        C = T.match_buffer(z, (m, k))

        for i, j, k in T.grid(m, k, n):
            with T.block("matmul"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = T.float32(0)
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable

测试带类型转换的复杂运算算子模式注解#

验证带有不同类型转换模式的矩阵乘法是否仍被正确识别为 kOutEWiseFusable 模式。测试多种类型转换场景:直接变换乘积结果、变换输入后相乘、嵌套变换。

def test(cast_pattern):
    @tvm.script.ir_module
    class InputModule:
        @T.prim_func
        def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
            T.func_attr({"global_symbol": "tir_matmul"})
            m = T.int32()
            n = T.int32()
            k = T.int32()
            A = T.match_buffer(x, (m, n), "float16")
            B = T.match_buffer(y, (n, k), "float16")
            C = T.match_buffer(z, (m, k), "float32")

            for i, j, k in T.grid(m, k, n):
                with T.block("matmul"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    with T.init():
                        C[vi, vj] = T.float32(0)
                    C[vi, vj] = C[vi, vj] + cast_pattern(A[vi, vk], B[vk, vj])
    return InputModule

args = [
    lambda a, b: T.Cast("float32", a * b),
    lambda a, b: T.Cast("float32", a) * T.Cast("float32", b),
    lambda a, b: T.Cast("float32", T.Cast("float16", a * b)),
]

for cast_pattern in args:
    mod = test(cast_pattern)
    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
    assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable

测试带有整型变量签名的复杂运算算子模式注解#

验证带有显式整型变量参数的矩阵乘法是否被正确识别为kOutEWiseFusable模式。此测试确保即使函数签名中包含显式维度参数,模式识别仍能正常工作。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def tir_matmul(x: T.handle, y: T.handle, z: T.handle, m: T.int64, n: T.int64, k: T.int64):
        T.func_attr({"global_symbol": "tir_matmul"})
        A = T.match_buffer(x, (m, n))
        B = T.match_buffer(y, (n, k))
        C = T.match_buffer(z, (m, k))

        for i, j, k in T.grid(m, k, n):
            with T.block("matmul"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = T.float32(0)
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable

kCommReduce 模式#

验证求和等归约操作是否被正确识别为kCommReduce模式。这类算子具有交换性,用于对输入数据进行汇总计算。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def sum(x: T.handle, y: T.handle) -> None:
        T.func_attr({"global_symbol": "elemwise"})
        A = T.match_buffer(x, (16, 16))
        B = T.match_buffer(y, (16,))

        for i, j in T.grid(16, 16):
            with T.block("matmul"):
                vi, vj = T.axis.remap("SR", [i, j])
                with T.init():
                    B[vi] = 0.0
                B[vi] += A[vi, vj]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["sum"].attrs["op_pattern"] == OpPatternKind.kCommReduce

kElemWise 模式#

验证简单的元素级操作(如加法)是否被正确识别为kElemWise模式。这类算子对输入张量的每个元素进行独立计算。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def elemwise(x: T.handle, y: T.handle) -> None:
        T.func_attr({"global_symbol": "elemwise"})
        A = T.match_buffer(x, (16, 16))
        B = T.match_buffer(y, (16, 16))

        for i, j in T.grid(16, 16):
            with T.block("matmul"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] + 1.0

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["elemwise"].attrs["op_pattern"] == OpPatternKind.kElemWise

kBroadcast 模式#

验证广播操作是否被正确识别为 kBroadcast 模式。这类算子可以将低维输入广播到高维输出,轴必须按顺序排列。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def broadcast(x: T.handle, y: T.handle) -> None:
        T.func_attr({"global_symbol": "elemwise"})
        A = T.match_buffer(x, (16, 16))
        B = T.match_buffer(y, (16, 16, 16, 16))

        for i0, j0, i1, j1 in T.grid(16, 16, 16, 16):
            with T.block("matmul"):
                vi0, vj0, vi1, vj1 = T.axis.remap("SSSS", [i0, j0, i1, j1])
                B[vi0, vj0, vi1, vj1] = A[vj0, vj1]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["broadcast"].attrs["op_pattern"] == OpPatternKind.kBroadcast

kInjective 模式#

验证单射算子是否被正确识别为 kInjective 模式。这类算子的输出轴可以单射映射到单个输入轴,可安全地与其他单射算子和归约算子融合。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def injective(x: T.handle, y: T.handle) -> None:
        T.func_attr({"global_symbol": "elemwise"})
        A = T.match_buffer(x, (4, 4, 4, 4))
        B = T.match_buffer(y, (16, 16))

        for i, j in T.grid(16, 16):
            with T.block("matmul"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi // 4, vj // 4, vi % 4, vj % 4]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["injective"].attrs["op_pattern"] == OpPatternKind.kInjective

测试偏置加法算子模式的注解#

验证偏置加法算子是否被正确识别为 kElemWise 模式。偏置加法是一种特殊的元素级算子,其中一个输入被广播到另一个输入的维度。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def tir_bias_add(
        A: T.Buffer((1, 1000), "float32"),
        B: T.Buffer((1000,), "float32"),
        C: T.Buffer((1, 1000), "float32"),
    ) -> None:
        # 函数属性字典
        T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1 in T.grid(1, 1000):
            with T.block("T_add"):
                ax0, ax1 = T.axis.remap("SS", [i0, i1])
                T.reads(A[ax0, ax1], B[ax1])
                T.writes(C[ax0, ax1])
                C[ax0, ax1] = A[ax0, ax1] + B[ax1]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["tir_bias_add"].attrs["op_pattern"] == OpPatternKind.kElemWise

测试带单位维度形状的广播加法算子模式注解#

验证带有单位维度(size=1)的广播加法是否被正确识别为 kElemWise 模式。此测试确保优化器能够正确处理常见的广播场景。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def add_with_unit_dim_len_broadcast(
        A: T.Buffer((1, 64, 112, 112), "float32"),
        B: T.Buffer((64, 1, 1), "float32"),
        C: T.Buffer((1, 64, 112, 112), "float32"),
    ) -> None:
        T.func_attr({"global_symbol": "add5", "tir.noalias": True})
        for i0, i1, i2, i3 in T.grid(1, 64, 112, 112):
            with T.block("T_add"):
                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(A[ax0, ax1, ax2, ax3], B[ax1, 0, 0])
                T.writes(C[ax0, ax1, ax2, ax3])
                C[ax0, ax1, ax2, ax3] = A[ax0, ax1, ax2, ax3] + B[ax1, 0, 0]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["add_with_unit_dim_len_broadcast"].attrs["op_pattern"] == OpPatternKind.kElemWise

测试零维元素级加法算子模式注解#

验证标量(零维张量)与向量的加法是否被正确识别为 kElemWise 模式。此测试确保优化器能够正确处理标量与数组的运算。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def add_zero_dim(
        A: T.Buffer((128,), "float32"),
        B: T.Buffer((), "float32"),
        C: T.Buffer((128,), "float32"),
    ) -> None:
        T.func_attr({"global_symbol": "add8", "tir.noalias": True})
        for i0 in T.serial(128):
            with T.block("T_add"):
                ax0 = T.axis.spatial(128, i0)
                T.reads(A[ax0], B[()])
                T.writes(C[ax0])
                C[ax0] = A[ax0] + B[()]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["add_zero_dim"].attrs["op_pattern"] == OpPatternKind.kElemWise

测试池化算子模式的注解#

验证最大池化算子是否被正确识别为 kOutEWiseFusable 模式。池化是一种复杂运算,可以融合元素级算子到其输出中。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def max_pool2d(
        rxplaceholder_1: T.Buffer((1, 64, 112, 112), "float32"),
        tensor_1: T.Buffer((1, 64, 56, 56), "float32"),
    ) -> None:
        # 函数属性字典
        T.func_attr({"global_symbol": "max_pool2d", "T.noalias": True})
        # body
        # with T.block("root")
        pad_temp_1 = T.alloc_buffer([1, 64, 114, 114], dtype="float32")
        for i0, i1, i2, i3 in T.grid(1, 64, 114, 114):
            with T.block("pad_temp"):
                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1])
                T.writes(pad_temp_1[ax0, ax1, ax2, ax3])
                pad_temp_1[ax0, ax1, ax2, ax3] = T.if_then_else(
                    1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113,
                    rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1],
                    T.float32(-3.4028234663852886e38),
                    dtype="float32",
                )
        for i0, i1, i2, i3, i4, i5 in T.grid(1, 64, 56, 56, 3, 3):
            with T.block("tensor"):
                ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
                T.reads(
                    tensor_1[ax0, ax1, ax2, ax3],
                    pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1],
                )
                T.writes(tensor_1[ax0, ax1, ax2, ax3])
                with T.init():
                    tensor_1[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e38)
                tensor_1[ax0, ax1, ax2, ax3] = T.max(
                    tensor_1[ax0, ax1, ax2, ax3],
                    pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1],
                )

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["max_pool2d"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable

测试softmax算子模式的注解#

验证 softmax 算子是否被正确识别为 kOutEWiseFusable 模式。softmax 是一种复杂运算,包含多个步骤但仍可融合元素级算子。

@tvm.script.ir_module
class InputModule:
    @T.prim_func
    def softmax(
        rxplaceholder_1: T.Buffer((16, 16), "float32"),
        T_softmax_norm_1: T.Buffer((16, 16), "float32"),
    ) -> None:
        # 函数属性字典
        T.func_attr({"global_symbol": "softmax", "T.noalias": True})
        # body
        # with T.block("root")
        T_softmax_maxelem_1 = T.alloc_buffer([16], dtype="float32")
        T_softmax_exp_1 = T.alloc_buffer([16, 16], dtype="float32")
        T_softmax_expsum_1 = T.alloc_buffer([16], dtype="float32")
        for i0_7, i1_3 in T.grid(16, 16):
            with T.block("T_softmax_maxelem"):
                i0_8, k = T.axis.remap("SR", [i0_7, i1_3])
                T.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k])
                T.writes(T_softmax_maxelem_1[i0_8])
                with T.init():
                    T_softmax_maxelem_1[i0_8] = T.float32(-3.4028234663852886e38)
                T_softmax_maxelem_1[i0_8] = T.max(
                    T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]
                )
        for i0_9, i1_4 in T.grid(16, 16):
            with T.block("T_softmax_exp"):
                i0_10, i1_5 = T.axis.remap("SS", [i0_9, i1_4])
                T.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10])
                T.writes(T_softmax_exp_1[i0_10, i1_5])
                T_softmax_exp_1[i0_10, i1_5] = T.exp(
                    rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32"
                )
        for i0_11, i1_6 in T.grid(16, 16):
            with T.block("T_softmax_expsum"):
                i0_12, k = T.axis.remap("SR", [i0_11, i1_6])
                T.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k])
                T.writes(T_softmax_expsum_1[i0_12])
                with T.init():
                    T_softmax_expsum_1[i0_12] = T.float32(0)
                T_softmax_expsum_1[i0_12] = (
                    T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k]
                )
        for i0_13, i1_7 in T.grid(16, 16):
            with T.block("T_softmax_norm"):
                i0_14, i1_8 = T.axis.remap("SS", [i0_13, i1_7])
                T.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14])
                T.writes(T_softmax_norm_1[i0_14, i1_8])
                T.block_attr({"axis": 1})
                T_softmax_norm_1[i0_14, i1_8] = (
                    T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14]
                )

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["softmax"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable

测试多缓冲区存储的算子模式注解回退行为#

验证累积和(cumsum)算子是否被正确识别为 kOpaque 模式。当算子包含复杂的缓冲区存储模式时,优化器会将其视为不透明算子。

@tvm.script.ir_module
class CumsumModule:
    @T.prim_func
    def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer(160, "float32")):
        rxplaceholder = T.match_buffer(
            var_rxplaceholder, [10, 16], dtype="float32", offset_factor=1
        )
        with T.block("cumsum_generic"):
            T.reads(rxplaceholder[0:10, 0:16])
            T.writes(out_buf[0:160])
            for fused in T.parallel(1):
                out_buf[fused * 160] = rxplaceholder[fused * 160 // 16, fused * 160 % 16]
                for v_k in T.serial(159):
                    out_buf[fused * 160 + (v_k + 1)] = (
                        out_buf[fused * 160 + (v_k + 1 - 1)]
                        + rxplaceholder[
                            (fused * 160 + (v_k + 1)) // 16,
                            (fused * 160 + (v_k + 1)) % 16,
                        ]
                    )

mod = CumsumModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["cumsum"].attrs["op_pattern"] == OpPatternKind.kOpaque

测试同时计算和与平方和的算子模式注解#

验证同时计算元素和与平方和的操作是否被正确识别为kCommReduce模式。此测试确保优化器能够正确处理多输出的归约操作。

@tvm.script.ir_module
class Module:
    @T.prim_func
    def sum_sqsum(
        A: T.Buffer((32, 64), "float32"),
        vsum: T.Buffer((32,), "float32"),
        sqsum: T.Buffer((32,), "float32"),
    ):
        for ax0, k0 in T.grid(32, 64):
            with T.block("block"):
                v_ax0, v_k0 = T.axis.remap("SR", [ax0, k0])
                T.reads(A[v_ax0, v_k0])
                T.writes(vsum[v_ax0], sqsum[v_ax0])
                with T.init():
                    vsum[v_ax0] = T.float32(0)
                    sqsum[v_ax0] = T.float32(0)
                v_vsum: T.float32 = vsum[v_ax0] + A[v_ax0, v_k0]
                v_sqsum: T.float32 = sqsum[v_ax0] + A[v_ax0, v_k0] * A[v_ax0, v_k0]
                vsum[v_ax0] = v_vsum
                sqsum[v_ax0] = v_sqsum

mod = Module
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["sum_sqsum"].attrs["op_pattern"] == OpPatternKind.kCommReduce

测试无缓冲区存储的算子模式注解#

验证包含外部调用且缺乏明确缓冲区存储模式的算子是否被正确识别为 kOpaque 模式。当优化器无法确定算子的具体行为时,会将其视为不透明算子。

@tvm.script.ir_module
class Module:
    @T.prim_func
    def no_buffer_stores(A: T.Buffer((32, 64), "float32"), vsum: T.Buffer((32,), "float32")):
        for ax0, k0 in T.grid(32, 64):
            with T.block("block"):
                v_ax0, v_k0 = T.axis.remap("SR", [ax0, k0])
                T.reads(A[v_ax0, v_k0])
                T.writes(vsum[v_ax0])
                # 无缓冲区存储通常发生在有外部计算调用的情况下
                # 在这种情况下,我们将其视为不透明操作
                T.call_packed("some_func")

mod = Module
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["no_buffer_stores"].attrs["op_pattern"] == OpPatternKind.kOpaque