在 ONNX IR 中的张量表示#

参考:ir.TensorProtocol

ONNX IR提供了 ir.TensorProtocol 接口,以便使用不同的数据结构作为张量的后备数据。除了传统的 onnx.TensorProto 之外,您还可以使用 np.ndarraytorch.Tensorjax.Array 以及几乎任何其他东西来表示计算图中的张量。这使得它们可以通过相同的 TensorProtocol 接口进行访问和序列化,而在初始化期间不会发生额外的复制。

ir.TensorProtocol#

ir.TensorProtocol 定义了一个只读接口,用于表示张量。实现该接口的张量类具有 nameshapedtypesizenbytesmetadata_props 等属性,用于描述张量的基本属性。此外,它还应实现两个方法 numpy__array__ ,这两个方法将从底层数据生成等效的 NumPy 数组。

备注

当与初始化器、常量值和张量属性交互时,最好假设使用 ir.TensorProtocol,只有在需要检查具体类时才使用 isinstance()

张量类#

ir.TensorProtoTensor#

使用 ir.TensorProtoTensor 作为对 proto 的包装以实现 ir.TensorProtocol 接口。您可以像往常一样访问 shapedtype 等。只有在调用 numpy() 时才会产生副本。

直接初始化 ir.TensorProtoTensor,如下所示,是可能的。然而,通常建议使用 ir.serde.deserialize_tensor,因为它可以处理所有类型的 TensorProto (例如,ir.TensorProtoTensor 不处理外部张量)。

import onnx
from onnxscript import ir

tensor_proto = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT16, (3,), [1, 2, 3])
tensor = ir.TensorProtoTensor(tensor_proto)
print("tensor: ", tensor)  # TensorProtoTensor<INT16,[3]>(name='tensor')
print("shape: ", tensor.shape)  # ir.Shape([3])
print("dtype: ", tensor.dtype)  # ir.DataType.INT16
print(tensor.raw == tensor_proto)  # The raw field is the exact tensor_proto provided at initialization
print("tobytes: ", tensor.tobytes())  # b'\x01\x00\x02\x00\x03\x00'
print("numpy: ", tensor.numpy())  # array([1, 2, 3], dtype=int16)
tensor:  TensorProtoTensor<INT16,[3]>(name='tensor')
shape:  [3]
dtype:  INT16
True
tobytes:  b'\x01\x00\x02\x00\x03\x00'
numpy:  [1 2 3]

ir.ExternalTensor#

存储在外部磁盘上的张量数据通常很大,加载时将占用内存。ir.ExternalTensor 类使用内存映射来避免将张量加载到内存中。您可以使用张量作为普通的 NumPy 数组,内存使用量最小。

请参阅 ir.serde.deserialize_tensor 以找到将 onnx.TensorProto 转换为 ir.ExternalTensor 的示例。

ir.Tensor#

ir.Tensor 是围绕 NumPy 数组兼容的数组对象(如 np.ndarraytorch.Tensor)的包装器。它最适合创建内存中的张量,而不将其转换为 TensorProto,以减少转换开销。

备注

如果一个数组对象定义了 __array__ 方法,则它是兼容的。

从数组创建张量,只需用 NumPy 数组初始化即可

import numpy as np
tensor = ir.Tensor(np.random.rand(1, 2))
tensor
Tensor<DOUBLE,[1,2]>(array([[0.71395225, 0.48701339]]), name=None)

初始化器将从数组中获取数据类型和形状信息。

如果要从 NumPy 数组以外的对象创建张量,您需要指定数据类型:

import torch
from onnxscript import ir

torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16)
tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16)
print(tensor.numpy())  # array([1., 2., 3.], dtype=float16)
[1. 2. 3.]

字符串张量#

使用 ir.StringTensor 创建字符串张量。

回退 TensorProto#

在以下场景中,展示了如何从 TensorProto 转换到 ir.Tensor,运行一些计算,然后将其转换回 ir.Tensor,最后 TensorProto

Hide code cell content
from onnxscript import ir
import onnx
import numpy as np

# 1. Create the TensorProto
proto = onnx.helper.make_tensor(
    "tensor", onnx.TensorProto.FLOAT16, [2, 3], [1, 2, 3, 4, 5, 6]
)

# 2. Create an IR Tensor from the Protobuf message
tensor = ir.serde.deserialize_tensor(proto)
# Note that we get a TensorProtoTensor that implements the TensorProtocol
print("tensor:", tensor)  # TensorProtoTensor<FLOAT16,[2,3]>(name='tensor')
print("tensor.numpy():", tensor.numpy())   # [[1. 2. 3.]
                                           #  [4. 5. 6.]]
print("tensor.tobytes():", tensor.tobytes())  # b'\x00<\x00@\x00B\x00D\x00E\x00F'

# 3. Do computation using numpy
mean = tensor.numpy().mean(axis=0)
print("mean:", mean)  # array([2.5, 3.5, 4.5], dtype=float16)

