ONNX Script chunk

ONNX Script chunk#

下面的例子是直接从新的 PyTorch ONNX 导出器改编而来,实现了对 torch.chunk() 的支持,该函数尝试将张量分割成指定数量的块。

from typing import Sequence
from onnxscript import opset18 as op, script, FLOAT, INT64

@script()
def aten_chunk(
    tensor: FLOAT[...], chunks: int, dim: int = 0,
) -> Sequence[FLOAT[...]]:
    neg_1 = op.Constant(value_ints=[-1])

    # Get size of specified dim
    dim_size = op.Shape(tensor)[dim]

    # Compute size/chunk to get the number of data in one chunk
    num_per_chunk = dim_size / chunks + op.Cast(dim_size % chunks > 0, to=INT64.dtype)

    # Compute real chunk number
    num_chunk = dim_size / num_per_chunk

    # Get something like [n, n, n, n, ...], total num_chunk
    list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))

    remainder = dim_size % num_per_chunk
    if remainder > 0:
        # Append the remainder to the [n, n, n, n, ..., r]
        list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)

    return op.SplitToSequence(tensor, list_split, axis=dim)

我们从 onnxscript 导入我们想要使用的 ONNX opset(在这个例子中是版本18)、@script 装饰器,以及 FLOAT 和 INT64 的张量类型。在 ONNX Script 中,张量形状是通过类型下标表示的,例如 FLOAT[2, 10],或者符号性地表示为 FLOAT["M", "N"],或者在张量形状未知的情况下使用 FLOAT[...]。如果没有下标(仅 FLOAT),该类型旨在表示标量(秩为 0 的张量)。

接下来,我们定义了一个带有类型注解的 aten_chunk 函数,并使用内置的 Python 语法和显式的 ONNX 算子调用来实现函数体。这个例子使用了各种二元表达式和一个 if 语句,但也支持许多其他的 Python 惯用构造。

我们还需要定义一个简单的模型来调用我们的 ONNX Script 函数,以便我们可以导出并验证一个端到端的例子:

@script()
def ten_chunks_model(tensor: FLOAT["M"]):
    return aten_chunk(tensor, chunks=10)

这个模型将简单地将提供的张量分割成十个张量,但它也展示了 ONNX 函数当然可以调用其他 ONNX 函数,而不仅仅是内置的 ONNX 算子。

我们现在将把 ONNX Script 模型导出到 ONNX,并在 Netron 中探索它。使用 @script 装饰的函数允许它们使用 to_model_proto 函数进行导出。

import onnx
onnx.save_model(
    ten_chunks_model.to_model_proto(),
    "ten_chunks_model.onnx",
)

图表展示了我们的两个 ONNX 函数;我们可以观察到原始的输入张量从 ten_chunks_model 流入 aten_chunk,以及属性 chunks=10。返回的是一系列最多包含 10 个张量的序列。正如人们所期望的,ONNX 中的函数可以定义一次,并在模型中任意多次调用。阅读更多关于核心 ONNX 概念的信息