分块矩阵乘法#

原作者: Thierry Moreau

本教程概述了如何在 VTA 设计中使用 TVM 有效地映射矩阵乘法。建议先学习 简单的矩阵乘法 教程。

在本教程中,将演示 TVM 调度优化,将大型神经网络算子分解为较小的块,以在有限的硬件加速器资源内实现计算。

RPC 设置#

首先编程 Pynq 的 FPGA 并构建它的 RPC 运行时。

import os
import tvm
from tvm import te
import vta
import numpy as np
from tvm import rpc
from tvm.contrib import utils
from vta.testing import simulator

# Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env()

# We read the Pynq RPC host IP address and port number from the OS environment
host = os.environ.get("VTA_RPC_HOST", "192.168.2.99")
port = int(os.environ.get("VTA_RPC_PORT", "9091"))

# We configure both the bitstream and the runtime system on the Pynq
# to match the VTA configuration specified by the vta_config.json file.
if env.TARGET == "pynq":

    # Make sure that TVM was compiled with RPC=1
    assert tvm.runtime.enabled("rpc")
    remote = rpc.connect(host, port)

    # Reconfigure the JIT runtime
    vta.reconfig_runtime(remote)

    # Program the FPGA with a pre-compiled VTA bitstream.
    # You can program the FPGA with your own custom bitstream
    # by passing the path to the bitstream file instead of None.
    vta.program_fpga(remote, bitstream=None)

# In simulation mode, host the RPC server locally.
elif env.TARGET in ["sim", "tsim"]:
    remote = rpc.LocalSession()

声明计算#

作为第一步,需要描述矩阵乘法的计算。将矩阵乘法定义为全连接层中的计算,由其 batch size、输入通道和输出通道定义。它们必须是 VTA 张量形状的整数倍:BATCHBLOCK_INBLOCK_OUT

在矩阵乘法中添加额外的算子,这些算子对输出进行了移位(shifting)和剪切(clipping),以模拟定点矩阵乘法,然后是修正的线性激活。将全连通层的 TVM 数据流图描述如下:

../../../../../_images/fc_dataflow.png

此计算被故意设置得太大,以至于不能一次全部放入 VTA 的 on-chip buffer。因此,在调度阶段,将依靠计算分块策略将计算分解为可管理的块。

# 全连接层 1024 x 1024
batch_size = 1
in_channels = 1024
out_channels = 1024
num_ops = in_channels * out_channels * batch_size * 2

assert batch_size % env.BATCH == 0
assert in_channels % env.BLOCK_IN == 0
assert out_channels % env.BLOCK_OUT == 0

# 推导出平铺的张量形状
data_shape = (
    batch_size // env.BATCH, 
    in_channels // env.BLOCK_IN,
    env.BATCH, env.BLOCK_IN
)
weight_shape = (
    out_channels // env.BLOCK_OUT,
    in_channels // env.BLOCK_IN,
    env.BLOCK_OUT,
    env.BLOCK_IN,
)
output_shape = (
    batch_size // env.BATCH, 
    out_channels // env.BLOCK_OUT, 
    env.BATCH, env.BLOCK_OUT
)

# Reduction axes
ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name="ic")
ic_tns = te.reduce_axis((0, env.BLOCK_IN), name="ic_tns")

# Input placeholder tensors
data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
weight = te.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype)

# Copy buffers
data_buf = te.compute(data_shape, lambda *i: data(*i), "data_buf")
weight_buf = te.compute(weight_shape, lambda *i: weight(*i), "weight_buf")

# 声明矩阵乘法计算
res_gemm = te.compute(
    output_shape,
    lambda bo, co, bi, ci: te.sum(
        data_buf[bo, ic, bi, ic_tns].astype(env.acc_dtype)
        * weight_buf[co, ic, ci, ic_tns].astype(env.acc_dtype),
        axis=[ic, ic_tns],
    ),
    name="res_gem",
)

# 为定点归一化(fix-point normalization)添加 shift stage
res_shr = te.compute(output_shape, lambda *i: res_gemm(*i) >> env.INP_WIDTH, name="res_shr")

