MSC 工具测试

MSC 工具测试#

MSCTools 与 MSCGraph 协同工作,它们决定压缩策略并控制压缩过程。MSCTools 由 RuntimeManager 管理。

from tvm.contrib.msc.core.transform import msc_transform
from tvm.contrib.msc.core.runtime import create_runtime_manager
from tvm.contrib.msc.core.tools import create_tool, MSC_TOOL

# build runtime manager from module and mscgraphs
optimized_mod, msc_graph, msc_config = msc_transform(mod, params)
rt_manager = create_runtime_manager(optimized_mod, params, msc_config)

# pruner is used for prune the model
rt_manager.create_tool(MSC_TOOL.PRUNE, prune_config)

# quantizer is used to do the calibration and quantize the model
rt_manager.create_tool(MSC_TOOL.QUANTIZE, quantize_config)

# collecter is used to collect the datas of each computational node
rt_manager.create_tool(MSC_TOOL.COLLECT, collect_config)

# distiller is used to do the knowledge distilliation
rt_manager.create_tool(MSC_TOOL.DISTILL, distill_config)

MSCProcessor#

MSCProcessor 为编译过程构建流水线。一个编译过程可能包括不同的阶段,每个阶段都有特殊的配置和策略。为了使编译过程易于管理,创建了 MSCProcessor。

from tvm.contrib.msc.pipeline import create_msc_processor

# get the torch model and config
model = get_torch_model()
config = get_msc_config()
processor = create_msc_processor(model, config)

if mode == "deploy":
    processor.compile()
    processor.export()
elif mode == "optimize":
    model = processor.optimize()
    for ep in EPOCHS:
        for datas in training_datas:
            train_model(model)
    processor.update_weights(get_weights(model))
    processor.compile()
    processor.export()

配置可以从文件中加载,从而可以控制、记录和重放编译过程。这对于构建编译服务和平台至关重要。

{
  "workspace": "msc_workspace",
  "verbose": "runtime",
  "log_file": "MSC_LOG",
  "baseline": {
    "check_config": {
      "atol": 0.05
    }
  },
  "quantize": {
    "strategy_file": "msc_quantize.json",
    "target": "tensorrt",
  },
  "profile": {
    "repeat": 1000
  },
  ...
}

MSCGym#

MSCGym 是 MSC 中自动压缩的平台。它的作用类似于 AutoTVM,但其架构更像 OpenAI-Gym。MSCGym 从压缩过程中提取任务,然后利用代理和环境之间的交互来为每个任务找到最佳行动。要使用 MSCGym 进行自动压缩,请为工具设置 gym 配置:

{
      ...
      "quantize": {
        "strategy_file": "msc_quantize.json",
        "target": "tensorrt",
        “gym”:[
          {
            “record”:”searched_config.json”,
            “env”:{
              “strategy”:”distill_loss”
            },
            “agent”:{
              “type”:”grid_search”,
            }
          },
        ]
      },
      ...
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

""" Test Tools in MSC. """

import json
import pytest
import torch

import tvm.testing
from tvm.contrib.msc.pipeline import MSCManager
from tvm.contrib.msc.core.tools import ToolType
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils

requires_tensorrt = pytest.mark.skipif(
    tvm.get_global_func("relax.ext.tensorrt", True) is None,
    reason="TENSORRT is not enabled",
)


def _get_config(
    model_type,
    compile_type,
    tools,
    inputs,
    outputs,
    atol=1e-2,
    rtol=1e-2,
    optimize_type=None,
):
    """Get msc config"""

    path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools])
    return {
        "workspace": msc_utils.msc_dir(path),
        "verbose": "critical",
        "model_type": model_type,
        "inputs": inputs,
        "outputs": outputs,
        "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}},
        "tools": tools,
        "prepare": {"profile": {"benchmark": {"repeat": 10}}},
        "baseline": {
            "run_type": model_type,
            "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
        },
        "optimize": {
            "run_type": optimize_type or model_type,
            "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
        },
        "compile": {
            "run_type": compile_type,
            "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
        },
    }


