Set5#

Set5 数据来源:mmagic/Set5

Set5数据集由 Marco Bevilacqua 等人的 基于非负邻域嵌入的低复杂度单图像超分辨率 引入的,它包含5张图像(“婴儿”、“鸟”、“蝴蝶”、“头部”、“女人”)的数据集,通常用于测试图像超分辨率模型的性能。

为了更方便地使用,本项目将其放置到 tests/data/Set5

%matplotlib inline
import matplotlib.pyplot as plt
# 关闭交互模式
plt.ioff()
from set_env import root_dir

data_dir = root_dir/"tests/data/Set5"
项目根目录:/media/pc/data/lxw/ai/torch-book
for p in (data_dir/"GTmod12").iterdir():
    print(p.name)
baby.png
head.png
bird.png
butterfly.png
woman.png

可视化:

from PIL import Image
from torch_book.data.cv.plot import CompareGridFrame

imgs_names = sorted((data_dir/"LRbicx4").iterdir())
gt_names = sorted((data_dir/"GTmod12").iterdir())
grid = CompareGridFrame(1, 5, layout="row")
imgs = [Image.open(p) for p in imgs_names]
gts = [Image.open(p) for p in gt_names]
grid(imgs, gts)
grid.figure
../../_images/9243fe5c388bc23688c32d61c4e0575c02aaf10d9db4cfabadd5d2c7a7d5e61b.png

构建数据集#

%%file utils/Set5.py
from dataclasses import dataclass
from typing import Callable
from pathlib import Path
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision import tv_tensors

@dataclass
class PairedDataset(Dataset):
    """成对图片数据集"""
    scale: int # 放大倍数, 2, 3, 4
    HR_path: str | Path # HR zip 数据路径
    LR_path: str | Path  # LR zip 数据路径
    transform: Callable | None = None

    def __post_init__(self):
        self.HR_path = Path(self.HR_path)
        self.LR_path = Path(self.LR_path)
        self.lr_names = sorted(self.LR_path.iterdir())
        self.hr_names = sorted(self.HR_path.iterdir())
        self._check()
    
    def _check(self):
        """检查图片对是否匹配"""
        assert len(self.lr_names) == len(self.hr_names)
        for a, b in zip(self.lr_names, self.hr_names):
            assert Path(a).name == Path(b).name, f"文件名 {a}{b} 不匹配"

    def __len__(self) -> int:
        """返回数据集长度"""
        return len(self.lr_names)

    def __getitem__(self, index: int) -> list[tv_tensors.Image, tv_tensors.Image]:
        """加载(LR, HR)图片对

        Args:
            index: 图片的索引
        Returns:
            buffer: 图片的二进制内容
        """
        with Image.open(self.lr_names[index]) as im:
            lr = tv_tensors.Image(im)
        
        with Image.open(self.hr_names[index]) as im:
            hr = tv_tensors.Image(im)
        if self.transform is not None:
            lr, hr = self.transform(lr, hr)
        return lr, hr
Overwriting utils/Set5.py
import torch
from torch.utils.data.dataloader import DataLoader
from utils.Set5 import PairedDataset
scale = 4 # 放大倍数, 2, 3, 4
dataset = PairedDataset(
    scale,
    data_dir/"GTmod12",
    data_dir/f"LRbicx{scale}",
    transform=None
)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

with torch.no_grad():
    for lr, hr in data_loader:
        break
    canvas = CompareGridFrame(1, 1, 2.5)
    left_axes, right_axes = canvas(lr.permute(0, 2, 3, 1), hr.permute(0, 2, 3, 1))
    canvas.figure.suptitle("LQ")
    plt.close()
canvas.figure
../../_images/53ca18f49dc3ae3aee05f810ffdd965f5946a6ce6536c88145707af80f05150f.png

裁剪图像的边缘像素#

CropBorder 从图像的边缘裁剪掉指定数量的像素,这些裁剪掉的像素不参与 PSNR 的计算。

from torchvision.transforms import v2
from torch_book.data.cv.plot import GridFrame
from torch_book.metrics.utils import CropBorder

data_dir = root_dir/"tests/data/Set5"  # 数据根目录
scale = 4 # 放大倍数, 2, 3, 4
transform = v2.Compose([
    CropBorder(scale),  # 裁剪边界
    v2.ToDtype(torch.float32, scale=True), # 转换数据类型为 float32,并归一化到 [0, 1]

])
dataset = PairedDataset(
    scale,
    data_dir/"GTmod12",
    data_dir/f"LRbicx{scale}",
    transform=transform
)
with torch.no_grad():
    for i, (lr, gt) in enumerate(data_loader):
        if i == 4:
            break
    print(lr.shape, gt.shape,)
    results = [v2.ToPILImage()(x[0]) for x in [lr, gt,]]

    canvas =  GridFrame(1, 2, 2.5)
    gs, axes = canvas(results)
    [ax.set_title(title) for ax, title in zip(axes, ["LR", "GT"])]
    plt.close()
canvas.figure
torch.Size([1, 3, 84, 57]) torch.Size([1, 3, 336, 228])
../../_images/ed1772c4f32075b4b8dd2d8b1d85b575eb5bb938d20dbed17143eba2c492f46b.png