# 将值裁剪到 (0, input max value)
inp_max = (1 << (env.INP_WIDTH - 1)) - 1
res_max = te.compute(output_shape, lambda *i: tvm.te.max(res_shr(*i), 0), "res_max")
res_min = te.compute(output_shape, lambda *i: tvm.te.min(res_max(*i), inp_max), "res_min")

# 在返回结果之前,对输入数据类型应用类型转换
res = te.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")

调度计算#

查看一组必要的调度变换,以有效的方式将矩阵乘法映射到 VTA。这些包括:

  • 分块计算(Computation blocking)

  • Lowering 到 VTA 硬件 intrinsics

# 创建 TVM 调度
s = te.create_schedule(res.op)
# 查看默认调度
tvm.lower(s, [data, weight, res], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(
        data: T.Buffer((1, 64, 1, 16), "int8"),
        weight: T.Buffer((64, 64, 16, 16), "int8"),
        res: T.Buffer((1, 64, 1, 16), "int8"),
    ):
        T.func_attr(
            {
                "from_legacy_te_schedule": T.bool(True),
                "global_symbol": "main",
                "tir.noalias": T.bool(True),
            }
        )
        data_buf = T.allocate([1024], "int8", "global")
        weight_buf = T.allocate([1048576], "int8", "global")
        res_gem = T.allocate([1024], "int32", "global")
        data_buf_1 = T.Buffer((1024,), "int8", data=data_buf)
        for i1, i3 in T.grid(64, 16):
            cse_var_1: T.int32 = i1 * 16 + i3
            data_1 = T.Buffer((1024,), "int8", data=data.data)
            data_buf_1[cse_var_1] = data_1[cse_var_1]
        weight_buf_1 = T.Buffer((1048576,), "int8", data=weight_buf)
        for i0, i1, i2, i3 in T.grid(64, 64, 16, 16):
            cse_var_2: T.int32 = i0 * 16384 + i1 * 256 + i2 * 16 + i3
            weight_1 = T.Buffer((1048576,), "int8", data=weight.data)
            weight_buf_1[cse_var_2] = weight_1[cse_var_2]
        res_gem_1 = T.Buffer((1024,), "int32", data=res_gem)
        for co, ci in T.grid(64, 16):
            res_gem_1[co * 16 + ci] = 0
            for ic, ic_tns in T.grid(64, 16):
                cse_var_3: T.int32 = co * 16 + ci
                res_gem_1[cse_var_3] = res_gem_1[cse_var_3] + T.Cast(
                    "int32", data_buf_1[ic * 16 + ic_tns]
                ) * T.Cast(
                    "int32", weight_buf_1[co * 16384 + ic * 256 + ci * 16 + ic_tns]
                )
        res_gem_2 = T.Buffer((1024,), "int32", data=res_gem)
        for i1, i3 in T.grid(64, 16):
            cse_var_4: T.int32 = i1 * 16 + i3
            res_gem_2[cse_var_4] = T.shift_right(res_gem_1[cse_var_4], 8)
        res_gem_3 = T.Buffer((1024,), "int32", data=res_gem)
        for i1, i3 in T.grid(64, 16):
            cse_var_5: T.int32 = i1 * 16 + i3
            res_gem_3[cse_var_5] = T.max(res_gem_2[cse_var_5], 0)
        res_gem_4 = T.Buffer((1024,), "int32", data=res_gem)
        for i1, i3 in T.grid(64, 16):
            cse_var_6: T.int32 = i1 * 16 + i3
            res_gem_4[cse_var_6] = T.min(res_gem_3[cse_var_6], 127)
        for i1, i3 in T.grid(64, 16):
            cse_var_7: T.int32 = i1 * 16 + i3
            res_1 = T.Buffer((1024,), "int8", data=res.data)
            res_1[cse_var_7] = T.Cast("int8", res_gem_4[cse_var_7])

分块计算#

在默认情况下,矩阵乘法对于激活或权重来说太大了,无法一次性适应 VTA 的 on-chip buffer。将 (1, 1024)×(1024, 1024) 矩阵乘法分成更小的 (1, 256) × (256, 256) 矩阵乘法,这样中间张量就可以装进加速器的 on-chip SRAM 中。这种方法类似于将分块技术应用于 CPU 和 GPU,以提高缓存命中率(cache hit rate)。

沿着每个轴执行分块(batch 轴不受影响,因为正在执行单 batch 推理)。也保持最内侧的 tensorization 轴不变,以便 TVM 能够进行模式匹配的 tensorization。在下面的图表中展示了分块在计算调度上的结果:

../../../../../_images/blocking.png

循环分割(splitting)和重新排序(reordering)后的代码等价于下面的伪代码。忽略 batch 轴,因为在这个例子中只执行单 batch 推断:

for (int oc_out = 0; oc_out < 4; ++oc_out) {
  // Initialization loop
  for (int oc_inn = 0; oc_inn < 16; ++oc_inn) {
   for (int oc_tns = 0; oc_tns < 16; ++oc_tns) {
    int j = (oc_out * 16 + oc_inn) * 16 + oc_tns;
    C[0][j] = 0;
   }
  }
  for (int ic_out = 0; ic_out < 4; ++ic_out) {
   // Block loop
   for (int oc_inn = 0; oc_inn < 16; ++oc_inn) {
    for (int ic_inn = 0; ic_inn < 16; ++ic_inn) {
     // Tensorization loop
     for (int oc_tns = 0; oc_tns < 16; ++oc_tns) {
      for (int ic_tns = 0; ic_tns < 16; ++ic_tns) {
       int i = (ic_out * 16 + ic_inn) * 16 + ic_tns;
       int j = (oc_out * 16 + oc_inn) * 16 + oc_tns;
       C[0][i] = C[0][i] + A[0][i] * B[j][i];
      }
     }
    }
   }
  }
 }
}
# Let's define tiling sizes (expressed in multiples of VTA tensor shape size)
b_block = 1 // env.BATCH
i_block = 256 // env.BLOCK_IN
o_block = 256 // env.BLOCK_OUT

