# mm-relu 示例

记

$$
\begin{cases}
\mathbf{x}_i = \begin{bmatrix} \mathbf{x}_i^{(1)} & \mathbf{x}_i^{(2)} & \cdots & \mathbf{x}_i^{(m)} \end{bmatrix}^T  \\
\mathbf{y}_j = \begin{bmatrix} \mathbf{y}_j^{(1)} & \mathbf{y}_j^{(2)} & \cdots & \mathbf{y}_j^{(p)} \end{bmatrix}^T 
\end{cases}
$$

有


$$
\begin{cases}
<\mathbf{x}_i, \mathbf{y}_j> = \mathbf{x}_i^T \mathbf{y}_j \\
\mathbf{X} = \begin{bmatrix} \mathbf{x}_1^T & \mathbf{x}_2^T & \cdots & \mathbf{x}_n^T \end{bmatrix}^T \\
\mathbf{Y} = \begin{bmatrix} \mathbf{y}_1^T & \mathbf{y}_2^T & \cdots & \mathbf{y}_n^T \end{bmatrix}^T
\end{cases}
$$

可以推出

$$
<\mathbf{X}, \mathbf{Y}> = \mathbf{X}^T \mathbf{Y}  = (\mathbf{x}_i^T \mathbf{y}_j)_{m \times p}
$$

ReLU 函数定义：

$$
\mathbf{relu}(X) = \mathbf{\max}(X, 0)
$$

In [1]:
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

使用 NumPy 实现如下：

In [2]:
dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
# a @ b 等价于 np.matmul(a, b)
c_mm_relu = np.maximum(a_np @ b_np, 0)

NumPy 低级实现：

In [3]:
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    # 分配中间数组用于存储矩阵乘法的结果
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

验证数值一致性

In [4]:
c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

## TensorIR 实现

In [5]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

In [6]:
sch = tvm.tir.Schedule(MyModule)
block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)
j0, j1 = sch.split(j, factors=[None, 4])
sch.reorder(j0, k, j1)
block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0)
sch.decompose_reduction(block_Y, k) # 将 Y 元素的初始化与归约更新分开

tir.BlockRV(0x55bcd40cc720)

In [7]:
sch.mod.show()

In [8]:
rt_lib = tvm.build(MyModule, target="llvm")
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
func_mm_relu = rt_lib["mm_relu"]
func_mm_relu(a_nd, b_nd, c_nd)

np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

In [9]:
print(sch.trace)

# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="Y", func_name="mm_relu")
  l1, l2, l3 = sch.get_loops(block=b0)
  l4, l5 = sch.split(loop=l2, factors=[None, 4], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l4, l3, l5)
  b6 = sch.get_block(name="C", func_name="mm_relu")
  sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
  b7 = sch.decompose_reduction(block=b0, loop=l3)


## 随机调度变换 (Stochastic Schedule Transformation)

考虑简单的模型：

In [11]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(128, 128, 128):
            with T.block("C"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0.0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

In [12]:
def stochastic_schedule_mm(sch: tvm.tir.Schedule):
    block_C = sch.get_block("C", "main")
    i, j, k = sch.get_loops(block=block_C)
    j_factors = sch.sample_perfect_tile(loop=j, n=2)
    j_0, j_1 = sch.split(loop=j, factors=j_factors)
    sch.reorder(i, j_0, k, j_1)
    sch.decompose_reduction(block_C, k)
    return sch

sch = tvm.tir.Schedule(MyModule)
sch = stochastic_schedule_mm(sch)
sch.mod.show()

## 随机变换搜索

In [13]:
def random_search(mod: tvm.IRModule, num_trials=5):
    best_result = None
    best_sch = None

    for i in range(num_trials):
        sch = stochastic_schedule_mm(tvm.tir.Schedule(mod))
        lib = tvm.build(sch.mod, target="llvm")
        f_timer_after = lib.time_evaluator("main", tvm.cpu())
        result = f_timer_after(a_nd, b_nd, c_nd).mean

        print("=====Attempt %d, time-cost: %.3f ms====" % (i, result * 1000))
        print(sch.trace)

        # book keep the best result so far
        if best_result is None or result < best_result:
            best_result = result
            best_sch = sch

    return best_sch

sch = random_search(MyModule)

=====Attempt 0, time-cost: 0.379 ms====
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="C", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
  l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l1, l6, l3, l7)
  b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 1, time-cost: 0.363 ms====
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="C", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[16, 8])
  l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l1, l6, l3, l7)
  b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 2, time-cost: 2.229 ms====
# from tvm import tir
def 

使用随机变换来指定好的程序的搜索空间，使用 ``tune_tir`` API 帮助在搜索空间内搜索并找到最优的调度变换。

In [15]:
from tvm import meta_schedule as ms

database = ms.tune_tir(
    mod=MyModule,
    target="llvm --num-cores=1",
    max_trials_global=64,
    num_trials_per_iter=64,
    space=ms.space_generator.ScheduleFn(stochastic_schedule_mm),
    work_dir=".temp/tune_tmp",
)

sch = ms.tir_integration.compile_tir(database, MyModule, "llvm --num-cores=1")

