在 ONNX IR 中的张量表示#
ONNX IR提供了 ir.TensorProtocol
接口,以便使用不同的数据结构作为张量的后备数据。除了传统的 onnx.TensorProto
之外,您还可以使用 np.ndarray
、torch.Tensor
、jax.Array
以及几乎任何其他东西来表示计算图中的张量。这使得它们可以通过相同的 TensorProtocol
接口进行访问和序列化,而在初始化期间不会发生额外的复制。
ir.TensorProtocol
#
ir.TensorProtocol
定义了一个只读接口,用于表示张量。实现该接口的张量类具有 name
、 shape
、 dtype
、 size
、 nbytes
和 metadata_props
等属性,用于描述张量的基本属性。此外,它还应实现两个方法 numpy
和 __array__
,这两个方法将从底层数据生成等效的 NumPy 数组。
备注
当与初始化器、常量值和张量属性交互时,最好假设使用 ir.TensorProtocol
,只有在需要检查具体类时才使用 isinstance()
。
张量类#
ir.TensorProtoTensor
#
使用 ir.TensorProtoTensor
作为对 proto 的包装以实现 ir.TensorProtocol
接口。您可以像往常一样访问 shape
、 dtype
等。只有在调用 numpy()
时才会产生副本。
直接初始化 ir.TensorProtoTensor
,如下所示,是可能的。然而,通常建议使用 ir.serde.deserialize_tensor
,因为它可以处理所有类型的 TensorProto
(例如,ir.TensorProtoTensor
不处理外部张量)。
import onnx
from onnxscript import ir
tensor_proto = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT16, (3,), [1, 2, 3])
tensor = ir.TensorProtoTensor(tensor_proto)
print("tensor: ", tensor) # TensorProtoTensor<INT16,[3]>(name='tensor')
print("shape: ", tensor.shape) # ir.Shape([3])
print("dtype: ", tensor.dtype) # ir.DataType.INT16
print(tensor.raw == tensor_proto) # The raw field is the exact tensor_proto provided at initialization
print("tobytes: ", tensor.tobytes()) # b'\x01\x00\x02\x00\x03\x00'
print("numpy: ", tensor.numpy()) # array([1, 2, 3], dtype=int16)
tensor: TensorProtoTensor<INT16,[3]>(name='tensor')
shape: [3]
dtype: INT16
True
tobytes: b'\x01\x00\x02\x00\x03\x00'
numpy: [1 2 3]
ir.ExternalTensor
#
存储在外部磁盘上的张量数据通常很大,加载时将占用内存。ir.ExternalTensor
类使用内存映射来避免将张量加载到内存中。您可以使用张量作为普通的 NumPy 数组,内存使用量最小。
请参阅 ir.serde.deserialize_tensor
以找到将 onnx.TensorProto
转换为 ir.ExternalTensor
的示例。
ir.Tensor
#
ir.Tensor
是围绕 NumPy 数组兼容的数组对象(如 np.ndarray
和 torch.Tensor
)的包装器。它最适合创建内存中的张量,而不将其转换为 TensorProto
,以减少转换开销。
备注
如果一个数组对象定义了 __array__
方法,则它是兼容的。
从数组创建张量,只需用 NumPy 数组初始化即可
import numpy as np
tensor = ir.Tensor(np.random.rand(1, 2))
tensor
Tensor<DOUBLE,[1,2]>(array([[0.71395225, 0.48701339]]), name=None)
初始化器将从数组中获取数据类型和形状信息。
如果要从 NumPy 数组以外的对象创建张量,您需要指定数据类型:
import torch
from onnxscript import ir
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16)
tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16)
print(tensor.numpy()) # array([1., 2., 3.], dtype=float16)
[1. 2. 3.]
字符串张量#
使用 ir.StringTensor
创建字符串张量。
回退 TensorProto
#
在以下场景中,展示了如何从 TensorProto
转换到 ir.Tensor
,运行一些计算,然后将其转换回 ir.Tensor
,最后 TensorProto
:
处理非原生 NumPy 数据类型:bfloat16
、float8
、int4
#
ir.Tensor.numpy()
生成张量值的 NumPy 数组表示。当张量的数据类型为 BFLOAT16
、 FLOAT8[...]
或 [U]INT4
时,这些类型不支持 NumPy,将使用 ml_dtypes
包中的数据类型。
uint4
/ int4
总是解包;tobyte()
生成预期的打包表示。
ir.Tensor
的初始化需要 NumPy 数组遵循以下类型约束,或者具有 ml_dtypes
数据类型。
int8
用于(未打包的)int4
,符号位扩展到 8 位。uint8
用于(解包的)uint4
。uint8
用于 8 位数据类型,如float8
。uint16
用于bfloat16
。
以下示例展示了如何创建 FLOAT8E4M3FN
张量,变换其值,并创建新的张量来存储变换后的值。
from onnxscript import ir
import numpy as np
array = np.array([0b1, 0b11], dtype=np.uint8)
# The array is reinterpreted using the ml_dtypes package
tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)
print(tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
print("tensor.numpy():", tensor.numpy()) # [0.00195312 0.00585938]
# Compute
times_100 = tensor.numpy() * 100
print("times_100:", times_100)
# Create a new tensor out of the new value; dtype must be specified
new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)
# You can also directly create the tensor from the float8 array without specifying dtype
# new_tensor = ir.Tensor(times_100)
print("new_tensor:", new_tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
print("new_tensor == times_100", new_tensor.numpy() == times_100) # array([ True, True])
Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
tensor.numpy(): [0.00195312 0.00585938]
times_100: [0.1875 0.5625]
new_tensor: Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
new_tensor == times_100 [ True True]
高级用法#
子类化 ir.Tensor
以实现更高效的访问和更广泛的支持 dtype
#
ir.Tensor
内部将任何与数组兼容的对象转换为 NumPy 数组,以生成 tobytes()
中的字节表示。由于额外的转换,这可能会降低效率。它还限制了对于 NumPy 不支持的数据类型(如 bfloat16
)的支持,因为 __array__
方法将失败。
为了完全支持来自其他框架的数组,通常创建专门的类来处理它们是个好主意。下面的 TorchTensor
类演示了您如何通过子类化 ir.Tensor
来处理 PyTorch 张量:
该类实现了 tobytes()
,以生成张量在序列化为 ONNX 文件/TensorProto 时的正确字节表示。该类还实现了 __array__()
方法,以返回 NumPy 不支持的数据类型的位表示。这样,分析阶段仍然可以对这些值进行计算。
使用不同框架进行计算#
由于 ir.Tensor
实现了 __array__
方法和 __dlpack__
方法,其内容可以在不复制的情况下与计算框架共享。例如:
这在您在图上创建需要计算具体值的传递时特别有用。您可以使用您喜欢的框架来创建传递。包含新创建的 ir.Tensor
的转换图将与下游传递兼容,即使它们利用其他计算框架也是如此。