importcopyimportitertoolsimportwarningsimporttorchimporttorch.nnasnnimporttorch.nn.quantizedasnnqfromtorch.nn.intrinsicimport_FusedModulefromtorch.ao.quantization.quantization_mappingsimport(get_default_dynamic_quant_module_mappings,get_default_static_quant_module_mappings,get_default_qat_module_mappings,get_default_qconfig_propagation_list,no_observer_set,_has_special_act_post_process,_get_special_act_post_process,)fromtorch.ao.quantization.stubsimportDeQuantStub,QuantWrapperfromtorch.ao.quantization.qconfigimport(add_module_to_qconfig_obs_ctr,default_dynamic_qconfig,float16_dynamic_qconfig,float_qparams_weight_only_qconfig,float_qparams_weight_only_qconfig_4bit,activation_is_memoryless)defis_activation_post_process(module):return(isinstance(module,torch.ao.quantization.ObserverBase)orisinstance(module,torch.ao.quantization.FakeQuantizeBase))def_propagate_qconfig_helper(module,qconfig_dict,qconfig_parent=None,prefix=''):r"""This is a helper function for `propagate_qconfig_` Args: module: input module qconfig_dict: dictionary that maps from name of submodule to quantization configuration qconfig_parent: quantization config of parent module, we will fallback to this config when there is no specified config for current module prefix: corresponding prefix of the current module, used as key in qconfig_dict Return: None, module is modified inplace with qconfig attached """module_qconfig=qconfig_dict.get(type(module),qconfig_parent)module_qconfig=qconfig_dict.get(prefix,module_qconfig)module_qconfig=getattr(module,'qconfig',module_qconfig)torch.ao.quantization.qconfig.assert_valid_qconfig(module_qconfig,module)qconfig_with_device_check=add_module_to_qconfig_obs_ctr(module_qconfig,module)module.qconfig=qconfig_with_device_checkforname,childinmodule.named_children():module_prefix=prefix+'.'+nameifprefixelsename_propagate_qconfig_helper(child,qconfig_dict,qconfig_with_device_check,module_prefix)defpropagate_qconfig_(module,qconfig_dict=None):r"""Propagate qconfig through the module hierarchy and assign `qconfig` attribute on each leaf module Args: module: input module qconfig_dict: dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute) Return: None, module is modified inplace with qconfig attached """ifqconfig_dictisNone:qconfig_dict={}_propagate_qconfig_helper(module,qconfig_dict)def_observer_forward_hook(self,input,output):r"""Forward hook that calls observer on the output """returnself.activation_post_process(output)def_observer_forward_pre_hook(self,input):r"""Forward pre hook that calls observer on the output """returnself.activation_post_process(input[0])defregister_activation_post_process_hook(module,pre_hook=False):asserthasattr(module,'activation_post_process'), \
'Expect activation_post_process attribute already attached to the module'ifpre_hook:handle=module.register_forward_pre_hook(_observer_forward_pre_hook)module._forward_pre_hooks.move_to_end(handle.id,last=False)else:handle=module.register_forward_hook(_observer_forward_hook)module._forward_hooks.move_to_end(handle.id,last=False)defadd_observer_(module,qconfig_propagation_list=None,non_leaf_module_list=None,device=None,custom_module_class_mapping=None):r"""Add observer for the leaf child of the module. This function insert observer module to all leaf child module that has a valid qconfig attribute. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize qconfig_propagation_list: a list of quantizable modules that will have observers added to them if they are leaf nodes device: parent device, if any non_leaf_module_list: list of non-leaf modules we want to add observer Return: None, module is modified inplace with added observer modules and forward_hooks """ifqconfig_propagation_listisNone:qconfig_propagation_list=get_default_qconfig_propagation_list()ifcustom_module_class_mappingisNone:custom_module_class_mapping={}# respect device affinity when adding observersifdeviceisNone:devices=get_unique_devices_(module)assertlen(devices)<=1,("add_observer_ only works with cpu or single-device CUDA modules, ""but got devices {}".format(devices))device=next(iter(devices))iflen(devices)>0elseNonedefget_activation_post_process(qconfig,device,special_act_post_process=None):activation=qconfig.activation()ifspecial_act_post_processisNoneelsespecial_act_post_process()ifdeviceisnotNone:activation.to(device)returnactivationdefneeds_observation(m):returnhasattr(m,'qconfig')andm.qconfigisnotNonedefinsert_activation_post_process(m,special_act_post_process=None):""" Adds an activation post process module and register a pre or post hook that calls the module """# We don't insert observer/fake_quantize for DeQuantStubifneeds_observation(m)andnotisinstance(m,DeQuantStub):# observer and hook will be gone after we swap the modulem.add_module('activation_post_process',get_activation_post_process(m.qconfig,device,special_act_post_process))# Register observer as the first entry in the hook list# All post forward hooks are preserved and will be executed after the observer before convertregister_activation_post_process_hook(m,pre_hook=activation_is_memoryless(m.qconfig))forname,childinmodule.named_children():# TODO remove Dropout special after codebase stableiftype(child)in[nn.Dropout]:continueeliftype(child)in[nnq.FloatFunctional,nnq.QFunctional]:ifneeds_observation(child):child.activation_post_process=get_activation_post_process(child.qconfig,device)elifisinstance(child,_FusedModule):# activation_post_process are now added directly to nn.Sequentail/_FusedModuleifneeds_observation(child):insert_activation_post_process(child)elif_has_special_act_post_process(child):special_act_post_process=_get_special_act_post_process(child)insert_activation_post_process(child,special_act_post_process)elifnon_leaf_module_listisnotNoneandtype(child)innon_leaf_module_list:ifneeds_observation(child):insert_activation_post_process(child)elifneeds_observation(child)andtype(child)incustom_module_class_mapping:observed_child=custom_module_class_mapping[type(child)].from_float(child)setattr(module,name,observed_child)# TODO: These are the modules that cannot be observed# Once there are more, we should move them to a separate listifcustom_module_class_mapping[type(child)]notinno_observer_set():insert_activation_post_process(observed_child)else:add_observer_(child,qconfig_propagation_list,non_leaf_module_list,device,custom_module_class_mapping)# Insert observers only for leaf nodes, note that this observer is for# the output of the module, for input QuantStub will observe themiflen(module._modules)==0andnotisinstance(module,torch.nn.Sequential) \
andtype(module)inqconfig_propagation_list:insert_activation_post_process(module)defget_unique_devices_(module):return{p.deviceforpinmodule.parameters()}| \
{p.deviceforpinmodule.buffers()}defadd_quant_dequant(module):r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize Return: Either the inplace modified module with submodules wrapped in `QuantWrapper` based on qconfig or a new `QuantWrapper` module which wraps the input module, the latter case only happens when the input module is a leaf module and we want to quantize it. """iflen(module._modules)==0andhasattr(module,'qconfig')andmodule.qconfig:returnQuantWrapper(module)forname,childinmodule.named_children():module._modules[name]=add_quant_dequant(child)returnmodule
[文档]defprepare(model,inplace=False,allow_list=None,observer_non_leaf_module_list=None,prepare_custom_config_dict=None):r"""Prepares a copy of the model for quantization calibration or quantization-aware training. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. The model will be attached with observer or fake quant modules, and qconfig will be propagated. Args: `model`: input model to be modified in-place `inplace`: carry out model transformations in-place, the original module is mutated `allow_list`: list of quantizable modules `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer `prepare_custom_config_dict`: customization configuration dictionary for prepare function .. code-block:: python # Example of prepare_custom_config_dict: prepare_custom_config_dict = { # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module "float_to_observed_custom_module_class": { CustomModule: ObservedCustomModule } } """torch._C._log_api_usage_once("quantization_api.quantize.prepare")ifprepare_custom_config_dictisNone:prepare_custom_config_dict={}custom_module_class_mapping=prepare_custom_config_dict.get("float_to_observed_custom_module_class",{})ifnotinplace:model=copy.deepcopy(model)# TODO: remove allow_listqconfig_propagation_list=allow_listifallow_listisNone:qconfig_propagation_list=get_default_qconfig_propagation_list()propagate_qconfig_(model,qconfig_dict=None)# sanity check common API misusageifnotany(hasattr(m,'qconfig')andm.qconfigforminmodel.modules()):warnings.warn("None of the submodule got qconfig applied. Make sure you ""passed correct configuration through `qconfig_dict` or ""by assigning the `.qconfig` attribute directly on submodules")add_observer_(model,qconfig_propagation_list,observer_non_leaf_module_list,custom_module_class_mapping=custom_module_class_mapping)returnmodel
def_remove_activation_post_process(module):# TODO: maybe we should change activation_post_process to _activation_post_process# to prevent it from being used by userifhasattr(module,'activation_post_process')and \
is_activation_post_process(module.activation_post_process):delattr(module,'activation_post_process')# remove activation_post_proceess pre and post hooksdefremove_hooks(pre_hook=False):hook_map=module._forward_pre_hooksifpre_hookelsemodule._forward_hooksobserver_hook=_observer_forward_pre_hookifpre_hookelse_observer_forward_hookhandle_ids_to_remove=set()forhandle_id,hook_fninhook_map.items():ifhook_fnisobserver_hook:handle_ids_to_remove.add(handle_id)forhandle_idinhandle_ids_to_remove:hook_map.pop(handle_id)remove_hooks(pre_hook=True)remove_hooks(pre_hook=False)# TODO: rename to something more generaldef_remove_qconfig(module):r"""Clean up the qconfig left in the module so that new qconfig can be propagated. Args: module: module to be cleaned up """forchildinmodule.children():_remove_qconfig(child)ifhasattr(module,"qconfig"):delmodule.qconfig_remove_activation_post_process(module)
[文档]defquantize(model,run_fn,run_args,mapping=None,inplace=False):r"""Quantize the input float model with post training static quantization. First it will prepare the model for calibration, then it calls `run_fn` which will run the calibration step, after that we will convert the model to a quantized model. Args: model: input float model run_fn: a calibration function for calibrating the prepared model run_args: positional arguments for `run_fn` inplace: carry out model transformations in-place, the original module is mutated mapping: correspondence between original module types and quantized counterparts Return: Quantized model. """torch._C._log_api_usage_once("quantization_api.quantize.quantize")ifmappingisNone:mapping=get_default_static_quant_module_mappings()ifnotinplace:model=copy.deepcopy(model)model.eval()prepare(model,inplace=True)run_fn(model,*run_args)convert(model,mapping,inplace=True)returnmodel
[文档]defquantize_dynamic(model,qconfig_spec=None,dtype=torch.qint8,mapping=None,inplace=False):r"""Converts a float model to dynamic (i.e. weights-only) quantized model. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants. Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. If `qconfig` is provided, the `dtype` argument is ignored. Args: model: input model qconfig_spec: Either: - A dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute). Entries in the dictionary need to be QConfig instances. - A set of types and/or submodule names to apply dynamic quantization to, in which case the `dtype` argument is used to specify the bit-width inplace: carry out model transformations in-place, the original module is mutated mapping: maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced """torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")ifqconfig_specisNone:ifdtype==torch.qint8:qconfig_spec={nn.Linear:default_dynamic_qconfig,nn.LSTM:default_dynamic_qconfig,nn.GRU:default_dynamic_qconfig,nn.LSTMCell:default_dynamic_qconfig,nn.RNNCell:default_dynamic_qconfig,nn.GRUCell:default_dynamic_qconfig,}elifdtype==torch.float16:qconfig_spec={nn.Linear:float16_dynamic_qconfig,nn.LSTM:float16_dynamic_qconfig,nn.GRU:float16_dynamic_qconfig,nn.LSTMCell:float16_dynamic_qconfig,nn.RNNCell:float16_dynamic_qconfig,nn.GRUCell:float16_dynamic_qconfig,}elifdtype==torch.quint8:qconfig_spec={nn.EmbeddingBag:float_qparams_weight_only_qconfig,nn.Embedding:float_qparams_weight_only_qconfig,}elifdtype==torch.quint4x2:qconfig_spec={nn.EmbeddingBag:float_qparams_weight_only_qconfig_4bit,}else:raiseValueError("Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))elifisinstance(qconfig_spec,set):ifdtypeistorch.qint8:default_qconfig=default_dynamic_qconfigelifdtypeistorch.float16:default_qconfig=float16_dynamic_qconfigelifdtypeistorch.quint8:default_qconfig=float_qparams_weight_only_qconfigelifdtypeistorch.quint4x2:default_qconfig=float_qparams_weight_only_qconfig_4bitelse:raiseRuntimeError('Unknown dtype specified for quantize_dynamic: ',str(dtype))qconfig_spec=dict(zip(qconfig_spec,itertools.repeat(default_qconfig)))ifmappingisNone:mapping=get_default_dynamic_quant_module_mappings()ifnotinplace:model=copy.deepcopy(model)model.eval()propagate_qconfig_(model,qconfig_spec)convert(model,mapping,inplace=True)returnmodel
[文档]defprepare_qat(model,mapping=None,inplace=False):r""" Prepares a copy of the model for quantization calibration or quantization-aware training and converts it to quantized version. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. Args: model: input model to be modified in-place mapping: dictionary that maps float modules to quantized modules to be replaced. inplace: carry out model transformations in-place, the original module is mutated """torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")assertmodel.training,"prepare_qat only works on models in training mode"ifmappingisNone:mapping=get_default_qat_module_mappings()ifnotinplace:model=copy.deepcopy(model)propagate_qconfig_(model,qconfig_dict=None)convert(model,mapping=mapping,inplace=True,remove_qconfig=False)prepare(model,observer_non_leaf_module_list=set(mapping.values()),inplace=True)returnmodel
[文档]defquantize_qat(model,run_fn,run_args,inplace=False):r"""Do quantization aware training and output a quantized model Args: model: input model run_fn: a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop run_args: positional arguments for `run_fn` Return: Quantized model. """torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")ifnotinplace:model=copy.deepcopy(model)model.train()prepare_qat(model,inplace=True)run_fn(model,*run_args)convert(model,inplace=True)returnmodel
[文档]defconvert(module,mapping=None,inplace=False,remove_qconfig=True,convert_custom_config_dict=None):r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class. And remove qconfig at the end if remove_qconfig is set to True. Args: `module`: prepared and calibrated module `mapping`: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules `inplace`: carry out model transformations in-place, the original module is mutated `convert_custom_config_dict`: custom configuration dictionary for convert function .. code-block:: python # Example of convert_custom_config_dict: convert_custom_config_dict = { # user will manually define the corresponding quantized # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { ObservedCustomModule: QuantizedCustomModule } } """torch._C._log_api_usage_once("quantization_api.quantize.convert")ifnotinplace:module=copy.deepcopy(module)_convert(module,mapping,inplace=True,convert_custom_config_dict=convert_custom_config_dict)ifremove_qconfig:_remove_qconfig(module)returnmodule
def_convert(module,mapping=None,inplace=False,convert_custom_config_dict=None):r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class Args: module: input module mapping: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules inplace: carry out model transformations in-place, the original module is mutated """ifmappingisNone:mapping=get_default_static_quant_module_mappings()ifconvert_custom_config_dictisNone:convert_custom_config_dict={}custom_module_class_mapping=convert_custom_config_dict.get("observed_to_quantized_custom_module_class",{})ifnotinplace:module=copy.deepcopy(module)reassign={}forname,modinmodule.named_children():# both fused modules and observed custom modules are# swapped as one unitifnotisinstance(mod,_FusedModule)and \
type(mod)notincustom_module_class_mapping:_convert(mod,mapping,True,# inplaceconvert_custom_config_dict)reassign[name]=swap_module(mod,mapping,custom_module_class_mapping)forkey,valueinreassign.items():module._modules[key]=valuereturnmoduledefswap_module(mod,mapping,custom_module_class_mapping):r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. Args: mod: input module mapping: a dictionary that maps from nn module to nnq module Return: The corresponding quantized module of `mod` """new_mod=modifhasattr(mod,'qconfig')andmod.qconfigisnotNone:swapped=Falseiftype(mod)incustom_module_class_mapping:new_mod=custom_module_class_mapping[type(mod)].from_observed(mod)swapped=Trueeliftype(mod)inmapping:new_mod=mapping[type(mod)].from_float(mod)swapped=Trueifswapped:# Preserve module's pre forward hooks. They'll be called on quantized inputforpre_hook_fninmod._forward_pre_hooks.values():new_mod.register_forward_pre_hook(pre_hook_fn)# Preserve module's post forward hooks except _observer_forward_hook# After convert they'll work with quantized outputforhook_fninmod._forward_hooks.values():ifhook_fnisnot_observer_forward_hook:new_mod.register_forward_hook(hook_fn)# respect device affinity when swapping modulesdevices=get_unique_devices_(mod)assertlen(devices)<=1,("swap_module only works with cpu or single-device CUDA modules, ""but got devices {}".format(devices))device=next(iter(devices))iflen(devices)>0elseNoneifdevice:new_mod.to(device)returnnew_moddefget_observer_dict(mod,target_dict,prefix=""):r"""Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug Args: mod: the top module we want to save all observers prefix: the prefix for the current module target_dict: the dictionary used to save all the observers """defget_prefix(prefix):returnprefixifprefix==""elseprefix+'.'ifhasattr(mod,'activation_post_process'):target_dict[get_prefix(prefix)+'activation_post_process']=mod.activation_post_processforname,childinmod.named_children():module_prefix=get_prefix(prefix)+nameifprefixelsenameget_observer_dict(child,target_dict,module_prefix)