图片数据预处理

图片数据预处理#

from dataclasses import dataclass
from pathlib import Path
import numpy as np

@dataclass
class Preprocessing:
    width: int
    height: int
    channels: int
    mean: tuple[float] = (0,)
    std: tuple[float] = (1,)
    layout: str = "HWC"
    name: str = "data"
    format: str = "RGB"

    def __post_init__(self):
        if self.layout == "HWC":
            self.shape = self.height, self.width, self.channels
        elif self.layout == "CHW":
            self.shape = self.channels, self.height, self.width
        else:
            raise ValueError(f"Unknown layout: {self.layout}")

    def load(self, path: str | Path) -> np.ndarray:
        """加载图片"""
        img = Image.open(path).resize((self.width, self.height)) # uint8 数据
        if self.format == "GRAY":
            img = img.convert("L")
            img = np.expand_dims(img, axis=-1) # WH->HWC
        elif self.format == "RGB":
            img = np.array(img.convert("RGB")) # WHC->HWC
        elif self.format == "BGR":
            img = np.array(img.convert("RGB")) # WHC->HWC
            img = img[..., ::-1] # RGB 转 BGR
        else:
            raise TypeError(f'暂未支持数据布局 {self.format}')
        return img
    
    def __call__(self, path: str | Path) -> np.ndarray:
        img = self.load(path)/255.0 # 归一化(将 uint8 数据归一化到 [0, 1],这是神经网络的标准输入格式)
        img = (img - self.mean) / self.std # 标准化,使数据分布更接近标准正态分布
        img = img.astype("float32")
        if self.layout == "CHW":
            img = img.transpose(2, 0, 1) # HWC->CHW
        return img

    def torch_call(self, path: str | Path) -> "torch.Tensor":
        assert self.layout == "CHW", "torchvision 只支持 CHW 布局"
        from torchvision.transforms import v2
        import torch
        from torch import nn
        inp = self.load(path)
        return nn.Sequential(
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(self.mean, self.std)
        )(inp)
from pathlib import Path
from PIL import Image
root_dir = Path('../../images')
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
layout = "CHW"
preprocessing = Preprocessing(32, 32, 3, mean, std, layout)
# torch_inp = preprocessing.torch_call(root_dir/"Giant_Panda_in_Beijing_Zoo_1.jpg")
inp = preprocessing(root_dir/"Giant_Panda_in_Beijing_Zoo_1.jpg")
# np.testing.assert_almost_equal(inp, torch_inp.numpy(), decimal=6)
inp.shape
(3, 32, 32)