转换和增强图像#
Torchvision 在 torchvision.transforms
和 torchvision.transforms.v2
模块中支持常见的计算机视觉变换。这些变换可以用于不同任务(图像分类、检测、分割、视频分类)的训练或推理过程中对数据进行变换或增强。
# Image Classification
import torch
from torchvision.transforms import v2
H, W = 32, 32
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = transforms(img)
# Detection (re-using imports and transforms from above)
from torchvision import tv_tensors
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
boxes = torch.randint(0, H // 2, size=(3, 4))
boxes[:, 2:] += boxes[:, :2]
boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
# The same transforms can be used!
img, boxes = transforms(img, boxes)
# And you can pass arbitrary input structures
output_dict = transforms({"image": img, "boxes": boxes})
支持的输入类型和约定#
大多数转换都接受 PIL 图像和张量输入。CPU 和 CUDA 张量都受支持。两种后端(PIL 或张量)的结果应该非常接近。一般来说,建议依赖张量后端以提高性能。
张量图像应具有形状 (C, H, W)
,其中 C
是通道数,H
和 W
分别表示高度和宽度。大多数转换都支持批量张量输入。一批张量图像是形状为 (N, C, H, W)
的张量,其中 N
是批次中的图像数量。v2
转换通常接受任意数量的前导维度 (..., C, H, W)
,并可以处理批量图像或批量视频。
数据类型和预期值范围#
张量图像的值的预期范围由张量数据类型隐式定义。具有浮点数据类型的张量图像的值应为 [0, 1]
。具有整数数据类型的张量图像的值应在 [0, MAX_DTYPE]
范围内,其中 MAX_DTYPE
是该数据类型可以表示的最大值。通常,dtype
为 torch.uint8
的图像的值应在 [0, 255]
范围内。
使用 ToDtype
将输入的数据类型和范围进行变换。
性能考虑#
为了获得最佳的变换性能,我们建议遵循以下准则:
依赖于
torchvision.transforms.v2
中的v2
转换。使用张量而不是 PIL 图像。
特别是对于调整大小操作,请使用
torch.uint8
数据类型。使用双线性或双三次插值模式进行大小调整。
典型的转换管道可能的样子:
from torchvision.transforms import v2
transforms = v2.Compose([
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
# ...
v2.RandomResizedCrop(size=(224, 224), antialias=True), # Or Resize(antialias=True)
# ...
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
Torchscript 支持#
大多数转换类和功能都支持 torchscript。对于组合转换,请使用 torch.nn.Sequential
而不是 Compose
:
transforms = torch.nn.Sequential(
v2.CenterCrop(10),
v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)