Set5#
Set5数据集由 Marco Bevilacqua 等人的 基于非负邻域嵌入的低复杂度单图像超分辨率 引入的,它包含5张图像(“婴儿”、“鸟”、“蝴蝶”、“头部”、“女人”)的数据集,通常用于测试图像超分辨率模型的性能。
为了更方便地使用,本项目将其放置到 tests/data/Set5
:
%matplotlib inline
import matplotlib.pyplot as plt
# 关闭交互模式
plt.ioff()
from set_env import root_dir
data_dir = root_dir/"tests/data/Set5"
项目根目录:/media/pc/data/lxw/ai/torch-book
for p in (data_dir/"GTmod12").iterdir():
print(p.name)
baby.png
head.png
bird.png
butterfly.png
woman.png
可视化:
from PIL import Image
from torch_book.data.cv.plot import CompareGridFrame
imgs_names = sorted((data_dir/"LRbicx4").iterdir())
gt_names = sorted((data_dir/"GTmod12").iterdir())
grid = CompareGridFrame(1, 5, layout="row")
imgs = [Image.open(p) for p in imgs_names]
gts = [Image.open(p) for p in gt_names]
grid(imgs, gts)
grid.figure
构建数据集#
%%file utils/Set5.py
from dataclasses import dataclass
from typing import Callable
from pathlib import Path
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision import tv_tensors
@dataclass
class PairedDataset(Dataset):
"""成对图片数据集"""
scale: int # 放大倍数, 2, 3, 4
HR_path: str | Path # HR zip 数据路径
LR_path: str | Path # LR zip 数据路径
transform: Callable | None = None
def __post_init__(self):
self.HR_path = Path(self.HR_path)
self.LR_path = Path(self.LR_path)
self.lr_names = sorted(self.LR_path.iterdir())
self.hr_names = sorted(self.HR_path.iterdir())
self._check()
def _check(self):
"""检查图片对是否匹配"""
assert len(self.lr_names) == len(self.hr_names)
for a, b in zip(self.lr_names, self.hr_names):
assert Path(a).name == Path(b).name, f"文件名 {a} 和 {b} 不匹配"
def __len__(self) -> int:
"""返回数据集长度"""
return len(self.lr_names)
def __getitem__(self, index: int) -> list[tv_tensors.Image, tv_tensors.Image]:
"""加载(LR, HR)图片对
Args:
index: 图片的索引
Returns:
buffer: 图片的二进制内容
"""
with Image.open(self.lr_names[index]) as im:
lr = tv_tensors.Image(im)
with Image.open(self.hr_names[index]) as im:
hr = tv_tensors.Image(im)
if self.transform is not None:
lr, hr = self.transform(lr, hr)
return lr, hr
Overwriting utils/Set5.py
import torch
from torch.utils.data.dataloader import DataLoader
from utils.Set5 import PairedDataset
scale = 4 # 放大倍数, 2, 3, 4
dataset = PairedDataset(
scale,
data_dir/"GTmod12",
data_dir/f"LRbicx{scale}",
transform=None
)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
with torch.no_grad():
for lr, hr in data_loader:
break
canvas = CompareGridFrame(1, 1, 2.5)
left_axes, right_axes = canvas(lr.permute(0, 2, 3, 1), hr.permute(0, 2, 3, 1))
canvas.figure.suptitle("LQ")
plt.close()
canvas.figure
裁剪图像的边缘像素#
CropBorder
从图像的边缘裁剪掉指定数量的像素,这些裁剪掉的像素不参与 PSNR 的计算。
from torchvision.transforms import v2
from torch_book.data.cv.plot import GridFrame
from torch_book.metrics.utils import CropBorder
data_dir = root_dir/"tests/data/Set5" # 数据根目录
scale = 4 # 放大倍数, 2, 3, 4
transform = v2.Compose([
CropBorder(scale), # 裁剪边界
v2.ToDtype(torch.float32, scale=True), # 转换数据类型为 float32,并归一化到 [0, 1]
])
dataset = PairedDataset(
scale,
data_dir/"GTmod12",
data_dir/f"LRbicx{scale}",
transform=transform
)
with torch.no_grad():
for i, (lr, gt) in enumerate(data_loader):
if i == 4:
break
print(lr.shape, gt.shape,)
results = [v2.ToPILImage()(x[0]) for x in [lr, gt,]]
canvas = GridFrame(1, 2, 2.5)
gs, axes = canvas(results)
[ax.set_title(title) for ax, title in zip(axes, ["LR", "GT"])]
plt.close()
canvas.figure
torch.Size([1, 3, 84, 57]) torch.Size([1, 3, 336, 228])