填充运算#
%cd ../..
import set_env
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm import te
简单示例#
将使用的例子是用 \(0\) 填充矩阵 \(A\)。
a_np = np.arange(1, 13, dtype='float32').reshape((3, 4))
b_np = np.zeros((5, 6), dtype='float32')
b_np[1:-1,1:-1] = a_np
print(b_np)
[[ 0. 0. 0. 0. 0. 0.]
[ 0. 1. 2. 3. 4. 0.]
[ 0. 5. 6. 7. 8. 0.]
[ 0. 9. 10. 11. 12. 0.]
[ 0. 0. 0. 0. 0. 0.]]
p = 1 # padding size
n, m = te.var('n'), te.var('m')
A = te.placeholder((m, n), name='A')
B = te.compute((m+p*2, n+p*2),
lambda i, j: te.if_then_else(te.any(i<p, i>=m+p, j<p, j>=n+p),
0, A[i-p, j-p]),
name='B')
te_func = te.create_prim_func([A, B])
te_func.show()
mod = tvm.build(te_func, target="llvm")
# from tvm.script import tir as T
@T.prim_func
def func(var_A: T.handle, var_B: T.handle):
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
A = T.match_buffer(var_A, [m, n], dtype="float32")
B = T.match_buffer(var_B, [m + 2, n + 2], dtype="float32")
# body
# with T.block("root")
for i0, i1 in T.grid(m + 2, n + 2):
with T.block("B"):
i, j = T.axis.remap("SS", [i0, i1])
T.reads(A[i - 1, j - 1])
T.writes(B[i, j])
B[i, j] = T.if_then_else(i < 1 or m + 1 <= i or j < 1 or n + 1 <= j, T.float32(0), A[i - 1, j - 1], dtype="float32")
验证结果:
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.empty(b_np.shape)
mod(a_nd, b_nd)
b_nd
<tvm.nd.NDArray shape=(5, 6), cpu(0)>
array([[ 0., 0., 0., 0., 0., 0.],
[ 0., 1., 2., 3., 4., 0.],
[ 0., 5., 6., 7., 8., 0.],
[ 0., 9., 10., 11., 12., 0.],
[ 0., 0., 0., 0., 0., 0.]], dtype=float32)
通用 2D 填充#
val = 0
dtype = "float32"
ph, pw = te.var("hpad"), te.var("wpad")
batch_size = te.var("batch_size")
kernel_size = te.var("kernel_size")
height = te.var("height")
width = te.var("width")
shape = batch_size, kernel_size, height, width
pad_shape = batch_size, kernel_size, height+2*ph, width+2*pw
data = te.placeholder(shape, dtype=dtype)
pad_data = te.compute(
pad_shape,
lambda *i: te.if_then_else(
te.any(i[-2]<ph, i[-2]>=height+ph, i[-1]<pw, i[-1]>=width+pw),
val, data[i[:-2]+(i[-2]-ph, i[-1]-pw)]),
name='pad_data')
te_func = te.create_prim_func([data, pad_data])
te_func.show()
sch = te.create_schedule(pad_data.op)
mod = tvm.build(sch, [data, pad_data, batch_size, kernel_size, height, width, ph, pw], target="llvm")
# from tvm.script import tir as T
@T.prim_func
def func(var_placeholder: T.handle, var_pad_data: T.handle):
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
batch_size = T.var("int32")
height = T.var("int32")
hpad = T.var("int32")
kernel_size = T.var("int32")
width = T.var("int32")
wpad = T.var("int32")
placeholder = T.match_buffer(var_placeholder, [batch_size, kernel_size, height, width], dtype="float32")
pad_data = T.match_buffer(var_pad_data, [batch_size, kernel_size, height + 2 * hpad, width + 2 * wpad], dtype="float32")
# body
# with T.block("root")
for i0, i1, i2, i3 in T.grid(batch_size, kernel_size, hpad * 2 + height, wpad * 2 + width):
with T.block("pad_data"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(placeholder[i0_1, i1_1, i2_1 - hpad, i3_1 - wpad])
T.writes(pad_data[i0_1, i1_1, i2_1, i3_1])
pad_data[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 < hpad or height + hpad <= i2_1 or i3_1 < wpad or width + wpad <= i3_1, T.float32(0), placeholder[i0_1, i1_1, i2_1 - hpad, i3_1 - wpad], dtype="float32")
def pad2d(X, ph, pw, val=0, name="pad_data"):
"""Pad X with the given value in 2-D
ph, pw : height and width padding
val : padding value, default 0
"""
assert len(X.shape) >= 2
nh, nw = X.shape[-2], X.shape[-1]
return te.compute(
(*X.shape[0:-2], nh+ph*2, nw+pw*2),
lambda *i: te.if_then_else(
te.any(i[-2]<ph, i[-2]>=nh+ph, i[-1]<pw, i[-1]>=nw+pw),
val, X[i[:-2]+(i[-2]-ph, i[-1]-pw)]),
name=name)
A = te.placeholder((2, 3, 4), name="data")
B = pad2d(A, 1, 2)
te_func = te.create_prim_func([A, B])
te_func.show()
mod = tvm.build(te_func, target="llvm")
a = tvm.nd.array(np.ones((2, 3, 4), dtype='float32'))
b = tvm.nd.array(np.empty((2, 5, 8), dtype='float32'))
mod(a, b)
print(b)
# from tvm.script import tir as T
@T.prim_func
def func(data: T.Buffer[(2, 3, 4), "float32"], pad_data: T.Buffer[(2, 5, 8), "float32"]):
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1, i2 in T.grid(2, 5, 8):
with T.block("pad_data"):
i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(data[i0_1, i1_1 - 1, i2_1 - 2])
T.writes(pad_data[i0_1, i1_1, i2_1])
pad_data[i0_1, i1_1, i2_1] = T.if_then_else(i1_1 < 1 or 4 <= i1_1 or i2_1 < 2 or 6 <= i2_1, T.float32(0), data[i0_1, i1_1 - 1, i2_1 - 2], dtype="float32")
[[[0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 1. 1. 1. 0. 0.]
[0. 0. 1. 1. 1. 1. 0. 0.]
[0. 0. 1. 1. 1. 1. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 1. 1. 1. 0. 0.]
[0. 0. 1. 1. 1. 1. 0. 0.]
[0. 0. 1. 1. 1. 1. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0.]]]