from dataclasses import dataclass
from pathlib import Path
import numpy as np
@dataclass
class Preprocessing:
width: int
height: int
channels: int
mean: tuple[float] = (0,)
std: tuple[float] = (1,)
layout: str = "HWC"
name: str = "data"
format: str = "RGB"
def __post_init__(self):
if self.layout == "HWC":
self.shape = self.height, self.width, self.channels
elif self.layout == "CHW":
self.shape = self.channels, self.height, self.width
else:
raise ValueError(f"Unknown layout: {self.layout}")
def load(self, path: str | Path) -> np.ndarray:
"""加载图片"""
img = Image.open(path).resize((self.width, self.height)) # uint8 数据
if self.format == "GRAY":
img = img.convert("L")
img = np.expand_dims(img, axis=-1) # WH->HWC
elif self.format == "RGB":
img = np.array(img.convert("RGB")) # WHC->HWC
elif self.format == "BGR":
img = np.array(img.convert("RGB")) # WHC->HWC
img = img[..., ::-1] # RGB 转 BGR
else:
raise TypeError(f'暂未支持数据布局 {self.format}')
return img
def __call__(self, path: str | Path) -> np.ndarray:
img = self.load(path)/255.0 # 归一化(将 uint8 数据归一化到 [0, 1],这是神经网络的标准输入格式)
img = (img - self.mean) / self.std # 标准化,使数据分布更接近标准正态分布
img = img.astype("float32")
if self.layout == "CHW":
img = img.transpose(2, 0, 1) # HWC->CHW
return img
def torch_call(self, path: str | Path) -> "torch.Tensor":
assert self.layout == "CHW", "torchvision 只支持 CHW 布局"
from torchvision.transforms import v2
import torch
from torch import nn
inp = self.load(path)
return nn.Sequential(
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(self.mean, self.std)
)(inp)