2D 卷积优化#

原作者: Thierry Moreau

本教程提供了关于如何使用 TVM 映射二维卷积工作负载有效的 VTA 设计的概述。建议先学习 分块矩阵乘法 教程。

二维卷积在大多数计算机视觉深度神经网络中占主导地位。在本教程中,将演示 TVM 调度优化,将 NCHW 布局中的 2D 卷积算子映射到 VTA。还引入了延迟隐藏(latency hiding)的概念,它允许最大化 VTA 的计算和内存资源利用。

RPC 设置#

首先编程 Pynq 的 FPGA 并构建它的 RPC 运行时。

import set_env
import os
import tvm
import tvm.testing
from tvm import te
import vta
import numpy as np

from tvm import rpc
from tvm.contrib import utils
from vta.testing import simulator

# Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env()

# We read the Pynq RPC host IP address and port number from the OS environment
host = os.environ.get("VTA_RPC_HOST", "192.168.2.99")
port = int(os.environ.get("VTA_RPC_PORT", "9091"))

# We configure both the bitstream and the runtime system on the Pynq
# to match the VTA configuration specified by the vta_config.json file.
if env.TARGET == "pynq":
    # Make sure that TVM was compiled with RPC=1
    assert tvm.runtime.enabled("rpc")
    remote = rpc.connect(host, port)

    # Reconfigure the JIT runtime
    vta.reconfig_runtime(remote)

    # Program the FPGA with a pre-compiled VTA bitstream.
    # You can program the FPGA with your own custom bitstream
    # by passing the path to the bitstream file instead of None.
    vta.program_fpga(remote, bitstream=None)

# In simulation mode, host the RPC server locally.
elif env.TARGET in ["sim", "tsim"]:
    remote = rpc.LocalSession()

声明计算#

作为第一步,需要用 NCHW 格式描述 2D 卷积计算。

通过 batch size、空间维度、输入通道、输出通道、核维度、核维度、填充维度和步长维度来定义二维卷积形状。

选择 ResNet-18 架构的第 9 个卷积层的形状作为卷积 workload 参数。

在 2D 卷积中添加了额外的算子,用于对输出进行移位和剪切,以模拟定点卷积之后的修正线性激活。将二维卷积层的 TVM 数据流图描述如下:

../../../../../_images/conv2d_dataflow.png

这个计算被故意设置得太大,以至于不能一次全部放入 VTA 的 on-chip buffers。因此,在调度阶段,将依靠计算分块策略将计算分解为可管理的块。

空间填充

注意,需要导入 TOPI 库来对输入特征映射张量应用空间填充(Spatial padding)。空间填充有助于在 2D 卷积环境中分块,因为如果卷积核窗口大小大于 1,那么任何给定层的输入特征映射的相同 (x, y) 空间位置将被读取多次。在 CPU 和 GPU 上,当并行工作时,提高内存访问效率的一种方法是空间打包(spatial packing),这需要重新布局数据。VTA load DMA 引擎可以自动插入填充,这样原始的输入特征映射就不必在内存中重新打包。

当数据从 DRAM load 到 VTA 的 SRAM 时,下面展示了 VTA 对动态空间填充的影响,随后是 2D 跨步和填充内存读取。

../../../../../_images/padding.png
from tvm import topi

# 2D convolution layer dimensions taken from ResNet-18 architecture
# (9th convolutional layer)
batch_size = 1
height = 14
width = 14
in_channels = 256
out_channels = 256
kernel_h = 3
kernel_w = 3
pad_h = 1
pad_w = 1
stride_h = 1
stride_w = 1
assert batch_size % env.BATCH == 0
assert in_channels % env.BLOCK_IN == 0
assert out_channels % env.BLOCK_OUT == 0

