四元数 CNN

四元数 CNN#

from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
from torchvision.utils import make_grid
from torchvision.transforms import v2 
from torchvision import datasets, models
import matplotlib.pyplot as plt
import logging
import os
from tqdm import tqdm
from torch.utils.data import DataLoader
from PIL import Image
from tempfile import TemporaryDirectory
from torch.profiler import profile, record_function, ProfilerActivity
from pathlib import Path
from taolib.utils.logger import config_logging

temp_dir = Path(".").resolve() / ".temp"
temp_dir.mkdir(exist_ok=True)
config_logging(f'{temp_dir}/0-compile.log', "root", maxBytes=5000000, backupCount=7)
torch.cuda.empty_cache() # 清空 GPU 缓存

cudnn.benchmark = True
plt.ion()   # interactive mode
@dataclass
class CIFAR10:
    root_dir: str

    def __post_init__(self):
        self.root_dir = Path(self.root_dir)
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        self.train_transform = v2.Compose([ # 在高度和宽度上将图像放大到40像素的正方形
            v2.Resize(40),
            # 随机裁剪出一个高度和宽度均为40像素的正方形图像,
            # 生成一个面积为原始图像面积0.64到1倍的小正方形,
            # 然后将其缩放为高度和宽度均为32像素的正方形
            v2.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
            v2.RandomHorizontalFlip(),
            v2.ToImage(),
        ])
        self.val_transform = v2.ToImage()
        self.train = datasets.CIFAR10(
            root=self.root_dir, 
            train=True, download=True, 
            transform=self.train_transform,
        )
        self.val = datasets.CIFAR10(
            root=self.root_dir, 
            train=False, download=True,
            transform=self.val_transform,
        )
        self.normalize = nn.Sequential(
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(self.mean, self.std)
        )
dataset = CIFAR10(temp_dir/"data")
batch_size = 16
dataloaders = {
    "train": DataLoader(dataset.train, batch_size=batch_size, shuffle=True),
    "val": DataLoader(dataset.val, batch_size=batch_size, shuffle=False),
}
dataset_sizes = {
    "train": len(dataset.train),
    "val": len(dataset.val),
}
class_names = dataset.train.classes
# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

可视化数据#

from taolib.plot.image import show_images
classe_names = dataset.train.classes
idx_to_class = {v:k for k, v in dataset.train.class_to_idx.items()}
# Get a batch of training data
inputs, classes = next(iter(dataloaders["val"]))
inputs, classes = inputs[:8], classes[:8]
inputs = [v2.ToPILImage()(inp) for inp in inputs]
# inputs = inputs.numpy().transpose((0, 2, 3, 1))
show_images(inputs, 2, 4, scale=2, titles=[idx_to_class[x.item()] for x in classes]);

训练模型#

from torchvision.ops.misc import Conv2dNormActivation
from core_qnn.quaternion_layers import QuaternionConv

def qconv(f_conv, in_channels, out_channels, groups):
    return QuaternionConv(
        in_channels, out_channels,
        f_conv.kernel_size, f_conv.stride, 
        f_conv.dilation, f_conv.padding, groups, f_conv.bias,
        init_criterion='glorot',
        weight_init='quaternion', seed=None, operation='convolution2d', 
        rotation=False, quaternion_format=True, scale=False
    )

class Model(nn.Module):
    def __init__(self, transform: nn.Module, *args, **kwargs):
        super().__init__(*args, **kwargs)
        model = models.mobilenet_v2(weights='IMAGENET1K_V1')
        model.features[0] = Conv2dNormActivation(3, 32, stride=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU6)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, len(class_names))
        # model.features[18][0] = qconv(model.features[18][0], model.features[18][0].groups)
        for index, blk in enumerate(model.features[1:18]):
            if index < 2:
                continue
            # if index==0:
            #     # k = 1
            #     # model.features[index+1].conv[k] = qconv(blk.conv[k], blk.in_channels, blk.out_channels, blk.conv[k].groups)
            #     continue
            # elif index<13:
            #     # for param in blk.parameters():
            #     #     param.requires_grad = False # 冻结参数
            #     continue
            else:
                # model.features[index+1].conv[0][0] = qconv(blk.conv[0][0], blk.in_channels, blk.out_channels, 1)
                k = 2
                model.features[index+1].conv[k] = qconv(blk.conv[k], blk.conv[k].in_channels, blk.conv[k].out_channels, blk.conv[k].groups)
                # k = 1
                # model.features[index+1].conv[k][0] = qconv(blk.conv[k][0], blk.in_channels, blk.out_channels, 1) # 普通模式
                # model.features[index+1].conv[k][0] = qconv(
                #     blk.conv[k][0], 
                #     4, 
                #     4, 
                #     4
                # )
        self.model = model
        self.transform = transform

    def forward(self, x):
        x = self.transform(x)
        return self.model(x)


lr = 0.00142857
lr_decay = 0.0857142
weight_decay = 0.00857142
momentum = 0.857142
criterion = nn.CrossEntropyLoss()
model_ft = Model(dataset.normalize)
optimizer_ft = optim.SGD(
    model_ft.parameters(), lr=lr, momentum=momentum, 
    weight_decay=weight_decay
)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=4, gamma=lr_decay)
# for index, blk in enumerate(model_ft.model.features[1:18]):
#     print(blk.conv)
#     # model_ft.model.features[index+1].conv[0][0] = blk.conv[0][0]
#     # if index==2:
#     #     break
#     # break
import torch
from torch import fx
from tvm.relax.frontend.torch import from_fx
from torch import _dynamo as dynamo

input_info = [((1, 3, 32, 32), "float32")]
# graph_module = fx.symbolic_trace(model_ft.eval())
# 
scripted_model = torch.jit.trace(model_ft.model.eval(), torch.randn((1, 3, 32, 32))).eval()
# 保存模型
torch.jit.save(scripted_model, temp_dir/'test.pt')
# scripted_model = torch.jit.load(temp_dir/'test.pt')
# scripted_model = torch.jit.script(scripted_model)
# mod = from_fx(scripted_model, input_info)

训练和评估模型:

from taolib.utils.timer import Timer
from torch_book.vision.classifier import Classifier
timer = Timer()
classifier = Classifier(
    model_ft, criterion, optimizer_ft, 
    exp_lr_scheduler, 
    dataloaders["train"], 
    dataloaders["val"],
    device, 
    timer)
classifier.fit(20, ylim=[0, 2], checkpoint_dir=temp_dir/'checkpoint')