ONNX Script 简介#
ONNX Script 使得开发者能够使用 Python 的子集自然地编写 ONNX 函数和模型。它旨在实现以下几点:
表达性强:能够编写所有 ONNX 函数。
简洁明了:函数代码自然且简单。
可调试:允许进行即时模式评估,可以使用标准的 Python 调试器来调试代码。
请注意,ONNX Script 并不打算支持整个 Python 语言。
ONNX Script 提供了一些主要功能,用于编写和调试 ONNX 模型和函数:
转换器,将 Python ONNX Script 函数转换为 ONNX 图,通过遍历 Python 抽象语法树来构建与该函数等效的 ONNX 图。
运行时填充层(shim),允许这样的函数被评估(以“即时模式”)。当前这个功能依赖于 ONNX Runtime 来执行 ONNX 操作,并且正在进行一个仅支持 Python 的 ONNX 参考运行时的开发,也将得到支持。
转换器,将 ONNX 模型和函数转换成 ONNX Script。这个功能可以用于完全地在 ONNX Script ↔ ONNX 图之间进行循环转换。
请注意,运行时旨在帮助理解和调试函数定义。性能不是此处的目标。
ONNX Script 尝鲜#
以下是使用 ONNX Script 的简单示例。
from onnxscript import script
# 使用 ONNX 算子集17来定义以下函数
from onnxscript import opset17 as op
# 使用 script 装饰器来表明接下来的函数旨在被转换为 ONNX 格式
@script()
def MatmulAdd(X, Wt, Bias):
return op.MatMul(X, Wt) + Bias
/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/converter.py:823: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
装饰器解析函数的代码并将其转换为中间表示形式。如果转换失败,它将触发报错信息,指示检测到的错误。如果成功,可以生成如下所示的函数的相应 ONNX 表示(FunctionProto 类型的值):
fp = MatmulAdd.to_function_proto() # returns an onnx.FunctionProto
类似地,可以生成 ONNX 模型。ONNX 模型和 ONNX 函数之间有一些区别。例如,ONNX 模型必须指定输入和输出的类型(与 ONNX 函数不同)。以下示例说明了我们如何生成 ONNX 模型:
from onnxscript import script
from onnxscript import opset15 as op
from onnxscript import FLOAT
@script()
def MatmulAddModel(X : FLOAT[64, 128] , Wt: FLOAT[128, 10], Bias: FLOAT[10]) -> FLOAT[64, 10]:
return op.MatMul(X, Wt) + Bias
model = MatmulAddModel.to_model_proto() # returns an onnx.ModelProto
ONNX Script 即时评估模式#
即时评估模式主要用于调试和检查中间结果是否符合预期。之前定义的函数可以如下调用,并且这将在即时评估模式下执行。
import numpy as np
x = np.array([[0, 1], [2, 3]], dtype=np.float32)
wt = np.array([[0, 1], [2, 3]], dtype=np.float32)
bias = np.array([0, 1], dtype=np.float32)
result = MatmulAdd(x, wt, bias)
/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/onnxscript/evaluator.py:277: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = function.param_schemas()