2025-04-10 13:50:25 [INFO] Logging directory: .temp/tune_tmp/logs
2025-04-10 13:50:42 [INFO] LocalBuilder: max_workers = 24
2025-04-10 13:50:42 [INFO] LocalRunner: max_workers = 1
2025-04-10 13:50:43 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,,,,0,



Total trials: 0
Total latency (us): 0

2025-04-10 13:50:53 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |            N/A |          N/A |                   N/A |      0 |      
------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0

2025-04-10 13:50:53 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:50:54 [INFO] [task_scheduler.cc:193] Sending 5 sample(s) to builder
2025-04-10 13:50:55 [INFO] [task_scheduler.cc:195] Sending 5 sample(s) to runner
2025-04-10 13:51:22 [DEBUG] XGB iter   0: tr-p-rmse: 0.327315	tr-a-peak@32: 1.000000	tr-rmse: 0.356233	tr-rmse: 0.356233
2025-04-10 13:51:22 [DEBUG] XGB iter  25: tr-p-rmse: 0.106004	tr-a-peak@32: 1.000000	

Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,13.2337,316.942,316.942,5,


2025-04-10 13:51:22 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:22 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:193] Sending 0 sample(s) to builder
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:195] Sending 0 sample(s) to runner
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,13.2337,316.942,316.942,5,


2025-04-10 13:51:23 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:23 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:193] Sending 0 sample(s) to builder
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:195] Sending 0 sample(s) to runner
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,13.2337,316.942,316.942,5,


2025-04-10 13:51:23 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:23 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:24 [INFO] [task_scheduler.cc:193] Sending 0 sample(s) to builder
2025-04-10 13:51:24 [INFO] [task_scheduler.cc:195] Sending 0 sample(s) to runner
2025-04-10 13:51:24 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,13.2337,316.942,316.942,5,



Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:24 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:24 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:25 [INFO] [task_scheduler.cc:193] Sending 0 sample(s) to builder
2025-04-10 13:51:25 [INFO] [task_scheduler.cc:195] Sending 0 sample(s) to runner
2025-04-10 13:51:25 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,13.2337,316.942,316.942,5,


2025-04-10 13:51:25 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:25 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:25 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,13.2337,316.942,316.942,5,Y


2025-04-10 13:51:25 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |    Y 
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942



查看调优结果:

In [17]:
sch.trace.show()

In [18]:
lib = tvm.build(sch.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000))

Time cost of MyModule after tuning: 0.372 ms


## 利用默认的自动调度

Meta-Schedule 带有内置通用随机变换集合，能够适用于广泛的 TensorIR 计算。这种方法也称为自动调度 (auto-scheduling)，因为搜索空间是由系统生成的。可以通过删除行 `space=ms.space_generator.ScheduleFn(stochastic_schedule_mm)` 来运行它。

在底层，Meta-Schedule 分析每个 TensorIR block 的数据访问和循环模式，并提出对程序的随机变换方式。我们不会在本章中讨论这些通用的变换，但要注意它们也只是随机转换加上代码分析而已。可以使用上一节中学到的相同机制来增强自动调度。

In [20]:
database = ms.tune_tir(
    mod=MyModule,
    target="llvm --num-cores=1",
    max_trials_global=64,
    num_trials_per_iter=64,
    work_dir=".temp/tune_tmp",
)
sch = ms.tir_integration.compile_tir(database, MyModule, "llvm --num-cores=1")

2025-04-10 14:11:05 [INFO] Logging directory: .temp/tune_tmp/logs
2025-04-10 14:11:05 [INFO] LocalBuilder: max_workers = 24
2025-04-10 14:11:06 [INFO] LocalRunner: max_workers = 1
2025-04-10 14:11:07 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,,,,0,



Total trials: 0
Total latency (us): 0

2025-04-10 14:11:07 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |            N/A |          N/A |                   N/A |      0 |      
------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0

2025-04-10 14:11:07 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 14:11:08 [INFO] [task_scheduler.cc:193] Sending 64 sample(s) to builder
2025-04-10 14:11:15 [INFO] [task_scheduler.cc:195] Sending 64 sample(s) to runner
2025-04-10 14:11:30 [DEBUG] XGB iter   0: tr-p-rmse: 0.428030	tr-a-peak@32: 0.998691	tr-rmse: 0.277890	tr-rmse: 0.277890
2025-04-10 14:11:30 [DEBUG] XGB iter  25: tr-p-rmse: 0.049583	tr-a-peak@32: 1.00000

Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,184.1318,22.7788,22.7788,64,



Total trials: 64
Total latency (us): 22.7788

2025-04-10 14:11:30 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |       184.1318 |      22.7788 |               22.7788 |     64 |      
------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 22.7788

2025-04-10 14:11:30 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,4194304,1,184.1318,22.7788,22.7788,64,Y



Total trials: 64
Total latency (us): 22.7788

2025-04-10 14:11:30 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |       184.1318 |      22.7788 |               22.7788 |     64 |    Y 
------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 22.7788



In [21]:
lib = tvm.build(sch.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000))

Time cost of MyModule after tuning: 0.040 ms


结果比原始代码快得多。可以查看历史轨迹和最终代码。在高层次的理解中，历史轨迹包含：
- 更多级的循环变换
- 中间计算的矢量化
- 并行化和循环展开

In [23]:
sch.trace.show()