注册自定义 ONNX 算子

注册自定义 ONNX 算子#

import functools
import torch
from torch.onnx import symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils, registration
from torch.onnx import register_custom_op_symbolic

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.asinh(x)

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
@_onnx_symbolic("aten::asinh")
# @symbolic_helper.parse_args("v")
# @_beartype.beartype
def asinh_symbolic(g: jit_utils.GraphContext, self, *, out=None):
    return g.op("Asinh", self)
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 14)

model = Model()
inputs = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, inputs, 'asinh.onnx')
================ Diagnostic Run torch.onnx.export version 2.0.0 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
linalg_inv