优化布局变换#

OptimizeLayoutTransform 能够优化计算图中的布局变换操作,消除不必要的张量布局变换,从而提高计算效率。文件通过三个测试场景全面检验了该优化 pass 的能力:

  1. 单参数布局转换优化:验证在单一输入和输出情况下,布局优化能够消除中间不必要的布局转换操作。

  2. 多参数布局转换优化:验证在多输入情况下,布局优化能够有效管理多个张量的布局转换,避免冗余操作。

  3. 填充与移除填充操作优化:验证在包含填充、计算和移除填充的复杂场景中,布局优化能够识别并消除不必要的填充/移除填充循环操作。

每个测试场景都定义了优化前的 Before 模块和预期优化后的 Expected 模块,并通过 _run_pass_compare_output 函数应用转换并验证结果。这些测试确保了 TVM Relax 布局优化转换能够正确识别和优化各种复杂场景下的张量布局转换操作。

import tvm
import numpy as np
from tvm import relax
from tvm.relax.transform import DeadCodeElimination, FuseTIR, OptimizeLayoutTransform
from tvm.script import ir as I, tir as T, relax as R
# 定义辅助函数,用于运行一系列转换并比较输出结果
def _run_pass_compare_output(Before,):
    # 顺序应用布局优化、死代码消除和 TIR 融合转换
    After = tvm.ir.transform.Sequential(
        [
            OptimizeLayoutTransform(),  # 优化布局转换
            DeadCodeElimination(),      # 消除死代码
            FuseTIR(),                  # 融合 TIR 函数
        ]
    )(Before)
    After.show()

单参数布局变换优化#

# 定义优化前的 IR 模块
@I.ir_module
class Before:
    # 定义 TIR 原语函数作为 relax.add 的替代实现
    @T.prim_func(private=True)
    def relax_add_replacement(
        arg0: T.Buffer((4, 4), "float32"),
        arg1: T.Buffer((4, 4), "float32"),
        output: T.Buffer((4, 4), "float32"),
    ):
        T.func_attr({"operator_name": "relax.add"})
        # with T.block("root"):
        for ax0, ax1 in T.grid(4, 4):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
                T.writes(output[v_ax0, v_ax1])
                output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]

    # 定义主函数,包含多次布局转换和计算操作
    @R.function
    def main(
        x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")
    ) -> R.Tensor((16,), dtype="float32"):
        with R.dataflow():
            # 将一维张量转换为二维布局 (4x4)
            lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
                x, index_map=lambda i: (i // 4, i % 4), pad_value=None
            )
            lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
                y, index_map=lambda i: (i // 4, i % 4), pad_value=None
            )
            # 调用 TIR 原语函数进行加法计算
            lv2 = R.call_tir(
                Before.relax_add_replacement,
                (lv, lv1),
                out_sinfo=R.Tensor((4, 4), dtype="float32"),
            )
            # 将结果转换回一维布局
            lv0: R.Tensor((16,), dtype="float32") = R.layout_transform(
                lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
            )
            # 重复不必要的布局转换和计算(这些应该被优化掉)
            lv3: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
                lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None
            )
            lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
                y, index_map=lambda i: (i // 4, i % 4), pad_value=None
            )
            lv5 = R.call_tir(
                Before.relax_add_replacement,
                (lv4, lv3),
                out_sinfo=R.Tensor((4, 4), dtype="float32"),
            )
            lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
                lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
            )
            gv: R.Tensor((16,), dtype="float32") = lv2_1
            R.output(gv)
        return gv
