FoldConstant
#
参考:tvm/tests/python/relax/test_transform_fold_constant.py
import tvm
import tvm.testing
from tvm import relax
import numpy as np
import tvm.script
from tvm.script import ir as I, tir as T, relax as R
def gen_mod(mod, name, binding):
"""处理 IR 模块的工具函数,主要完成三个功能:
1. 选择指定名称的 Relax 函数
2. 将选中函数重命名为 main 入口
3. 绑定常量参数
Args:
mod: 原始 IR 模块,包含多个函数
name: 目标 Relax 函数名(需保留并重命名)
binding: 需要绑定的常量参数字典 {参数名: numpy数组}
"""
# 将numpy数组参数转换为TVM NDArray格式
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
funcs = {} # 存储处理后的函数集合
# 遍历模块中的所有函数
for k, v in mod.functions.items():
# 处理Relax函数:仅保留指定名称的函数
if isinstance(v, tvm.relax.Function):
if k.name_hint == name:
# 重命名为 main
gv = tvm.ir.GlobalVar("main")
funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr(
"global_symbol", "main"
)
# 保留所有非Relax函数(如TIR原语函数)
else:
funcs[k] = v
# 构建新模块并绑定常量参数
mod = tvm.IRModule(funcs)
return relax.transform.BindParams("main", binding)(mod)
折叠 +1
常量#
@tvm.script.ir_module
class Module:
@T.prim_func
def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None:
for i, j in T.grid(16, 16):
with T.block("addone"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.float32(1)
@R.function
def before(c0: R.Tensor((16, 16), "float32")):
cls = Module
lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="float32"))
return lv0
c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
c1_np = c0_np + 1
before = gen_mod(Module, "before", {"c0": c0_np})
before.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("addone"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] + T.float32(1.0)
@R.function
def main() -> R.Tensor((16, 16), dtype="float32"):
cls = Module
lv0 = R.call_tir(cls.addone, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((16, 16), dtype="float32"))
return lv0
# Metadata omitted. Use show_meta=True in script() method to show it.
after = relax.transform.FoldConstant()(before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("addone"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] + T.float32(1.0)
@R.function
def main() -> R.Tensor((16, 16), dtype="float32"):
return metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.
折叠常量转的转置#
@tvm.script.ir_module
class Module:
@T.prim_func
def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")) -> None:
for i, j in T.grid(3, 2):
with T.block("transpose"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vj, vi]
@R.function
def before(c0: R.Tensor((2, 3), "float32")):
cls = Module
lv0 = relax.call_tir(cls.func, (c0,), R.Tensor((3, 2), dtype="float32"))
return lv0
c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3)
c1_np = c0_np.T
before = gen_mod(Module, "before", {"c0": c0_np})
before.show()
after = relax.transform.FoldConstant()(before)
after.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")):
# with T.block("root"):
for i, j in T.grid(3, 2):
with T.block("transpose"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vj, vi])
T.writes(B[vi, vj])
B[vi, vj] = A[vj, vi]
@R.function
def main() -> R.Tensor((3, 2), dtype="float32"):
cls = Module
lv0 = R.call_tir(cls.func, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((3, 2), dtype="float32"))
return lv0
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")):
# with T.block("root"):
for i, j in T.grid(3, 2):
with T.block("transpose"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vj, vi])
T.writes(B[vi, vj])
B[vi, vj] = A[vj, vi]
@R.function
def main() -> R.Tensor((3, 2), dtype="float32"):
return metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.
two_hop_addone#
@tvm.script.ir_module
class Module:
@T.prim_func
def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), "float32")) -> None:
for i, j in T.grid(2, 2):
with T.block("addone"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.float32(1)
@R.function
def before(c0: R.Tensor((2, 2), "float32")):
cls = Module
lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((2, 2), dtype="float32"))
lv1 = relax.call_tir(cls.addone, (lv0,), R.Tensor((2, 2), dtype="float32"))
return lv1
c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2)
c1_np = c0_np + 1
c2_np = c1_np + 1
before = gen_mod(Module, "before", {"c0": c0_np})
before.show()
after = relax.transform.FoldConstant()(before)
after.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), "float32")):
# with T.block("root"):
for i, j in T.grid(2, 2):
with T.block("addone"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] + T.float32(1.0)
@R.function
def main() -> R.Tensor((2, 2), dtype="float32"):
cls = Module
lv0 = R.call_tir(cls.addone, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((2, 2), dtype="float32"))
lv1 = R.call_tir(cls.addone, (lv0,), out_sinfo=R.Tensor((2, 2), dtype="float32"))
return lv1
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), "float32")):
# with T.block("root"):
for i, j in T.grid(2, 2):
with T.block("addone"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] + T.float32(1.0)
@R.function
def main() -> R.Tensor((2, 2), dtype="float32"):
return metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.
折叠 dataflow#
@tvm.script.ir_module
class Module:
@T.prim_func
def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None:
for i, j in T.grid(16, 16):
with T.block("identity"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]
@R.function
def before(c0: R.Tensor((16, 16), "float32")):
cls = Module
with R.dataflow():
gv0 = relax.call_tir(cls.identity, (c0,), R.Tensor((16, 16), dtype="float32"))
R.output(gv0)
return gv0
c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
c1_np = c0_np
before = gen_mod(Module, "before", {"c0": c0_np})
before.show()
after = relax.transform.FoldConstant()(before)
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("identity"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj]
@R.function
def main() -> R.Tensor((16, 16), dtype="float32"):
cls = Module
with R.dataflow():
gv0 = R.call_tir(cls.identity, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((16, 16), dtype="float32"))
R.output(gv0)
return gv0
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("identity"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj]
@R.function
def main() -> R.Tensor((16, 16), dtype="float32"):
return metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.
fold_mixed_case#
@tvm.script.ir_module
class Module:
# TIR function can handle different cases.
@T.prim_func
def addone(a: T.handle, b: T.handle) -> None:
n = T.int32()
m = T.int32()
A = T.match_buffer(a, (n, m))
B = T.match_buffer(b, (n, m))
for i, j in T.grid(n, m):
with T.block("addone"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.float32(1)
@T.prim_func
def sub(
A: T.Buffer((16, 16), "float32"),
B: T.Buffer((16, 16), "float32"),
C: T.Buffer((16, 16), "float32"),
) -> None:
for i, j in T.grid(16, 16):
with T.block("sub"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] - B[vi, vj]
@R.function
def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)):
n, m = T.int64(), T.int64()
cls = Module
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
# this line cannot be folded because n is unknown
lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32"))
# this line can be folded
lv1 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="float32"))
# this line can be folded because all inputs are const
lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16), dtype="float32"))
# this line can not be folded because x's shape is unknown
lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16), dtype="float32"))
return (lv0, lv3)
c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
c1_np = c0_np + 1
c2_np = c0_np - c1_np
before.show()
after = relax.transform.FoldConstant()(before)
after.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("identity"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj]
@R.function
def main() -> R.Tensor((16, 16), dtype="float32"):
cls = Module
with R.dataflow():
gv0 = R.call_tir(cls.identity, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((16, 16), dtype="float32"))
R.output(gv0)
return gv0
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("identity"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj]
@R.function
def main() -> R.Tensor((16, 16), dtype="float32"):
return metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.
fold_shape_computation#
@I.ir_module
class Module:
@R.function
def before(
data: R.Tensor((5, 4, 3, 2), dtype="float32"),
indices: R.Tensor((1,), dtype="int64"),
) -> R.Tensor((1, 1), dtype="int64"):
with R.dataflow():
lv: R.Tensor((4,), dtype="int64") = R.shape_to_tensor(R.shape([5, 4, 3, 2]))
lv1: R.Tensor((1,), dtype="int64") = R.take(lv, indices, axis=0)
lv2: R.Tensor((1, 1), dtype="int64") = R.expand_dims(lv1, axis=[0])
gv: R.Tensor((1, 1), dtype="int64") = R.concat((lv2,), axis=0)
R.output(gv)
return gv
before = gen_mod(Module, "before", {"indices": tvm.nd.array(np.array([0]).astype("int64"))})
before.show()
after = relax.transform.FoldConstant()(before)
after.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(data: R.Tensor((5, 4, 3, 2), dtype="float32")) -> R.Tensor((1, 1), dtype="int64"):
with R.dataflow():
lv: R.Tensor((4,), dtype="int64") = R.shape_to_tensor(R.shape([5, 4, 3, 2]))
lv1: R.Tensor((1,), dtype="int64") = R.take(lv, metadata["relax.expr.Constant"][0], axis=0)
lv2: R.Tensor((1, 1), dtype="int64") = R.expand_dims(lv1, axis=[0])
gv: R.Tensor((1, 1), dtype="int64") = R.concat((lv2,), axis=0)
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(data: R.Tensor((5, 4, 3, 2), dtype="float32")) -> R.Tensor((1, 1), dtype="int64"):
return metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.
在当前单元格或上一个单元格中执行代码时 Kernel 崩溃。
请查看单元格中的代码,以确定故障的可能原因。
单击<a href='https://aka.ms/vscodeJupyterKernelCrash'>此处</a>了解详细信息。
有关更多详细信息,请查看 Jupyter <a href='command:jupyter.viewOutput'>log</a>。