TVTensors 常见问题解答#

TVTensors 是与 torchvision.transforms.v2 同时引入的张量子类。这个例子展示了这些 TVTensors 是什么以及它们的行为。

import PIL.Image

import torch
from torchvision import tv_tensors

TVTensors 是什么?#

TVTensors 是零拷贝张量子类:

tensor = torch.rand(3, 256, 256)
image = tv_tensors.Image(tensor)

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()

在幕后,它们被用于 torchvision.transforms.v2 中,以便正确地将输入数据分配给适当的函数。

torchvision.tv_tensors 支持四种类型的 TVTensors:

  • Image

  • Video

  • BoundingBoxes

  • Mask

使用 TVTensor 做什么?#

TVTensors 的外观和感觉就像普通的张量——它们就是张量。所有在普通 torch.Tensor 上支持的操作,比如 .sum() 或任何 torch.* 算子,也同样适用于 TVTensors。

如何构建 TVTensor?#

使用构造函数#

每个 TVTensor 类都可以接受任何类似张量的数据,这些数据可以被转换成 Tensor

image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
print(image)
Image([[[[0, 1],
         [1, 0]]]], )

与PyTorch中的其他创建操作类似,该构造函数也接受dtypedevicerequires_grad参数。

float_image = tv_tensors.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)
Image([[[0., 1.],
        [1., 0.]]], grad_fn=<AliasBackward0>, )

此外,ImageMask 类也可以直接接受 PIL.Image.Image 类型的对象:

image = tv_tensors.Image(PIL.Image.open("../images/astronaut.jpg"))
print(image.shape, image.dtype)
torch.Size([3, 512, 512]) torch.uint8

某些TVTensors在构建时需要传递额外的元数据。例如,BoundingBoxes 类不仅需要实际的数值,还需要坐标格式以及对应图像的大小(canvas_size)。这些元数据对于正确转换边界框是必不可少的。

bboxes = tv_tensors.BoundingBoxes(
    [[17, 16, 344, 495], [0, 10, 0, 10]],
    format=tv_tensors.BoundingBoxFormat.XYXY,
    canvas_size=image.shape[-2:]
)
print(bboxes)
BoundingBoxes([[ 17,  16, 344, 495],
               [  0,  10,   0,  10]], format=BoundingBoxFormat.XYXY, canvas_size=torch.Size([512, 512]))

您还可以使用 torchvision.tv_tensors.wrap() 函数将张量对象封装为 TVTensor。当您已经拥有所需类型的对象时,这非常有用,这种情况通常发生在编写转换时:您只需像处理输入一样处理输出即可。

new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size

new_bboxes 的元数据与 bboxes 相同,但您可以将其作为参数传递以覆盖它。

有一个 TVTensor,但现在得到了 Tensor。怎么办!#

默认情况下,对 TVTensor 对象进行的操作会返回纯 Tensor:

assert isinstance(bboxes, tv_tensors.BoundingBoxes)

# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3

assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)

请注意,此行为仅影响原生的 torch 运算。如果您使用的是内置的torchvision转换或函数,您将始终得到与输入相同类型的输出(纯TensorTVTensor)。

但我想要 TVTensor!#

您可以简单地调用TVTensor构造函数,或将纯张量重新包装为TVTensor,或者使用:wrap()函数:

new_bboxes = bboxes + 3
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)

或者,你可以使用 set_return_type() 作为整个程序的全局配置设置,或者作为上下文管理器:

with tv_tensors.set_return_type("TVTensor"):
    new_bboxes = bboxes + 3
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)

为什么会发生这种情况?#

出于性能的考虑TVTensor类是Tensor的子类,因此任何涉及 TVTensor对象的操作都将经过__torch_function__协议。这会导致一些额外的开销,我们希望尽可能地避免这种开销。 对于内置的“torchvision”转换来说,这不是问题,因为我们可以避免那里的开销,但在你的模型的“forward”过程中,这可能是个问题。

**另一种选择也不见得更好。**对于每个保留 TVTensor类型有意义的操作,都有同样多的操作更适合返回纯Tensor:例如,img.sum()仍然是 Image吗?如果我们一直保留 TVTensor类型,那么即使是模型的逻辑值或损失函数的输出也会变成 Image类型,而这显然是不可取的。

例外情况#

有几个例外情况适用于这个“解包”规则: clone(), to(), torch.Tensor.detach(), 和 requires_grad_() 保持TVTensor类型。

在TVTensor上的原地操作(如obj.add_())将保留“obj”的类型。然而,原地操作的返回值将是纯张量:

image = tv_tensors.Image([[[0, 1], [1, 0]]])

new_image = image.add_(1).mul_(2)

# image got transformed in-place and is still a TVTensor Image, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, tv_tensors.Image)
print(image)

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, tv_tensors.Image)
assert (new_image == image).all()
assert new_image.data_ptr() == image.data_ptr()
Image([[[2, 4],
        [4, 2]]], )