_run_pass_compare_output(Before,)
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
        T.func_attr({"operator_name": "relax.add"})
        # with T.block("root"):
        for ax0, ax1 in T.grid(4, 4):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
                T.writes(output[v_ax0, v_ax1])
                output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]

    @R.function
    def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=T.index_map(lambda i: (i // 4, i % 4)), pad_value=None, axis_separators=[], input_axis_separators=[])
            lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=T.index_map(lambda i: (i // 4, i % 4)), pad_value=None, axis_separators=[], input_axis_separators=[])
            lv2 = R.call_tir(cls.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32"))
            lv5 = R.call_tir(cls.relax_add_replacement, (lv1, lv2), out_sinfo=R.Tensor((4, 4), dtype="float32"))
            gv: R.Tensor((16,), dtype="float32") = R.layout_transform(lv5, index_map=T.index_map(lambda axis0, axis1: (axis0 * 4 + axis1,)), pad_value=None, axis_separators=[], input_axis_separators=[])
            R.output(gv)
        return gv

多参数(三个输入)情况下的布局变换优化#

# 定义优化前的 IR 模块
@I.ir_module
class Before:
    # 定义 TIR 原语函数
    @T.prim_func(private=True)
    def relax_add_replacement(
        arg0: T.Buffer((4, 4), "float32"),
        arg1: T.Buffer((4, 4), "float32"),
        output: T.Buffer((4, 4), "float32"),
    ):
        T.func_attr({"operator_name": "relax.add"})
        # with T.block("root"):
        for ax0, ax1 in T.grid(4, 4):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
                T.writes(output[v_ax0, v_ax1])
                output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]

    # 定义主函数,包含三个输入和多次布局转换
    @R.function
    def main(
        x: R.Tensor((16,), dtype="float32"),
        y: R.Tensor((16,), dtype="float32"),
        z: R.Tensor((16,), dtype="float32"),
    ) -> R.Tensor((16,), dtype="float32"):
        with R.dataflow():
            # 将三个输入转换为二维布局
            lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
                x, index_map=lambda i: (i // 4, i % 4), pad_value=None
            )
            lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
                y, index_map=lambda i: (i // 4, i % 4), pad_value=None
            )
            lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
                z, index_map=lambda i: (i // 4, i % 4), pad_value=None
            )
            # 第一次加法计算
            lv3 = R.call_tir(
                Before.relax_add_replacement,
                (lv, lv1),
                out_sinfo=R.Tensor((4, 4), dtype="float32"),
            )
            # 第二次加法计算
            lv4 = R.call_tir(
                Before.relax_add_replacement,
                (lv, lv2),
                out_sinfo=R.Tensor((4, 4), dtype="float32"),
            )
            # 不必要的布局转换回一维
            lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
                lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
            )
            lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
                lv4, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
            )
            # 再次转换为二维,用于第三次加法
            lv7: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
                lv5, index_map=lambda i: (i // 4, i % 4), pad_value=None
            )
            lv8: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
                lv6, index_map=lambda i: (i // 4, i % 4), pad_value=None
            )
            lv9 = R.call_tir(
                Before.relax_add_replacement,
                (lv7, lv8),
                out_sinfo=R.Tensor((4, 4), dtype="float32"),
            )
            # 最后转换回一维
            lv10: R.Tensor((16,), dtype="float32") = R.layout_transform(
                lv9, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
            )
            gv: R.Tensor((16,), dtype="float32") = lv10
            R.output(gv)
        return gv
_run_pass_compare_output(Before,)
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
        T.func_attr({"operator_name": "relax.add"})
        # with T.block("root"):
        for ax0, ax1 in T.grid(4, 4):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
                T.writes(output[v_ax0, v_ax1])
                output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]

    @R.function
    def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32"), z: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=T.index_map(lambda i: (i // 4, i % 4)), pad_value=None, axis_separators=[], input_axis_separators=[])
            lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=T.index_map(lambda i: (i // 4, i % 4)), pad_value=None, axis_separators=[], input_axis_separators=[])
            lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(z, index_map=T.index_map(lambda i: (i // 4, i % 4)), pad_value=None, axis_separators=[], input_axis_separators=[])
            lv3 = R.call_tir(cls.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32"))
            lv4 = R.call_tir(cls.relax_add_replacement, (lv, lv2), out_sinfo=R.Tensor((4, 4), dtype="float32"))
            lv9 = R.call_tir(cls.relax_add_replacement, (lv3, lv4), out_sinfo=R.Tensor((4, 4), dtype="float32"))
            gv: R.Tensor((16,), dtype="float32") = R.layout_transform(lv9, index_map=T.index_map(lambda axis0, axis1: (axis0 * 4 + axis1,)), pad_value=None, axis_separators=[], input_axis_separators=[])
            R.output(gv)
        return gv

包含填充和移除填充操作的布局变换优化#

# 定义优化前的 IR 模块
@I.ir_module
class Before:
    # 定义 ReLU 操作的 TIR 原语函数实现
    @T.prim_func(private=True)
    def relax_relu_replacement(
        arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")
    ):
        T.func_attr({"operator_name": "relax.relu"})
        # with T.block("root"):
        for ax0 in range(16):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(16, ax0)
                T.reads(arg0[v_ax0])
                T.writes(output[v_ax0])
                output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))

    # 定义移除填充操作的 TIR 原语函数
    @T.prim_func(private=True)
    def remove_pad(var_input: T.handle, var_output: T.handle):
        T.func_attr({"operator_name": "remove_pad", "tir.noalias": True})
        p0 = T.int64()
        input = T.match_buffer(var_input, (p0,))
        i0 = T.int64()
        output = T.match_buffer(var_output, (i0,))
        # with T.block("root"):
        for ax0 in range(i0):
            with T.block("output"):
                v_ax0 = T.axis.spatial(i0, ax0)
                T.reads(input[v_ax0])
                T.writes(output[v_ax0])
                output[v_ax0] = input[v_ax0]

    # 定义主函数,包含填充、ReLU、移除填充等操作
    @R.function
    def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"):
        with R.dataflow():
            # 填充操作:将 14 元素的张量填充到 16 元素
            lv: R.Tensor((16,), dtype="float32") = R.layout_transform(
                x,
                index_map=T.index_map(lambda i: (i % 16,)),
                pad_value=None,
                axis_separators=[],
            )
            # 应用 ReLU 操作
            lv1 = R.call_tir(
                Before.relax_relu_replacement,
                (lv,),
                out_sinfo=R.Tensor((16,), dtype="float32"),
            )
            # 不必要的恒等布局转换
            lv2: R.Tensor((16,), dtype="float32") = R.layout_transform(
                lv1,
                index_map=T.index_map(lambda axis0: (axis0,)),
                pad_value=None,
                axis_separators=[],
            )
            # 移除填充,从 16 元素回到 14 元素
            lv_1 = R.call_tir(
                Before.remove_pad, (lv2,), out_sinfo=R.Tensor((14,), dtype="float32")
            )
            # 再次进行填充、ReLU、不必要的布局转换和移除填充(这些应该被优化掉)
            lv3: R.Tensor((16,), dtype="float32") = R.layout_transform(
                lv_1,
                index_map=T.index_map(lambda i: (i % 16,)),
                pad_value=None,
                axis_separators=[],
            )
            lv4 = R.call_tir(
                Before.relax_relu_replacement,
                (lv3,),
                out_sinfo=R.Tensor((16,), dtype="float32"),
            )
            lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
                lv4,
                index_map=T.index_map(lambda axis0: (axis0,)),
                pad_value=None,
                axis_separators=[],
            )
            lv_2 = R.call_tir(
                Before.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32")
            )
            gv: R.Tensor((14,), dtype="float32") = lv_2
            R.output(gv)
        return gv
_run_pass_compare_output(Before,)
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relax_relu_replacement(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), "float32")):
        T.func_attr({"operator_name": "relax.relu"})
        # with T.block("root"):
        for ax0 in range(16):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(16, ax0)
                T.reads(arg0[v_ax0])
                T.writes(output[v_ax0])
                output[v_ax0] = T.max(arg0[v_ax0], T.float32(0.0))

    @T.prim_func(private=True)
    def remove_pad(var_input: T.handle, var_output: T.handle):
        T.func_attr({"operator_name": "remove_pad", "tir.noalias": True})
        p0 = T.int64()
        input = T.match_buffer(var_input, (p0,))
        i0 = T.int64()
        output = T.match_buffer(var_output, (i0,))
        # with T.block("root"):
        for ax0 in range(i0):
            with T.block("output"):
                v_ax0 = T.axis.spatial(i0, ax0)
                T.reads(input[v_ax0])
                T.writes(output[v_ax0])
                output[v_ax0] = input[v_ax0]

    @R.function
    def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((16,), dtype="float32") = R.layout_transform(x, index_map=T.index_map(lambda i: (i % 16,)), pad_value=None, axis_separators=[], input_axis_separators=[])
            lv1 = R.call_tir(cls.relax_relu_replacement, (lv,), out_sinfo=R.Tensor((16,), dtype="float32"))
            lv4 = R.call_tir(cls.relax_relu_replacement, (lv1,), out_sinfo=R.Tensor((16,), dtype="float32"))
            lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(lv4, index_map=T.index_map(lambda axis0: (axis0,)), pad_value=None, axis_separators=[], input_axis_separators=[])
            gv = R.call_tir(cls.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32"))
            R.output(gv)
        return gv