FastGlobalAvgPool

FastGlobalAvgPool#

加载所需包:

import testing
import numpy as np
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import (
    wildcard, is_op, 
    is_constant, 
    is_tuple,
    # is_tuple_get_item,
    DFPatternCallback,
    rewrite
)
from tvm.relay import transform as _transform

构建模型:

import torch
from torch import nn
from torch.nn import functional as F

class FastGlobalAvgPool(nn.Module):
    def __init__(self, flatten=False):
        super().__init__()
        self.flatten = flatten

    def forward(self, x):
        if self.flatten:
            in_size = x.size()
            return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
        else:
            return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)

模型推理:

xx = torch.randn(1, 64, 8, 9)
m1 = FastGlobalAvgPool()
y1 = m1(xx)
y1.squeeze()
tensor([ 0.0662, -0.0284, -0.0418, -0.0225,  0.0498,  0.1133, -0.2982,  0.1316,
         0.0266,  0.0706, -0.2277,  0.0217,  0.1399, -0.1460, -0.0213,  0.0273,
        -0.1333,  0.0970, -0.0018,  0.0033, -0.1182,  0.1267, -0.2266,  0.1154,
        -0.2021,  0.0986, -0.0339,  0.0458, -0.0835, -0.2597, -0.1705, -0.0668,
         0.0603,  0.0841, -0.0064, -0.0208,  0.0534,  0.1446,  0.0612,  0.0570,
         0.0104, -0.1448, -0.0248,  0.0502, -0.0378, -0.0554, -0.0469,  0.1171,
        -0.0959, -0.0214, -0.1768,  0.0065, -0.0934, -0.0269, -0.1720,  0.1292,
        -0.0320,  0.0377,  0.0316, -0.2030,  0.0909, -0.2396, -0.0421, -0.0082])

等价于:

m2 = nn.AdaptiveAvgPool2d(1)
y2 = m2(xx)
y2.squeeze()
tensor([ 0.0662, -0.0284, -0.0418, -0.0225,  0.0498,  0.1133, -0.2982,  0.1316,
         0.0266,  0.0706, -0.2277,  0.0217,  0.1399, -0.1460, -0.0213,  0.0273,
        -0.1333,  0.0970, -0.0018,  0.0033, -0.1182,  0.1267, -0.2266,  0.1154,
        -0.2021,  0.0986, -0.0339,  0.0458, -0.0835, -0.2597, -0.1705, -0.0668,
         0.0603,  0.0841, -0.0064, -0.0208,  0.0534,  0.1446,  0.0612,  0.0570,
         0.0104, -0.1448, -0.0248,  0.0502, -0.0378, -0.0554, -0.0469,  0.1171,
        -0.0959, -0.0214, -0.1768,  0.0065, -0.0934, -0.0269, -0.1720,  0.1292,
        -0.0320,  0.0377,  0.0316, -0.2030,  0.0909, -0.2396, -0.0421, -0.0082])

数值一致性检验:

np.testing.assert_allclose(y1.numpy(), y2.numpy())

转换为 Relay 模型:

name = "x"
shape = (1, 1024, 8, 6)
data_np = (np.random.randint(0, 256, shape)/255).astype("float32")
data_torch = torch.from_numpy(data_np)

model = FastGlobalAvgPool().eval()
scripted_model = torch.jit.trace(model, data_torch).eval()
shape_list = [(name, shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
_transform.InferType()(tvm.IRModule.from_expr(mod["main"])).show()
def @main(%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) -> Tensor[(1, 1024, 1, 1), float32] {
  %0 = reshape(%x, newshape=[1, 1024, -1]) /* ty=Tensor[(1, 1024, 48), float32] span=aten::view_0:0:0 */;
  %1 = mean(%0, axis=[-1]) /* ty=Tensor[(1, 1024), float32] span=aten::mean_0:0:0 */;
  reshape(%1, newshape=[1, 1024, 1, 1]) /* ty=Tensor[(1, 1024, 1, 1), float32] span=aten::view_1:0:0 */
}
from tvm_book.transforms.simplify import FastGlobalAvgPoolSimplify

FastGlobalAvgPoolSimplify?
Init signature: FastGlobalAvgPoolSimplify()
Docstring:     
简化 reshape+mean+reshape 为 nn.adaptive_avg_pool2d(%x, output_size=[1]) 

简化 
    def @main(%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) -> Tensor[(1, 1024, 1, 1), float32] {
        %0 = reshape(%x, newshape=[1, 1024, -1]) /* ty=Tensor[(1, 1024, 48), float32] span=aten::view_0:0:0 */;
        %1 = mean(%0, axis=[-1]) /* ty=Tensor[(1, 1024), float32] span=aten::mean_0:0:0 */;
        reshape(%1, newshape=[1, 1024, 1, 1]) /* ty=Tensor[(1, 1024, 1, 1), float32] span=aten::view_1:0:0 */
        }
为
    def @main(%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) {
        nn.adaptive_avg_pool2d(%x, output_size=[1]) /* ty=Tensor[(1, 1024, 1, 1), float32] */
        }
File:           /media/pc/data/lxw/ai/tvm-book/src/tvm_book/transforms/simplify.py
Type:           type
Subclasses:
run_mod = tvm.IRModule()
run_mod["main"] = rewrite(FastGlobalAvgPoolSimplify(), mod["main"])
run_mod.show()
def @main(%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) {
  nn.adaptive_avg_pool2d(%x, output_size=[1, 1]) /* ty=Tensor[(1, 1024, 1, 1), float32] */
}

验证数值一致性:

with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
    lib = relay.build(mod, target="llvm", params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(tvm.cpu(0)))
module.run(**{name: data_np})
num_outputs = module.get_num_outputs()
origin_outputs = [module.get_output(k).numpy() for k in range(num_outputs)]
with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
    lib = relay.build(run_mod, target="llvm", params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(tvm.cpu(0)))
module.run(**{name: data_np})
num_outputs = module.get_num_outputs()
outputs = [module.get_output(k).numpy() for k in range(num_outputs)]
np.testing.assert_allclose(origin_outputs[0], outputs[0])