测试#
import torch
from torch import nn, Tensor
from torch.nn import functional as F
class FastGlobalAvgPool(nn.Module):
def __init__(self, flatten=False):
super().__init__()
self.flatten = flatten
def forward(self, x):
if self.flatten:
in_size = x.size()
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
else:
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
class BatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
bias_init=0.0):
super().__init__(num_features, eps=eps, momentum=momentum)
if weight_init is not None: nn.init.constant_(self.weight, weight_init)
if bias_init is not None: nn.init.constant_(self.bias, bias_init)
self.weight.requires_grad_(not weight_freeze)
self.bias.requires_grad_(not bias_freeze)
class Head(nn.Module):
def __init__(self, feat_dim=1024, num_class=9):
super().__init__()
self.pool_layer = FastGlobalAvgPool()
self.bottleneck = BatchNorm(feat_dim, bias_freeze=True)
self.bnneck = nn.BatchNorm1d(num_class)
self.weight = nn.Parameter(Tensor(num_class, feat_dim))
def forward(self, x: Tensor) -> Tensor:
pool_feat = self.pool_layer(x)
neck_feat = self.bottleneck(pool_feat)
neck_feat = neck_feat.view(neck_feat.size(0), -1)
logits = F.linear(neck_feat,self.weight)
logits = self.bnneck(logits)
return logits
torch.cuda.empty_cache()
model_path = "/media/pc/data/board/arria10/lxw/tasks/tools/npu_user_demos/models/telecom_pt/Nin1_helmet_small/helmet_small.pth"
model = Head()
state_dict = torch.load(model_path, weights_only=False, map_location=torch.device('cpu'))
state_dict = {k.replace('bottleneck.0', 'bottleneck').replace('heads.', ''): v for k, v in state_dict['model'].items() if "heads" in k or "bottleneck" in k}
model.load_state_dict(state_dict, strict=True)
model = model.eval()
import set_env
import numpy as np
import tvm
from tvm import relay
name = "x"
shape = (1, 1024, 8, 6)
data_np = (np.random.randint(0, 256, shape)/255).astype("float32")
data_torch = torch.from_numpy(data_np)
scripted_model = torch.jit.trace(model, data_torch).eval()
shape_list = [(name, shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
print(mod["main"])
fn (%x: Tensor[(1, 1024, 8, 6), float32] /* span=aten::size_0.x:0:0 */, %aten::batch_norm_0.weight: Tensor[(1024), float32] /* span=aten::batch_norm_0.weight:0:0 */, %aten::batch_norm_0.bias: Tensor[(1024), float32] /* span=aten::batch_norm_0.bias:0:0 */, %aten::batch_norm_0.running_mean: Tensor[(1024), float32] /* span=aten::batch_norm_0.running_mean:0:0 */, %aten::batch_norm_0.running_var: Tensor[(1024), float32] /* span=aten::batch_norm_0.running_var:0:0 */, %aten::linear_0.weight: Tensor[(9, 1024), float32] /* span=aten::linear_0.weight:0:0 */, %aten::batch_norm_1.weight: Tensor[(9), float32] /* span=aten::batch_norm_1.weight:0:0 */, %aten::batch_norm_1.bias: Tensor[(9), float32] /* span=aten::batch_norm_1.bias:0:0 */, %aten::batch_norm_1.running_mean: Tensor[(9), float32] /* span=aten::batch_norm_1.running_mean:0:0 */, %aten::batch_norm_1.running_var: Tensor[(9), float32] /* span=aten::batch_norm_1.running_var:0:0 */) {
%0 = reshape(%x, newshape=[1, 1024, -1]) /* span=aten::view_0:0:0 */;
%1 = mean(%0, axis=[-1]) /* span=aten::mean_0:0:0 */;
%2 = reshape(%1, newshape=[1, 1024, 1, 1]) /* span=aten::view_1:0:0 */;
%3 = nn.batch_norm(%2, %aten::batch_norm_0.weight, %aten::batch_norm_0.bias, %aten::batch_norm_0.running_mean, %aten::batch_norm_0.running_var) /* span=aten::batch_norm_0:0:0 */;
%4 = %3.0 /* span=aten::batch_norm_0:0:0 */;
%5 = reshape(%4, newshape=[1, -1]) /* span=aten::view_2:0:0 */;
%6 = nn.dense(%5, %aten::linear_0.weight, units=None) /* span=aten::linear_0:0:0 */;
%7 = nn.batch_norm(%6, %aten::batch_norm_1.weight, %aten::batch_norm_1.bias, %aten::batch_norm_1.running_mean, %aten::batch_norm_1.running_var) /* span=aten::batch_norm_1:0:0 */;
%7.0 /* span=aten::batch_norm_1:0:0 */
}
from tvm.relay.dataflow_pattern import rewrite
from tvm_book.transforms.simplify import FastGlobalAvgPoolSimplify
run_mod = tvm.IRModule()
run_mod["main"] = rewrite(FastGlobalAvgPoolSimplify(), mod["main"])
run_mod.show()
def @main(%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */, %aten::batch_norm_0.weight: Tensor[(1024), float32] /* span=aten::batch_norm_0.weight:0:0 */, %aten::batch_norm_0.bias: Tensor[(1024), float32] /* span=aten::batch_norm_0.bias:0:0 */, %aten::batch_norm_0.running_mean: Tensor[(1024), float32] /* span=aten::batch_norm_0.running_mean:0:0 */, %aten::batch_norm_0.running_var: Tensor[(1024), float32] /* span=aten::batch_norm_0.running_var:0:0 */, %aten::linear_0.weight: Tensor[(9, 1024), float32] /* span=aten::linear_0.weight:0:0 */, %aten::batch_norm_1.weight: Tensor[(9), float32] /* span=aten::batch_norm_1.weight:0:0 */, %aten::batch_norm_1.bias: Tensor[(9), float32] /* span=aten::batch_norm_1.bias:0:0 */, %aten::batch_norm_1.running_mean: Tensor[(9), float32] /* span=aten::batch_norm_1.running_mean:0:0 */, %aten::batch_norm_1.running_var: Tensor[(9), float32] /* span=aten::batch_norm_1.running_var:0:0 */) {
%0 = nn.adaptive_avg_pool2d(%x, output_size=[1, 1]) /* ty=Tensor[(1, 1024, 1, 1), float32] */;
%1 = nn.batch_norm(%0, %aten::batch_norm_0.weight, %aten::batch_norm_0.bias, %aten::batch_norm_0.running_mean, %aten::batch_norm_0.running_var) /* span=aten::batch_norm_0:0:0 */;
%2 = %1.0 /* span=aten::batch_norm_0:0:0 */;
%3 = reshape(%2, newshape=[1, -1]) /* span=aten::view_2:0:0 */;
%4 = nn.dense(%3, %aten::linear_0.weight, units=None) /* span=aten::linear_0:0:0 */;
%5 = nn.batch_norm(%4, %aten::batch_norm_1.weight, %aten::batch_norm_1.bias, %aten::batch_norm_1.running_mean, %aten::batch_norm_1.running_var) /* span=aten::batch_norm_1:0:0 */;
%5.0 /* span=aten::batch_norm_1:0:0 */
}
with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
lib = relay.build(mod, target="llvm", params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(tvm.cpu(0)))
module.run(**{name: data_np})
num_outputs = module.get_num_outputs()
origin_outputs = [module.get_output(k).numpy() for k in range(num_outputs)]
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
lib = relay.build(run_mod, target="llvm", params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(tvm.cpu(0)))
module.run(**{name: data_np})
num_outputs = module.get_num_outputs()
outputs = [module.get_output(k).numpy() for k in range(num_outputs)]
np.testing.assert_allclose(origin_outputs[0], outputs[0])
from copy import deepcopy
import tvm
from tvm import relay
from tvm.relay.quantize.quantize import _bind_params
optimize = tvm.transform.Sequential(
[
relay.transform.SimplifyInference(),
relay.transform.FoldConstant(),
relay.transform.FoldScaleAxis(),
# relay.transform.CanonicalizeOps(),
# relay.transform.FoldConstant(),
]
)
run_mod = deepcopy(mod)
run_mod["main"] = rewrite(FastGlobalAvgPoolSimplify(), run_mod["main"])
run_mod["main"] = _bind_params(run_mod["main"], params)
with tvm.transform.PassContext(opt_level=3):
run_mod2 = relay.quantize.prerequisite_optimize(deepcopy(run_mod), params)
run_mod = optimize(run_mod)
print(run_mod["main"])
fn (%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) -> Tensor[(1, 9), float32] {
%0 = nn.adaptive_avg_pool2d(%x, output_size=[1, 1]) /* ty=Tensor[(1, 1024, 1, 1), float32] */;
%1 = multiply(%0, meta[relay.Constant][0] /* ty=Tensor[(1024, 1, 1), float32] */) /* ty=Tensor[(1, 1024, 1, 1), float32] */;
%2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1024, 1, 1), float32] */) /* ty=Tensor[(1, 1024, 1, 1), float32] */;
%3 = reshape(%2, newshape=[1, -1]) /* ty=Tensor[(1, 1024), float32] span=aten::view_2:0:0 */;
%4 = nn.dense(%3, meta[relay.Constant][2] /* ty=Tensor[(9, 1024), float32] */, units=None) /* ty=Tensor[(1, 9), float32] */;
add(%4, meta[relay.Constant][3] /* ty=Tensor[(9), float32] */) /* ty=Tensor[(1, 9), float32] */
} /* ty=fn (Tensor[(1, 1024, 8, 6), float32]) -> Tensor[(1, 9), float32] */
print(run_mod2["main"])
fn (%x: Tensor[(1, 1024, 8, 6), float32] /* ty=Tensor[(1, 1024, 8, 6), float32] span=aten::size_0.x:0:0 */) -> Tensor[(1, 9), float32] {
%0 = nn.adaptive_avg_pool2d(%x, output_size=[1, 1]) /* ty=Tensor[(1, 1024, 1, 1), float32] */;
%1 = multiply(%0, meta[relay.Constant][0] /* ty=Tensor[(1024, 1, 1), float32] */) /* ty=Tensor[(1, 1024, 1, 1), float32] */;
%2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1024, 1, 1), float32] */) /* ty=Tensor[(1, 1024, 1, 1), float32] */;
%3 = reshape(%2, newshape=[1, -1]) /* ty=Tensor[(1, 1024), float32] span=aten::view_2:0:0 */;
%4 = nn.dense(%3, meta[relay.Constant][2] /* ty=Tensor[(9, 1024), float32] */, units=None) /* ty=Tensor[(1, 9), float32] */;
add(%4, meta[relay.Constant][3] /* ty=Tensor[(9), float32] */) /* ty=Tensor[(1, 9), float32] */
} /* ty=fn (Tensor[(1, 1024, 8, 6), float32]) -> Tensor[(1, 9), float32] */