torch_book.data.cv.grid#

Classes#

GridConfig

网格的配置

Grid

Base class for all TVTensors.

PairedGrid

将 LR 和 HR 图像裁剪成 Grid 数据对

FlattenPairedGrid

将 LR 和 HR 图像裁剪成 Grid 数据对展平

PairedRandomCrop

一种用于图像数据增强的技术,通常用于生成图像对(例如高分辨率图像和低分辨率图像)的训练数据。

Module Contents#

class torch_book.data.cv.grid.GridConfig[源代码]#

网格的配置

crop_size: int = 480#
step: int = 240#
thresh_size: int = 0#
make_space(h, w) list[int, int][源代码]#
__rshift__(scale: int)[源代码]#

将网格配置按 scale 因子缩小

class torch_book.data.cv.grid.Grid(*args: Any, device: torch._prims_common.DeviceLikeType | None = None)[源代码]#
class torch_book.data.cv.grid.Grid(storage: torch.types.Storage)
class torch_book.data.cv.grid.Grid(other: torch.Tensor)
class torch_book.data.cv.grid.Grid(size: torch.types._size, *, device: torch._prims_common.DeviceLikeType | None = None)

Bases: torchvision.tv_tensors.TVTensor

Base class for all TVTensors.

You probably don't want to use this class unless you're defining your own custom TVTensors. See sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py for details.

__len__() int[源代码]#

返回 Grid 中包含的图像块数量

randmeshgrid(indexing: str | None = 'ij') list[torch.Tensor, torch.Tensor][源代码]#

返回随机排列的索引,用于打乱 Grid 中图像块的顺序

shuffle(indexes: list[torch.Tensor, torch.Tensor] | None = None)[源代码]#

随机打乱 Grid 中图像块的顺序

flatten() Grid[源代码]#

将 Grid 数据展平为 (h*w, C, H, W) 形状的 Tensor

unflatten() Grid[源代码]#

将展平的 Grid 数据恢复为 (h, w, num_cols, C, H, W) 形状

class torch_book.data.cv.grid.PairedGrid(scale: int, config: GridConfig, *args, **kwargs)[源代码]#

Bases: torch.nn.Module

将 LR 和 HR 图像裁剪成 Grid 数据对

scale#
config#
forward(lr: torch.Tensor, hr: torch.Tensor)[源代码]#
class torch_book.data.cv.grid.FlattenPairedGrid(*args, **kwargs)[源代码]#

Bases: torch.nn.Module

将 LR 和 HR 图像裁剪成 Grid 数据对展平

forward(lr: Grid, hr: Grid) torch.Tensor[源代码]#
class torch_book.data.cv.grid.PairedRandomCrop(scale: int, gt_patch_size: int, *args, **kwargs)[源代码]#

Bases: torch.nn.Module

一种用于图像数据增强的技术,通常用于生成图像对(例如高分辨率图像和低分辨率图像)的训练数据。

主要目的是确保在数据增强过程中,高分辨率图像和低分辨率图像的裁剪区域保持一致,从而保证训练数据的配对关系。

gt_patch_size#
scale#
lq_patch_size#
forward(lr: torch.Tensor, hr: torch.Tensor)[源代码]#