转换和增强图像#

Torchvision 在 torchvision.transformstorchvision.transforms.v2 模块中支持常见的计算机视觉变换。这些变换可以用于不同任务(图像分类、检测、分割、视频分类)的训练或推理过程中对数据进行变换或增强。

# Image Classification
import torch
from torchvision.transforms import v2

H, W = 32, 32
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = transforms(img)
# Detection (re-using imports and transforms from above)
from torchvision import tv_tensors

img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
boxes = torch.randint(0, H // 2, size=(3, 4))
boxes[:, 2:] += boxes[:, :2]
boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))

# The same transforms can be used!
img, boxes = transforms(img, boxes)
# And you can pass arbitrary input structures
output_dict = transforms({"image": img, "boxes": boxes})

支持的输入类型和约定#

大多数转换都接受 PIL 图像和张量输入。CPU 和 CUDA 张量都受支持。两种后端(PIL 或张量)的结果应该非常接近。一般来说,建议依赖张量后端以提高性能。

张量图像应具有形状 (C, H, W),其中 C 是通道数,HW 分别表示高度和宽度。大多数转换都支持批量张量输入。一批张量图像是形状为 (N, C, H, W) 的张量,其中 N 是批次中的图像数量。v2 转换通常接受任意数量的前导维度 (..., C, H, W),并可以处理批量图像或批量视频。

数据类型和预期值范围#

张量图像的值的预期范围由张量数据类型隐式定义。具有浮点数据类型的张量图像的值应为 [0, 1]。具有整数数据类型的张量图像的值应在 [0, MAX_DTYPE] 范围内,其中 MAX_DTYPE 是该数据类型可以表示的最大值。通常,dtypetorch.uint8 的图像的值应在 [0, 255] 范围内。

使用 ToDtype 将输入的数据类型和范围进行变换。

性能考虑#

为了获得最佳的变换性能,我们建议遵循以下准则:

  • 依赖于 torchvision.transforms.v2 中的 v2 转换。

  • 使用张量而不是 PIL 图像。

  • 特别是对于调整大小操作,请使用 torch.uint8 数据类型。

  • 使用双线性或双三次插值模式进行大小调整。

典型的转换管道可能的样子:

from torchvision.transforms import v2
transforms = v2.Compose([
    v2.ToImage(),  # Convert to tensor, only needed if you had a PIL image
    v2.ToDtype(torch.uint8, scale=True),  # optional, most input are already uint8 at this point
    # ...
    v2.RandomResizedCrop(size=(224, 224), antialias=True),  # Or Resize(antialias=True)
    # ...
    v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Torchscript 支持#

大多数转换类和功能都支持 torchscript。对于组合转换,请使用 torch.nn.Sequential 而不是 Compose

transforms = torch.nn.Sequential(
    v2.CenterCrop(10),
    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)