Relay 转 ONNX

Relay 转 ONNX#

import set_env
import numpy as np
import onnxruntime as rt

import tvm
from tvm import relay
from tvm.contrib.target.onnx import to_onnx
from tvm.relay.testing import run_infer_type

def func_to_onnx(mod, params, name):
    onnx_model = to_onnx(mod, params, name, path=None)
    return onnx_model.SerializeToString()

def run_onnx(mod, params, name, input_data):
    onnx_model = func_to_onnx(mod, params, name)
    sess = rt.InferenceSession(onnx_model)
    input_names = {}
    for input, data in zip(sess.get_inputs(), input_data):
        input_names[input.name] = data
    output_names = [output.name for output in sess.get_outputs()]
    res = sess.run(output_names, input_names)
    return res[0]

def get_data(in_data_shapes, dtype="float32"):
    in_data = OrderedDict()
    for name, shape in in_data_shapes.items():
        in_data[name] = np.random.uniform(size=shape).astype(dtype)
    return in_data


def run_relay(mod, params, in_data):
    target = "llvm"
    dev = tvm.device("llvm", 0)
    in_data = [tvm.nd.array(value) for value in in_data.values()]
    return (
        relay.create_executor("graph", mod, device=dev, target=target)
        .evaluate()(*in_data, **params)
        .numpy()
    )


def _verify_results(mod, params, name, in_data):
    a = run_relay(mod, params, in_data)
    b = run_onnx(mod, params, name, in_data.values())
    np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-7)