mm-relu 示例#

{xi=[xi(1)xi(2)xi(m)]Tyj=[yj(1)yj(2)yj(p)]T

{<xi,yj>=xiTyjX=[x1Tx2TxnT]TY=[y1Ty2TynT]T

可以推出

<X,Y>=XTY=(xiTyj)m×p

ReLU 函数定义:

relu(X)=max(X,0)
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T

使用 NumPy 实现如下:

dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
# a @ b 等价于 np.matmul(a, b)
c_mm_relu = np.maximum(a_np @ b_np, 0)

NumPy 低级实现:

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    # 分配中间数组用于存储矩阵乘法的结果
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

验证数值一致性

c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

TensorIR 实现#

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
sch = tvm.tir.Schedule(MyModule)
block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)
j0, j1 = sch.split(j, factors=[None, 4])
sch.reorder(j0, k, j1)
block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0)
sch.decompose_reduction(block_Y, k) # 将 Y 元素的初始化与归约更新分开
tir.BlockRV(0x55bcd40cc720)
sch.mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j_0 in T.grid(128, 32):
            for j_1_init in range(4):
                with T.block("Y_init"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + j_1_init)
                    T.reads()
                    T.writes(Y[vi, vj])
                    Y[vi, vj] = T.float32(0.0)
            for k, j_1 in T.grid(128, 4):
                with T.block("Y_update"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in range(4):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + ax0)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
rt_lib = tvm.build(MyModule, target="llvm")
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
func_mm_relu = rt_lib["mm_relu"]
func_mm_relu(a_nd, b_nd, c_nd)

np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)
print(sch.trace)
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="Y", func_name="mm_relu")
  l1, l2, l3 = sch.get_loops(block=b0)
  l4, l5 = sch.split(loop=l2, factors=[None, 4], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l4, l3, l5)
  b6 = sch.get_block(name="C", func_name="mm_relu")
  sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
  b7 = sch.decompose_reduction(block=b0, loop=l3)

随机调度变换 (Stochastic Schedule Transformation)#

考虑简单的模型:

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(128, 128, 128):
            with T.block("C"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0.0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def stochastic_schedule_mm(sch: tvm.tir.Schedule):
    block_C = sch.get_block("C", "main")
    i, j, k = sch.get_loops(block=block_C)
    j_factors = sch.sample_perfect_tile(loop=j, n=2)
    j_0, j_1 = sch.split(loop=j, factors=j_factors)
    sch.reorder(i, j_0, k, j_1)
    sch.decompose_reduction(block_C, k)
    return sch

sch = tvm.tir.Schedule(MyModule)
sch = stochastic_schedule_mm(sch)
sch.mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j_0 in T.grid(128, 64):
            for j_1_init in range(2):
                with T.block("C_init"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 2 + j_1_init)
                    T.reads()
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.float32(0.0)
            for k, j_1 in T.grid(128, 2):
                with T.block("C_update"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 2 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

随机变换搜索#

def random_search(mod: tvm.IRModule, num_trials=5):
    best_result = None
    best_sch = None

    for i in range(num_trials):
        sch = stochastic_schedule_mm(tvm.tir.Schedule(mod))
        lib = tvm.build(sch.mod, target="llvm")
        f_timer_after = lib.time_evaluator("main", tvm.cpu())
        result = f_timer_after(a_nd, b_nd, c_nd).mean

        print("=====Attempt %d, time-cost: %.3f ms====" % (i, result * 1000))
        print(sch.trace)

        # book keep the best result so far
        if best_result is None or result < best_result:
            best_result = result
            best_sch = sch

    return best_sch

sch = random_search(MyModule)
=====Attempt 0, time-cost: 0.379 ms====
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="C", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16])
  l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l1, l6, l3, l7)
  b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 1, time-cost: 0.363 ms====
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="C", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[16, 8])
  l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l1, l6, l3, l7)
  b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 2, time-cost: 2.229 ms====
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="C", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[128, 1])
  l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l1, l6, l3, l7)
  b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 3, time-cost: 1.015 ms====
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="C", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[64, 2])
  l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l1, l6, l3, l7)
  b8 = sch.decompose_reduction(block=b0, loop=l3)
=====Attempt 4, time-cost: 1.853 ms====
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="C", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[128, 1])
  l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l1, l6, l3, l7)
  b8 = sch.decompose_reduction(block=b0, loop=l3)

使用随机变换来指定好的程序的搜索空间,使用 tune_tir API 帮助在搜索空间内搜索并找到最优的调度变换。

from tvm import meta_schedule as ms

database = ms.tune_tir(
    mod=MyModule,
    target="llvm --num-cores=1",
    max_trials_global=64,
    num_trials_per_iter=64,
    space=ms.space_generator.ScheduleFn(stochastic_schedule_mm),
    work_dir=".temp/tune_tmp",
)

