# 填充运算

In [None]:
%cd ../..
import set_env

In [2]:
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm import te

## 简单示例

将使用的例子是用 $0$ 填充矩阵 $A$。

In [3]:
a_np = np.arange(1, 13, dtype='float32').reshape((3, 4))
b_np = np.zeros((5, 6), dtype='float32')
b_np[1:-1,1:-1] = a_np
print(b_np)

[[ 0.  0.  0.  0.  0.  0.]
 [ 0.  1.  2.  3.  4.  0.]
 [ 0.  5.  6.  7.  8.  0.]
 [ 0.  9. 10. 11. 12.  0.]
 [ 0.  0.  0.  0.  0.  0.]]


In [4]:
p = 1 # padding size
n, m = te.var('n'), te.var('m')
A = te.placeholder((m, n), name='A')
B = te.compute((m+p*2, n+p*2),
                lambda i, j: te.if_then_else(te.any(i<p, i>=m+p, j<p, j>=n+p), 
                                             0, A[i-p, j-p]),
                name='B')
te_func = te.create_prim_func([A, B])
te_func.show()
mod = tvm.build(te_func, target="llvm")

验证结果：

In [5]:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.empty(b_np.shape)
mod(a_nd, b_nd)
b_nd

<tvm.nd.NDArray shape=(5, 6), cpu(0)>
array([[ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  2.,  3.,  4.,  0.],
       [ 0.,  5.,  6.,  7.,  8.,  0.],
       [ 0.,  9., 10., 11., 12.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)

## 通用 2D 填充

In [16]:
val = 0
dtype = "float32"
ph, pw = te.var("hpad"), te.var("wpad")
batch_size = te.var("batch_size")
kernel_size = te.var("kernel_size")
height = te.var("height")
width = te.var("width")
shape = batch_size, kernel_size, height, width
pad_shape = batch_size, kernel_size, height+2*ph, width+2*pw
data = te.placeholder(shape, dtype=dtype)
pad_data = te.compute(
            pad_shape,
            lambda *i: te.if_then_else(
                te.any(i[-2]<ph, i[-2]>=height+ph, i[-1]<pw, i[-1]>=width+pw),
                val, data[i[:-2]+(i[-2]-ph, i[-1]-pw)]),
            name='pad_data')
te_func = te.create_prim_func([data, pad_data])
te_func.show()
sch = te.create_schedule(pad_data.op)
mod = tvm.build(sch, [data, pad_data, batch_size, kernel_size, height, width, ph, pw], target="llvm")

In [34]:
def pad2d(X, ph, pw, val=0, name="pad_data"):
    """Pad X with the given value in 2-D

    ph, pw : height and width padding
    val : padding value, default 0
    """
    assert len(X.shape) >= 2
    nh, nw = X.shape[-2], X.shape[-1]
    return te.compute(
            (*X.shape[0:-2], nh+ph*2, nw+pw*2),
            lambda *i: te.if_then_else(
                te.any(i[-2]<ph, i[-2]>=nh+ph, i[-1]<pw, i[-1]>=nw+pw),
                val, X[i[:-2]+(i[-2]-ph, i[-1]-pw)]),
            name=name)

In [35]:
A = te.placeholder((2, 3, 4), name="data")
B = pad2d(A, 1, 2)
te_func = te.create_prim_func([A, B])
te_func.show()
mod = tvm.build(te_func, target="llvm")
a = tvm.nd.array(np.ones((2, 3, 4), dtype='float32'))
b = tvm.nd.array(np.empty((2, 5, 8), dtype='float32'))
mod(a, b)
print(b)

[[[0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0.]]]
