[文档]deftry_gpu(i=0):"""Return gpu(i) if exists, otherwise return cpu(). """iftorch.cuda.device_count()>=i+1:returntorch.device(f'cuda:{i}')returntorch.device('cpu')
[文档]deftry_all_gpus():"""Return all available GPUs, or [cpu(),] if no GPU exists. """devices=[torch.device(f'cuda:{i}')foriinrange(torch.cuda.device_count())]returndevicesifdeviceselse[torch.device('cpu')]
[文档]classCV:@staticmethoddefget_dataloader_workers():"""Use 4 processes to read the data. """return4@staticmethoddefload_data_fashion_mnist(batch_size,resize=None):"""Download the Fashion-MNIST dataset and then load it into memory. Defined in :numref:`sec_fashion_mnist`"""trans=[transforms.ToTensor()]ifresize:trans.insert(0,transforms.Resize(resize))trans=transforms.Compose(trans)mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)return(data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=CV.get_dataloader_workers()),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=CV.get_dataloader_workers()))@staticmethoddefload_data_cifar10(batch_size,resize=None,num_workers=4):"""Download the Cifar10 dataset and then load it into memory."""trans=[transforms.ToTensor()]ifresize:trans.insert(0,transforms.Resize(resize))trans=transforms.Compose(trans)_train=torchvision.datasets.CIFAR10(root="../data",train=True,transform=trans,download=True)_test=torchvision.datasets.CIFAR10(root="../data",train=False,transform=trans,download=True)return(data.DataLoader(_train,batch_size,shuffle=True,num_workers=num_workers),data.DataLoader(_test,batch_size,shuffle=False,num_workers=num_workers))@staticmethoddefaccuracy(y_hat,y):"""Compute the number of correct predictions. """iflen(y_hat.shape)>1andy_hat.shape[1]>1:y_hat=Fx.argmax(y_hat,axis=1)cmp=Fx.astype(y_hat,y.dtype)==yreturnfloat(Fx.reduce_sum(Fx.astype(cmp,y.dtype)))@staticmethoddefevaluate_accuracy(net,data_iter,device='cpu'):"""计算在指定数据集上模型的精度 """net=net.to(device)ifisinstance(net,torch.nn.Module):net.eval()# 将模型设置为评估模式metric=Accumulator(2)# 正确预测数、预测总数withtorch.no_grad():forX,yindata_iter:X=X.to(device)y=y.to(device)metric.add(CV.accuracy(net(X),y),Fx.size(y))returnmetric[0]/metric[1]@staticmethoddefevaluate_accuracy_gpu(net,data_iter,device=None):"""Compute the accuracy for a model on a dataset using a GPU. Defined in :numref:`sec_lenet`"""ifisinstance(net,nn.Module):net.eval()# Set the model to evaluation modeifnotdevice:device=next(iter(net.parameters())).device# No. of correct predictions, no. of predictionsmetric=Accumulator(2)withtorch.no_grad():forX,yindata_iter:ifisinstance(X,list):# Required for BERT Fine-tuning (to be covered later)X=[x.to(device)forxinX]else:X=X.to(device)y=y.to(device)metric.add(CV.accuracy(net(X),y),Fx.size(y))returnmetric[0]/metric[1]@staticmethoddeftrain_batch(net,X,y,loss,trainer,device):"""Train for a minibatch with mutiple GPUs. """ifisinstance(X,list):# Required for BERT fine-tuning (to be covered later)X=[x.to(device)forxinX]else:X=X.to(device)y=y.to(device)net.train()trainer.zero_grad()pred=net(X)l=loss(pred,y)l.sum().backward()trainer.step()train_loss_sum=l.sum()train_acc_sum=CV.accuracy(pred,y)returntrain_loss_sum,train_acc_sum@staticmethoddeftrain(net,train_iter,test_iter,loss,trainer,num_epochs,device='cpu',need_prepare=False,is_freeze=False,is_quantized_acc=False,backend='fbgemm',ylim=[0,1]):"""Train a model with mutiple GPUs. """timer,num_batches=Timer(),len(train_iter)_ylim=''ifylim[0]==0elsef'{ylim[0]}+'animator=Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=ylim,legend=[f'{_ylim}train loss','train acc','test acc'])# nn.DataParallel(net, device_ids=devices).to(devices[0])net=net.to(device)ifneed_prepare:net.fuse_model()net.qconfig=get_default_qat_qconfig(backend)net=prepare_qat(net)forepochinrange(num_epochs):metric=Accumulator(4)ifis_freeze:ifepoch>3:# 冻结 quantizer 参数net.apply(disable_observer)ifepoch>2:# 冻结 batch 的平均值和方差估计net.apply(nn.intrinsic.qat.freeze_bn_stats)fori,(features,labels)inenumerate(train_iter):timer.start()l,acc=CV.train_batch(net,features,labels,loss,trainer,device)metric.add(l,acc,labels.shape[0],labels.numel())timer.stop()if(i+1)%(num_batches//5)==0ori==num_batches-1:# print((metric[0] / metric[2])+ylim[0])animator.add(epoch+(i+1)/num_batches,((metric[0]/metric[2])+ylim[0],metric[1]/metric[3],None))ifis_quantized_acc:quantized_model=deepcopy(net).to('cpu').eval()quantized_model=convert(quantized_model,inplace=False)test_acc=CV.evaluate_accuracy(quantized_model,test_iter)else:test_acc=CV.evaluate_accuracy_gpu(net,test_iter)animator.add(epoch+1,(None,None,test_acc))print(f'loss {metric[0]/metric[2]:.3f}, train acc 'f'{metric[1]/metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2]*num_epochs/timer.sum():.1f} examples/sec on 'f'{str(device)}')@staticmethoddeftrain_fine_tuning(net,train_iter,test_iter,learning_rate,num_epochs=5,device='cuda:0',is_freeze=False,is_quantized_acc=False,need_prepare=False,param_group=True,ylim=[0,1],output_layer='classifier'):# 如果param_group=True,输出层中的模型参数将使用十倍的学习率# param_name 可能为 'fc' 或者 'classifier'loss=nn.CrossEntropyLoss(reduction="none")ifparam_group:params_1x=[paramforname,paraminnet.named_parameters()ifname.split('.')[0]!=output_layer]trainer=torch.optim.SGD([{'params':params_1x},{'params':getattr(net,output_layer).parameters(),'lr':learning_rate*10}],lr=learning_rate,weight_decay=0.001)else:trainer=torch.optim.SGD(net.parameters(),lr=learning_rate,weight_decay=0.001)CV.train(net,train_iter,test_iter,loss,trainer,num_epochs,device,ylim=ylim,need_prepare=need_prepare,is_freeze=is_freeze,is_quantized_acc=is_quantized_acc)