# Input feature map: (N, IC, H, W, n, ic)
data_shape = (
    batch_size // env.BATCH,
    in_channels // env.BLOCK_IN,
    height,
    width,
    env.BATCH,
    env.BLOCK_IN,
)
# Kernel: (OC, IC, H, W, oc, ic)
kernel_shape = (
    out_channels // env.BLOCK_OUT,
    in_channels // env.BLOCK_IN,
    kernel_h,
    kernel_w,
    env.BLOCK_OUT,
    env.BLOCK_IN,
)
# Derive output feature map dimensions
fout_height = (height + 2 * pad_h - kernel_h) // stride_h + 1
fout_width = (width + 2 * pad_w - kernel_w) // stride_w + 1
# Output feature map: (N, OC, H, W, n, oc)
output_shape = (
    batch_size // env.BATCH,
    out_channels // env.BLOCK_OUT,
    fout_height,
    fout_width,
    env.BATCH,
    env.BLOCK_OUT,
)

# Convolution reduction axes
dy = te.reduce_axis((0, kernel_h), name="dy")
dx = te.reduce_axis((0, kernel_w), name="dx")
ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name="ic")
ic_tns = te.reduce_axis((0, env.BLOCK_IN), name="ic_tns")

# Input placeholder tensors
data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)

# Copy buffers:
#   Apply spatial padding to input feature map
data_buf = topi.nn.pad(data, [0, 0, pad_h, pad_w, 0, 0], name="data_buf")
kernel_buf = te.compute(kernel_shape, lambda *i: kernel(*i), "kernel_buf")

# Declare 2D convolution
res_conv = te.compute(
    output_shape,
    lambda bo, co, i, j, bi, ci: te.sum(
        data_buf[bo, ic, i * stride_h + dy, j * stride_w + dx, bi, ic_tns].astype(env.acc_dtype)
        * kernel_buf[co, ic, dy, dx, ci, ic_tns].astype(env.acc_dtype),
        axis=[ic, dy, dx, ic_tns],
    ),
    name="res_conv",
)

# Add shift stage for fix-point normalization
res_shr = te.compute(output_shape, lambda *i: res_conv(*i) >> 8, name="res_shr")

# Apply clipping between (0, input max value)
inp_max = (1 << (env.INP_WIDTH - 1)) - 1
res_max = te.compute(output_shape, lambda *i: tvm.te.max(res_shr(*i), 0), "res_max")
res_min = te.compute(output_shape, lambda *i: tvm.te.min(res_max(*i), inp_max), "res_min")

# Result Tensor
res = te.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")

调度计算#

将看到一组必要的调度变换,以有效的方式将 2D 卷积映射到 VTA。这些包括:

  • 分块计算

  • 增加计算利用率(compute utilization)的虚拟线程(Virtual threading)

  • Lowering 到 VTA 硬件 intrinsics

