开始使用 ONNX IR#
ONNX IR 与 ONNX Script 包一起提供,可以通过 onnxscript.ir
获取。要从 ONNX 文件创建 IR
对象,将其加载为 ModelProto
并调用 ir.from_proto()
或 ir.serde.deserialize_model
:
参考:ONNX IR 快速上手
# Define an example model for this example
MODEL_TEXT = r"""
<
ir_version: 8,
opset_import: ["" : 18],
producer_name: "pytorch",
producer_version: "2.0.0"
>
torch_jit (float[5,5,5] input_0) => (float[5,5] val_19, float[5,5] val_6) {
val_1 = Constant <value_int: ints = [1]> ()
val_2 = Shape <start: int = 0> (val_1)
val_3 = Size (val_2)
val_4 = Constant <value: tensor = int64 {0}> ()
val_5 = Equal (val_3, val_4)
val_6 = ReduceMean <keepdims: int = 0, noop_with_empty_axes: int = 0> (input_0, val_1)
val_7 = ReduceMean <keepdims: int = 1, noop_with_empty_axes: int = 0> (input_0, val_1)
val_8 = Shape <start: int = 0> (input_0)
val_9 = Gather <axis: int = 0> (val_8, val_1)
val_10 = ReduceProd <keepdims: int = 0, noop_with_empty_axes: int = 0> (val_9)
val_11 = Sub (input_0, val_7)
val_12 = Mul (val_11, val_11)
val_13 = ReduceMean <keepdims: int = 0, noop_with_empty_axes: int = 0> (val_12, val_1)
val_14 = Cast <to: int = 1> (val_10)
val_15 = Mul (val_13, val_14)
val_16 = Constant <value: tensor = int64 {1}> ()
val_17 = Sub (val_10, val_16)
val_18 = Cast <to: int = 1> (val_17)
val_19 = Div (val_15, val_18)
}
"""
import onnx
from onnxscript import ir
# Load the model as onnx.ModelProto
# You can also load the model from a file using onnx.load("model.onnx")
model_proto = onnx.parser.parse_model(MODEL_TEXT)
# Create an IR object from the model
model = ir.serde.deserialize_model(model_proto)
现在我们可以探索 IR 对象了:
print(f"The main graph has {len(model.graph)} nodes.")
The main graph has 19 nodes.
输入信息:
print(model.graph.inputs)
[Value('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None)]
输出信息:
print(model.graph.outputs)
[Value('val_19', type=Tensor(FLOAT), shape=[5,5], producer=, index=0), Value('val_6', type=Tensor(FLOAT), shape=[5,5], producer=, index=0)]
使用第一个输入的节点。
print(list(model.graph.inputs[0].uses()))
[(Node(name='', domain='', op_type='ReduceMean', inputs=(Value('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None), Value('val_1', type=None, shape=None, producer=, index=0)), attributes=OrderedDict({'keepdims': Attr('keepdims', INT, 0), 'noop_with_empty_axes': Attr('noop_with_empty_axes', INT, 0)}), overload='', outputs=(Value('val_6', type=Tensor(FLOAT), shape=[5,5], producer=, index=0),), version=None, doc_string=None), 0), (Node(name='', domain='', op_type='ReduceMean', inputs=(Value('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None), Value('val_1', type=None, shape=None, producer=, index=0)), attributes=OrderedDict({'keepdims': Attr('keepdims', INT, 1), 'noop_with_empty_axes': Attr('noop_with_empty_axes', INT, 0)}), overload='', outputs=(Value('val_7', type=None, shape=None, producer=, index=0),), version=None, doc_string=None), 0), (Node(name='', domain='', op_type='Shape', inputs=(Value('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None),), attributes=OrderedDict({'start': Attr('start', INT, 0)}), overload='', outputs=(Value('val_8', type=None, shape=None, producer=, index=0),), version=None, doc_string=None), 0), (Node(name='', domain='', op_type='Sub', inputs=(Value('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None), Value('val_7', type=None, shape=None, producer=, index=0)), attributes=OrderedDict(), overload='', outputs=(Value('val_11', type=None, shape=None, producer=, index=0),), version=None, doc_string=None), 0)]
产生最后一个输出(作为第i个输出)的节点。
print(model.graph.outputs[-1].producer())
print(model.graph.outputs[-1].index())
%"val_6"<FLOAT,[5,5]> ⬅️ ::ReduceMean(%"input_0", %"val_1") {keepdims=0, noop_with_empty_axes=0}
0
打印计算图:
model.graph.display(
page=False
) # Set page=True to use a pager in the terminal so long outputs are scrollable
graph(
name=torch_jit,
inputs=(
%"input_0"<FLOAT,[5,5,5]>
),
outputs=(
%"val_19"<FLOAT,[5,5]>,
%"val_6"<FLOAT,[5,5]>
),
) {
0 | # :anonymous_node:140400872897376
%"val_1"<?,?> ⬅️ ::Constant() {value_int=[1]}
1 | # :anonymous_node:140398669428592
%"val_2"<?,?> ⬅️ ::Shape(%"val_1") {start=0}
2 | # :anonymous_node:140398669429456
%"val_3"<?,?> ⬅️ ::Size(%"val_2")
3 | # :anonymous_node:140398669429600
%"val_4"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')}
4 | # :anonymous_node:140398669429744
%"val_5"<?,?> ⬅️ ::Equal(%"val_3", %"val_4")
5 | # :anonymous_node:140398669429888
%"val_6"<FLOAT,[5,5]> ⬅️ ::ReduceMean(%"input_0", %"val_1") {keepdims=0, noop_with_empty_axes=0}
6 | # :anonymous_node:140398669430032
%"val_7"<?,?> ⬅️ ::ReduceMean(%"input_0", %"val_1") {keepdims=1, noop_with_empty_axes=0}
7 | # :anonymous_node:140398669430176
%"val_8"<?,?> ⬅️ ::Shape(%"input_0") {start=0}
8 | # :anonymous_node:140398669430320
%"val_9"<?,?> ⬅️ ::Gather(%"val_8", %"val_1") {axis=0}
9 | # :anonymous_node:140398669430464
%"val_10"<?,?> ⬅️ ::ReduceProd(%"val_9") {keepdims=0, noop_with_empty_axes=0}
10 | # :anonymous_node:140398669430608
%"val_11"<?,?> ⬅️ ::Sub(%"input_0", %"val_7")
11 | # :anonymous_node:140398668218448
%"val_12"<?,?> ⬅️ ::Mul(%"val_11", %"val_11")
12 | # :anonymous_node:140398668218592
%"val_13"<?,?> ⬅️ ::ReduceMean(%"val_12", %"val_1") {keepdims=0, noop_with_empty_axes=0}
13 | # :anonymous_node:140398668218880
%"val_14"<?,?> ⬅️ ::Cast(%"val_10") {to=1}
14 | # :anonymous_node:140398668219024
%"val_15"<?,?> ⬅️ ::Mul(%"val_13", %"val_14")
15 | # :anonymous_node:140398668219456
%"val_16"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')}
16 | # :anonymous_node:140398668219600
%"val_17"<?,?> ⬅️ ::Sub(%"val_10", %"val_16")
17 | # :anonymous_node:140398668219744
%"val_18"<?,?> ⬅️ ::Cast(%"val_17") {to=1}
18 | # :anonymous_node:140398668219888
%"val_19"<FLOAT,[5,5]> ⬅️ ::Div(%"val_15", %"val_18")
return %"val_19"<FLOAT,[5,5]>, %"val_6"<FLOAT,[5,5]>
}
Tip: Install the rich library with 'pip install rich' to pretty print this Graph.
将 IR
对象转换回 ModelProto
:
model_proto_back = ir.serde.serialize_model(model)