测试#
from matplotlib import pyplot as plt
import torch
from mod import load_mod
plt.ion()
# 载入自定义模块
load_mod()
from xinet import CV
batch_size = 128
train_iter, test_iter = CV.load_data_cifar10(batch_size=batch_size)
Files already downloaded and verified
Files already downloaded and verified
/home/pc/xinet/anaconda3/envs/ai/lib/python3.9/site-packages/torch/ao/quantization/observer.py:172: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
warnings.warn(
CV.train_fine_tuning(model_ft, train_iter, test_iter,
learning_rate=1e-3,
num_epochs=100,
param_group=False)
loss 0.008, train acc 0.997, test acc 0.756
814.4 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2)]
from torch.quantization import convert
model_ft.cpu()
model_convert = convert(model_ft, inplace=False)
torch.save(model_ft.state_dict(), '../models/cifar10.pt')
torch.save(model_convert.state_dict(), '../models/cifar10_convert.pt')