sch = ms.tir_integration.compile_tir(database, MyModule, "llvm --num-cores=1")
Hide code cell output
2025-04-10 13:50:25 [INFO] Logging directory: .temp/tune_tmp/logs
2025-04-10 13:50:42 [INFO] LocalBuilder: max_workers = 24
2025-04-10 13:50:42 [INFO] LocalRunner: max_workers = 1
2025-04-10 13:50:43 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 N/A N/A N/A 0
Total trials: 0
Total latency (us): 0

2025-04-10 13:50:53 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |            N/A |          N/A |                   N/A |      0 |      
------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0

2025-04-10 13:50:53 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:50:54 [INFO] [task_scheduler.cc:193] Sending 5 sample(s) to builder
2025-04-10 13:50:55 [INFO] [task_scheduler.cc:195] Sending 5 sample(s) to runner
2025-04-10 13:51:22 [DEBUG] XGB iter   0: tr-p-rmse: 0.327315	tr-a-peak@32: 1.000000	tr-rmse: 0.356233	tr-rmse: 0.356233
2025-04-10 13:51:22 [DEBUG] XGB iter  25: tr-p-rmse: 0.106004	tr-a-peak@32: 1.000000	tr-rmse: 0.062224	tr-rmse: 0.062224
2025-04-10 13:51:22 [DEBUG] XGB iter  50: tr-p-rmse: 0.101582	tr-a-peak@32: 1.000000	tr-rmse: 0.059773	tr-rmse: 0.059773
2025-04-10 13:51:22 [DEBUG] XGB iter  75: tr-p-rmse: 0.101558	tr-a-peak@32: 1.000000	tr-rmse: 0.059773	tr-rmse: 0.059773
2025-04-10 13:51:22 [DEBUG] XGB iter 100: tr-p-rmse: 0.101558	tr-a-peak@32: 1.000000	tr-rmse: 0.059773	tr-rmse: 0.059773
2025-04-10 13:51:22 [DEBUG] XGB stopped. Best iteration: [74] tr-p-rmse:0.10156	tr-a-peak@32:1.00000	tr-rmse:0.05977	tr-rmse:0.05977 
2025-04-10 13:51:22 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 13.2337 316.9420 316.9420 5
2025-04-10 13:51:22 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:22 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:193] Sending 0 sample(s) to builder
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:195] Sending 0 sample(s) to runner
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 13.2337 316.9420 316.9420 5
2025-04-10 13:51:23 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:23 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:193] Sending 0 sample(s) to builder
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:195] Sending 0 sample(s) to runner
2025-04-10 13:51:23 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 13.2337 316.9420 316.9420 5
2025-04-10 13:51:23 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:23 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:24 [INFO] [task_scheduler.cc:193] Sending 0 sample(s) to builder
2025-04-10 13:51:24 [INFO] [task_scheduler.cc:195] Sending 0 sample(s) to runner
2025-04-10 13:51:24 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 13.2337 316.9420 316.9420 5
Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:24 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:24 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:25 [INFO] [task_scheduler.cc:193] Sending 0 sample(s) to builder
2025-04-10 13:51:25 [INFO] [task_scheduler.cc:195] Sending 0 sample(s) to runner
2025-04-10 13:51:25 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 13.2337 316.9420 316.9420 5
2025-04-10 13:51:25 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |      
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942

2025-04-10 13:51:25 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 13:51:25 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 13.2337 316.9420 316.9420 5 Y
2025-04-10 13:51:25 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |        13.2337 |     316.9420 |              316.9420 |      5 |    Y 
------------------------------------------------------------------------------------------------------
Total trials: 5
Total latency (us): 316.942


Total trials: 5
Total latency (us): 316.942

查看调优结果:

sch.trace.show()
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="C", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[16, 8])
  l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l1, l6, l3, l7)
  b8 = sch.decompose_reduction(block=b0, loop=l3)
  sch.enter_postproc()
lib = tvm.build(sch.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000))
Time cost of MyModule after tuning: 0.372 ms

利用默认的自动调度#

Meta-Schedule 带有内置通用随机变换集合,能够适用于广泛的 TensorIR 计算。这种方法也称为自动调度 (auto-scheduling),因为搜索空间是由系统生成的。可以通过删除行 space=ms.space_generator.ScheduleFn(stochastic_schedule_mm) 来运行它。

在底层,Meta-Schedule 分析每个 TensorIR block 的数据访问和循环模式,并提出对程序的随机变换方式。我们不会在本章中讨论这些通用的变换,但要注意它们也只是随机转换加上代码分析而已。可以使用上一节中学到的相同机制来增强自动调度。

