ONNX Script chunk#
下面的例子是直接从新的 PyTorch ONNX 导出器改编而来,实现了对 torch.chunk()
的支持,该函数尝试将张量分割成指定数量的块。
from typing import Sequence
from onnxscript import opset18 as op, script, FLOAT, INT64
@script()
def aten_chunk(
tensor: FLOAT[...], chunks: int, dim: int = 0,
) -> Sequence[FLOAT[...]]:
neg_1 = op.Constant(value_ints=[-1])
# Get size of specified dim
dim_size = op.Shape(tensor)[dim]
# Compute size/chunk to get the number of data in one chunk
num_per_chunk = dim_size / chunks + op.Cast(dim_size % chunks > 0, to=INT64.dtype)
# Compute real chunk number
num_chunk = dim_size / num_per_chunk
# Get something like [n, n, n, n, ...], total num_chunk
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))
remainder = dim_size % num_per_chunk
if remainder > 0:
# Append the remainder to the [n, n, n, n, ..., r]
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)
return op.SplitToSequence(tensor, list_split, axis=dim)
我们从 onnxscript 导入我们想要使用的 ONNX opset(在这个例子中是版本18)、@script
装饰器,以及 FLOAT 和 INT64 的张量类型。在 ONNX Script 中,张量形状是通过类型下标表示的,例如 FLOAT[2, 10]
,或者符号性地表示为 FLOAT["M", "N"]
,或者在张量形状未知的情况下使用 FLOAT[...]
。如果没有下标(仅 FLOAT),该类型旨在表示标量(秩为 0 的张量)。
接下来,我们定义了一个带有类型注解的 aten_chunk
函数,并使用内置的 Python 语法和显式的 ONNX 算子调用来实现函数体。这个例子使用了各种二元表达式和一个 if
语句,但也支持许多其他的 Python 惯用构造。
我们还需要定义一个简单的模型来调用我们的 ONNX Script 函数,以便我们可以导出并验证一个端到端的例子:
@script()
def ten_chunks_model(tensor: FLOAT["M"]):
return aten_chunk(tensor, chunks=10)
这个模型将简单地将提供的张量分割成十个张量,但它也展示了 ONNX 函数当然可以调用其他 ONNX 函数,而不仅仅是内置的 ONNX 算子。
我们现在将把 ONNX Script 模型导出到 ONNX,并在 Netron 中探索它。使用 @script
装饰的函数允许它们使用 to_model_proto
函数进行导出。
import onnx
onnx.save_model(
ten_chunks_model.to_model_proto(),
"ten_chunks_model.onnx",
)
图表展示了我们的两个 ONNX 函数;我们可以观察到原始的输入张量从 ten_chunks_model
流入 aten_chunk
,以及属性 chunks=10。返回的是一系列最多包含 10 个张量的序列。正如人们所期望的,ONNX 中的函数可以定义一次,并在模型中任意多次调用。阅读更多关于核心 ONNX 概念的信息。