融合 Transpose 和 Matmul

融合 Transpose 和 Matmul#

import tvm
import tvm.testing
from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
import numpy as np
@I.ir_module
class Before:
    @R.function
    def main(
        x: R.Tensor((128, 256), "float32"),
        w: R.Tensor((128, 256), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        with R.dataflow():
            wT = R.permute_dims(w, [1, 0])
            o = R.matmul(x, wT)
            R.output(o)
        return o
after = tvm.ir.transform.Sequential(
    [
        relax.transform.FuseTransposeMatmul(),
        relax.transform.FuseTIR(),  # Only used for remove unused primitive function
    ]
)(Before)
Before.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((128, 256), dtype="float32"), w: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
        with R.dataflow():
            wT: R.Tensor((256, 128), dtype="float32") = R.permute_dims(w, axes=[1, 0])
            o: R.Tensor((128, 128), dtype="float32") = R.matmul(x, wT, out_dtype="void")
            R.output(o)
        return o
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def NT_matmul(x: T.Buffer((T.int64(128), T.int64(256)), "float32"), w: T.Buffer((T.int64(128), T.int64(256)), "float32"), NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], w[v_i1, v_k])
                T.writes(NT_matmul[v_i0, v_i1])
                with T.init():
                    NT_matmul[v_i0, v_i1] = T.float32(0.0)
                NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0, v_k] * w[v_i1, v_k]

    @R.function
    def main(x: R.Tensor((128, 256), dtype="float32"), w: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.NT_matmul, (x, w), out_sinfo=R.Tensor((128, 128), dtype="float32"))
            R.output(gv)
        return gv

融合 Transpose 和 常量 Matmul#

w = relax.const(np.random.uniform(-1e-3, 1e-3, (128, 256)), "float32")

@I.ir_module
class Before:
    @R.function
    def main(
        x: R.Tensor((128, 256), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        with R.dataflow():
            wT = R.permute_dims(w, [1, 0])
            o = R.matmul(x, wT)
            R.output(o)
        return o
after = tvm.ir.transform.Sequential(
    [
        relax.transform.FuseTransposeMatmul(),
        relax.transform.FuseTIR(),  # Only used for remove unused primitive function
    ]
)(Before)
Before.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
        with R.dataflow():
            wT: R.Tensor((256, 128), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=[1, 0])
            o: R.Tensor((128, 128), dtype="float32") = R.matmul(x, wT, out_dtype="void")
            R.output(o)
        return o

# Metadata omitted. Use show_meta=True in script() method to show it.
after.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def NT_matmul(x: T.Buffer((T.int64(128), T.int64(256)), "float32"), B: T.Buffer((T.int64(128), T.int64(256)), "float32"), NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], B[v_i1, v_k])
                T.writes(NT_matmul[v_i0, v_i1])
                with T.init():
                    NT_matmul[v_i0, v_i1] = T.float32(0.0)
                NT_matmul[v_i0, v_i1] = NT_matmul[v_i0, v_i1] + x[v_i0, v_k] * B[v_i1, v_k]

    @R.function
    def main(x: R.Tensor((128, 256), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.NT_matmul, (x, metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((128, 128), dtype="float32"))
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
在当前单元格或上一个单元格中执行代码时 Kernel 崩溃

请查看单元格中的代码以确定故障的可能原因

单击<a href='https://aka.ms/vscodeJupyterKernelCrash'>此处</a>了解详细信息

有关更多详细信息请查看 Jupyter <a href='command:jupyter.viewOutput'>log</a>