PyTorch Relax 测试

PyTorch Relax 测试#

path = "/media/pc/data/board/arria10/lxw/tasks/tools/npuusertools/models/test/nanotrack_head/nanotrack_head.pt"
import torch
from torchvision import models
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx

shape = [1, 96, 8, 8]
input_info = [(shape, "float32")]
with torch.no_grad():
    # graph_model = 
    graph_module = torch.jit.load(path)
    # mod = from_fx(graph_model, input_info)
graph_module.graph.nodes()
<pybind11_builtins.iterator at 0x7fdf0df84a30>