融合 Transpose 和 Matmul#
import tvm
import tvm.testing
from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
import numpy as np
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor((128, 256), "float32"),
w: R.Tensor((128, 256), "float32"),
) -> R.Tensor((128, 128), "float32"):
with R.dataflow():
wT = R.permute_dims(w, [1, 0])
o = R.matmul(x, wT)
R.output(o)
return o
after = tvm.ir.transform.Sequential(
[
relax.transform.FuseTransposeMatmul(),
relax.transform.FuseTIR(), # Only used for remove unused primitive function
]
)(Before)
Before.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((128, 256), dtype="float32"), w: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
with R.dataflow():
wT: R.Tensor((256, 128), dtype="float32") = R.permute_dims(w, axes=[1, 0])
o: R.Tensor((128, 128), dtype="float32") = R.matmul(x, wT, out_dtype="void")
R.output(o)
return o
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def NT_matmul(x: T.Buffer((T.int64(128), T.int64(256)), "float32"), w: T.Buffer((T.int64(128), T.int64(256)), "float32"), NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)):
with T.block("NT_matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], w[v_i1, v_k])
T.writes(NT_matmul[v_i0, v_i1])
with T.init():
NT_matmul[v_i0, v_i1] = T.float32(0.0)
NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0, v_k] * w[v_i1, v_k]
@R.function
def main(x: R.Tensor((128, 256), dtype="float32"), w: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
cls = Module
with R.dataflow():
gv = R.call_tir(cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128), dtype="float32"))
R.output(gv)
return gv
融合 Transpose 和 常量 Matmul#
w = relax.const(np.random.uniform(-1e-3, 1e-3, (128, 256)), "float32")
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor((128, 256), "float32"),
) -> R.Tensor((128, 128), "float32"):
with R.dataflow():
wT = R.permute_dims(w, [1, 0])
o = R.matmul(x, wT)
R.output(o)
return o
after = tvm.ir.transform.Sequential(
[
relax.transform.FuseTransposeMatmul(),
relax.transform.FuseTIR(), # Only used for remove unused primitive function
]
)(Before)
Before.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
with R.dataflow():
wT: R.Tensor((256, 128), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=[1, 0])
o: R.Tensor((128, 128), dtype="float32") = R.matmul(x, wT, out_dtype="void")
R.output(o)
return o
# Metadata omitted. Use show_meta=True in script() method to show it.
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def NT_matmul(x: T.Buffer((T.int64(128), T.int64(256)), "float32"), B: T.Buffer((T.int64(128), T.int64(256)), "float32"), NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)):
with T.block("NT_matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], B[v_i1, v_k])
T.writes(NT_matmul[v_i0, v_i1])
with T.init():
NT_matmul[v_i0, v_i1] = T.float32(0.0)
NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0, v_k] * B[v_i1, v_k]
@R.function
def main(x: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
cls = Module
with R.dataflow():
gv = R.call_tir(cls.NT_matmul, (x, metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((128, 128), dtype="float32"))
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
在当前单元格或上一个单元格中执行代码时 Kernel 崩溃。
请查看单元格中的代码,以确定故障的可能原因。
单击<a href='https://aka.ms/vscodeJupyterKernelCrash'>此处</a>了解详细信息。
有关更多详细信息,请查看 Jupyter <a href='command:jupyter.viewOutput'>log</a>。