MSC demo

MSC demo#

import set_env
from dataclasses import dataclass
from typing import Iterator

from tvm.contrib.msc.pipeline import TorchWrapper
from tvm.contrib.msc.core.tools import ToolType
from tvm.contrib.msc.core.utils.message import MSCStage
from utils import *

@dataclass
class Args:
    gym: bool = False # Whether to use gym for tool
    prune: bool = False # Whether to use pruner
    quantize: bool = False # Whether to use quantizer
    distill: bool = False # Whether to use distiller for tool
    compile_type: str = "tvm" # The compile type of model
    test_batch: int = 1 # Whether to use gym for tool
    verbose: str = "info" # "The verbose level, info|debug:1,2,3|critical
    dynamic: bool = False # Whether to use dynamic wrapper

def get_config(calib_loader: Iterator, train_loader: Iterator, args: Args):
    tools, dataset = [], {MSCStage.PREPARE: {"loader": calib_loader}}
    if args.prune:
        config = {"gym_configs": ["default"]} if args.gym else "default"
        tools.append((ToolType.PRUNER, config))
    if args.quantize:
        config = {"gym_configs": ["default"]} if args.gym else "default"
        tools.append((ToolType.QUANTIZER, config))
    if args.distill:
        config = {
            "options": {
                "optimizer": "adam",
                "opt_config": {"lr": 0.00000001, "weight_decay": 0.08},
            }
        }
        tools.append((ToolType.DISTILLER, config))
        dataset[MSCStage.DISTILL] = {"loader": train_loader}
    return TorchWrapper.create_config(
        inputs=[("input", [args.test_batch, 3, 32, 32], "float32")],
        outputs=["output"],
        compile_type=args.compile_type,
        dataset=dataset,
        tools=tools,
        verbose=args.verbose,
        dynamic=args.dynamic,
    )
from pathlib import Path
dataset = ".temp/msc_dataset"
dataset = Path(dataset)
if not dataset.exists():
    dataset.mkdir(parents=True)
train_batch = 32
test_batch = 1
trainloader, testloader = get_dataloaders(dataset, train_batch, test_batch)
def _get_calib_datas(calibrate_iter, dynamic):
    for i, (inputs, _) in enumerate(testloader, 0):
        if i >= calibrate_iter > 0:
            break
        yield inputs if dynamic else {"input": inputs}

def _get_train_datas(train_iter, dynamic):
    for i, (inputs, _) in enumerate(trainloader, 0):
        if i >= train_iter > 0:
            break
        yield inputs if dynamic else {"input": inputs}
from _resnet import resnet50
checkpoint = ".temp/msc_models" # The folder saving training and testing datas
checkpoint = Path(checkpoint)
checkpoint.mkdir(exist_ok=True, parents=True)
# 参考 https://github.com/huyvnphan/PyTorch_CIFAR10/tree/master?tab=readme-ov-file


model = resnet50(pretrained=str(checkpoint))
if torch.cuda.is_available():
    model = model.to(torch.device("cuda:0"))

test_iter = 100 # 测试迭代次数
acc = eval_model(model, testloader, max_iter=test_iter)
print("Baseline acc: " + str(acc))
model = TorchWrapper(model, get_config(_get_calib_datas, _get_train_datas, Args()))

# optimize the model with tool
model.optimize()
acc = eval_model(model, testloader, max_iter=test_iter)
print("Optimized acc: " + str(acc))
import torch
import torch.optim as optim
train_iter = 100
train_epoch = 5
# train the model with tool
optimizer = optim.Adam(model.parameters(), lr=0.0000001, weight_decay=0.08)
for ep in range(train_epoch):
    train_model(model, trainloader, optimizer, max_iter=train_iter)
    acc = eval_model(model, testloader, max_iter=test_iter)
    print("Train[{}] acc: {}".format(ep, acc))

# compile the model
model.compile()
acc = eval_model(model, testloader, max_iter=test_iter)
print("Compiled acc: " + str(acc))

# export the model
path = model.export()
print("Export model to " + str(path))
file_costtime ="/media/pc/data/board/arria10/lxw/tasks/tools/npuusertools/temp/xmdemo.fa_color_bertha/inference/outputs/costtime_perf_first_L0.log"
with open(file_costtime) as fp:
    time_info = fp.read()
import re
match = re.search(r"first run time:\s*(\d+)\s*ms\n", time_info)
if match:
    time_value = match.group(1)
time_info = time_info.replace(match.group(0), "$").split("$")[1]
time_info = time_info.split("================================================================\ntimer end")
print(time_info[-1])
profiler:
{"VTA_TOPI_yuv420sp2rgb_N0": {
 "inp_load_nbytes":5760,
 "wgt_load_nbytes":0,
 "acc_load_nbytes":0,
 "uop_load_nbytes":0,
 "alu_uop_load_nbytes":0,
 "alu_tab_load_nbytes":0,
 "bias_load_nbytes":0,
 "out_store_nbytes":10800,
 "gemm_counter":0,
 "alu_counter":0
}
}
profiler_end
{
 "mem_weight":59648,
 "mem_op":34848,
 "mem_lib":21408,
 "mem_total":115904,
 "run_time":28
}
match = re.search(r"first run time:\s*(\d+)\s*ms\n", time_info)
    if match:
        pre_infer_time = match.group(1)
    unit_time = match.group(0)[match.group(0).find(pre_infer_time):].strip().split(" ")[1]
    assert "ms" in unit_time.strip() == "ms", f"时间单位 {unit_time.strip()} 出错"
    time_info = time_info.replace(match.group(0), "$").split("$")[1]
    time_info = time_info.split("\n================================================================\ntimer end")
    summary = eval(time_info[-1]) #
'ms'