FoldConstant#

参考:tvm/tests/python/relax/test_transform_fold_constant.py

import tvm
import tvm.testing
from tvm import relax
import numpy as np

import tvm.script
from tvm.script import ir as I, tir as T, relax as R
def gen_mod(mod, name, binding):
    """处理 IR 模块的工具函数,主要完成三个功能:
    1. 选择指定名称的 Relax 函数
    2. 将选中函数重命名为 main 入口
    3. 绑定常量参数
    
    Args:
        mod: 原始 IR 模块,包含多个函数
        name: 目标 Relax 函数名(需保留并重命名)
        binding: 需要绑定的常量参数字典 {参数名: numpy数组}
    """
    # 将numpy数组参数转换为TVM NDArray格式
    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
    
    funcs = {}  # 存储处理后的函数集合
    
    # 遍历模块中的所有函数
    for k, v in mod.functions.items():
        # 处理Relax函数:仅保留指定名称的函数
        if isinstance(v, tvm.relax.Function):
            if k.name_hint == name:
                # 重命名为 main
                gv = tvm.ir.GlobalVar("main")
                funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr(
                    "global_symbol", "main"
                )
        # 保留所有非Relax函数(如TIR原语函数)
        else:
            funcs[k] = v
    
    # 构建新模块并绑定常量参数
    mod = tvm.IRModule(funcs)
    return relax.transform.BindParams("main", binding)(mod)

折叠 +1 常量#

@tvm.script.ir_module
class Module:
    @T.prim_func
    def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None:
        for i, j in T.grid(16, 16):
            with T.block("addone"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] + T.float32(1)

    @R.function
    def before(c0: R.Tensor((16, 16), "float32")):
        cls = Module
        lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="float32"))
        return lv0

c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
c1_np = c0_np + 1
before = gen_mod(Module, "before", {"c0": c0_np})
before.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
        # with T.block("root"):
        for i, j in T.grid(16, 16):
            with T.block("addone"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vi, vj] + T.float32(1.0)

    @R.function
    def main() -> R.Tensor((16, 16), dtype="float32"):
        cls = Module
        lv0 = R.call_tir(cls.addone, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((16, 16), dtype="float32"))
        return lv0

# Metadata omitted. Use show_meta=True in script() method to show it.
after = relax.transform.FoldConstant()(before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
        # with T.block("root"):
        for i, j in T.grid(16, 16):
            with T.block("addone"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vi, vj] + T.float32(1.0)

    @R.function
    def main() -> R.Tensor((16, 16), dtype="float32"):
        return metadata["relax.expr.Constant"][0]

# Metadata omitted. Use show_meta=True in script() method to show it.

折叠常量转的转置#

@tvm.script.ir_module
class Module:
    @T.prim_func
    def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")) -> None:
        for i, j in T.grid(3, 2):
            with T.block("transpose"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vj, vi]

    @R.function
    def before(c0: R.Tensor((2, 3), "float32")):
        cls = Module
        lv0 = relax.call_tir(cls.func, (c0,), R.Tensor((3, 2), dtype="float32"))
        return lv0
c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3)
c1_np = c0_np.T
before = gen_mod(Module, "before", {"c0": c0_np})
before.show()
after = relax.transform.FoldConstant()(before)
after.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")):
        # with T.block("root"):
        for i, j in T.grid(3, 2):
            with T.block("transpose"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vj, vi])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vj, vi]

    @R.function
    def main() -> R.Tensor((3, 2), dtype="float32"):
        cls = Module
        lv0 = R.call_tir(cls.func, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((3, 2), dtype="float32"))
        return lv0

# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")):
        # with T.block("root"):
        for i, j in T.grid(3, 2):
            with T.block("transpose"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vj, vi])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vj, vi]

    @R.function
    def main() -> R.Tensor((3, 2), dtype="float32"):
        return metadata["relax.expr.Constant"][0]

# Metadata omitted. Use show_meta=True in script() method to show it.

two_hop_addone#

@tvm.script.ir_module
class Module:
    @T.prim_func
    def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), "float32")) -> None:
        for i, j in T.grid(2, 2):
            with T.block("addone"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] + T.float32(1)

    @R.function
    def before(c0: R.Tensor((2, 2), "float32")):
        cls = Module
        lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((2, 2), dtype="float32"))
        lv1 = relax.call_tir(cls.addone, (lv0,), R.Tensor((2, 2), dtype="float32"))
        return lv1
c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2)
c1_np = c0_np + 1
c2_np = c1_np + 1
before = gen_mod(Module, "before", {"c0": c0_np})
before.show()
after = relax.transform.FoldConstant()(before)
after.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), "float32")):
        # with T.block("root"):
        for i, j in T.grid(2, 2):
            with T.block("addone"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vi, vj] + T.float32(1.0)

    @R.function
    def main() -> R.Tensor((2, 2), dtype="float32"):
        cls = Module
        lv0 = R.call_tir(cls.addone, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((2, 2), dtype="float32"))
        lv1 = R.call_tir(cls.addone, (lv0,), out_sinfo=R.Tensor((2, 2), dtype="float32"))
        return lv1

# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), "float32")):
        # with T.block("root"):
        for i, j in T.grid(2, 2):
            with T.block("addone"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vi, vj] + T.float32(1.0)

    @R.function
    def main() -> R.Tensor((2, 2), dtype="float32"):
        return metadata["relax.expr.Constant"][0]

