TorchDynamic

TorchDynamic#

%cd ..
from pathlib import Path

temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc
def _get_torch_model(name, training=False):
    """Get model from torch vision"""

    # pylint: disable=import-outside-toplevel
    try:
        import torchvision

        model = getattr(torchvision.models, name)()
        if training:
            model = model.train()
        else:
            model = model.eval()
        return model
    except:  # pylint: disable=bare-except
        print("please install torchvision package")
        return None
from tvm.contrib.msc.core import utils as msc_utils

def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1):
    """Get msc config"""

    path = f'test_pipe_{model_type}_{compile_type}_{"dynamic" if dynamic else "static"}'
    return {
        "workspace": msc_utils.msc_dir(f"{temp_dir}/{path}", keep_history=False),
        "verbose": "critical",
        "model_type": model_type,
        "inputs": inputs,
        "outputs": outputs,
        "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}},
        "prepare": {"profile": {"benchmark": {"repeat": 10}}},
        "baseline": {
            "run_type": model_type,
            "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
        },
        "compile": {
            "run_type": compile_type,
            "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
        },
    }
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.pipeline import TorchDynamic
import torch

for compile_type in [MSCFramework.TORCH, MSCFramework.TVM]:
    torch_model = _get_torch_model("resnet50", False)
    if torch.cuda.is_available():
        torch_model = torch_model.to(torch.device("cuda:0"))
    config = _get_config(
        MSCFramework.TORCH,
        compile_type,
        inputs=[["input_0", [1, 3, 224, 224], "float32"]],
        outputs=["output"],
        dynamic = True,
        atol = 1e-1,
        rtol = 1e-1,
    )
    pipeline = TorchDynamic(torch_model, config)
    pipeline.run_pipe() # 运行管道
    print(pipeline.report) # 打印模型信息
    pipeline.destory() # 销毁管道
/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
{'success': False, 'info': {'prepare': {'profile': {'jit_0': '46.75 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '6.19 s(49.31%)', 'parse': '0.09 s(0.68%)', 'total': '12.55 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py", line 162, in run_pipe\n    self.parse()\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py", line 226, in parse\n    info, report = self._parse()\n                   ^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py", line 157, in _parse\n    info[name], report[name] = w_ctx["worker"].parse()\n                               ^^^^^^^^^^^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py", line 320, in parse\n    self._relax_mod, _ = stage_config["parser"](self._model, **parse_config)\n                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py", line 119, in from_torch\n    relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\n                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 960, in from_fx\n    return TorchFXImporter().from_fx(\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 837, in from_fx\n    assert (\nAssertionError: Unsupported function type batch_norm\n'}
[12:54:57] /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!
{'success': False, 'info': {'prepare': {'profile': {'jit_0': '42.50 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '4.81 s(49.18%)', 'parse': '0.08 s(0.82%)', 'total': '9.78 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py", line 162, in run_pipe\n    self.parse()\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py", line 226, in parse\n    info, report = self._parse()\n                   ^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py", line 157, in _parse\n    info[name], report[name] = w_ctx["worker"].parse()\n                               ^^^^^^^^^^^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py", line 320, in parse\n    self._relax_mod, _ = stage_config["parser"](self._model, **parse_config)\n                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py", line 119, in from_torch\n    relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\n                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 960, in from_fx\n    return TorchFXImporter().from_fx(\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 837, in from_fx\n    assert (\nAssertionError: Unsupported function type batch_norm\n'}