database = ms.tune_tir(
    mod=MyModule,
    target="llvm --num-cores=1",
    max_trials_global=64,
    num_trials_per_iter=64,
    work_dir=".temp/tune_tmp",
)
sch = ms.tir_integration.compile_tir(database, MyModule, "llvm --num-cores=1")
Hide code cell output
2025-04-10 14:11:05 [INFO] Logging directory: .temp/tune_tmp/logs
2025-04-10 14:11:05 [INFO] LocalBuilder: max_workers = 24
2025-04-10 14:11:06 [INFO] LocalRunner: max_workers = 1
2025-04-10 14:11:07 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 N/A N/A N/A 0
Total trials: 0
Total latency (us): 0

2025-04-10 14:11:07 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |            N/A |          N/A |                   N/A |      0 |      
------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0

2025-04-10 14:11:07 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-04-10 14:11:08 [INFO] [task_scheduler.cc:193] Sending 64 sample(s) to builder
2025-04-10 14:11:15 [INFO] [task_scheduler.cc:195] Sending 64 sample(s) to runner
2025-04-10 14:11:30 [DEBUG] XGB iter   0: tr-p-rmse: 0.428030	tr-a-peak@32: 0.998691	tr-rmse: 0.277890	tr-rmse: 0.277890
2025-04-10 14:11:30 [DEBUG] XGB iter  25: tr-p-rmse: 0.049583	tr-a-peak@32: 1.000000	tr-rmse: 0.327711	tr-rmse: 0.327711
2025-04-10 14:11:30 [DEBUG] XGB iter  50: tr-p-rmse: 0.049571	tr-a-peak@32: 1.000000	tr-rmse: 0.327729	tr-rmse: 0.327729
2025-04-10 14:11:30 [DEBUG] XGB iter  75: tr-p-rmse: 0.049571	tr-a-peak@32: 1.000000	tr-rmse: 0.327729	tr-rmse: 0.327729
2025-04-10 14:11:30 [DEBUG] XGB stopped. Best iteration: [33] tr-p-rmse:0.04957	tr-a-peak@32:1.00000	tr-rmse:0.32773	tr-rmse:0.32773 
2025-04-10 14:11:30 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 184.1318 22.7788 22.7788 64
Total trials: 64
Total latency (us): 22.7788

2025-04-10 14:11:30 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |       184.1318 |      22.7788 |               22.7788 |     64 |      
------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 22.7788

2025-04-10 14:11:30 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 184.1318 22.7788 22.7788 64 Y
Total trials: 64
Total latency (us): 22.7788

2025-04-10 14:11:30 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
------------------------------------------------------------------------------------------------------
  0 | main | 4194304 |      1 |       184.1318 |      22.7788 |               22.7788 |     64 |    Y 
------------------------------------------------------------------------------------------------------
Total trials: 64
Total latency (us): 22.7788
lib = tvm.build(sch.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000))
Time cost of MyModule after tuning: 0.040 ms

结果比原始代码快得多。可以查看历史轨迹和最终代码。在高层次的理解中,历史轨迹包含:

  • 更多级的循环变换

  • 中间计算的矢量化

  • 并行化和循环展开

sch.trace.show()
Hide code cell output
# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="C", func_name="main")
  b1 = sch.get_block(name="root", func_name="main")
  sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
  l2, l3, l4 = sch.get_loops(block=b0)
  v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[2, 8, 2, 4])
  l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True, disable_predication=False)
  v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[4, 4, 1, 8])
  l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True, disable_predication=False)
  v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[64, 2])
  l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
  b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")
  sch.reverse_compute_at(block=b25, loop=l18, preserve_unit_loops=True, index=-1)
  sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=16)
  sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64)
  v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=2)
  sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)
  sch.enter_postproc()
  b27 = sch.get_block(name="root", func_name="main")
  sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.parallel")
  sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.vectorize")
  sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.unroll_explicit")
  b28, b29 = sch.get_child_blocks(b27)
  l30, l31, l32, l33, l34, l35, l36, l37, l38, l39 = sch.get_loops(block=b28)
  l40 = sch.fuse(l30, l31, l32, preserve_unit_iters=True)
  sch.parallel(loop=l40)
  l41 = sch.fuse(l39, preserve_unit_iters=True)
  sch.vectorize(loop=l41)
  sch.annotate(block_or_loop=l40, ann_key="pragma_auto_unroll_max_step", ann_val=64)
  sch.annotate(block_or_loop=l40, ann_key="pragma_unroll_explicit", ann_val=1)
  l42, l43, l44, l45 = sch.get_loops(block=b29)
  l46 = sch.fuse(l45, preserve_unit_iters=True)
  sch.vectorize(loop=l46)
  b47 = sch.get_block(name="C", func_name="main")
  l48, l49, l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b47)
  b56 = sch.decompose_reduction(block=b47, loop=l50)