# 4. Create a Tensor from the ndarray. Note that we use ir.Tensor
tensor_mean = ir.Tensor(mean)
print("tensor_mean:", tensor_mean)  # Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name='')

# 5. Obtain the TensorProto from ir.Tensor
mean_tensor_proto: onnx.TensorProto = ir.serde.serialize_tensor(tensor_mean)
print("mean_tensor_proto:", mean_tensor_proto)
print(
    "onnx.numpy_helper.to_array(mean_tensor_proto):",
    onnx.numpy_helper.to_array(mean_tensor_proto)
    # array([2.5, 3.5, 4.5], dtype=float16)
)

# You can obtain the bytes data as well
print("tensor_mean.tobytes():", tensor_mean.tobytes())
print("Bytes same as proto:", mean_tensor_proto.raw_data == tensor_mean.tobytes())

# Explore other methods defined by TensorProtocol:
print("\n# Explore other methods defined by TensorProtocol:")
print("tensor_mean.shape:", tensor_mean.shape)
print("tensor_mean.dtype:", tensor_mean.dtype)
print("tensor_mean.name:", tensor_mean.name)
print("tensor_mean.doc_string:", tensor_mean.doc_string)
print("tensor_mean.raw:", tensor_mean.raw)
print("tensor_mean.metadata_props:", tensor_mean.metadata_props)
print("tensor_mean.size:", tensor_mean.size)
print("tensor_mean.nbytes:", tensor_mean.nbytes)
print("tensor_mean.raw:", tensor_mean.raw)
tensor: TensorProtoTensor<FLOAT16,[2,3]>(name='tensor')
tensor.numpy(): [[1. 2. 3.]
 [4. 5. 6.]]
tensor.tobytes(): b'\x00<\x00@\x00B\x00D\x00E\x00F'
mean: [2.5 3.5 4.5]
tensor_mean: Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name=None)
mean_tensor_proto: dims: 3
data_type: 10
raw_data: "\000A\000C\200D"

onnx.numpy_helper.to_array(mean_tensor_proto): [2.5 3.5 4.5]
tensor_mean.tobytes(): b'\x00A\x00C\x80D'
Bytes same as proto: True

# Explore other methods defined by TensorProtocol:
tensor_mean.shape: [3]
tensor_mean.dtype: FLOAT16
tensor_mean.name: None
tensor_mean.doc_string: None
tensor_mean.raw: [2.5 3.5 4.5]
tensor_mean.metadata_props: {}
tensor_mean.size: 3
tensor_mean.nbytes: 6
tensor_mean.raw: [2.5 3.5 4.5]

处理非原生 NumPy 数据类型:bfloat16float8int4#

ir.Tensor.numpy() 生成张量值的 NumPy 数组表示。当张量的数据类型为 BFLOAT16FLOAT8[...][U]INT4 时,这些类型不支持 NumPy,将使用 ml_dtypes 包中的数据类型。

uint4 / int4 总是解包;tobyte() 生成预期的打包表示。

ir.Tensor 的初始化需要 NumPy 数组遵循以下类型约束,或者具有 ml_dtypes 数据类型。

  • int8 用于(未打包的) int4,符号位扩展到 8 位。

  • uint8 用于(解包的)uint4

  • uint8 用于 8 位数据类型,如 float8

  • uint16 用于 bfloat16

以下示例展示了如何创建 FLOAT8E4M3FN 张量,变换其值,并创建新的张量来存储变换后的值。

from onnxscript import ir
import numpy as np

array = np.array([0b1, 0b11], dtype=np.uint8)
# The array is reinterpreted using the ml_dtypes package
tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)
print(tensor)  # Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
print("tensor.numpy():", tensor.numpy())  # [0.00195312 0.00585938]

# Compute
times_100 = tensor.numpy() * 100
print("times_100:", times_100)

# Create a new tensor out of the new value; dtype must be specified
new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)
# You can also directly create the tensor from the float8 array without specifying dtype
# new_tensor = ir.Tensor(times_100)
print("new_tensor:", new_tensor)  # Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
print("new_tensor == times_100", new_tensor.numpy() == times_100)  # array([ True,  True])
Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)
tensor.numpy(): [0.00195312 0.00585938]
times_100: [0.1875 0.5625]
new_tensor: Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)
new_tensor == times_100 [ True  True]

高级用法#

子类化 ir.Tensor以实现更高效的访问和更广泛的支持 dtype#

ir.Tensor 内部将任何与数组兼容的对象转换为 NumPy 数组,以生成 tobytes() 中的字节表示。由于额外的转换,这可能会降低效率。它还限制了对于 NumPy 不支持的数据类型(如 bfloat16)的支持,因为 __array__ 方法将失败。

为了完全支持来自其他框架的数组,通常创建专门的类来处理它们是个好主意。下面的 TorchTensor 类演示了您如何通过子类化 ir.Tensor 来处理 PyTorch 张量:

Hide code cell content
import ctypes
from typing import Any

import torch
from onnxscript import ir