# Create TVM schedule
s = te.create_schedule(res.op)
# Let's look at the default TVM schedule
tvm.lower(s, [data, kernel, res], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(data: T.Buffer((1, 16, 14, 14, 1, 16), "int8"), kernel: T.Buffer((16, 16, 3, 3, 16, 16), "int8"), res: T.Buffer((1, 16, 14, 14, 1, 16), "int8")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        data_buf = T.allocate([65536], "int8", "global")
        kernel_buf = T.allocate([589824], "int8", "global")
        res_conv = T.allocate([50176], "int32", "global")
        data_buf_1 = T.Buffer((65536,), "int8", data=data_buf)
        for i1, i2, i3, i5 in T.grid(16, 16, 16, 16):
            cse_var_1: T.int32 = i3 * 16
            data_1 = T.Buffer((50176,), "int8", data=data.data)
            data_buf_1[i1 * 4096 + i2 * 256 + cse_var_1 + i5] = T.if_then_else(1 <= i2 and i2 < 15 and 1 <= i3 and i3 < 15, data_1[i1 * 3136 + i2 * 224 + cse_var_1 + i5 - 240], T.int8(0))
        kernel_buf_1 = T.Buffer((589824,), "int8", data=kernel_buf)
        for i0, i1, i2, i3, i4, i5 in T.grid(16, 16, 3, 3, 16, 16):
            cse_var_2: T.int32 = i0 * 36864 + i1 * 2304 + i2 * 768 + i3 * 256 + i4 * 16 + i5
            kernel_1 = T.Buffer((589824,), "int8", data=kernel.data)
            kernel_buf_1[cse_var_2] = kernel_1[cse_var_2]
        res_conv_1 = T.Buffer((50176,), "int32", data=res_conv)
        for co, i, j, ci in T.grid(16, 14, 14, 16):
            res_conv_1[co * 3136 + i * 224 + j * 16 + ci] = 0
            for ic, dy, dx, ic_tns in T.grid(16, 3, 3, 16):
                cse_var_4: T.int32 = j * 16
                cse_var_3: T.int32 = co * 3136 + i * 224 + cse_var_4 + ci
                res_conv_1[cse_var_3] = res_conv_1[cse_var_3] + T.Cast("int32", data_buf_1[ic * 4096 + i * 256 + dy * 256 + cse_var_4 + dx * 16 + ic_tns]) * T.Cast("int32", kernel_buf_1[co * 36864 + ic * 2304 + dy * 768 + dx * 256 + ci * 16 + ic_tns])
        res_conv_2 = T.Buffer((50176,), "int32", data=res_conv)
        for i1, i2, i3, i5 in T.grid(16, 14, 14, 16):
            cse_var_5: T.int32 = i1 * 3136 + i2 * 224 + i3 * 16 + i5
            res_conv_2[cse_var_5] = T.shift_right(res_conv_1[cse_var_5], 8)
        res_conv_3 = T.Buffer((50176,), "int32", data=res_conv)
        for i1, i2, i3, i5 in T.grid(16, 14, 14, 16):
            cse_var_6: T.int32 = i1 * 3136 + i2 * 224 + i3 * 16 + i5
            res_conv_3[cse_var_6] = T.max(res_conv_2[cse_var_6], 0)
        res_conv_4 = T.Buffer((50176,), "int32", data=res_conv)
        for i1, i2, i3, i5 in T.grid(16, 14, 14, 16):
            cse_var_7: T.int32 = i1 * 3136 + i2 * 224 + i3 * 16 + i5
            res_conv_4[cse_var_7] = T.min(res_conv_3[cse_var_7], 127)
        for i1, i2, i3, i5 in T.grid(16, 14, 14, 16):
            cse_var_8: T.int32 = i1 * 3136 + i2 * 224 + i3 * 16 + i5
            res_1 = T.Buffer((50176,), "int8", data=res.data)
            res_1[cse_var_8] = T.Cast("int8", res_conv_4[cse_var_8])

分块计算#

默认情况下,2D 卷积太大,激活或卷积核权重无法同时适应 VTA 的 on-chip buffer。沿着输入通道、输出通道和高度空间维度应用分块。不沿宽度空间维度进行分块,因为它是 NCHW 布局中的最内层维度(因此,为了增加局部性,最好不要沿最内层维度进行分块)。

# 定义 tiling sizes
b_block = 1 // env.BATCH
oc_block = 128 // env.BLOCK_OUT
ic_block = 16 // env.BLOCK_IN
h_block = 7
w_block = 14

# 沿着空间和输出通道维度平铺输出张量(因为默认情况下做单批推理,沿着批维分割没有效果)
b, oc, y, x, b_tns, oc_tns = s[res].op.axis
b_out, b_inn = s[res].split(b, factor=b_block)
oc_out, oc_inn = s[res].split(oc, factor=oc_block)
y_out, y_inn = s[res].split(y, factor=h_block)
x_out, x_inn = s[res].split(x, factor=w_block)
s[res].reorder(b_out, oc_out, y_out, x_out, b_inn, oc_inn, y_inn, x_inn, b_tns, oc_tns)

# 将中间计算移动到每个输出计算 tile 中
s[res_conv].compute_at(s[res], x_out)
s[res_shr].compute_at(s[res], x_out)
s[res_max].compute_at(s[res], x_out)
s[res_min].compute_at(s[res], x_out)

# 沿着规约轴(输入通道)应用额外的循环分割(loop split)
b_inn, oc_inn, y_inn, x_inn, b_tns, oc_tns = s[res_conv].op.axis
ic_out, ic_inn = s[res_conv].split(ic, factor=ic_block)

重排轴#

  1. 将 VTA 张量轴分组在最内侧位置:b_tns, oc_tns, ic_tns,允许 TVM 张量化。

  2. 将 ic_out 轴整个移出卷积循环,沿着归约轴进行分块。

  3. 重新排列块轴:b_inn, oc_inn, y_inn, x_inn, ic_inn, dy, dx。VTA 运行时/硬件要求为每个 VTA 张量运算写入不同的输出特征映射(feature map)位置。这个限制要求我们将 oc_inn、y_inn 或 x_inn 中的一个排序在 b_tns 之前,因为它们都会影响输出特征映射索引。因此,我们选择把 x_inn 放在里面,如下所示。

s[res_conv].reorder(ic_out, b_inn, oc_inn, y_inn, ic_inn, dy, dx, x_inn, b_tns, oc_tns, ic_tns)

虚拟线程#

虚拟线程(virtual thread)是一种在 VTA 硬件设计中增加任务级管道并行性的机制。换句话说,它通过隐藏内存访问延迟(hiding memory access latency)提高了计算资源的利用率。

在下面的实现中,虚拟线程将工作分配给沿输出通道轴划分的两个线程。在下面的图中,展示了计算 2D 卷积时工作是如何分割的。

../../../../../_images/virtual_threading.png
# VTA 仅支持 2 个虚拟线程
v_threads = 2

# 沿输出通道外轴执行虚拟线程 split 
_, tx = s[res].split(oc_out, factor=v_threads)
s[res].reorder(tx, b_out)
s[res].bind(tx, te.thread_axis("cthread"))
# 看看在分块和虚拟线程之后当前的 TVM 调度
tvm.lower(s, [data, kernel, res], simple_mode=True).show()
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(data: T.Buffer((1, 16, 14, 14, 1, 16), "int8"), kernel: T.Buffer((16, 16, 3, 3, 16, 16), "int8"), res: T.Buffer((1, 16, 14, 14, 1, 16), "int8")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        data_buf = T.allocate([65536], "int8", "global")
        kernel_buf = T.allocate([589824], "int8", "global")
        res_conv = T.allocate([25088], "int32", "global")
        data_buf_1 = T.Buffer((65536,), "int8", data=data_buf)
        for i1, i2, i3, i5 in T.grid(16, 16, 16, 16):
            cse_var_1: T.int32 = i3 * 16
            data_1 = T.Buffer((50176,), "int8", data=data.data)
            data_buf_1[i1 * 4096 + i2 * 256 + cse_var_1 + i5] = T.if_then_else(1 <= i2 and i2 < 15 and 1 <= i3 and i3 < 15, data_1[i1 * 3136 + i2 * 224 + cse_var_1 + i5 - 240], T.int8(0))
        kernel_buf_1 = T.Buffer((589824,), "int8", data=kernel_buf)
        for i0, i1, i2, i3, i4, i5 in T.grid(16, 16, 3, 3, 16, 16):
            cse_var_2: T.int32 = i0 * 36864 + i1 * 2304 + i2 * 768 + i3 * 256 + i4 * 16 + i5
            kernel_1 = T.Buffer((589824,), "int8", data=kernel.data)
            kernel_buf_1[cse_var_2] = kernel_1[cse_var_2]
        for i2_outer in range(2):
            res_conv_1 = T.Buffer((157351936,), "int32", data=res_conv)
            for co_init, i_init, j_init, ci_init in T.grid(8, 7, 14, 16):
                cse_var_3: T.int32 = co_init * 1568 + i_init * 224 + j_init * 16 + ci_init
                res_conv_1[cse_var_3] = 0
                res_conv_1[cse_var_3 + 12544] = 0
            for ic_outer, co, i, dy, dx, j, ci, ic_tns in T.grid(16, 8, 7, 3, 3, 14, 16, 16):
                cse_var_8: T.int32 = j * 16
                cse_var_7: T.int32 = co * 1568 + i * 224 + cse_var_8 + ci
                cse_var_6: T.int32 = cse_var_7 + 12544
                cse_var_5: T.int32 = co * 36864 + ic_outer * 2304 + dy * 768 + dx * 256 + ci * 16 + ic_tns
                cse_var_4: T.int32 = ic_outer * 4096 + i2_outer * 1792 + i * 256 + dy * 256 + cse_var_8 + dx * 16 + ic_tns
                res_conv_1[cse_var_7] = res_conv_1[cse_var_7] + T.Cast("int32", data_buf_1[cse_var_4]) * T.Cast("int32", kernel_buf_1[cse_var_5])
                res_conv_1[cse_var_6] = res_conv_1[cse_var_6] + T.Cast("int32", data_buf_1[cse_var_4]) * T.Cast("int32", kernel_buf_1[cse_var_5 + 294912])
            res_conv_2 = T.Buffer((157351936,), "int32", data=res_conv)
            for i1, i2, i3, i5 in T.grid(8, 7, 14, 16):
                cse_var_10: T.int32 = i1 * 1568 + i2 * 224 + i3 * 16 + i5
                cse_var_9: T.int32 = cse_var_10 + 12544
                res_conv_2[cse_var_10] = T.shift_right(res_conv_1[cse_var_10], 8)
                res_conv_2[cse_var_9] = T.shift_right(res_conv_1[cse_var_9], 8)
            res_conv_3 = T.Buffer((157351936,), "int32", data=res_conv)
            for i1, i2, i3, i5 in T.grid(8, 7, 14, 16):
                cse_var_12: T.int32 = i1 * 1568 + i2 * 224 + i3 * 16 + i5
                cse_var_11: T.int32 = cse_var_12 + 12544
                res_conv_3[cse_var_12] = T.max(res_conv_2[cse_var_12], 0)
                res_conv_3[cse_var_11] = T.max(res_conv_2[cse_var_11], 0)
            res_conv_4 = T.Buffer((157351936,), "int32", data=res_conv)
            for i1, i2, i3, i5 in T.grid(8, 7, 14, 16):
                cse_var_14: T.int32 = i1 * 1568 + i2 * 224 + i3 * 16 + i5
                cse_var_13: T.int32 = cse_var_14 + 12544
                res_conv_4[cse_var_14] = T.min(res_conv_3[cse_var_14], 127)
                res_conv_4[cse_var_13] = T.min(res_conv_3[cse_var_13], 127)
            for i1_inner, i2_inner, i3_inner, i5 in T.grid(8, 7, 14, 16):
                cse_var_18: T.int32 = i2_inner * 224
                cse_var_17: T.int32 = i3_inner * 16
                cse_var_16: T.int32 = i1_inner * 1568 + cse_var_18 + cse_var_17 + i5
                cse_var_15: T.int32 = i1_inner * 3136 + i2_outer * 1568 + cse_var_18 + cse_var_17 + i5
                res_1 = T.Buffer((50176,), "int8", data=res.data)
                res_1[cse_var_15] = T.Cast("int8", res_conv_4[cse_var_16])
                res_1[cse_var_15 + 25088] = T.Cast("int8", res_conv_4[cse_var_16 + 12544])

Lowering Copies 到 DMA Transfers#

接下来,设置相应的 on-chip VTA SRAM buffers 的 buffers 作用域。将 load 循环移动到 2D 卷积计算循环,以 stage 内存加载,以便它们适合 on-chip SRAM buffers。最后,用 DMA 复制 pragma 注解了 load/store 循环外轴,以便在 VTA 上执行大容量内存传输。

# Set scope of SRAM buffers
s[data_buf].set_scope(env.inp_scope)
s[kernel_buf].set_scope(env.wgt_scope)
s[res_conv].set_scope(env.acc_scope)
s[res_shr].set_scope(env.acc_scope)
s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(env.acc_scope)

# Block data and kernel cache reads
s[data_buf].compute_at(s[res_conv], ic_out)
s[kernel_buf].compute_at(s[res_conv], ic_out)

# Use DMA copy pragma on DRAM->SRAM operations
s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy)
s[kernel_buf].pragma(s[kernel_buf].op.axis[0], env.dma_copy)

# 在每个结果块中对 SRAM->DRAM 操作使用 DMA copy pragma(这意味着这些 copy 应该沿着 b_inn 或结果轴 4 执行)
s[res].pragma(s[res].op.axis[4], env.dma_copy)

Lowering 计算到 VTA 计算 Intrinsics#

最后阶段是通过将二维卷积映射为张量 intrinsics,并将位移和剪切计算映射为向量 ALU,从而将计算循环 lower 到 VTA 硬件 intrinsics。

# Apply tensorization over the batch tensor tile axis
s[res_conv].tensorize(b_tns, env.gemm)

# Add an ALU pragma over the shift and clipping operations
s[res_shr].pragma(s[res_shr].op.axis[0], env.alu)
s[res_min].pragma(s[res_min].op.axis[0], env.alu)
s[res_max].pragma(s[res_max].op.axis[0], env.alu)

看看在将内存 loads/stores lower 到 DMA copy intrinsic,并将计算 lowering 到 VTA 计算 intrinsic 之后,最终 lowering 的 TVM 调度。

vta.lower(s, [data, kernel, res], simple_mode=True).show()
[22:39:39] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.uop_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.command_handle
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_push
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_dep_pop
[22:39:39] /media/pc/data/lxw/ai/tvm/src/script/printer/tir/expr.cc:246: Warning: No TScriptPrinterName attribute for tir.vta.coproc_sync
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(data: T.Buffer((1, 16, 14, 14, 1, 16), "int8"), kernel: T.Buffer((16, 16, 3, 3, 16, 16), "int8"), res: T.Buffer((1, 16, 14, 14, 1, 16), "int8")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        T.tir.vta.coproc_dep_push(3, 2)
        T.tir.vta.coproc_dep_push(3, 2)
        for i2_outer in range(2):
            for cthread_s in range(2):
                vta = T.int32()
                T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 2)
                T.tir.vta.coproc_dep_pop(3, 2)
                with T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_uop_scope", "VTAPushGEMMOp"):
                    T.call_extern("int32", "VTAUopLoopBegin", 8, 98, 0, 0)
                    T.call_extern("int32", "VTAUopLoopBegin", 7, 14, 0, 0)
                    for j_init in range(14):
                        T.tir.vta.uop_push(0, 1, cthread_s * 784 + j_init, 0, 0, 0, 0, 0)
                    T.call_extern("int32", "VTAUopLoopEnd")
                    T.call_extern("int32", "VTAUopLoopEnd")
                T.tir.vta.coproc_dep_push(2, 1)
            vta = T.int32()
            for ic_outer in range(16):
                cse_var_6: T.int32 = i2_outer * 7
                cse_var_5: T.int32 = ic_outer * 9
                cse_var_4: T.int32 = T.max(1 - cse_var_6, 0)
                cse_var_3: T.int32 = T.max(cse_var_6 - 6, 0)
                cse_var_2: T.int32 = 9 - cse_var_4 - cse_var_3
                cse_var_1: T.int32 = ic_outer * 196 + i2_outer * 98 + cse_var_4 * 14 - 14
                with T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 1):
                    T.tir.vta.coproc_dep_pop(2, 1)
                    T.call_extern("int32", "VTALoadBuffer2D", T.tvm_thread_context(T.tir.vta.command_handle()), data.data, cse_var_1, 14, cse_var_2, 14, 1, cse_var_4, 1, cse_var_3, 0, 2)
                    T.call_extern("int32", "VTALoadBuffer2D", T.tvm_thread_context(T.tir.vta.command_handle()), kernel.data, cse_var_5, 9, 8, 144, 0, 0, 0, 0, 0, 1)
                    T.tir.vta.coproc_dep_push(1, 2)
                with T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 1):
                    T.tir.vta.coproc_dep_pop(2, 1)
                    T.call_extern("int32", "VTALoadBuffer2D", T.tvm_thread_context(T.tir.vta.command_handle()), data.data, cse_var_1, 14, cse_var_2, 14, 1, cse_var_4, 1, cse_var_3, 144, 2)
                    T.call_extern("int32", "VTALoadBuffer2D", T.tvm_thread_context(T.tir.vta.command_handle()), kernel.data, cse_var_5 + 1152, 9, 8, 144, 0, 0, 0, 0, 72, 1)
                    T.tir.vta.coproc_dep_push(1, 2)
                for cthread_s in range(2):
                    T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 2)
                    T.tir.vta.coproc_dep_pop(1, 2)
                    with T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_uop_scope", "VTAPushGEMMOp"):
                        T.call_extern("int32", "VTAUopLoopBegin", 8, 98, 0, 9)
                        T.call_extern("int32", "VTAUopLoopBegin", 7, 14, 16, 0)
                        for dy, dx, j in T.grid(3, 3, 14):
                            T.tir.vta.uop_push(0, 0, cthread_s * 784 + j, cthread_s * 144 + dy * 16 + j + dx, cthread_s * 72 + dy * 3 + dx, 0, 0, 0)
                        T.call_extern("int32", "VTAUopLoopEnd")
                        T.call_extern("int32", "VTAUopLoopEnd")
                    T.tir.vta.coproc_dep_push(2, 1)
            T.tir.vta.coproc_dep_pop(2, 1)
            T.tir.vta.coproc_dep_pop(2, 1)
            for cthread_s in range(2):
                cse_var_7: T.int32 = cthread_s * 784
                T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 2)
                with T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_uop_scope", "VTAPushALUOp"):
                    T.call_extern("int32", "VTAUopLoopBegin", 784, 1, 1, 0)
                    T.tir.vta.uop_push(1, 0, cse_var_7, cse_var_7, 0, 3, 1, 8)
                    T.call_extern("int32", "VTAUopLoopEnd")
                with T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_uop_scope", "VTAPushALUOp"):
                    T.call_extern("int32", "VTAUopLoopBegin", 784, 1, 1, 0)
                    T.tir.vta.uop_push(1, 0, cse_var_7, cse_var_7, 0, 1, 1, 0)
                    T.call_extern("int32", "VTAUopLoopEnd")
                with T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_uop_scope", "VTAPushALUOp"):
                    T.call_extern("int32", "VTAUopLoopBegin", 784, 1, 1, 0)
                    T.tir.vta.uop_push(1, 0, cse_var_7, cse_var_7, 0, 0, 1, 127)
                    T.call_extern("int32", "VTAUopLoopEnd")
                T.tir.vta.coproc_dep_push(2, 3)
            for cthread_s in range(2):
                T.attr(T.iter_var(vta, None, "ThreadIndex", "vta"), "coproc_scope", 3)
                T.tir.vta.coproc_dep_pop(2, 3)
                for i1_inner, i2_inner, i3_inner in T.grid(8, 7, 14):
                    cse_var_8: T.int32 = i2_inner * 14
                    T.call_extern("int32", "VTAStoreBuffer2D", T.tvm_thread_context(T.tir.vta.command_handle()), cthread_s * 784 + i1_inner * 98 + cse_var_8 + i3_inner, 4, res.data, cthread_s * 1568 + i1_inner * 196 + i2_outer * 98 + cse_var_8 + i3_inner, 1, 1, 1)
                T.tir.vta.coproc_dep_push(3, 2)
        T.tir.vta.coproc_dep_pop(3, 2)
        T.tir.vta.coproc_dep_pop(3, 2)
        T.tir.vta.coproc_sync()

