[文档]def_replace_relu(module:nn.Module)->None:reassign={}forname,modinmodule.named_children():_replace_relu(mod)# Checking for explicit type instead of instance# as we only want to replace modules of the exact type# not inherited classesiftype(mod)isnn.ReLUortype(mod)isnn.ReLU6:reassign[name]=nn.ReLU(inplace=False)forkey,valueinreassign.items():module._modules[key]=value
[文档]defquantize_model(model:nn.Module,backend:str)->None:_dummy_input_data=torch.rand(1,3,299,299)ifbackendnotintorch.backends.quantized.supported_engines:raiseRuntimeError("Quantized backend not supported ")torch.backends.quantized.engine=backendmodel.eval()# Make sure that weight qconfig matches that of the serialized modelsifbackend=="fbgemm":model.qconfig=torch.ao.quantization.QConfig(# type: ignore[assignment]activation=torch.ao.quantization.default_observer,weight=torch.ao.quantization.default_per_channel_weight_observer,)elifbackend=="qnnpack":model.qconfig=torch.ao.quantization.QConfig(# type: ignore[assignment]activation=torch.ao.quantization.default_observer,weight=torch.ao.quantization.default_weight_observer)# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659model.fuse_model()# type: ignore[operator]torch.ao.quantization.prepare(model,inplace=True)model(_dummy_input_data)torch.ao.quantization.convert(model,inplace=True)