# Tile the output tensor along the batch and output channel dimensions
# (since by default we are doing single batch inference, the split along
#  the batch dimension has no effect)
b, oc, b_tns, oc_tns = s[res].op.axis
b_out, b_inn = s[res].split(b, b_block)
oc_out, oc_inn = s[res].split(oc, o_block)
s[res].reorder(b_out, oc_out, b_inn, oc_inn)

# Move intermediate computation into each output compute tile
s[res_gemm].compute_at(s[res], oc_out)
s[res_shr].compute_at(s[res], oc_out)
s[res_max].compute_at(s[res], oc_out)
s[res_min].compute_at(s[res], oc_out)

# Apply additional loop split along reduction axis (input channel)
b_inn, oc_inn, b_tns, oc_tns = s[res_gemm].op.axis
ic_out, ic_inn = s[res_gemm].split(ic, i_block)

# Reorder axes. We move the ic_out axis all the way out of the GEMM
# loop to block along the reduction axis
s[res_gemm].reorder(ic_out, b_inn, oc_inn, ic_inn, b_tns, oc_tns, ic_tns)

# Let's look at the current TVM schedule after blocking
tvm.lower(s, [data, weight, res], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(
        data: T.Buffer((1, 64, 1, 16), "int8"),
        weight: T.Buffer((64, 64, 16, 16), "int8"),
        res: T.Buffer((1, 64, 1, 16), "int8"),
    ):
        T.func_attr(
            {
                "from_legacy_te_schedule": T.bool(True),
                "global_symbol": "main",
                "tir.noalias": T.bool(True),
            }
        )
        data_buf = T.allocate([1024], "int8", "global")
        weight_buf = T.allocate([1048576], "int8", "global")
        res_gem = T.allocate([256], "int32", "global")
        data_buf_1 = T.Buffer((1024,), "int8", data=data_buf)
        for i1, i3 in T.grid(64, 16):
            cse_var_1: T.int32 = i1 * 16 + i3
            data_1 = T.Buffer((1024,), "int8", data=data.data)
            data_buf_1[cse_var_1] = data_1[cse_var_1]
        weight_buf_1 = T.Buffer((1048576,), "int8", data=weight_buf)
        for i0, i1, i2, i3 in T.grid(64, 64, 16, 16):
            cse_var_2: T.int32 = i0 * 16384 + i1 * 256 + i2 * 16 + i3
            weight_1 = T.Buffer((1048576,), "int8", data=weight.data)
            weight_buf_1[cse_var_2] = weight_1[cse_var_2]
        for i1_outer in range(4):
            res_gem_1 = T.Buffer((256,), "int32", data=res_gem)
            for co_init, ci_init in T.grid(16, 16):
                res_gem_1[co_init * 16 + ci_init] = 0
            for ic_outer, co, ic_inner, ci, ic_tns in T.grid(4, 16, 16, 16, 16):
                cse_var_3: T.int32 = co * 16 + ci
                res_gem_1[cse_var_3] = res_gem_1[cse_var_3] + T.Cast(
                    "int32", data_buf_1[ic_outer * 256 + ic_inner * 16 + ic_tns]
                ) * T.Cast(
                    "int32",
                    weight_buf_1[
                        i1_outer * 262144
                        + co * 16384
                        + ic_outer * 4096
                        + ic_inner * 256
                        + ci * 16
                        + ic_tns
                    ],
                )
            res_gem_2 = T.Buffer((256,), "int32", data=res_gem)
            for i1, i3 in T.grid(16, 16):
                cse_var_4: T.int32 = i1 * 16 + i3
                res_gem_2[cse_var_4] = T.shift_right(res_gem_1[cse_var_4], 8)
            res_gem_3 = T.Buffer((256,), "int32", data=res_gem)
            for i1, i3 in T.grid(16, 16):
                cse_var_5: T.int32 = i1 * 16 + i3
                res_gem_3[cse_var_5] = T.max(res_gem_2[cse_var_5], 0)
            res_gem_4 = T.Buffer((256,), "int32", data=res_gem)
            for i1, i3 in T.grid(16, 16):
                cse_var_6: T.int32 = i1 * 16 + i3
                res_gem_4[cse_var_6] = T.min(res_gem_3[cse_var_6], 127)
            for i1_inner, i3 in T.grid(16, 16):
                cse_var_7: T.int32 = i1_inner * 16
                res_1 = T.Buffer((1024,), "int8", data=res.data)
                res_1[i1_outer * 256 + cse_var_7 + i3] = T.Cast(
                    "int8", res_gem_4[cse_var_7 + i3]
                )

lowering 复制到 DMA 传输#

接下来,将 buffer 作用域设置为相应的 on-chip VTA SRAM buffer。将 load 循环移动到矩阵乘法计算循环中,以使它们适合于 on-chip SRAM buffer。最后,用 DMA 复制实用程序对 load/store 循环外轴进行注解,以在 VTA 上执行批量内存传输。

# Set scope of SRAM buffers
s[data_buf].set_scope(env.inp_scope)
s[weight_buf].set_scope(env.wgt_scope)
s[res_gemm].set_scope(env.acc_scope)
s[res_shr].set_scope(env.acc_scope)
s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(env.acc_scope)

# Block data and weight cache reads
s[data_buf].compute_at(s[res_gemm], ic_out)
s[weight_buf].compute_at(s[res_gemm], ic_out)

# Use DMA copy pragma on DRAM->SRAM operations
s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy)
s[weight_buf].pragma(s[weight_buf].op.axis[0], env.dma_copy)