# Define utilities to convert PyTorch data types so users do not need to specify manually
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
    torch.bfloat16: ir.DataType.BFLOAT16,
    torch.bool: ir.DataType.BOOL,
    torch.complex128: ir.DataType.COMPLEX128,
    torch.complex64: ir.DataType.COMPLEX64,
    torch.float16: ir.DataType.FLOAT16,
    torch.float32: ir.DataType.FLOAT,
    torch.float64: ir.DataType.DOUBLE,
    torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
    torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
    torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
    torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
    torch.int16: ir.DataType.INT16,
    torch.int32: ir.DataType.INT32,
    torch.int64: ir.DataType.INT64,
    torch.int8: ir.DataType.INT8,
    torch.uint8: ir.DataType.UINT8,
}


def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
    return _TORCH_DTYPE_TO_ONNX[dtype]

class TorchTensor(ir.Tensor):
    def __init__(self, tensor: torch.Tensor):
        # Pass the tensor as the raw data to ir.Tensor's constructor
        super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))

    def __array__(self, dtype: Any = None) -> "np.ndarray":
        # numpy() calls __array__ in ir.Tensor
        if self.dtype == ir.DataType.BFLOAT16:
            return self.raw.view(torch.uint16).__array__(dtype)
        if self.dtype in {
            ir.DataType.FLOAT8E4M3FN,
            ir.DataType.FLOAT8E4M3FNUZ,
            ir.DataType.FLOAT8E5M2,
            ir.DataType.FLOAT8E5M2FNUZ
        }:
            return self.raw.view(torch.uint8).__array__(dtype)
        return self.raw.__array__(dtype)

    def tobytes(self) -> bytes:
        # Implement tobytes to support native PyTorch types so we can use types like bloat16
        # Reading from memory directly is also more efficient because
        # it avoids copying to a NumPy array
        tensor = self.raw.detach().cpu().contiguous()
        return bytes(
            (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
                tensor.data_ptr()
            )
        )

# Test the implementation
torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16)
tensor = TorchTensor(torch_tensor)
print("tensor: ", tensor)
print("numpy: ", tensor.numpy())
print("tobytes: ", tensor.tobytes())  # b'\x80?\x00@@@'
print("nbytes: ", tensor.nbytes)  # 6
tensor:  TorchTensor<BFLOAT16,[3]>(tensor([1., 2., 3.], dtype=torch.bfloat16), name=None)
numpy:  [16256 16384 16448]
tobytes:  b'\x80?\x00@@@'
nbytes:  6

该类实现了 tobytes() ,以生成张量在序列化为 ONNX 文件/TensorProto 时的正确字节表示。该类还实现了 __array__() 方法,以返回 NumPy 不支持的数据类型的位表示。这样,分析阶段仍然可以对这些值进行计算。

使用不同框架进行计算#

由于 ir.Tensor 实现了 __array__ 方法和 __dlpack__ 方法,其内容可以在不复制的情况下与计算框架共享。例如:

Hide code cell content
from onnxscript import ir

# We can call numpy methods directly on ir.Tensor
import numpy as np
print(np.multiply(ir.Tensor(np.array([1, 2])), 42))  # array([42., 84.])

# We can transfer arrays to different frameworks
import jax.numpy as jnp
import jax
import torch

# Create ir.Tensor
jax_array = jnp.array([10., 20.])
ir_tensor_jax = ir.Tensor(jax_array, dtype=ir.DataType.FLOAT)
torch_tensor = torch.tensor([30., 40.])
ir_tensor_torch = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)

# Use numpy for computation
print(np.multiply(ir_tensor_jax, ir_tensor_torch))  # array([300., 800.], dtype=float32)

# Use jax for computation by calling from_dlpack to transfer the tensor data without copying when the device is the same
jax_array_from_ir = jax.dlpack.from_dlpack(ir_tensor_torch)
print(jax_array_from_ir + jax_array)  # [40. 60.]

# Use PyTorch for computation
torch_tensor_from_ir = torch.from_dlpack(ir_tensor_jax)
print(torch_tensor_from_ir - torch_tensor)  # tensor([-20., -20.])

# They can all be serialized into TensorProto
proto = ir.serde.serialize_tensor(ir_tensor_jax)
print(type(proto))  # <class 'onnx.onnx_ml_pb2.TensorProto'>
print(proto)

# The value is exactly the same as jax_array
print(ir.serde.deserialize_tensor(proto).numpy())  # [10. 20.]
[42. 84.]
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[300. 800.]
[40. 60.]
tensor([-20., -20.])
<class 'onnx.onnx_ml_pb2.TensorProto'>
dims: 2
data_type: 1
raw_data: "\000\000 A\000\000\240A"

[10. 20.]

这在您在图上创建需要计算具体值的传递时特别有用。您可以使用您喜欢的框架来创建传递。包含新创建的 ir.Tensor 的转换图将与下游传递兼容,即使它们利用其他计算框架也是如此。