解构 TVM 量化#
import logging
import set_env
from d2py.utils.log_config import config_logging
from d2py.utils.file import mkdir
# 配置日志信息
temp_dir = ".temp"
logger_name = "parse"
mkdir(temp_dir)
config_logging(
f"{temp_dir}/{logger_name}.log", logger_name,
filemode="w", filter_mod_names={"te_compiler"}
)
logger = logging.getLogger(logger_name)
加载模块:
import numpy as np
import tvm
from tvm import relay
from tvm.relay import transform as _transform
from tvm.relay import expr as _expr
from tvm.relay import Call, Constant, Function
from tvm.ir.op import Op
from tvm.relay import op as _op
from tvm.relay import expr as _expr
from tvm_book.tvm_utils.llvm_utils import run_llvm_graph
定义简单网络:
def load_model(input_shape=[1, 3, 224, 224]):
"""加载前端模型"""
import torch
from torchvision.models import resnet18
from torchvision.models.resnet import ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT)
data = torch.randn(*input_shape)
return torch.jit.trace(model.eval(), data)
size = 224, 224
input_shape = (1, 3, *size)
input_name = "data"
traced_model = load_model(input_shape).eval()
# 将前端模型翻译为 relay 模型
origin_mod, params = relay.frontend.from_pytorch(traced_model, [(input_name, input_shape)])
先解构 resnet18 第一个计算块:
mod = relay.analysis.extract_intermdeiate_expr(origin_mod, 3)
mod = _transform.InferType()(mod)
转换前端模型为 relay 模型:
def _bind_params(func, params):
"""将 params 绑定到 func"""
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError(f"Multiple args in the function have name {k}")
bind_dict[arg] = _expr.const(v)
return _expr.bind(func, bind_dict)
print('原始模型:')
mod.show()
# 将 params 绑定到 origin_mod
if params:
mod["main"] = _bind_params(mod["main"], params)
print('原始模型(绑定参数):')
mod.show()
# 化简并折叠常量
optimize = tvm.transform.Sequential([
_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant(),
])
with tvm.transform.PassContext(opt_level=3):
run_mod = optimize(mod)
print('原始模型(化简后):')
run_mod.show()
原始模型:
原始模型(绑定参数):
原始模型(化简后):
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */, %aten::_convolution_0.weight: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] span=aten::_convolution_0.weight:0:0 */, %aten::batch_norm_0.weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.weight:0:0 */, %aten::batch_norm_0.bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.bias:0:0 */, %aten::batch_norm_0.running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_mean:0:0 */, %aten::batch_norm_0.running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_var:0:0 */) -> Tensor[(1, 64, 112, 112), float32] {
%0 = nn.conv2d(%data, %aten::_convolution_0.weight, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::_convolution_0:0:0 */;
%1 = nn.batch_norm(%0, %aten::batch_norm_0.weight, %aten::batch_norm_0.bias, %aten::batch_norm_0.running_mean, %aten::batch_norm_0.running_var) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) span=aten::batch_norm_0:0:0 */;
%2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::batch_norm_0:0:0 */;
nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */
}
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 112, 112), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0], strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::_convolution_0:0:0 */;
%1 = nn.batch_norm(%0, meta[relay.Constant][1], meta[relay.Constant][2], meta[relay.Constant][3], meta[relay.Constant][4]) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) span=aten::batch_norm_0:0:0 */;
%2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::batch_norm_0:0:0 */;
nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */
}
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 112, 112), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */
}
查看化简前后卷积参数变化:
class _Transform(tvm.relay.ExprMutator):
def __init__(self):
super().__init__()
self.binds = {}
self.func_id = 0
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
call = Call(new_fn, new_args, call.attrs, call.type_args, call.span)
if isinstance(new_fn, Op):
if new_fn.name == "nn.conv2d":
self.binds[f"{new_fn.name}_{self.func_id}"] = new_args[1]
self.func_id += 1
return call
transform = _Transform()
transform.visit(mod["main"])
weight_ori = transform.binds['nn.conv2d_0']
transform = _Transform()
transform.visit(run_mod["main"])
weight = transform.binds['nn.conv2d_0']
weight_ori.data.numpy()[0, 0, :5, :5]
array([[-0.01041935, -0.00613561, -0.00180978, 0.07484142, 0.05661485],
[ 0.01108271, 0.00952757, -0.10992692, -0.28050068, -0.27123755],
[-0.00694335, 0.05908897, 0.29548222, 0.587196 , 0.5197189 ],
[ 0.03050456, -0.06701802, -0.29841137, -0.4386757 , -0.27085286],
[-0.02753477, 0.01604508, 0.07259498, -0.05410165, -0.33284944]],
dtype=float32)
weight.data.numpy()[0, 0, :5, :5]
array([[-0.00242674, -0.00142902, -0.00042151, 0.01743106, 0.01318597],
[ 0.00258124, 0.00221903, -0.0256027 , -0.06533046, -0.06317302],
[-0.00161715, 0.01376221, 0.06881975, 0.13676181, 0.12104595],
[ 0.00710471, -0.01560894, -0.06950197, -0.10217047, -0.06308342],
[-0.00641303, 0.00373701, 0.01690785, -0.01260063, -0.07752283]],
dtype=float32)