注册 aten::linalg_inv
到 ONNX#
参考:test_pytorch_onnx_shape_inference
import torch
opset_version = 14
class CustomInverse(torch.nn.Module):
def forward(self, x):
return torch.inverse(x) + x
def linalg_inv_settype(g, self):
return g.op("com.microsoft::Inverse", self).setType(
self.type().with_dtype(torch.float).with_sizes([None, 3, 3])
)
torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9)
model = CustomInverse()
x = torch.randn(2, 3, 3)
torch.onnx.export(
model,
(x,),
"inv.onnx",
opset_version=opset_version,
custom_opsets={"com.microsoft": 1},
input_names=["x"],
dynamic_axes={"x": {0: "batch"}},
)
================ Diagnostic Run torch.onnx.export version 2.0.0 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================