MSC demo#
import set_env
from dataclasses import dataclass
from typing import Iterator
from tvm.contrib.msc.pipeline import TorchWrapper
from tvm.contrib.msc.core.tools import ToolType
from tvm.contrib.msc.core.utils.message import MSCStage
from utils import *
@dataclass
class Args:
gym: bool = False # Whether to use gym for tool
prune: bool = False # Whether to use pruner
quantize: bool = False # Whether to use quantizer
distill: bool = False # Whether to use distiller for tool
compile_type: str = "tvm" # The compile type of model
test_batch: int = 1 # Whether to use gym for tool
verbose: str = "info" # "The verbose level, info|debug:1,2,3|critical
dynamic: bool = False # Whether to use dynamic wrapper
def get_config(calib_loader: Iterator, train_loader: Iterator, args: Args):
tools, dataset = [], {MSCStage.PREPARE: {"loader": calib_loader}}
if args.prune:
config = {"gym_configs": ["default"]} if args.gym else "default"
tools.append((ToolType.PRUNER, config))
if args.quantize:
config = {"gym_configs": ["default"]} if args.gym else "default"
tools.append((ToolType.QUANTIZER, config))
if args.distill:
config = {
"options": {
"optimizer": "adam",
"opt_config": {"lr": 0.00000001, "weight_decay": 0.08},
}
}
tools.append((ToolType.DISTILLER, config))
dataset[MSCStage.DISTILL] = {"loader": train_loader}
return TorchWrapper.create_config(
inputs=[("input", [args.test_batch, 3, 32, 32], "float32")],
outputs=["output"],
compile_type=args.compile_type,
dataset=dataset,
tools=tools,
verbose=args.verbose,
dynamic=args.dynamic,
)
from pathlib import Path
dataset = ".temp/msc_dataset"
dataset = Path(dataset)
if not dataset.exists():
dataset.mkdir(parents=True)
train_batch = 32
test_batch = 1
trainloader, testloader = get_dataloaders(dataset, train_batch, test_batch)
def _get_calib_datas(calibrate_iter, dynamic):
for i, (inputs, _) in enumerate(testloader, 0):
if i >= calibrate_iter > 0:
break
yield inputs if dynamic else {"input": inputs}
def _get_train_datas(train_iter, dynamic):
for i, (inputs, _) in enumerate(trainloader, 0):
if i >= train_iter > 0:
break
yield inputs if dynamic else {"input": inputs}
from _resnet import resnet50
checkpoint = ".temp/msc_models" # The folder saving training and testing datas
checkpoint = Path(checkpoint)
checkpoint.mkdir(exist_ok=True, parents=True)
# 参考 https://github.com/huyvnphan/PyTorch_CIFAR10/tree/master?tab=readme-ov-file
model = resnet50(pretrained=str(checkpoint))
if torch.cuda.is_available():
model = model.to(torch.device("cuda:0"))
test_iter = 100 # 测试迭代次数
acc = eval_model(model, testloader, max_iter=test_iter)
print("Baseline acc: " + str(acc))
model = TorchWrapper(model, get_config(_get_calib_datas, _get_train_datas, Args()))
# optimize the model with tool
model.optimize()
acc = eval_model(model, testloader, max_iter=test_iter)
print("Optimized acc: " + str(acc))
import torch
import torch.optim as optim
train_iter = 100
train_epoch = 5
# train the model with tool
optimizer = optim.Adam(model.parameters(), lr=0.0000001, weight_decay=0.08)
for ep in range(train_epoch):
train_model(model, trainloader, optimizer, max_iter=train_iter)
acc = eval_model(model, testloader, max_iter=test_iter)
print("Train[{}] acc: {}".format(ep, acc))
# compile the model
model.compile()
acc = eval_model(model, testloader, max_iter=test_iter)
print("Compiled acc: " + str(acc))
# export the model
path = model.export()
print("Export model to " + str(path))
file_costtime ="/media/pc/data/board/arria10/lxw/tasks/tools/npuusertools/temp/xmdemo.fa_color_bertha/inference/outputs/costtime_perf_first_L0.log"
with open(file_costtime) as fp:
time_info = fp.read()
import re
match = re.search(r"first run time:\s*(\d+)\s*ms\n", time_info)
if match:
time_value = match.group(1)
time_info = time_info.replace(match.group(0), "$").split("$")[1]
time_info = time_info.split("================================================================\ntimer end")
print(time_info[-1])
profiler:
{"VTA_TOPI_yuv420sp2rgb_N0": {
"inp_load_nbytes":5760,
"wgt_load_nbytes":0,
"acc_load_nbytes":0,
"uop_load_nbytes":0,
"alu_uop_load_nbytes":0,
"alu_tab_load_nbytes":0,
"bias_load_nbytes":0,
"out_store_nbytes":10800,
"gemm_counter":0,
"alu_counter":0
}
}
profiler_end
{
"mem_weight":59648,
"mem_op":34848,
"mem_lib":21408,
"mem_total":115904,
"run_time":28
}
match = re.search(r"first run time:\s*(\d+)\s*ms\n", time_info)
if match:
pre_infer_time = match.group(1)
unit_time = match.group(0)[match.group(0).find(pre_infer_time):].strip().split(" ")[1]
assert "ms" in unit_time.strip() == "ms", f"时间单位 {unit_time.strip()} 出错"
time_info = time_info.replace(match.group(0), "$").split("$")[1]
time_info = time_info.split("\n================================================================\ntimer end")
summary = eval(time_info[-1]) #
'ms'