TVM 计算和验证#

在指定调度之后,可以将其编译为 TVM 函数。保存模块,这样就可以通过 RPC 发送它。运行该函数并对 numpy 实现进行验证,以确保其正确性。

# This library facilitates 2D convolution testing
from tvm.topi.testing import conv2d_nchw_python

# Compile the TVM module
with vta.build_config(disabled_pass={"tir.CommonSubexprElimTIR"}):
    my_conv = vta.build(
        s, [data, kernel, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_conv"
    )
temp = utils.tempdir()
my_conv.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o"))
f = remote.load_module("conv2d.o")

# Get the remote device context
ctx = remote.ext_dev(0)

# Initialize the data and kernel arrays randomly in the int range
# of (-128, 128] in NCHW layout
data_np = np.random.randint(-128, 128, size=(batch_size, in_channels, height, width)).astype(
    data.dtype
)
kernel_np = np.random.randint(
    -128, 128, size=(out_channels, in_channels, kernel_h, kernel_w)
).astype(kernel.dtype)

# Apply packing to the data and kernel arrays from a 2D NCHW
# to a 4D NCHWnc packed layout
data_packed = data_np.reshape(
    batch_size // env.BATCH, env.BATCH, in_channels // env.BLOCK_IN, env.BLOCK_IN, height, width
).transpose((0, 2, 4, 5, 1, 3))

kernel_packed = kernel_np.reshape(
    out_channels // env.BLOCK_OUT,
    env.BLOCK_OUT,
    in_channels // env.BLOCK_IN,
    env.BLOCK_IN,
    kernel_h,
    kernel_w,
).transpose((0, 2, 4, 5, 1, 3))

# Format the input/output arrays with tvm.nd.array to the DLPack standard
data_nd = tvm.nd.array(data_packed, ctx)
kernel_nd = tvm.nd.array(kernel_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)

# Clear stats
if env.TARGET in ["sim", "tsim"]:
    simulator.clear_stats()

# Invoke the module to perform the computation
f(data_nd, kernel_nd, res_nd)

# Verify against numpy implementation
res_ref = conv2d_nchw_python(
    data_np.astype(env.acc_dtype),
    kernel_np.astype(env.acc_dtype),
    (stride_h, stride_w),
    (pad_h, pad_w),
).astype(env.acc_dtype)
res_ref = res_ref >> env.INP_WIDTH
res_ref = np.clip(res_ref, 0, inp_max)
res_ref = res_ref.astype(res.dtype)
res_ref = res_ref.reshape(
    (
        batch_size // env.BATCH,
        env.BATCH,
        out_channels // env.BLOCK_OUT,
        env.BLOCK_OUT,
        fout_height,
        fout_width,
    )
).transpose((0, 2, 4, 5, 1, 3))
tvm.testing.assert_allclose(res_ref, res_nd.numpy())

# Print stats
if env.TARGET in ["sim", "tsim"]:
    sim_stats = simulator.stats()
    print("Execution statistics:")
    for k, v in sim_stats.items():
        print("\t{:<16}: {:>16}".format(k, v))

print("Successful 2D convolution test!")
Execution statistics:
	inp_load_nbytes :           114688
	wgt_load_nbytes :          1179648
	acc_load_nbytes :                0
	uop_load_nbytes :             1144
	out_store_nbytes:            50176
	gemm_counter    :           451584
	alu_counter     :             9408
Successful 2D convolution test!
[22:39:42] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-12-20 22:39:42.499 INFO load_module /tmp/tmpuq10oquo/conv2d.o