# 量化

In [1]:
import set_env
from pathlib import Path

temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)

In [2]:
from tvm.contrib.msc.core.tools import ToolType
# pylint: disable=import-outside-toplevel
from tvm.contrib.msc.core.tools.quantize import QuantizeStage
from tvm.contrib.msc.core.utils.namespace import MSCFramework

run_type = MSCFramework.MSC
if run_type == MSCFramework.TENSORRT:
    config = {"plan_file": "msc_quantizer.json", "strategys": []}
else:
    op_types = ["nn.conv2d", "msc.conv2d_bias", "msc.linear", "msc.linear_bias"]
    config = {
        "plan_file": "msc_quantizer.json",
        "strategys": [
            {
                "methods": {
                    "input": "gather_maxmin",
                    "output": "gather_maxmin",
                    "weights": "gather_max_per_channel",
                },
                "op_types": op_types,
                "stages": [QuantizeStage.GATHER],
            },
            {
                "methods": {"input": "calibrate_maxmin", "output": "calibrate_maxmin"},
                "op_types": op_types,
                "stages": [QuantizeStage.CALIBRATE],
            },
            {
                "methods": {
                    "input": "quantize_normal",
                    "weights": "quantize_normal",
                    "output": "dequantize_normal",
                },
                "op_types": op_types,
            },
        ],
    }
tools = [{"tool_type": ToolType.QUANTIZER, "tool_config": config}]


In [3]:
from utils import get_model_info, _test_from_torch

_test_from_torch(
    MSCFramework.TVM, tools, 
    get_model_info(MSCFramework.TVM), 
    temp_dir,
    training=False
)

  state_dict = torch.load(folder.relpath(graph.name + ".pth"))
  state_dict = torch.load(folder.relpath(graph.name + ".pth"))


## 蒸馏

In [4]:
config = {
    "plan_file": "msc_distiller.json",
    "strategys": [
        {
            "methods": {"mark": "loss_lp_norm"},
            "marks": ["loss"],
        },
    ],
}
tools.append({"tool_type": ToolType.DISTILLER, "tool_config": config})

In [5]:
from utils import get_model_info, _test_from_torch

_test_from_torch(
    MSCFramework.TVM, tools, 
    get_model_info(MSCFramework.TVM), 
    temp_dir,
    training=False
)

  state_dict = torch.load(folder.relpath(graph.name + ".pth"))
