pytorch_book.datasets.examples.hymenoptera 源代码

from pathlib import Path

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

from ..imagenet import Transforms
from ..file import ZipDataset, PathType


[文档]class Hymenoptera: '''`Hymenoptera 数据集 <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`__ 是 ImageNet 的一个非常小的子集。 有大约 120 个蚂蚁和蜜蜂的训练图像。每个类有 75 张验证图片。 Args: data_dir: 数据根目录 Attributes: dataset_sizes: 数据集大小 dataloaders: 数据加载器 class_names: 类名列表 ''' def __init__(self, data_dir: PathType='data/hymenoptera_data', download: bool = False) -> None: data_dir = Path(data_dir) if download or not data_dir.exists(): self.loader() data_types = ['train', 'val'] image_datasets = {x: ImageFolder(data_dir/x, getattr(Transforms, x)) for x in data_types} self.dataloaders = {x: DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in data_types} self.dataset_sizes = {x: len(image_datasets[x]) for x in data_types} self.class_names = image_datasets['train'].classes def loader(self): root = 'data' url = 'https://download.pytorch.org/tutorial/hymenoptera_data.zip' zip_name = url.split('/')[-1] zipset = ZipDataset(root) zip_name = zipset.download(url, zip_name) # 下载数据 zipset.extractall(zip_name) # 解压数据