Codegen 模块#
codegen 模块和 MSCGraph 一起使用,用于将 MSCGraph 转译成 Python 脚本或 C++ 脚本。
%cd ..
from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
/media/pc/data/lxw/ai/tvm-book/doc/read/msc
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
from tvm.contrib.msc.framework.torch import codegen as torch_codegen
from tvm.contrib.msc.framework.torch.frontend import translate as torch_translate
from tvm.contrib.msc.core import utils as msc_utils
from tvm.contrib.msc.core.frontend import translate
from graph.model import get_model
input_info = [((1, 3, 4, 4), "float32")] # 给定输入 shape 和数据类型
mod, torch_fx_model = get_model(input_info)
graph, weights = translate.from_relax(mod)
build_folder = msc_utils.msc_dir(f"{temp_dir}/tvm_test")
mod = tvm_codegen.to_relax(graph, weights, build_folder=build_folder)
mod.show()
build_folder = msc_utils.msc_dir(f"{temp_dir}/torch_test")
graph, weights = torch_translate.from_torch(torch_fx_model, input_info, via_relax=True)
model = torch_codegen.to_torch(graph, weights, build_folder=build_folder)
model
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(inp_0: R.Tensor((1, 3, 4, 4), dtype="float32")) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
with R.dataflow():
conv2d: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(inp_0, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
relu: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.relu(conv2d)
R.output(relu)
return relu
# Metadata omitted. Use show_meta=True in script() method to show it.
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(folder.relpath(graph.name + ".pth"))
main(
(conv2d): Conv2d(3, 6, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU()
)
产生的代码片段:
import os
import numpy as np
from typing import List, Dict, Any
import tvm
from tvm.contrib.msc.core import utils as msc_utils
from tvm import relax
# Define the helpers
def load_data(name: str, shape: List[int], dtype: str) -> np.ndarray:
path = os.path.join("baseline", name + ".bin")
if os.path.isfile(path):
data = np.fromfile(path, dtype=dtype).reshape(shape)
else:
data = np.ones((shape)).astype(dtype)
return data
# Define the graph
def main(res_0: relax.Var) -> tvm.IRModule:
inputs = [res_0]
# Define the weights
weight_1 = relax.Var("const", relax.TensorStructInfo([6, 3, 1, 1], "float32"))
inputs.append(weight_1)
# Define the module
block_builder = relax.BlockBuilder()
with block_builder.function(name="main", params=inputs.copy()):
# conv2d(nn.conv2d): <res_0> -> <res_1>
res_1 = relax.op.nn.conv2d(res_0, weight_1, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
res_1 = block_builder.emit(res_1, name_hint="conv2d")
# relu(nn.relu): <res_1> -> <res_2>
res_2 = relax.op.nn.relu(res_1)
res_2 = block_builder.emit(res_2, name_hint="relu")
# Emit the outputs
block_builder.emit_func_output(res_2)
mod = block_builder.finalize()
return mod
# Define the test
if __name__ == "__main__":
# Prepare test datas
inputs = {}
golden = {}
inputs["inp_0"] = load_data("inp_0", [1, 3, 4, 4], "float32")
golden["relu"] = load_data("relu", [1, 6, 4, 4], "float32")
# Build and inference the graph
res_0 = relax.Var("inp_0", relax.TensorStructInfo([1, 3, 4, 4], "float32"))
# Build Module
mod = main(res_0)
# Load weights
with open("main_params.bin", "rb") as f:
params = tvm.runtime.load_param_dict(f.read())
bind_params = tvm.relax.transform.BindParams("main", params)
mod = bind_params(mod)
target = tvm.target.Target("llvm")
mod = tvm.relax.transform.LegalizeOps()(mod)
with tvm.transform.PassContext(opt_level=3):
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
f_main = vm["main"]
outputs = f_main(inputs["inp_0"])
msc_utils.compare_arrays(golden, outputs, verbose="detail")
import os
import numpy as np
from typing import List, Dict, Any
import tvm
from tvm.contrib.msc.core import utils as msc_utils
import torch
from torch import nn
from torch.nn import functional
# Define the helpers
def load_data(name: str, shape: List[int], dtype: str) -> np.ndarray:
path = os.path.join("baseline", name + ".bin")
if os.path.isfile(path):
data = np.fromfile(path, dtype=dtype).reshape(shape)
else:
data = np.ones((shape)).astype(dtype)
return data
# Define the graph
class main(torch.nn.Module):
def __init__(self: torch.nn.Module) -> torch.nn.Module:
super(main, self).__init__()
# conv2d(nn.conv2d): <res_0> -> <res_1>
self.conv2d = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=[1, 1], stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1, bias=False)
# relu(nn.relu): <res_1> -> <res_2>
self.relu = nn.ReLU()
def forward(self: torch.nn.Module, res_0: torch.Tensor) -> List[torch.Tensor]:
# conv2d(nn.conv2d): <res_0> -> <res_1>
res_1 = self.conv2d(res_0)
# relu(nn.relu): <res_1> -> <res_2>
res_2 = self.relu(res_1)
outputs = res_2
return outputs
# Define the test
if __name__ == "__main__":
# Prepare test datas
inputs = {}
golden = {}
inputs["inp_0"] = load_data("inp_0", [1, 3, 4, 4], "float32")
golden["relu"] = load_data("relu", [1, 6, 4, 4], "float32")
# Build and inference the graph
# Build Model
model = main()
# Load weights
weights = torch.load("main.pth")
model.load_state_dict(weights)
res_0 = torch.from_numpy(inputs["inp_0"])
outputs = model(res_0)
msc_utils.compare_arrays(golden, outputs, verbose="detail")