# Use DMA copy pragma on SRAM->DRAM operation
# (this implies that these copies should be performed along b_inn,
# or result axis 2)
s[res].pragma(s[res].op.axis[2], env.dma_copy)

Lowering 计算到 VTA Compute Intrinsics#

最后阶段是通过将矩阵乘法映射到张量 intrinsics,将 shift 映射到矢量 ALU,从而将计算循环 lowering 到 VTA 硬件 intrinsics。

# Apply tensorization over the batch tensor tile axis
s[res_gemm].tensorize(b_tns, env.gemm)

# Add an ALU pragma over the shift and clipping operations
s[res_shr].pragma(s[res_shr].op.axis[0], env.alu)
s[res_min].pragma(s[res_min].op.axis[0], env.alu)
s[res_max].pragma(s[res_max].op.axis[0], env.alu)

# Let's look at the final lowered TVM schedule after lowering memory
# loads/stores down to DMA copy intrinsics, and the computation down to
# VTA compute intrinsics.
vta.lower(s, [data, weight, res], simple_mode=True).show()
[18:17:17] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_sync
[18:17:17] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
# from tvm.script import ir as I
# from tvm.script import tir as T


@I.ir_module
class Module:
    @T.prim_func
    def main(
        data: T.Buffer((1, 64, 1, 16), "int8"),
        weight: T.Buffer((64, 64, 16, 16), "int8"),
        res: T.Buffer((1, 64, 1, 16), "int8"),
    ):
        T.func_attr(
            {
                "from_legacy_te_schedule": T.bool(True),
                "global_symbol": "main",
                "tir.noalias": T.bool(True),
            }
        )
        T.tir.vta.coproc_dep_push(3, 2)
        for i1_outer in range(4):
            vta = T.int32()
            with T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 2):
                T.tir.vta.coproc_dep_pop(3, 2)
                with T.attr(
                    T.iter_var(vta, None, "ThreadIndex", "vta"),
                    "coproc_uop_scope",
                    "VTAPushGEMMOp",
                ):
                    T.call_extern("int32", "VTAUopLoopBegin", 16, 1, 0, 0)
                    T.tir.vta.uop_push(0, 1, 0, 0, 0, 0, 0, 0)
                    T.call_extern("int32", "VTAUopLoopEnd")
                T.tir.vta.coproc_dep_push(2, 1)
            for ic_outer in range(4):
                cse_var_1: T.int32 = ic_outer * 16
                with T.attr(
                    T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 1
                ):
                    T.tir.vta.coproc_dep_pop(2, 1)
                    T.call_extern(
                        "int32",
                        "VTALoadBuffer2D",
                        T.tvm_thread_context(T.tir.vta.command_handle()),
                        data.data,
                        cse_var_1,
                        16,
                        1,
                        16,
                        0,
                        0,
                        0,
                        0,
                        0,
                        2,
                    )
                    T.call_extern(
                        "int32",
                        "VTALoadBuffer2D",
                        T.tvm_thread_context(T.tir.vta.command_handle()),
                        weight.data,
                        i1_outer * 1024 + cse_var_1,
                        16,
                        16,
                        64,
                        0,
                        0,
                        0,
                        0,
                        0,
                        1,
                    )
                    T.tir.vta.coproc_dep_push(1, 2)
                T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 2)
                T.tir.vta.coproc_dep_pop(1, 2)
                with T.attr(
                    T.iter_var(vta, None, "ThreadIndex", "vta"),
                    "coproc_uop_scope",
                    "VTAPushGEMMOp",
                ):
                    T.call_extern("int32", "VTAUopLoopBegin", 16, 1, 0, 16)
                    T.call_extern("int32", "VTAUopLoopBegin", 16, 0, 1, 1)
                    T.tir.vta.uop_push(0, 0, 0, 0, 0, 0, 0, 0)
                    T.call_extern("int32", "VTAUopLoopEnd")
                    T.call_extern("int32", "VTAUopLoopEnd")
                T.tir.vta.coproc_dep_push(2, 1)
            T.tir.vta.coproc_dep_pop(2, 1)
            with T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 2):
                with T.attr(
                    T.iter_var(vta, None, "ThreadIndex", "vta"),
                    "coproc_uop_scope",
                    "VTAPushALUOp",
                ):
                    T.call_extern("int32", "VTAUopLoopBegin", 16, 1, 1, 0)
                    T.tir.vta.uop_push(1, 0, 0, 0, 0, 3, 1, 8)
                    T.call_extern("int32", "VTAUopLoopEnd")
                with T.attr(
                    T.iter_var(vta, None, "ThreadIndex", "vta"),
                    "coproc_uop_scope",
                    "VTAPushALUOp",
                ):
                    T.call_extern("int32", "VTAUopLoopBegin", 16, 1, 1, 0)
                    T.tir.vta.uop_push(1, 0, 0, 0, 0, 1, 1, 0)
                    T.call_extern("int32", "VTAUopLoopEnd")
                with T.attr(
                    T.iter_var(vta, None, "ThreadIndex", "vta"),
                    "coproc_uop_scope",
                    "VTAPushALUOp",
                ):
                    T.call_extern("int32", "VTAUopLoopBegin", 16, 1, 1, 0)
                    T.tir.vta.uop_push(1, 0, 0, 0, 0, 0, 1, 127)
                    T.call_extern("int32", "VTAUopLoopEnd")
                T.tir.vta.coproc_dep_push(2, 3)
            T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 3)
            T.tir.vta.coproc_dep_pop(2, 3)
            for i1_inner in range(16):
                T.call_extern(
                    "int32",
                    "VTAStoreBuffer2D",
                    T.tvm_thread_context(T.tir.vta.command_handle()),
                    i1_inner,
                    4,
                    res.data,
                    i1_outer * 16 + i1_inner,
                    1,
                    1,
                    1,
                )
            T.tir.vta.coproc_dep_push(3, 2)
        T.tir.vta.coproc_sync()
        T.tir.vta.coproc_dep_pop(3, 2)