# Metadata omitted. Use show_meta=True in script() method to show it.

折叠 dataflow#

@tvm.script.ir_module
class Module:
    @T.prim_func
    def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None:
        for i, j in T.grid(16, 16):
            with T.block("identity"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj]

    @R.function
    def before(c0: R.Tensor((16, 16), "float32")):
        cls = Module
        with R.dataflow():
            gv0 = relax.call_tir(cls.identity, (c0,), R.Tensor((16, 16), dtype="float32"))
            R.output(gv0)
        return gv0
c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
c1_np = c0_np
before = gen_mod(Module, "before", {"c0": c0_np})
before.show()
after = relax.transform.FoldConstant()(before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
        # with T.block("root"):
        for i, j in T.grid(16, 16):
            with T.block("identity"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vi, vj]

    @R.function
    def main() -> R.Tensor((16, 16), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv0 = R.call_tir(cls.identity, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((16, 16), dtype="float32"))
            R.output(gv0)
        return gv0

# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
        # with T.block("root"):
        for i, j in T.grid(16, 16):
            with T.block("identity"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vi, vj]

    @R.function
    def main() -> R.Tensor((16, 16), dtype="float32"):
        return metadata["relax.expr.Constant"][0]

# Metadata omitted. Use show_meta=True in script() method to show it.

fold_mixed_case#

@tvm.script.ir_module
class Module:
    # TIR function can handle different cases.
    @T.prim_func
    def addone(a: T.handle, b: T.handle) -> None:
        n = T.int32()
        m = T.int32()
        A = T.match_buffer(a, (n, m))
        B = T.match_buffer(b, (n, m))
        for i, j in T.grid(n, m):
            with T.block("addone"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] + T.float32(1)

    @T.prim_func
    def sub(
        A: T.Buffer((16, 16), "float32"),
        B: T.Buffer((16, 16), "float32"),
        C: T.Buffer((16, 16), "float32"),
    ) -> None:
        for i, j in T.grid(16, 16):
            with T.block("sub"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = A[vi, vj] - B[vi, vj]

    @R.function
    def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)):
        n, m = T.int64(), T.int64()
        cls = Module
        x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
        # this line cannot be folded because n is unknown
        lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32"))
        # this line can be folded
        lv1 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="float32"))
        # this line can be folded because all inputs are const
        lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16), dtype="float32"))
        # this line can not be folded because x's shape is unknown
        lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16), dtype="float32"))
        return (lv0, lv3)
c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
c1_np = c0_np + 1
c2_np = c0_np - c1_np

before.show()
after = relax.transform.FoldConstant()(before)
after.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
        # with T.block("root"):
        for i, j in T.grid(16, 16):
            with T.block("identity"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vi, vj]

    @R.function
    def main() -> R.Tensor((16, 16), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv0 = R.call_tir(cls.identity, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((16, 16), dtype="float32"))
            R.output(gv0)
        return gv0

# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
        # with T.block("root"):
        for i, j in T.grid(16, 16):
            with T.block("identity"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = A[vi, vj]

    @R.function
    def main() -> R.Tensor((16, 16), dtype="float32"):
        return metadata["relax.expr.Constant"][0]

# Metadata omitted. Use show_meta=True in script() method to show it.

fold_shape_computation#

@I.ir_module
class Module:
    @R.function
    def before(
        data: R.Tensor((5, 4, 3, 2), dtype="float32"),
        indices: R.Tensor((1,), dtype="int64"),
    ) -> R.Tensor((1, 1), dtype="int64"):
        with R.dataflow():
            lv: R.Tensor((4,), dtype="int64") = R.shape_to_tensor(R.shape([5, 4, 3, 2]))
            lv1: R.Tensor((1,), dtype="int64") = R.take(lv, indices, axis=0)
            lv2: R.Tensor((1, 1), dtype="int64") = R.expand_dims(lv1, axis=[0])
            gv: R.Tensor((1, 1), dtype="int64") = R.concat((lv2,), axis=0)
            R.output(gv)
        return gv
before = gen_mod(Module, "before", {"indices": tvm.nd.array(np.array([0]).astype("int64"))})
before.show()
after = relax.transform.FoldConstant()(before)
after.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((5, 4, 3, 2), dtype="float32")) -> R.Tensor((1, 1), dtype="int64"):
        with R.dataflow():
            lv: R.Tensor((4,), dtype="int64") = R.shape_to_tensor(R.shape([5, 4, 3, 2]))
            lv1: R.Tensor((1,), dtype="int64") = R.take(lv, metadata["relax.expr.Constant"][0], axis=0)
            lv2: R.Tensor((1, 1), dtype="int64") = R.expand_dims(lv1, axis=[0])
            gv: R.Tensor((1, 1), dtype="int64") = R.concat((lv2,), axis=0)
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((5, 4, 3, 2), dtype="float32")) -> R.Tensor((1, 1), dtype="int64"):
        return metadata["relax.expr.Constant"][0]

# Metadata omitted. Use show_meta=True in script() method to show it.
在当前单元格或上一个单元格中执行代码时 Kernel 崩溃

请查看单元格中的代码以确定故障的可能原因

单击<a href='https://aka.ms/vscodeJupyterKernelCrash'>此处</a>了解详细信息

有关更多详细信息请查看 Jupyter <a href='command:jupyter.viewOutput'>log</a>