[文档]classFloatFunctional(torch.nn.Module):r"""State collector class for float operations. The instance of this class can be used instead of the ``torch.`` prefix for some operations. See example usage below. .. note:: This class does not provide a ``forward`` hook. Instead, you must use one of the underlying functions (e.g. ``add``). Examples:: >>> f_add = FloatFunctional() >>> a = torch.tensor(3.0) >>> b = torch.tensor(4.0) >>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)`` Valid operation names: - add - cat - mul - add_relu - add_scalar - mul_scalar """def__init__(self):super(FloatFunctional,self).__init__()self.activation_post_process=torch.nn.Identity()defforward(self,x):raiseRuntimeError("FloatFunctional is not intended to use the "+"'forward'. Please use the underlying operation")r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""defadd(self,x:Tensor,y:Tensor)->Tensor:r=torch.add(x,y)r=self.activation_post_process(r)returnrr"""Operation equivalent to ``torch.add(Tensor, float)``"""defadd_scalar(self,x:Tensor,y:float)->Tensor:r=torch.add(x,y)# Note: this operation is not observed because the observation is not# needed for the quantized op.returnrr"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""defmul(self,x:Tensor,y:Tensor)->Tensor:r=torch.mul(x,y)r=self.activation_post_process(r)returnrr"""Operation equivalent to ``torch.mul(Tensor, float)``"""defmul_scalar(self,x:Tensor,y:float)->Tensor:r=torch.mul(x,y)# Note: this operation is not observed because the observation is not# needed for the quantized op.returnrr"""Operation equivalent to ``torch.cat``"""defcat(self,x:List[Tensor],dim:int=0)->Tensor:r=torch.cat(x,dim=dim)r=self.activation_post_process(r)returnrr"""Operation equivalent to ``relu(torch.add(x,y))``"""defadd_relu(self,x:Tensor,y:Tensor)->Tensor:r=torch.add(x,y)r=torch.nn.functional.relu(r)r=self.activation_post_process(r)returnr
[文档]classFXFloatFunctional(torch.nn.Module):r""" module to replace FloatFunctional module before FX graph mode quantization, since activation_post_process will be inserted in top level module directly Valid operation names: - add - cat - mul - add_relu - add_scalar - mul_scalar """defforward(self,x):raiseRuntimeError("FloatFunctional is not intended to use the "+"'forward'. Please use the underlying operation")r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""defadd(self,x:Tensor,y:Tensor)->Tensor:r=torch.add(x,y)returnrr"""Operation equivalent to ``torch.add(Tensor, float)``"""defadd_scalar(self,x:Tensor,y:float)->Tensor:r=torch.add(x,y)returnrr"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""defmul(self,x:Tensor,y:Tensor)->Tensor:r=torch.mul(x,y)returnrr"""Operation equivalent to ``torch.mul(Tensor, float)``"""defmul_scalar(self,x:Tensor,y:float)->Tensor:r=torch.mul(x,y)returnrr"""Operation equivalent to ``torch.cat``"""defcat(self,x:List[Tensor],dim:int=0)->Tensor:r=torch.cat(x,dim=dim)returnrr"""Operation equivalent to ``relu(torch.add(x,y))``"""defadd_relu(self,x:Tensor,y:Tensor)->Tensor:r=torch.add(x,y)r=torch.nn.functional.relu(r)returnr
[文档]classQFunctional(torch.nn.Module):r"""Wrapper class for quantized operations. The instance of this class can be used instead of the ``torch.ops.quantized`` prefix. See example usage below. .. note:: This class does not provide a ``forward`` hook. Instead, you must use one of the underlying functions (e.g. ``add``). Examples:: >>> q_add = QFunctional() >>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32) >>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32) >>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)`` Valid operation names: - add - cat - mul - add_relu - add_scalar - mul_scalar """def__init__(self):super(QFunctional,self).__init__()self.scale=1.0self.zero_point=0self.activation_post_process=torch.nn.Identity()def_save_to_state_dict(self,destination,prefix,keep_vars):super(QFunctional,self)._save_to_state_dict(destination,prefix,keep_vars)destination[prefix+'scale']=torch.tensor(self.scale)destination[prefix+'zero_point']=torch.tensor(self.zero_point)def_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs):self.scale=float(state_dict.pop(prefix+'scale'))self.zero_point=int(state_dict.pop(prefix+'zero_point'))super(QFunctional,self)._load_from_state_dict(state_dict,prefix,local_metadata,False,missing_keys,unexpected_keys,error_msgs)def_get_name(self):return'QFunctional'defextra_repr(self):return'scale={}, zero_point={}'.format(self.scale,self.zero_point)defforward(self,x):raiseRuntimeError("Functional is not intended to use the "+"'forward'. Please use the underlying operation")r"""Operation equivalent to ``torch.ops.quantized.add``"""defadd(self,x:Tensor,y:Tensor)->Tensor:r=ops.quantized.add(x,y,scale=self.scale,zero_point=self.zero_point)r=self.activation_post_process(r)returnrr"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``"""defadd_scalar(self,x:Tensor,y:float)->Tensor:r=ops.quantized.add_scalar(x,y)# Note: this operation is not observed because the observation is not# needed for the quantized op.returnrr"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``"""defmul(self,x:Tensor,y:Tensor)->Tensor:r=ops.quantized.mul(x,y,scale=self.scale,zero_point=self.zero_point)r=self.activation_post_process(r)returnrr"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``"""defmul_scalar(self,x:Tensor,y:float)->Tensor:r=ops.quantized.mul_scalar(x,y)# Note: this operation is not observed because the observation is not# needed for the quantized op.returnrr"""Operation equivalent to ``torch.ops.quantized.cat``"""defcat(self,x:List[Tensor],dim:int=0)->Tensor:r=ops.quantized.cat(x,scale=self.scale,zero_point=self.zero_point,dim=dim)r=self.activation_post_process(r)returnrr"""Operation equivalent to ``torch.ops.quantized.add_relu``"""defadd_relu(self,x:Tensor,y:Tensor)->Tensor:r=ops.quantized.add_relu(x,y,scale=self.scale,zero_point=self.zero_point)r=self.activation_post_process(r)returnr@classmethoddeffrom_float(cls,mod):asserttype(mod)==FloatFunctional,\
"QFunctional.from_float expects an instance of FloatFunctional"scale,zero_point=mod.activation_post_process.calculate_qparams()new_mod=QFunctional()new_mod.scale=float(scale)new_mod.zero_point=int(zero_point)returnnew_mod