FastGlobalAvgPool#
加载所需包:
import testing
import numpy as np
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import (
wildcard, is_op,
is_constant,
is_tuple,
# is_tuple_get_item,
DFPatternCallback,
rewrite
)
from tvm.relay import transform as _transform
构建模型:
import torch
from torch import nn
from torch.nn import functional as F
class FastGlobalAvgPool(nn.Module):
def __init__(self, flatten=False):
super().__init__()
self.flatten = flatten
def forward(self, x):
if self.flatten:
in_size = x.size()
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
else:
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
模型推理:
xx = torch.randn(1, 64, 8, 9)
m1 = FastGlobalAvgPool()
y1 = m1(xx)
y1.squeeze()
tensor([ 0.0662, -0.0284, -0.0418, -0.0225, 0.0498, 0.1133, -0.2982, 0.1316,
0.0266, 0.0706, -0.2277, 0.0217, 0.1399, -0.1460, -0.0213, 0.0273,
-0.1333, 0.0970, -0.0018, 0.0033, -0.1182, 0.1267, -0.2266, 0.1154,
-0.2021, 0.0986, -0.0339, 0.0458, -0.0835, -0.2597, -0.1705, -0.0668,
0.0603, 0.0841, -0.0064, -0.0208, 0.0534, 0.1446, 0.0612, 0.0570,
0.0104, -0.1448, -0.0248, 0.0502, -0.0378, -0.0554, -0.0469, 0.1171,
-0.0959, -0.0214, -0.1768, 0.0065, -0.0934, -0.0269, -0.1720, 0.1292,
-0.0320, 0.0377, 0.0316, -0.2030, 0.0909, -0.2396, -0.0421, -0.0082])
等价于:
m2 = nn.AdaptiveAvgPool2d(1)
y2 = m2(xx)
y2.squeeze()
tensor([ 0.0662, -0.0284, -0.0418, -0.0225, 0.0498, 0.1133, -0.2982, 0.1316,
0.0266, 0.0706, -0.2277, 0.0217, 0.1399, -0.1460, -0.0213, 0.0273,
-0.1333, 0.0970, -0.0018, 0.0033, -0.1182, 0.1267, -0.2266, 0.1154,
-0.2021, 0.0986, -0.0339, 0.0458, -0.0835, -0.2597, -0.1705, -0.0668,
0.0603, 0.0841, -0.0064, -0.0208, 0.0534, 0.1446, 0.0612, 0.0570,
0.0104, -0.1448, -0.0248, 0.0502, -0.0378, -0.0554, -0.0469, 0.1171,
-0.0959, -0.0214, -0.1768, 0.0065, -0.0934, -0.0269, -0.1720, 0.1292,
-0.0320, 0.0377, 0.0316, -0.2030, 0.0909, -0.2396, -0.0421, -0.0082])
数值一致性检验:
np.testing.assert_allclose(y1.numpy(), y2.numpy())
转换为 Relay 模型:
name = "x"
shape = (1, 1024, 8, 6)
data_np = (np.random.randint(0, 256, shape)/255).astype("float32")
data_torch = torch.from_numpy(data_np)
model = FastGlobalAvgPool().eval()
scripted_model = torch.jit.trace(model, data_torch).eval()
shape_list = [(name, shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
_transform.InferType()(tvm.IRModule.from_expr(mod["main"])).show()
def @main(%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) -> Tensor[(1, 1024, 1, 1), float32] {
%0 = reshape(%x, newshape=[1, 1024, -1]) /* ty=Tensor[(1, 1024, 48), float32] span=aten::view_0:0:0 */;
%1 = mean(%0, axis=[-1]) /* ty=Tensor[(1, 1024), float32] span=aten::mean_0:0:0 */;
reshape(%1, newshape=[1, 1024, 1, 1]) /* ty=Tensor[(1, 1024, 1, 1), float32] span=aten::view_1:0:0 */
}
from tvm_book.transforms.simplify import FastGlobalAvgPoolSimplify
FastGlobalAvgPoolSimplify?
Init signature: FastGlobalAvgPoolSimplify()
Docstring:
简化 reshape+mean+reshape 为 nn.adaptive_avg_pool2d(%x, output_size=[1])
简化
def @main(%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) -> Tensor[(1, 1024, 1, 1), float32] {
%0 = reshape(%x, newshape=[1, 1024, -1]) /* ty=Tensor[(1, 1024, 48), float32] span=aten::view_0:0:0 */;
%1 = mean(%0, axis=[-1]) /* ty=Tensor[(1, 1024), float32] span=aten::mean_0:0:0 */;
reshape(%1, newshape=[1, 1024, 1, 1]) /* ty=Tensor[(1, 1024, 1, 1), float32] span=aten::view_1:0:0 */
}
为
def @main(%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) {
nn.adaptive_avg_pool2d(%x, output_size=[1]) /* ty=Tensor[(1, 1024, 1, 1), float32] */
}
File: /media/pc/data/lxw/ai/tvm-book/src/tvm_book/transforms/simplify.py
Type: type
Subclasses:
run_mod = tvm.IRModule()
run_mod["main"] = rewrite(FastGlobalAvgPoolSimplify(), mod["main"])
run_mod.show()
def @main(%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) {
nn.adaptive_avg_pool2d(%x, output_size=[1, 1]) /* ty=Tensor[(1, 1024, 1, 1), float32] */
}
验证数值一致性:
with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
lib = relay.build(mod, target="llvm", params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(tvm.cpu(0)))
module.run(**{name: data_np})
num_outputs = module.get_num_outputs()
origin_outputs = [module.get_output(k).numpy() for k in range(num_outputs)]
with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
lib = relay.build(run_mod, target="llvm", params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(tvm.cpu(0)))
module.run(**{name: data_np})
num_outputs = module.get_num_outputs()
outputs = [module.get_output(k).numpy() for k in range(num_outputs)]
np.testing.assert_allclose(origin_outputs[0], outputs[0])