2D 卷积

2D 卷积#

约定:

  1. 尽量使用 NumPy 低级接口与 TIR 进行对比。

  2. NumPy 高级接口版本计算结果作为基准。

import numpy as np
import tvm
from tvm.script import tir as T

使用 NCHW 布局的卷积的数学定义:

\[ \text{Conv}[b, k, i, j] = \sum_{d_i, d_j, q} A[b, q, \text{strides} * i + d_i, \text{strides} * j + d_j] * W[k, q, d_i, d_j], \]

其中,\(A\) 是输入张量,\(W\) 是权重张量,\(b\) 是批次索引,\(k\) 是输出通道,\(i\)\(j\) 是图像高度和宽度的索引,\(d_i\)\(d_j\) 是权重的索引,\(q\) 是输入通道,strides 是过滤器窗口的步幅。

下面考虑简单的情况:stride=1, padding=0

N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)

torch 版本:

import torch

data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch
array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

TVM 版本:

T.int64(A)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[14], line 1
----> 1 T.int64(A)

NameError: name 'A' is not defined
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(A: T.Buffer((1, 1, 8, 8), "int64"), # 1,1,8,8
          B: T.Buffer((2, 1, 3, 3), "int64"), # 2,1,3,3
          C: T.Buffer((1, 2, 6, 6), "int64")): # 1,2,6,6
    T.func_attr({"global_symbol": "conv", "tir.noalias": True})
    for n, c, h, w, i, k1, k2 in T.grid(N, CO, OUT_H, OUT_W, CI, K, K):
      with T.block("C"):
        vn = T.axis.spatial(1, n)
        vc = T.axis.spatial(2, c)
        vh = T.axis.spatial(6, h)
        vw = T.axis.spatial(6, w)
        vi = T.axis.spatial(1, i)
        vk1 = T.axis.reduce(3, k1)
        vk2 = T.axis.reduce(3, k2)
        with T.init():
          C[vn, vc, vh, vw] = 0
        C[vn, vc, vh, vw] = C[vn, vc, vh, vw] + A[vn, vi, vh + vk1, vw + vk2] * B[vc, vi, vk1, vk2]
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)

使用 TOPI 构建卷积#

c = 64
inp = tvm.te.placeholder((1, c, 8, 8), name="data", dtype="float32")
kernel = tvm.te.placeholder((c, c, 3, 3), name="kernel", dtype="float32")
conv = tvm.topi.nn.conv(inp, kernel, stride=1, padding=0, groups=c//16, dilation=1, data_layout="NCHW")
mod = tvm.te.create_prim_func([inp, kernel, conv])
mod.show()
# from tvm.script import tir as T


@T.prim_func
def main(
    data: T.Buffer((1, 64, 8, 8), "float32"),
    kernel: T.Buffer((64, 64, 3, 3), "float32"),
    group_conv2d_nchw: T.Buffer((1, 64, 6, 6), "float32"),
):
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    # with T.block("root"):
    pad_temp = T.alloc_buffer((1, 64, 8, 8))
    for i0, i1, i2, i3 in T.grid(1, 64, 8, 8):
        with T.block("pad_temp"):
            v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.reads(data[v_i0, v_i1, v_i2, v_i3])
            T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
            pad_temp[v_i0, v_i1, v_i2, v_i3] = data[v_i0, v_i1, v_i2, v_i3]
    for nn, ff, yy, xx, rc, ry, rx in T.grid(1, 64, 6, 6, 16, 3, 3):
        with T.block("group_conv2d_nchw"):
            v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap(
                "SSSSRRR", [nn, ff, yy, xx, rc, ry, rx]
            )
            T.reads(
                pad_temp[v_nn, v_ff // 16 * 16 + v_rc, v_yy + v_ry, v_xx + v_rx],
                kernel[v_ff, v_rc, v_ry, v_rx],
            )
            T.writes(group_conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
            with T.init():
                group_conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
            group_conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = (
                group_conv2d_nchw[v_nn, v_ff, v_yy, v_xx]
                + pad_temp[v_nn, v_ff // 16 * 16 + v_rc, v_yy + v_ry, v_xx + v_rx]
                * kernel[v_ff, v_rc, v_ry, v_rx]
            )