翻译 PyTorch 代码

目录

翻译 PyTorch 代码#

import set_env
import numpy as np

import torch
from torch.nn import Module

import tvm.testing
from tvm.contrib.msc.framework.torch.frontend import translate
from tvm.contrib.msc.framework.torch import codegen
def verify_model(torch_model, input_info, via_relax=True):
    """比较 torch 模型结果"""

    graph, weights = translate.from_torch(torch_model, input_info, via_relax=via_relax)
    model = codegen.to_torch(graph, weights)
    torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i in input_info]
    with torch.no_grad():
        golden = torch_model(*torch_datas)
    with torch.no_grad():
        if not graph.get_inputs():
            result = model()
        else:
            result = model(*torch_datas)
    if not isinstance(golden, (list, tuple)):
        golden = [golden]
    if not isinstance(result, (list, tuple)):
        result = [result]
    assert len(golden) == len(result), "golden {} mismatch with result {}".format(
        len(golden), len(result)
    )
    for gol_r, new_r in zip(golden, result):
        if isinstance(gol_r, torch.Tensor):
            tvm.testing.assert_allclose(
                gol_r.detach().numpy(), new_r.detach().numpy(), atol=1e-5, rtol=1e-5
            )
        else:
            assert gol_r == new_r

conv1d#

class Conv1D1(Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)

    def forward(self, data):
        return self.conv(data)

class Conv1D2(Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)

    def forward(self, data):
        return self.conv(data)

input_info = [([1, 3, 10], "float32")]
for via_relax in [True, False]:
    verify_model(Conv1D1(), input_info, via_relax)
    verify_model(Conv1D2(), input_info, via_relax)