在 ONNX IR 中的张量表示#
ONNX IR提供了 ir.TensorProtocol
接口,以便使用不同的数据结构作为张量的后备数据。除了传统的 onnx.TensorProto
之外,您还可以使用 np.ndarray
、torch.Tensor
、jax.Array
以及几乎任何其他东西来表示计算图中的张量。这使得它们可以通过相同的 TensorProtocol
接口进行访问和序列化,而在初始化期间不会发生额外的复制。
ir.TensorProtocol
#
ir.TensorProtocol
定义了一个只读接口,用于表示张量。实现该接口的张量类具有 name
、 shape
、 dtype
、 size
、 nbytes
和 metadata_props
等属性,用于描述张量的基本属性。此外,它还应实现两个方法 numpy
和 __array__
,这两个方法将从底层数据生成等效的 NumPy 数组。
备注
当与初始化器、常量值和张量属性交互时,最好假设使用 ir.TensorProtocol
,只有在需要检查具体类时才使用 isinstance()
。
张量类#
ir.TensorProtoTensor
#
使用 ir.TensorProtoTensor
作为对 proto 的包装以实现 ir.TensorProtocol
接口。您可以像往常一样访问 shape
、 dtype
等。只有在调用 numpy()
时才会产生副本。
直接初始化 ir.TensorProtoTensor
,如下所示,是可能的。然而,通常建议使用 ir.serde.deserialize_tensor
,因为它可以处理所有类型的 TensorProto
(例如,ir.TensorProtoTensor
不处理外部张量)。
import onnx
from onnxscript import ir
tensor_proto = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT16, (3,), [1, 2, 3])
tensor = ir.TensorProtoTensor(tensor_proto)
print("tensor: ", tensor) # TensorProtoTensor<INT16,[3]>(name='tensor')
print("shape: ", tensor.shape) # ir.Shape([3])
print("dtype: ", tensor.dtype) # ir.DataType.INT16
print(tensor.raw == tensor_proto) # The raw field is the exact tensor_proto provided at initialization
print("tobytes: ", tensor.tobytes()) # b'\x01\x00\x02\x00\x03\x00'
print("numpy: ", tensor.numpy()) # array([1, 2, 3], dtype=int16)
tensor: TensorProtoTensor<INT16,[3]>(name='tensor')
shape: [3]
dtype: INT16
True
tobytes: b'\x01\x00\x02\x00\x03\x00'
numpy: [1 2 3]
ir.ExternalTensor
#
存储在外部磁盘上的张量数据通常很大,加载时将占用内存。ir.ExternalTensor
类使用内存映射来避免将张量加载到内存中。您可以使用张量作为普通的 NumPy 数组,内存使用量最小。
请参阅 ir.serde.deserialize_tensor
以找到将 onnx.TensorProto
转换为 ir.ExternalTensor
的示例。
ir.Tensor
#
ir.Tensor
是围绕 NumPy 数组兼容的数组对象(如 np.ndarray
和 torch.Tensor
)的包装器。它最适合创建内存中的张量,而不将其转换为 TensorProto
,以减少转换开销。
备注
如果一个数组对象定义了 __array__
方法,则它是兼容的。
从数组创建张量,只需用 NumPy 数组初始化即可
import numpy as np
tensor = ir.Tensor(np.random.rand(1, 2))
tensor
Tensor<DOUBLE,[1,2]>(array([[0.71395225, 0.48701339]]), name=None)
初始化器将从数组中获取数据类型和形状信息。
如果要从 NumPy 数组以外的对象创建张量,您需要指定数据类型:
import torch
from onnxscript import ir
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16)
tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16)
print(tensor.numpy()) # array([1., 2., 3.], dtype=float16)
[1. 2. 3.]
字符串张量#
使用 ir.StringTensor
创建字符串张量。
回退 TensorProto
#
在以下场景中,展示了如何从 TensorProto
转换到 ir.Tensor
,运行一些计算,然后将其转换回 ir.Tensor
,最后 TensorProto
:
Show code cell content
from onnxscript import ir
import onnx
import numpy as np
# 1. Create the TensorProto
proto = onnx.helper.make_tensor(
"tensor", onnx.TensorProto.FLOAT16, [2, 3], [1, 2, 3, 4, 5, 6]
)
# 2. Create an IR Tensor from the Protobuf message
tensor = ir.serde.deserialize_tensor(proto)
# Note that we get a TensorProtoTensor that implements the TensorProtocol
print("tensor:", tensor) # TensorProtoTensor<FLOAT16,[2,3]>(name='tensor')
print("tensor.numpy():", tensor.numpy()) # [[1. 2. 3.]
# [4. 5. 6.]]
print("tensor.tobytes():", tensor.tobytes()) # b'\x00<\x00@\x00B\x00D\x00E\x00F'
# 3. Do computation using numpy
mean = tensor.numpy().mean(axis=0)
print("mean:", mean) # array([2.5, 3.5, 4.5], dtype=float16)
# 4. Create a Tensor from the ndarray. Note that we use ir.Tensor
tensor_mean = ir.Tensor(mean)
print("tensor_mean:", tensor_mean) # Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name='')
# 5. Obtain the TensorProto from ir.Tensor
mean_tensor_proto: onnx.TensorProto = ir.serde.serialize_tensor(tensor_mean)
print("mean_tensor_proto:", mean_tensor_proto)
print(
"onnx.numpy_helper.to_array(mean_tensor_proto):",
onnx.numpy_helper.to_array(mean_tensor_proto)
# array([2.5, 3.5, 4.5], dtype=float16)
)
# You can obtain the bytes data as well
print("tensor_mean.tobytes():", tensor_mean.tobytes())
print("Bytes same as proto:", mean_tensor_proto.raw_data == tensor_mean.tobytes())
# Explore other methods defined by TensorProtocol:
print("\n# Explore other methods defined by TensorProtocol:")
print("tensor_mean.shape:", tensor_mean.shape)
print("tensor_mean.dtype:", tensor_mean.dtype)
print("tensor_mean.name:", tensor_mean.name)
print("tensor_mean.doc_string:", tensor_mean.doc_string)
print("tensor_mean.raw:", tensor_mean.raw)
print("tensor_mean.metadata_props:", tensor_mean.metadata_props)
print("tensor_mean.size:", tensor_mean.size)
print("tensor_mean.nbytes:", tensor_mean.nbytes)
print("tensor_mean.raw:", tensor_mean.raw)
tensor: TensorProtoTensor<FLOAT16,[2,3]>(name='tensor')
tensor.numpy(): [[1. 2. 3.]
[4. 5. 6.]]
tensor.tobytes(): b'\x00<\x00@\x00B\x00D\x00E\x00F'
mean: [2.5 3.5 4.5]
tensor_mean: Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name=None)
mean_tensor_proto: dims: 3
data_type: 10
raw_data: "\000A\000C\200D"
onnx.numpy_helper.to_array(mean_tensor_proto): [2.5 3.5 4.5]
tensor_mean.tobytes(): b'\x00A\x00C\x80D'
Bytes same as proto: True
# Explore other methods defined by TensorProtocol:
tensor_mean.shape: [3]
tensor_mean.dtype: FLOAT16
tensor_mean.name: None
tensor_mean.doc_string: None
tensor_mean.raw: [2.5 3.5 4.5]
tensor_mean.metadata_props: {}
tensor_mean.size: 3
tensor_mean.nbytes: 6
tensor_mean.raw: [2.5 3.5 4.5]
处理非原生 NumPy 数据类型:bfloat16
、float8
、int4
#
ir.Tensor.numpy()
生成张量值的 NumPy 数组表示。当张量的数据类型为 BFLOAT16
、 FLOAT8[...]
或 [U]INT4
时,这些类型不支持 NumPy,将使用 ml_dtypes
包中的数据类型。
uint4
/ int4
总是解包;tobyte()
生成预期的打包表示。
ir.Tensor
的初始化需要 NumPy 数组遵循以下类型约束,或者具有 ml_dtypes
数据类型。
int8
用于(未打包的)int4
,符号位扩展到 8 位。uint8
用于(解包的)uint4
。uint8
用于 8 位数据类型,如float8
。uint16
用于bfloat16
。
以下示例展示了如何创建 FLOAT8E4M3FN
张量,变换其值,并创建新的张量来存储变换后的值。
from onnxscript import ir
import numpy as np
array = np.array([0b1, 0b11], dtype=np.uint8)
# The array is reinterpreted using the ml_dtypes package
tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)
print(tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
print("tensor.numpy():", tensor.numpy()) # [0.00195312 0.00585938]
# Compute
times_100 = tensor.numpy() * 100
print("times_100:", times_100)
# Create a new tensor out of the new value; dtype must be specified
new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)
# You can also directly create the tensor from the float8 array without specifying dtype
# new_tensor = ir.Tensor(times_100)
print("new_tensor:", new_tensor) # Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
print("new_tensor == times_100", new_tensor.numpy() == times_100) # array([ True, True])
Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
tensor.numpy(): [0.00195312 0.00585938]
times_100: [0.1875 0.5625]
new_tensor: Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
new_tensor == times_100 [ True True]
高级用法#
子类化 ir.Tensor
以实现更高效的访问和更广泛的支持 dtype
#
ir.Tensor
内部将任何与数组兼容的对象转换为 NumPy 数组,以生成 tobytes()
中的字节表示。由于额外的转换,这可能会降低效率。它还限制了对于 NumPy 不支持的数据类型(如 bfloat16
)的支持,因为 __array__
方法将失败。
为了完全支持来自其他框架的数组,通常创建专门的类来处理它们是个好主意。下面的 TorchTensor
类演示了您如何通过子类化 ir.Tensor
来处理 PyTorch 张量:
Show code cell content
import ctypes
from typing import Any
import torch
from onnxscript import ir
# Define utilities to convert PyTorch data types so users do not need to specify manually
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
torch.bfloat16: ir.DataType.BFLOAT16,
torch.bool: ir.DataType.BOOL,
torch.complex128: ir.DataType.COMPLEX128,
torch.complex64: ir.DataType.COMPLEX64,
torch.float16: ir.DataType.FLOAT16,
torch.float32: ir.DataType.FLOAT,
torch.float64: ir.DataType.DOUBLE,
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
torch.int8: ir.DataType.INT8,
torch.uint8: ir.DataType.UINT8,
}
def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
return _TORCH_DTYPE_TO_ONNX[dtype]
class TorchTensor(ir.Tensor):
def __init__(self, tensor: torch.Tensor):
# Pass the tensor as the raw data to ir.Tensor's constructor
super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))
def __array__(self, dtype: Any = None) -> "np.ndarray":
# numpy() calls __array__ in ir.Tensor
if self.dtype == ir.DataType.BFLOAT16:
return self.raw.view(torch.uint16).__array__(dtype)
if self.dtype in {
ir.DataType.FLOAT8E4M3FN,
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E5M2FNUZ
}:
return self.raw.view(torch.uint8).__array__(dtype)
return self.raw.__array__(dtype)
def tobytes(self) -> bytes:
# Implement tobytes to support native PyTorch types so we can use types like bloat16
# Reading from memory directly is also more efficient because
# it avoids copying to a NumPy array
tensor = self.raw.detach().cpu().contiguous()
return bytes(
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
tensor.data_ptr()
)
)
# Test the implementation
torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16)
tensor = TorchTensor(torch_tensor)
print("tensor: ", tensor)
print("numpy: ", tensor.numpy())
print("tobytes: ", tensor.tobytes()) # b'\x80?\x00@@@'
print("nbytes: ", tensor.nbytes) # 6
tensor: TorchTensor<BFLOAT16,[3]>(tensor([1., 2., 3.], dtype=torch.bfloat16), name=None)
numpy: [16256 16384 16448]
tobytes: b'\x80?\x00@@@'
nbytes: 6
该类实现了 tobytes()
,以生成张量在序列化为 ONNX 文件/TensorProto 时的正确字节表示。该类还实现了 __array__()
方法,以返回 NumPy 不支持的数据类型的位表示。这样,分析阶段仍然可以对这些值进行计算。
使用不同框架进行计算#
由于 ir.Tensor
实现了 __array__
方法和 __dlpack__
方法,其内容可以在不复制的情况下与计算框架共享。例如:
Show code cell content
from onnxscript import ir
# We can call numpy methods directly on ir.Tensor
import numpy as np
print(np.multiply(ir.Tensor(np.array([1, 2])), 42)) # array([42., 84.])
# We can transfer arrays to different frameworks
import jax.numpy as jnp
import jax
import torch
# Create ir.Tensor
jax_array = jnp.array([10., 20.])
ir_tensor_jax = ir.Tensor(jax_array, dtype=ir.DataType.FLOAT)
torch_tensor = torch.tensor([30., 40.])
ir_tensor_torch = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)
# Use numpy for computation
print(np.multiply(ir_tensor_jax, ir_tensor_torch)) # array([300., 800.], dtype=float32)
# Use jax for computation by calling from_dlpack to transfer the tensor data without copying when the device is the same
jax_array_from_ir = jax.dlpack.from_dlpack(ir_tensor_torch)
print(jax_array_from_ir + jax_array) # [40. 60.]
# Use PyTorch for computation
torch_tensor_from_ir = torch.from_dlpack(ir_tensor_jax)
print(torch_tensor_from_ir - torch_tensor) # tensor([-20., -20.])
# They can all be serialized into TensorProto
proto = ir.serde.serialize_tensor(ir_tensor_jax)
print(type(proto)) # <class 'onnx.onnx_ml_pb2.TensorProto'>
print(proto)
# The value is exactly the same as jax_array
print(ir.serde.deserialize_tensor(proto).numpy()) # [10. 20.]
[42. 84.]
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[300. 800.]
[40. 60.]
tensor([-20., -20.])
<class 'onnx.onnx_ml_pb2.TensorProto'>
dims: 2
data_type: 1
raw_data: "\000\000 A\000\000\240A"
[10. 20.]
这在您在图上创建需要计算具体值的传递时特别有用。您可以使用您喜欢的框架来创建传递。包含新创建的 ir.Tensor
的转换图将与下游传递兼容,即使它们利用其他计算框架也是如此。