importwarningsfromtypingimportCallable,Any,Optional,ListimporttorchfromtorchimportTensorfromtorchimportnnfrom.._internally_replaced_utilsimportload_state_dict_from_urlfrom..ops.miscimportConvNormActivationfrom..utilsimport_log_api_usage_oncefrom._utilsimport_make_divisible__all__=["MobileNetV2","mobilenet_v2"]model_urls={"mobilenet_v2":"https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",}# necessary for backwards compatibilityclass_DeprecatedConvBNAct(ConvNormActivation):def__init__(self,*args,**kwargs):warnings.warn("The ConvBNReLU/ConvBNActivation classes are deprecated since 0.12 and will be removed in 0.14. ""Use torchvision.ops.misc.ConvNormActivation instead.",FutureWarning,)ifkwargs.get("norm_layer",None)isNone:kwargs["norm_layer"]=nn.BatchNorm2difkwargs.get("activation_layer",None)isNone:kwargs["activation_layer"]=nn.ReLU6super().__init__(*args,**kwargs)ConvBNReLU=_DeprecatedConvBNActConvBNActivation=_DeprecatedConvBNActclassInvertedResidual(nn.Module):def__init__(self,inp:int,oup:int,stride:int,expand_ratio:int,norm_layer:Optional[Callable[...,nn.Module]]=None)->None:super().__init__()self.stride=strideassertstridein[1,2]ifnorm_layerisNone:norm_layer=nn.BatchNorm2dhidden_dim=int(round(inp*expand_ratio))self.use_res_connect=self.stride==1andinp==ouplayers:List[nn.Module]=[]ifexpand_ratio!=1:# pwlayers.append(ConvNormActivation(inp,hidden_dim,kernel_size=1,norm_layer=norm_layer,activation_layer=nn.ReLU6))layers.extend([# dwConvNormActivation(hidden_dim,hidden_dim,stride=stride,groups=hidden_dim,norm_layer=norm_layer,activation_layer=nn.ReLU6,),# pw-linearnn.Conv2d(hidden_dim,oup,1,1,0,bias=False),norm_layer(oup),])self.conv=nn.Sequential(*layers)self.out_channels=oupself._is_cn=stride>1defforward(self,x:Tensor)->Tensor:ifself.use_res_connect:returnx+self.conv(x)else:returnself.conv(x)
[文档]classMobileNetV2(nn.Module):def__init__(self,num_classes:int=1000,width_mult:float=1.0,inverted_residual_setting:Optional[List[List[int]]]=None,round_nearest:int=8,block:Optional[Callable[...,nn.Module]]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,dropout:float=0.2,)->None:""" MobileNet V2 main class Args: num_classes (int): Number of classes width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount inverted_residual_setting: Network structure round_nearest (int): Round the number of channels in each layer to be a multiple of this number Set to 1 to turn off rounding block: Module specifying inverted residual building block for mobilenet norm_layer: Module specifying the normalization layer to use dropout (float): The droupout probability """super().__init__()_log_api_usage_once(self)ifblockisNone:block=InvertedResidualifnorm_layerisNone:norm_layer=nn.BatchNorm2dinput_channel=32last_channel=1280ifinverted_residual_settingisNone:inverted_residual_setting=[# t, c, n, s[1,16,1,1],[6,24,2,2],[6,32,3,2],[6,64,4,2],[6,96,3,1],[6,160,3,2],[6,320,1,1],]# only check the first element, assuming user knows t,c,n,s are requirediflen(inverted_residual_setting)==0orlen(inverted_residual_setting[0])!=4:raiseValueError(f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}")# building first layerinput_channel=_make_divisible(input_channel*width_mult,round_nearest)self.last_channel=_make_divisible(last_channel*max(1.0,width_mult),round_nearest)features:List[nn.Module]=[ConvNormActivation(3,input_channel,stride=2,norm_layer=norm_layer,activation_layer=nn.ReLU6)]# building inverted residual blocksfort,c,n,sininverted_residual_setting:output_channel=_make_divisible(c*width_mult,round_nearest)foriinrange(n):stride=sifi==0else1features.append(block(input_channel,output_channel,stride,expand_ratio=t,norm_layer=norm_layer))input_channel=output_channel# building last several layersfeatures.append(ConvNormActivation(input_channel,self.last_channel,kernel_size=1,norm_layer=norm_layer,activation_layer=nn.ReLU6))# make it nn.Sequentialself.features=nn.Sequential(*features)# building classifierself.classifier=nn.Sequential(nn.Dropout(p=dropout),nn.Linear(self.last_channel,num_classes),)# weight initializationforminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode="fan_out")ifm.biasisnotNone:nn.init.zeros_(m.bias)elifisinstance(m,(nn.BatchNorm2d,nn.GroupNorm)):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)elifisinstance(m,nn.Linear):nn.init.normal_(m.weight,0,0.01)nn.init.zeros_(m.bias)def_forward_impl(self,x:Tensor)->Tensor:# This exists since TorchScript doesn't support inheritance, so the superclass method# (this one) needs to have a name other than `forward` that can be accessed in a subclassx=self.features(x)# Cannot use "squeeze" as batch-size can be 1x=nn.functional.adaptive_avg_pool2d(x,(1,1))x=torch.flatten(x,1)x=self.classifier(x)returnxdefforward(self,x:Tensor)->Tensor:returnself._forward_impl(x)
defmobilenet_v2(pretrained:bool=False,progress:bool=True,**kwargs:Any)->MobileNetV2:""" Constructs a MobileNetV2 architecture from `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """model=MobileNetV2(**kwargs)ifpretrained:state_dict=load_state_dict_from_url(model_urls["mobilenet_v2"],progress=progress)model.load_state_dict(state_dict)returnmodel