[文档]classAverageMeter:"""Computes and stores the average and current value"""def__init__(self,name,fmt=':f'):self.name=nameself.fmt=fmtself.reset()defreset(self):self.val=0self.avg=0self.sum=0self.count=0defupdate(self,val,n=1):self.val=valself.sum+=val*nself.count+=nself.avg=self.sum/self.countdef__str__(self):fmtstr='{name} {val'+self.fmt+'} ({avg'+self.fmt+'})'returnfmtstr.format(**self.__dict__)
[文档]defaccuracy(output,target,topk=(1,)):"""Computes the accuracy over the k top predictions for the specified values of k"""withtorch.no_grad():maxk=max(topk)batch_size=target.size(0)_,pred=output.topk(maxk,1,True,True)pred=pred.t()correct=pred.eq(target.view(1,-1).expand_as(pred))res=[]forkintopk:correct_k=correct[:k].reshape(-1).float().sum(0,keepdim=True)res.append(correct_k.mul_(100.0/batch_size))returnres