TVM 计算和验证#

在指定调度之后,可以将其编译为 TVM 函数。保存模块,这样就可以通过 RPC 发送它。运行该函数并对 numpy 实现进行验证,以确保其正确性。

# Compile the TVM module
my_gemm = vta.build(
    s, [data, weight, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_gemm"
)
temp = utils.tempdir()
my_gemm.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")

# Get the remote device context
ctx = remote.ext_dev(0)

# Initialize the data and weight arrays randomly in the int range of (-128, 128]
data_np = np.random.randint(-128, 128, size=(batch_size, in_channels)).astype(data.dtype)
weight_np = np.random.randint(-128, 128, size=(out_channels, in_channels)).astype(weight.dtype)

# Apply packing to the data and weight arrays from a 2D to a 4D packed layout
data_packed = data_np.reshape(
    batch_size // env.BATCH, env.BATCH, in_channels // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))
weight_packed = weight_np.reshape(
    out_channels // env.BLOCK_OUT, env.BLOCK_OUT, in_channels // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))

# Format the input/output arrays with tvm.nd.array to the DLPack standard
data_nd = tvm.nd.array(data_packed, ctx)
weight_nd = tvm.nd.array(weight_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)

# Clear stats
if env.TARGET in ["sim", "tsim"]:
    simulator.clear_stats()

# Invoke the module to perform the computation
f(data_nd, weight_nd, res_nd)

# Verify against numpy implementation
res_ref = np.dot(data_np.astype(env.acc_dtype), weight_np.T.astype(env.acc_dtype))
res_ref = res_ref >> env.INP_WIDTH
res_ref = np.clip(res_ref, 0, inp_max)
res_ref = res_ref.astype(res.dtype)
res_ref = res_ref.reshape(
    batch_size // env.BATCH, env.BATCH, out_channels // env.BLOCK_OUT, env.BLOCK_OUT
).transpose((0, 2, 1, 3))
np.testing.assert_equal(res_ref, res_nd.numpy())

# Print stats
if env.TARGET in ["sim", "tsim"]:
    sim_stats = simulator.stats()
    print("Execution statistics:")
    for k, v in sim_stats.items():
        print("\t{:<16}: {:>16}".format(k, v))

print("Successful blocked matrix multiply test!")
Execution statistics:
	inp_load_nbytes :             4096
	wgt_load_nbytes :          1048576
	acc_load_nbytes :                0
	uop_load_nbytes :               20
	out_store_nbytes:             1024
	gemm_counter    :             4096
	alu_counter     :              192
Successful blocked matrix multiply test!
[18:17:19] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-04-20 18:17:19.789 INFO load_module /tmp/tmpqql8a5fm/gemm.o

小结#