DLPack 协议#

DLPack 协议是一种稳定的内存数据结构,允许在处理多维数组或张量的主要框架之间进行数据交换。它旨在支持跨硬件操作,意味着它允许在 CPU 之外的设备(例如 GPU)上交换数据。

DLPack 协议已被 Python 数据 API 标准联盟选为 Python 数组 API 标准,以便在 Python 生态系统中的数组/张量库之间启用设备感知的数据交换。有关该标准的更多信息,请参阅协议文档;有关 DLPack 的更多信息,请参阅 Python DLPack 规范

DLPack 在 PyArrow 中的实现#

DLPack 协议的生产端为 pa.Array 实现了,并且可以用于在 PyArrow 和其他张量库之间交换数据。支持的数据类型包括整数、无符号整数和浮点数。该协议不支持缺失数据,这意味着具有缺失值的 PyArrow 数组不能通过 DLPack 协议传输。目前,Arrow 对该协议的实现仅支持 CPU 设备上的数据。

协议的数据交换语法包括

  • from_dlpack(x):消费实现了 __dlpack__ 方法的数组对象,并在共享内存的同时创建新数组。

  • __dlpack__(self, stream=None)__dlpack_device__:生成一个包含 DLPack 结构的 PyCapsule,该结构由 from_dlpack(x) 内部调用。

PyArrow 实现了协议的第二部分(__dlpack__(self, stream=None)__dlpack_device__),因此可以被实现了 from_dlpack 的库所消费。

示例#

将 PyArrow CPU 数组转换为 NumPy 数组:

import pyarrow as pa
array = pa.array([2, 0, 2, 4])

import numpy as np
np.from_dlpack(array)
array([2, 0, 2, 4])

将 PyArrow CPU 数组转换为 PyTorch 张量:

import torch
torch.from_dlpack(array)
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 1
----> 1 import torch
      2 torch.from_dlpack(array)

ModuleNotFoundError: No module named 'torch'