torch_book.data.cv.grid 源代码

from typing import Any
from dataclasses import dataclass
import numpy as np
import torch
from torchvision import tv_tensors

[文档] @dataclass class GridConfig: """网格的配置""" crop_size: int = 480 # HR 的裁剪尺寸,这个尺寸通常是预先设定的 step: int = 240 # 指在 HR 图像进行某种处理时,每次移动或采样的步长 thresh_size: int = 0 # HR 图像处理中,用于判断或筛选某些特征或对象的尺寸阈值
[文档] def make_space(self, h, w)->list[int, int]: h_space = np.arange(0, h - self.crop_size + 1, self.step) if h - (h_space[-1] + self.crop_size) > self.thresh_size: h_space = np.append(h_space, h - self.crop_size) w_space = np.arange(0, w - self.crop_size + 1, self.step) if w - (w_space[-1] + self.crop_size) > self.thresh_size: w_space = np.append(w_space, w - self.crop_size) return h_space, w_space
[文档] def __rshift__(self, scale: int): """将网格配置按 `scale` 因子缩小""" return GridConfig( crop_size = self.crop_size//scale, # LR 的裁剪尺寸,这个尺寸通常是预先设定的 step = self.step//scale, # 指在 LR 图像进行某种处理时,每次移动或采样的步长 thresh_size = self.thresh_size//scale, # LR 图像处理中,用于判断或筛选某些特征或对象的尺寸阈值 )
[文档] class Grid(tv_tensors.TVTensor): def __new__( cls, data: Any, *, config: GridConfig = GridConfig(), ) -> "Grid": C, H, W = data.shape h_space, w_space = config.make_space(H, W) height = len(h_space) width = len(w_space) shape = (height, width, C, config.crop_size, config.crop_size) grid = torch.full( shape, fill_value=0, dtype=data.dtype, device=data.device, requires_grad=data.requires_grad ) for row, y in enumerate(h_space): for col, x in enumerate(w_space): grid[row, col] = data[:, y:y + config.crop_size, x:x + config.crop_size] grid = grid.as_subclass(cls) cls.height = height cls.width = width return grid
[文档] def __len__(self) -> int: """返回 Grid 中包含的图像块数量""" return self.height * self.width
[文档] def randmeshgrid(self, indexing: str | None = "ij") -> list[torch.Tensor, torch.Tensor]: """返回随机排列的索引,用于打乱 Grid 中图像块的顺序""" indexes = torch.meshgrid( torch.randperm(self.height), torch.randperm(self.width), indexing=indexing, ) return indexes
[文档] def shuffle(self, indexes: list[torch.Tensor, torch.Tensor] | None = None): """随机打乱 Grid 中图像块的顺序""" if indexes is None: indexes = self.randmeshgrid() data = self[indexes] return tv_tensors.wrap(data, like=self)
[文档] def flatten(self) -> "Grid": """将 Grid 数据展平为 (h*w, C, H, W) 形状的 Tensor""" data = self.reshape(self.height*self.width, *self.shape[2:]) return tv_tensors.wrap(data, like=self)
[文档] def unflatten(self) -> "Grid": """将展平的 Grid 数据恢复为 (h, w, num_cols, C, H, W) 形状""" data = self.reshape(self.height, self.width, *self.shape[1:]) return tv_tensors.wrap(data, like=self)
[文档] class PairedGrid(torch.nn.Module): """将 LR 和 HR 图像裁剪成 Grid 数据对""" def __init__(self, scale: int, config: GridConfig, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.scale = scale self.config = config
[文档] def forward(self, lr: torch.Tensor, hr: torch.Tensor): hr_gird = Grid(hr, config=self.config) lr_gird = Grid(lr, config=self.config>>self.scale) return lr_gird, hr_gird
[文档] class FlattenPairedGrid(torch.nn.Module): """将 LR 和 HR 图像裁剪成 Grid 数据对展平"""
[文档] def forward(self, lr: Grid, hr: Grid) -> torch.Tensor: return lr.flatten(), hr.flatten()
[文档] class PairedRandomCrop(torch.nn.Module): """一种用于图像数据增强的技术,通常用于生成图像对(例如高分辨率图像和低分辨率图像)的训练数据。 主要目的是确保在数据增强过程中,高分辨率图像和低分辨率图像的裁剪区域保持一致,从而保证训练数据的配对关系。 """ def __init__(self, scale: int, gt_patch_size: int, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.gt_patch_size = gt_patch_size self.scale = scale self.lq_patch_size = self.gt_patch_size // self.scale
[文档] def forward(self, lr: torch.Tensor, hr: torch.Tensor): h_lq, w_lq = lr.shape[-2:] h_gt, w_gt = hr.shape[-2:] if h_gt != h_lq * self.scale or w_gt != w_lq * self.scale: raise ValueError( f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {self.scale}x ' f'multiplication of LQ ({h_lq}, {w_lq}).') if h_lq < self.lq_patch_size or w_lq < self.lq_patch_size: raise ValueError( f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' f'({self.lq_patch_size}, {self.lq_patch_size}). Please check it.') # 随机选择图像块的左上角坐标 top = np.random.randint(h_lq - self.lq_patch_size + 1) left = np.random.randint(w_lq - self.lq_patch_size + 1) lq = lr[..., top:top + self.lq_patch_size, left:left + self.lq_patch_size] # 裁剪对应的 GT(Ground Truth)块。 top_gt, left_gt = int(top * self.scale), int(left * self.scale) gt = hr[..., top_gt:top_gt + self.gt_patch_size, left_gt:left_gt + self.gt_patch_size,] return tv_tensors.wrap(lq, like=lr), tv_tensors.wrap(gt, like=hr)