ONNX Script 模型本地函数

ONNX Script 模型本地函数#

ONNX 中的模型可能包含模型本地函数。当将 onnxscript 函数转换为 ModelProto 时,默认行为是将所有被传递调用的 function-ops 的函数定义作为生成模型中的模型本地函数包含在内(对于这些函数,已经看到了 onnxscript 函数定义)。调用者可以通过明确提供要包含在生成模型中的 FunctionProtos 列表来覆盖此行为。

首先,让我们定义一个调用其他 ONNXScript 函数的 ONNXScript 函数。

import onnx

from onnxscript import FLOAT, script
from onnxscript import opset15 as op
from onnxscript.values import Opset

# A dummy opset used for model-local functions
local = Opset("local", 1)


@script(local, default_opset=op)
def diff_square(x, y):
    diff = x - y
    return diff * diff


@script(local)
def sum(z):
    return op.ReduceSum(z, keepdims=1)


@script()
def l2norm(x: FLOAT["N"], y: FLOAT["N"]) -> FLOAT[1]:  # noqa: F821
    return op.Sqrt(sum(diff_square(x, y)))

让我们看看默认生成的模型是什么样的:

model = l2norm.to_model_proto()
print(onnx.printer.to_text(model))
<
   ir_version: 8,
   opset_import: ["local" : 1, "" : 15]
>
l2norm (float[N] x, float[N] y) => (float[1] return_val) {
   tmp = local.diff_square (x, y)
   tmp_0 = local.sum (tmp)
   return_val = Sqrt (tmp_0)
}
<
  domain: "local",
  opset_import: ["" : 15]
>
sum (z) => (return_val)
{
   return_val = ReduceSum <keepdims: int = 1> (z)
}
<
  domain: "local",
  opset_import: ["" : 15]
>
diff_square (x, y) => (return_val)
{
   diff = Sub (x, y)
   return_val = Mul (diff, diff)
}

现在,让我们明确指定要包含哪些函数。首先,生成一个不包含模型本地函数的模型:

model = l2norm.to_model_proto(functions=[])
print(onnx.printer.to_text(model))
<
   ir_version: 8,
   opset_import: ["local" : 1, "" : 15]
>
l2norm (float[N] x, float[N] y) => (float[1] return_val) {
   tmp = local.diff_square (x, y)
   tmp_0 = local.sum (tmp)
   return_val = Sqrt (tmp_0)
}

现在,生成一个包含一个模型本地函数的模型:

model = l2norm.to_model_proto(functions=[sum])
print(onnx.printer.to_text(model))
<
   ir_version: 8,
   opset_import: ["local" : 1, "" : 15]
>
l2norm (float[N] x, float[N] y) => (float[1] return_val) {
   tmp = local.diff_square (x, y)
   tmp_0 = local.sum (tmp)
   return_val = Sqrt (tmp_0)
}
<
  domain: "local",
  opset_import: ["" : 15]
>
sum (z) => (return_val)
{
   return_val = ReduceSum <keepdims: int = 1> (z)
}