# 融合 Transpose 和 Matmul


In [1]:
import tvm
import tvm.testing
from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
import numpy as np

In [2]:
@I.ir_module
class Before:
    @R.function
    def main(
        x: R.Tensor((128, 256), "float32"),
        w: R.Tensor((128, 256), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        with R.dataflow():
            wT = R.permute_dims(w, [1, 0])
            o = R.matmul(x, wT)
            R.output(o)
        return o

In [3]:
after = tvm.ir.transform.Sequential(
    [
        relax.transform.FuseTransposeMatmul(),
        relax.transform.FuseTIR(),  # Only used for remove unused primitive function
    ]
)(Before)

In [4]:
Before.show()

In [5]:
after.show()

## 融合 Transpose 和 常量 Matmul

In [6]:
w = relax.const(np.random.uniform(-1e-3, 1e-3, (128, 256)), "float32")

@I.ir_module
class Before:
    @R.function
    def main(
        x: R.Tensor((128, 256), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        with R.dataflow():
            wT = R.permute_dims(w, [1, 0])
            o = R.matmul(x, wT)
            R.output(o)
        return o

In [7]:
after = tvm.ir.transform.Sequential(
    [
        relax.transform.FuseTransposeMatmul(),
        relax.transform.FuseTIR(),  # Only used for remove unused primitive function
    ]
)(Before)

In [8]:
Before.show()

In [None]:
after.show()

: 