# PyTorch 模型

In [None]:
import torch
import torch.nn.functional as F
from torch import fx
from torch.nn import Module
from torchvision.models.resnet import ResNet18_Weights, resnet18

In [None]:
from tvm.relax.frontend.torch import from_fx
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend import detach_params

def verify_model(torch_model, input_info, binding, expected):
    graph_model = fx.symbolic_trace(torch_model)
    with torch.no_grad():
        mod = from_fx(graph_model, input_info)
    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
    expected = relax.transform.BindParams("main", binding)(expected)
    tvm.ir.assert_structural_equal(mod, expected)

In [7]:
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
input_info = [([1, 3, 224, 224], "float32")]
graph_model = fx.symbolic_trace(torch_model)
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
    mod = relax.transform.FoldConstant()(mod)
    # mod = relax.transform.FuseOps()(mod)
    # mod = relax.get_pipeline("zero")(mod)

In [8]:
mod.show()

In [None]:
import numpy as np

target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
device = tvm.cpu()
vm = relax.VirtualMachine(ex, device)
data = np.random.rand(1, 3, 224, 224).astype("float32")
tvm_data = tvm.nd.array(data, device=device)
tvm_output = vm["main"](tvm_data).numpy()

In [None]:
with torch.no_grad():
    torch_output = graph_model(torch.from_numpy(data)).numpy()

In [None]:
np.testing.assert_allclose(tvm_output, torch_output, rtol=1e-07, atol=1e-5)