from __future__ import annotations
import copy
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
[文档]def train_model(model, loader, criterion, optimizer, scheduler, num_epochs=25, device='cpu'):
"""
模型训练所支持的函数
Args:
model: 待训练的模型
criterion: 优化准则 (loss)
optimizer: 用于训练的 Optimizer
scheduler: :mod:`torch.optim.lr_scheduler` 中的 LR 调度器对象
num_epochs: epochs 数量
device: 用来进行训练的设备。必须是 'cpu' 或 'cuda'
"""
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
model.to(device)
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for inputs, labels in loader.dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / loader.dataset_sizes[phase]
epoch_acc = running_corrects.double() / loader.dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print(f'完成训练用时 {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'最佳 val Acc: {best_acc:4f}')
# load best model weights
model.load_state_dict(best_model_wts)
return model