TorchDynamic
#
%cd ..
from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc
def _get_torch_model(name, training=False):
"""Get model from torch vision"""
# pylint: disable=import-outside-toplevel
try:
import torchvision
model = getattr(torchvision.models, name)()
if training:
model = model.train()
else:
model = model.eval()
return model
except: # pylint: disable=bare-except
print("please install torchvision package")
return None
from tvm.contrib.msc.core import utils as msc_utils
def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1):
"""Get msc config"""
path = f'test_pipe_{model_type}_{compile_type}_{"dynamic" if dynamic else "static"}'
return {
"workspace": msc_utils.msc_dir(f"{temp_dir}/{path}", keep_history=False),
"verbose": "critical",
"model_type": model_type,
"inputs": inputs,
"outputs": outputs,
"dataset": {"prepare": {"loader": "from_random", "max_iter": 5}},
"prepare": {"profile": {"benchmark": {"repeat": 10}}},
"baseline": {
"run_type": model_type,
"profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
},
"compile": {
"run_type": compile_type,
"profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
},
}
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.pipeline import TorchDynamic
import torch
for compile_type in [MSCFramework.TORCH, MSCFramework.TVM]:
torch_model = _get_torch_model("resnet50", False)
if torch.cuda.is_available():
torch_model = torch_model.to(torch.device("cuda:0"))
config = _get_config(
MSCFramework.TORCH,
compile_type,
inputs=[["input_0", [1, 3, 224, 224], "float32"]],
outputs=["output"],
dynamic = True,
atol = 1e-1,
rtol = 1e-1,
)
pipeline = TorchDynamic(torch_model, config)
pipeline.run_pipe() # 运行管道
print(pipeline.report) # 打印模型信息
pipeline.destory() # 销毁管道
/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
{'success': False, 'info': {'prepare': {'profile': {'jit_0': '46.75 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '6.19 s(49.31%)', 'parse': '0.09 s(0.68%)', 'total': '12.55 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py", line 162, in run_pipe\n self.parse()\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py", line 226, in parse\n info, report = self._parse()\n ^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py", line 157, in _parse\n info[name], report[name] = w_ctx["worker"].parse()\n ^^^^^^^^^^^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py", line 320, in parse\n self._relax_mod, _ = stage_config["parser"](self._model, **parse_config)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py", line 119, in from_torch\n relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 960, in from_fx\n return TorchFXImporter().from_fx(\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 837, in from_fx\n assert (\nAssertionError: Unsupported function type batch_norm\n'}
[12:54:57] /media/pc/data/lxw/ai/tvm/src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!
{'success': False, 'info': {'prepare': {'profile': {'jit_0': '42.50 ms @ cpu'}}}, 'duration': {'setup': '0.00 s(0.00%)', 'prepare': '4.81 s(49.18%)', 'parse': '0.08 s(0.82%)', 'total': '9.78 s(100.00%)'}, 'err_msg': 'Pipeline failed: Unsupported function type batch_norm', 'err_info': 'Traceback (most recent call last):\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py", line 162, in run_pipe\n self.parse()\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/pipeline.py", line 226, in parse\n info, report = self._parse()\n ^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/dynamic.py", line 157, in _parse\n info[name], report[name] = w_ctx["worker"].parse()\n ^^^^^^^^^^^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/pipeline/worker.py", line 320, in parse\n self._relax_mod, _ = stage_config["parser"](self._model, **parse_config)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/frontend/translate.py", line 119, in from_torch\n relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 960, in from_fx\n return TorchFXImporter().from_fx(\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 837, in from_fx\n assert (\nAssertionError: Unsupported function type batch_norm\n'}