def get_tools(tool_type, use_distill=False, run_type=MSCFramework.MSC):
    """Get config for the tool"""

    tools = []
    if tool_type == ToolType.PRUNER:
        config = {
            "plan_file": "msc_pruner.json",
            "strategys": [
                {
                    "methods": {
                        "weights": {"method_name": "per_channel", "density": 0.8},
                        "output": {"method_name": "per_channel", "density": 0.8},
                    }
                }
            ],
        }
        tools.append({"tool_type": ToolType.PRUNER, "tool_config": config})
    elif tool_type == ToolType.QUANTIZER:
        # pylint: disable=import-outside-toplevel
        from tvm.contrib.msc.core.tools.quantize import QuantizeStage

        if run_type == MSCFramework.TENSORRT:
            config = {"plan_file": "msc_quantizer.json", "strategys": []}
        else:
            op_types = ["nn.conv2d", "msc.conv2d_bias", "msc.linear", "msc.linear_bias"]
            config = {
                "plan_file": "msc_quantizer.json",
                "strategys": [
                    {
                        "methods": {
                            "input": "gather_maxmin",
                            "output": "gather_maxmin",
                            "weights": "gather_max_per_channel",
                        },
                        "op_types": op_types,
                        "stages": [QuantizeStage.GATHER],
                    },
                    {
                        "methods": {"input": "calibrate_maxmin", "output": "calibrate_maxmin"},
                        "op_types": op_types,
                        "stages": [QuantizeStage.CALIBRATE],
                    },
                    {
                        "methods": {
                            "input": "quantize_normal",
                            "weights": "quantize_normal",
                            "output": "dequantize_normal",
                        },
                        "op_types": op_types,
                    },
                ],
            }
        tools.append({"tool_type": ToolType.QUANTIZER, "tool_config": config})
    elif tool_type == ToolType.TRACKER:
        # pylint: disable=import-outside-toplevel
        from tvm.contrib.msc.core.utils import MSCStage

        config = {
            "plan_file": "msc_tracker.json",
            "strategys": [
                {
                    "methods": {
                        "output": {
                            "method_name": "save_compared",
                            "compare_to": {
                                MSCStage.OPTIMIZE: [MSCStage.BASELINE],
                                MSCStage.COMPILE: [MSCStage.OPTIMIZE, MSCStage.BASELINE],
                            },
                        }
                    },
                    "op_types": ["nn.relu"],
                }
            ],
        }
        tools.append({"tool_type": ToolType.TRACKER, "tool_config": config})
    if use_distill:
        config = {
            "plan_file": "msc_distiller.json",
            "strategys": [
                {
                    "methods": {"mark": "loss_lp_norm"},
                    "marks": ["loss"],
                },
            ],
        }
        tools.append({"tool_type": ToolType.DISTILLER, "tool_config": config})
    return tools


def _get_torch_model(name, training=False):
    """Get model from torch vision"""

    # pylint: disable=import-outside-toplevel
    try:
        import torchvision

        model = getattr(torchvision.models, name)()
        if training:
            model = model.train()
        else:
            model = model.eval()
        return model
    except:  # pylint: disable=bare-except
        print("please install torchvision package")
        return None


def _check_manager(manager, expected_info):
    """Check the manager results"""

    model_info = manager.get_runtime().model_info
    passed, err = True, ""
    if not manager.report["success"]:
        passed = False
        err = "Failed to run pipe for {} -> {}".format(manager.model_type, manager.compile_type)
    if not msc_utils.dict_equal(model_info, expected_info):
        passed = False
        err = "Model info {} mismatch with expected {}".format(model_info, expected_info)
    manager.destory()
    if not passed:
        raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2)))


def _test_from_torch(
    compile_type,
    tools,
    expected_info,
    training=False,
    atol=1e-1,
    rtol=1e-1,
    optimize_type=None,
):
    torch_model = _get_torch_model("resnet50", training)
    if torch_model:
        if torch.cuda.is_available():
            torch_model = torch_model.to(torch.device("cuda:0"))
        config = _get_config(
            MSCFramework.TORCH,
            compile_type,
            tools,
            inputs=[["input_0", [1, 3, 224, 224], "float32"]],
            outputs=["output"],
            atol=atol,
            rtol=rtol,
            optimize_type=optimize_type,
        )
        manager = MSCManager(torch_model, config)
        manager.run_pipe()
        _check_manager(manager, expected_info)


def get_model_info(compile_type):
    """Get the model info"""

    if compile_type == MSCFramework.TVM:
        return {
            "inputs": [
                {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"}
            ],
            "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NC"}],
            "nodes": {
                "total": 229,
                "input": 1,
                "nn.conv2d": 53,
                "nn.batch_norm": 53,
                "get_item": 53,
                "nn.relu": 49,
                "nn.max_pool2d": 1,
                "add": 16,
                "nn.adaptive_avg_pool2d": 1,
                "reshape": 1,
                "msc.linear_bias": 1,
            },
        }
    if compile_type == MSCFramework.TENSORRT:
        return {
            "inputs": [
                {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"}
            ],
            "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}],
            "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1},
        }
    raise TypeError("Unexpected compile_type " + str(compile_type))


@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER])
def test_tvm_tool(tool_type):
    """Test tools for tvm"""

    tools = get_tools(tool_type)
    _test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False)


@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER])
def test_tvm_distill(tool_type):
    """Test tools for tvm with distiller"""

    tools = get_tools(tool_type, use_distill=True)
    _test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False)


@requires_tensorrt
@pytest.mark.parametrize(
    "tool_type",
    [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER],
)
def test_tensorrt_tool(tool_type):
    """Test tools for tensorrt"""

    tools = get_tools(tool_type, run_type=MSCFramework.TENSORRT)
    if tool_type == ToolType.QUANTIZER:
        optimize_type = MSCFramework.TENSORRT
    else:
        optimize_type = None
    _test_from_torch(
        MSCFramework.TENSORRT,
        tools,
        get_model_info(MSCFramework.TENSORRT),
        training=False,
        atol=1e-1,
        rtol=1e-1,
        optimize_type=optimize_type,
    )


@requires_tensorrt
@pytest.mark.parametrize("tool_type", [ToolType.PRUNER])
def test_tensorrt_distill(tool_type):
    """Test tools for tensorrt with distiller"""

    tools = get_tools(tool_type, use_distill=True)
    _test_from_torch(
        MSCFramework.TENSORRT, tools, get_model_info(MSCFramework.TENSORRT), training=False
    )


if __name__ == "__main__":
    tvm.testing.main()