注册 aten::linalg_inv 到 ONNX

注册 aten::linalg_inv 到 ONNX#

参考:test_pytorch_onnx_shape_inference

import torch
opset_version = 14
class CustomInverse(torch.nn.Module):
    def forward(self, x):
        return torch.inverse(x) + x

def linalg_inv_settype(g, self):
    return g.op("com.microsoft::Inverse", self).setType(
        self.type().with_dtype(torch.float).with_sizes([None, 3, 3])
    )

torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9)
model = CustomInverse()
x = torch.randn(2, 3, 3)
torch.onnx.export(
    model,
    (x,),
    "inv.onnx",
    opset_version=opset_version,
    custom_opsets={"com.microsoft": 1},
    input_names=["x"],
    dynamic_axes={"x": {0: "batch"}},
)
================ Diagnostic Run torch.onnx.export version 2.0.0 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================