自定义 VTA Graph Pack#
from copy import deepcopy
import tvm
from tvm import relay
from vta_utils.pack_tool import graph_pack, WithVTAFunctionTransform
VTA 模型样例#
from torch import nn
import torch
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.conv = nn.Conv2d(3, 36, 3, 1, 1, bias=True)
self.bn = nn.BatchNorm2d(36)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
pt_model = Model().eval().float()
ishape = (1, 3, 4, 4)
input_name = "data"
input_shapes = [(input_name, ishape)]
# script_module = torch.jit.script(pt_model)
# mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
idata = torch.randn(ishape).type(torch.float32)
traced_model = torch.jit.trace(pt_model, idata)
# traced_model 翻译为 TVM 前端模型
mod, params = relay.frontend.from_pytorch(traced_model, input_shapes)
# 量化
with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(skip_conv_layers=[], weight_scale="max",):
mod = relay.quantize.quantize(mod, params)
mod.show()
def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 36, 4, 4), float32] {
%0 = multiply(%data, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%1 = round(%0) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */;
%4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(36, 3, 3, 3), int8] */, padding=[1, 1, 1, 1], channels=36, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 36, 4, 4), int32] */;
%5 = add(%4, meta[relay.Constant][1] /* ty=Tensor[(36, 1, 1), int32] */) /* ty=Tensor[(1, 36, 4, 4), int32] */;
%6 = fixed_point_multiply(%5, multiplier=0, shift=0) /* ty=Tensor[(1, 36, 4, 4), int32] */;
%7 = cast(%6, dtype="int32") /* ty=Tensor[(1, 36, 4, 4), int32] */;
%8 = add(%7, meta[relay.Constant][2] /* ty=Tensor[(36, 1, 1), int32] */) /* ty=Tensor[(1, 36, 4, 4), int32] */;
%9 = nn.relu(%8) /* ty=Tensor[(1, 36, 4, 4), int32] */;
%10 = cast(%9, dtype="int64") /* ty=Tensor[(1, 36, 4, 4), int64] */;
%11 = fixed_point_multiply(%10, multiplier=0, shift=0) /* ty=Tensor[(1, 36, 4, 4), int64] */;
%12 = clip(%11, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 36, 4, 4), int64] */;
%13 = cast(%12, dtype="int32") /* ty=Tensor[(1, 36, 4, 4), int32] */;
%14 = cast(%13, dtype="int8") /* ty=Tensor[(1, 36, 4, 4), int8] */;
%15 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 36, 4, 4), int8] */;
%16 = cast(%15, dtype="float32") /* ty=Tensor[(1, 36, 4, 4), float32] */;
multiply(%16, 0.0625f /* ty=float32 */) /* ty=Tensor[(1, 36, 4, 4), float32] */
}
VTA Graph Pack#
from tvm.relay.function import Function
from tvm.relay.testing import run_opt_pass
import vta
env = vta.get_env()
bfactor = env.BATCH
cfactor = env.BLOCK_OUT
weight_bits = env.WGT_WIDTH
run_mod = deepcopy(mod)
new_fn = graph_pack(
run_mod["main"],
bfactor, cfactor, weight_bits
)
tvm.IRModule.from_expr(new_fn).show()
def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 3, 4, 4, 1, 16), float32] {
%0 = multiply(%data, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%1 = round(%0) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */;
%4 = nn.pad(%3, 0 /* ty=int32 */, pad_width=[[0, 0], [0, 13], [0, 0], [0, 0]]) /* ty=Tensor[(1, 16, 4, 4), int8] */;
%5 = reshape(%4, newshape=[1, 1, 1, 16, 4, 4]) /* ty=Tensor[(1, 1, 1, 16, 4, 4), int8] */;
%6 = nn.pad(meta[relay.Constant][0] /* ty=Tensor[(36, 3, 3, 3), int8] */, 0 /* ty=int32 */, pad_width=[[0, 12], [0, 13], [0, 0], [0, 0]]) /* ty=Tensor[(48, 16, 3, 3), int8] */;
%7 = reshape(%6, newshape=[3, 16, 1, 16, 3, 3]) /* ty=Tensor[(3, 16, 1, 16, 3, 3), int8] */;
%8 = transpose(%5, axes=[0, 2, 4, 5, 1, 3]) /* ty=Tensor[(1, 1, 4, 4, 1, 16), int8] */;
%9 = transpose(%7, axes=[0, 2, 4, 5, 1, 3]) /* ty=Tensor[(3, 1, 3, 3, 16, 16), int8] */;
%10 = nn.pad(meta[relay.Constant][1] /* ty=Tensor[(36, 1, 1), int32] */, 0 /* ty=int32 */, pad_width=[[0, 12], [0, 0], [0, 0]]) /* ty=Tensor[(48, 1, 1), int32] */;
%11 = reshape(%10, newshape=[3, 16, 1, 1, 1]) /* ty=Tensor[(3, 16, 1, 1, 1), int32] */;
%12 = transpose(%11, axes=[0, 2, 3, 4, 1]) /* ty=Tensor[(3, 1, 1, 1, 16), int32] */;
%13 = nn.conv2d(%8, %9, padding=[1, 1, 1, 1], channels=48, kernel_size=[3, 3], data_layout="NCHW1n16c", kernel_layout="OIHW16o16i", out_dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%14 = broadcast_to(%12, shape=[3, 1, 1, 1, 16]) /* ty=Tensor[(3, 1, 1, 1, 16), int32] */;
%15 = add(%13, %14) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%16 = fixed_point_multiply(%15, multiplier=0, shift=0) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%17 = nn.pad(meta[relay.Constant][2] /* ty=Tensor[(36, 1, 1), int32] */, 0 /* ty=int32 */, pad_width=[[0, 12], [0, 0], [0, 0]]) /* ty=Tensor[(48, 1, 1), int32] */;
%18 = reshape(%17, newshape=[3, 16, 1, 1, 1]) /* ty=Tensor[(3, 16, 1, 1, 1), int32] */;
%19 = transpose(%18, axes=[0, 2, 3, 4, 1]) /* ty=Tensor[(3, 1, 1, 1, 16), int32] */;
%20 = cast(%16, dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%21 = broadcast_to(%19, shape=[3, 1, 1, 1, 16]) /* ty=Tensor[(3, 1, 1, 1, 16), int32] */;
%22 = add(%20, %21) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%23 = nn.relu(%22) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%24 = cast(%23, dtype="int64") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
%25 = fixed_point_multiply(%24, multiplier=0, shift=0) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
%26 = clip(%25, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
%27 = cast(%26, dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%28 = cast(%27, dtype="int8") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */;
%29 = annotation.stop_fusion(%28) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */;
%30 = cast(%29, dtype="float32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */;
multiply(%30, 0.0625f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */
}
VTA 模型的算子融合#
创建融合策略:
from vta_utils.vta_pattern import (
preprocessing_pattern,
pad_reshape_transpose_pattern,
conv_add_activate_pattern,
output_pattern,
)
pattern_table = [
("vta_preprocessing", preprocessing_pattern()),
("vta_reshape_transpose", pad_reshape_transpose_pattern()),
("vta_conv2d", conv_add_activate_pattern()),
("vta_output", output_pattern()),
]
实现算子融合:
import vta
env = vta.get_env()
bfactor = env.BATCH
cfactor = env.BLOCK_OUT
weight_bits = env.WGT_WIDTH
prepare_transform = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(), # 折叠常量参数
relay.transform.MergeComposite(pattern_table), # 算子融合
WithVTAFunctionTransform(), # 为融合函数 vta_conv2d 添加 ConvAttrs 属性
relay.transform.InferType(),
])
run_mod = deepcopy(mod)
with tvm.transform.PassContext(opt_level=3):
new_fn = graph_pack(run_mod["main"], bfactor, cfactor, weight_bits)
run_mod = tvm.IRModule.from_expr(new_fn)
run_mod = prepare_transform(run_mod)
run_mod.show()
def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 3, 4, 4, 1, 16), float32] {
%0 = @vta_preprocessing__0(%data) /* ty=Tensor[(1, 3, 4, 4), int8] */;
%1 = @vta_reshape_transpose__1(%0, 0 /* ty=int32 */) /* ty=Tensor[(1, 1, 4, 4, 1, 16), int8] */;
%2 = @vta_conv2d__2(%1, meta[relay.Constant][0] /* ty=Tensor[(3, 1, 3, 3, 16, 16), int8] */, meta[relay.Constant][1] /* ty=Tensor[(3, 1, 1, 1, 16), int32] */, meta[relay.Constant][2] /* ty=Tensor[(3, 1, 1, 1, 16), int32] */) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */;
@vta_output__3(%2) /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */
}
def @vta_conv2d__2(%FunctionVar_0_0: Tensor[(1, 1, 4, 4, 1, 16), int8] /* ty=Tensor[(1, 1, 4, 4, 1, 16), int8] */, %FunctionVar_0_1: Tensor[(3, 1, 3, 3, 16, 16), int8] /* ty=Tensor[(3, 1, 3, 3, 16, 16), int8] */, %FunctionVar_0_2: Tensor[(3, 1, 1, 1, 16), int32] /* ty=Tensor[(3, 1, 1, 1, 16), int32] */, %FunctionVar_0_3: Tensor[(3, 1, 1, 1, 16), int32] /* ty=Tensor[(3, 1, 1, 1, 16), int32] */, PartitionedFromPattern="nn.conv2d_add_fixed_point_multiply_cast_add_nn.relu_cast_fixed_point_multiply_clip_cast_cast_annotation.stop_fusion_", Composite="vta_conv2d", ConvAttrs={padding=[1, 1, 1, 1], channels=48, kernel_size=[3, 3], data_layout="NCHW1n16c", kernel_layout="OIHW16o16i", out_dtype="int32"}) -> Tensor[(1, 3, 4, 4, 1, 16), int8] {
%3 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1], channels=48, kernel_size=[3, 3], data_layout="NCHW1n16c", kernel_layout="OIHW16o16i", out_dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%4 = add(%3, %FunctionVar_0_2) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%5 = fixed_point_multiply(%4, multiplier=0, shift=0) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%6 = cast(%5, dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%7 = add(%6, %FunctionVar_0_3) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%8 = nn.relu(%7) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%9 = cast(%8, dtype="int64") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
%10 = fixed_point_multiply(%9, multiplier=0, shift=0) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
%11 = clip(%10, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
%12 = cast(%11, dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
%13 = cast(%12, dtype="int8") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */;
annotation.stop_fusion(%13) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */
}
def @vta_output__3(%FunctionVar_0_01: Tensor[(1, 3, 4, 4, 1, 16), int8] /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */, PartitionedFromPattern="cast_multiply_", Composite="vta_output") -> Tensor[(1, 3, 4, 4, 1, 16), float32] {
%14 = cast(%FunctionVar_0_01, dtype="float32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */;
multiply(%14, 0.0625f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */
}
def @vta_preprocessing__0(%FunctionVar_0_02: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] */, PartitionedFromPattern="multiply_round_clip_cast_", Composite="vta_preprocessing") -> Tensor[(1, 3, 4, 4), int8] {
%15 = multiply(%FunctionVar_0_02, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%16 = round(%15) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%17 = clip(%16, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
cast(%17, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */
}
def @vta_reshape_transpose__1(%FunctionVar_0_03: Tensor[(1, 3, 4, 4), int8] /* ty=Tensor[(1, 3, 4, 4), int8] */, %FunctionVar_0_11: int32 /* ty=int32 */, PartitionedFromPattern="nn.pad_reshape_transpose_", Composite="vta_reshape_transpose") -> Tensor[(1, 1, 4, 4, 1, 16), int8] {
%18 = nn.pad(%FunctionVar_0_03, %FunctionVar_0_11, pad_width=[[0, 0], [0, 13], [0, 0], [0, 0]]) /* ty=Tensor[(1, 16, 4, 4), int8] */;
%19 = reshape(%18, newshape=[1, 1, 1, 16, 4, 4]) /* ty=Tensor[(1, 1, 1, 16, 4, 4), int8] */;
transpose(%19, axes=[0, 2, 4, 5, 1, 3]) /* ty=Tensor[(1, 1, 4, 4, 1, 16), int8] */
}