# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""ONNX: Open Neural Network Exchange importer for Relax.This module implements the required functionality to read ONNX modelsand convert them into equivalent Relax functions. The entry point that encapsulatesthis functionality is the function from_onnx.In order to extend the functionality of the importer, you can add newoperators to the operator registry. The operator registry is a dictionarythat maps operator names to operator converters. The registry is definedin the _get_converter_map function. To add a new operator, you can definea new class that inherits from the OnnxOpConverter class and implementthe _impl method.By default, ONNX defines models in terms of dynamic shapes. The ONNX importerretains dynamic shapes upon import, and when possible, the compiler attempts toconvert the model to use static shapes at compile time.If this fails, there may still be dynamic operations in the model.Not all TVM kernels currently support dynamic shapes, please file an issue ongithub.com/apache/tvm/issues if you hit an error with dynamic kernels."""importmathimportoperatorimportreimportwarningsfromtypingimportAny,Callable,Dict,List,Optional,Tuple,Unionimportnumpyas_npimportonnx.onnx_ml_pb2importtvmfromtvmimportTVMError,relax,tir,topifromtvm.irimportIRModulefromtvm.ir.supplyimportNameSupplyfromtvm.tir.genericimportcastfromtvm.topi.utilsimportget_const_tuplefrom..commonimportautopaddefget_type(elem_type:Union[str,int])->str:"""Converts onnx integer datatype to numpy datatype"""# If a string was passed instead of a tensor type, it does not need# conversion and can be returned.ifisinstance(elem_type,str):returnelem_typetry:fromonnx.helperimport(# pylint: disable=import-outside-topleveltensor_dtype_to_np_dtype,)exceptImportErrorasexception:raiseImportError("Unable to import onnx which is required {}".format(exception))returnstr(tensor_dtype_to_np_dtype(elem_type))defget_constant(var:Union[relax.Constant,relax.Var],params:List[Dict[str,relax.Var]],)->Union[relax.Constant,relax.Var]:"""Attempt to convert a variable to a constant if possible. This is the primary function meant to interact with params. Parameters ---------- var: Union[relax.Constant, relax.Var] The input value to try to convert to a constant. params: List[Dict[str, relax.Var]] The parameters for the graph. Contains both the global registry of nodes for the graph and the parameter dictionary. The global registry is updated with a constant value if possible. Returns ------- var : Union[relax.Constant, relax.Var] The input value converted to a constant if possible. If the value isn't found in params, the input variable is returned unmodified. """# Params is actually both the graph nodes and param dictionary, unpack them.graph_nodes,params=params# Convert if possibleifisinstance(var,relax.Var)andvar.name_hintinparams:# When converting a parameter to a constant, update references to it as well._,value=params[var.name_hint]const_value=relax.const(value)graph_nodes[var.name_hint]=const_valuereturnconst_value# Otherwise return variable.else:returnvardefget_value(token,value_dict:Dict[str,tvm.tir.SizeVar])->Union[int,tvm.tir.SizeVar]:"""Converts to token to an integer value if it a constant, otherwise it generates a SizeVar Parameters ---------- token: str current token to decode. value_dict: Dict The Dictionary mapping from the name of ValueInfoProto to SizeVar. Returns ------- Union[int, tvm.tir.SizeVar] The decoded token """try:returnint(token)exceptValueError:iftokennotinvalue_dictortoken=="?":value_dict[token]=tvm.tir.SizeVar(token,"int64")value=value_dict[token]returnvaluedefparse_shape_name(name:str,value_dict:Dict[str,tvm.tir.SizeVar])->Union[tir.PrimExpr,tvm.tir.SizeVar]:"""Converts expressions in the shape dimension name to prim expressions. Parameters ---------- name: str name of shape dimension. value_dict: Dict The Dictionary mapping from the name of ValueInfoProto to SizeVar. Returns ------- Union[tir.PrimExpr, tvm.tir.SizeVar] The expression of the shape dimension. """tokens=re.split(r"(\+|\-|\*|\/\/|\/)",name.replace(" ",""))operators={"+":operator.add,"-":operator.sub,"*":operator.mul,"/":operator.floordiv,# is floordiv since the operands are always int"//":operator.floordiv,}value_stack=[]operator_stack=[]fortokenintokens:iftokeninoperators:operator_stack.append(token)else:value=get_value(token,value_dict)ifvalue_stackandoperator_stack:prev_value=value_stack.pop()op=operator_stack.pop()result=operators[op](prev_value,value)value_stack.append(result)else:value_stack.append(value)ifvalue_stack:returnvalue_stack[0]else:raiseException("Shape dimension could not be inferred")defget_info(info_proto:onnx.onnx_ml_pb2.ValueInfoProto,value_dict:Dict[str,tvm.tir.SizeVar])->Tuple[str,List,str,List,Dict]:"""Extract the shape from a ValueInfoProto. Parameters ---------- info_proto: onnx.onnx_ml_pb2.ValueInfoProto The ValueInfoProto to extract the info from. value_dict: Dict The Dictionary mapping from the name of ValueInfoProto to SizeVar Returns ------- Tuple[str, List, str, List, Dict] The name, shape, type, and shape name of the ValueInfoProto, and the value_dict. """shape=[]shape_name=[]fordimininfo_proto.type.tensor_type.shape.dim:name=dim.dim_paramvalue=dim.dim_valueifvalueisNoneorvalue==0:value=parse_shape_name(name,value_dict)shape_name.append(name)else:shape_name.append(value)shape.append(value)name=info_proto.nameifinfo_proto.type.tensor_type.elem_type:dtype=get_type(info_proto.type.tensor_type.elem_type)else:dtype=Nonereturnname,shape,dtype,shape_name,value_dictdefget_numpy(tensor_proto:onnx.onnx_ml_pb2.TensorProto)->_np.ndarray:"""Grab data in TensorProto and convert to numpy array."""try:fromonnx.numpy_helperimportto_array# pylint: disable=import-outside-toplevelexceptImportErrorasexception:raiseImportError("Unable to import onnx which is required {}".format(exception))returnto_array(tensor_proto)defget_prim_expr_list(inputs:Union[relax.Constant,relax.ShapeExpr],)->List[Union[int,tir.PrimExpr]]:"""Attempt to convert a variable to list of PrimExpr if possible. Parameters ---------- inputs : Union[relax.Constant, relax.ShapeExpr, relax.PrimValue] The input value to try to convert to a list of PrimExpr. Returns ------- ret : List[Union[int, tir.PrimExpr]] The input value converted to a list of PrimExpr if possible. """ifisinstance(inputs,relax.Constant):np_value=inputs.data.numpy()ifnp_value.ndim!=1:raiseValueError("Cannot cast {} to list of PrimExpr".format(type(inputs)))returnnp_value.tolist()elifisinstance(inputs,relax.ShapeExpr):returninputs.valueselifisinstance(inputs,relax.PrimValue):return[inputs.value.value]else:raiseValueError("Cannot cast {} to list of PrimExpr".format(type(inputs)))classonnx_input(list):# pylint: disable=invalid-name"""A list that returns None when out-of-bounds indices are accessed."""def__getitem__(self,item):ifisinstance(item,slice):ifitem.stopisNone:stop=len(self)else:stop=item.stopindices=list(range(stop)[item])return[self[i]foriinindices]ifisinstance(item,int):returnlist(self)[item]ifitem<len(self)elseNoneraiseTypeError("list indices must be integers or slices, not %s"%type(item).__name__)# pylint: disable=invalid-name, len-as-condition, unused-argument, too-many-lines, redefined-builtinclassOnnxOpConverter(object):"""A helper class for holding the common logic for ONNX op converters. Each converter maps to a single ONNX op and defines the equivalent functionality using Relax expressions. The converter can define multiple versions of the op and the version is selected based on the opset version of the model. """@classmethoddefget_converter(cls,opset):"""Get converter matches given opset. Parameters ---------- opset: int opset from model. Returns ------- converter, which should be `_impl_vx`. Number x is the biggest number smaller than or equal to opset belongs to all support versions. """versions=[int(d.replace("_impl_v",""))fordindir(cls)if"_impl_v"ind]versions=sorted(versions+[opset])version=versions[max([ifori,vinenumerate(versions)ifv==opset])-1]ifhasattr(cls,"_impl_v{}".format(version)):returngetattr(cls,"_impl_v{}".format(version))raiseNotImplementedError("opset version {} of {} not implemented".format(version,cls.__name__))classMatMul(OnnxOpConverter):"""Converts an onnx MatMul node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):returnrelax.op.matmul(inputs[0],inputs[1])def_to_numpy(x):ifisinstance(x,relax.PrimValue):x=x.valueifisinstance(x,(tir.IntImm,tir.FloatImm)):x=x.valuereturn_np.array(x)else:returnx.data.numpy()classBinaryBase(OnnxOpConverter):"""Converts an onnx BinaryBase node into an equivalent Relax expression."""numpy_op:Callable=Nonerelax_op:Callable=None@classmethoddefbase_impl(cls,bb,inputs,attr,params):"""Base implementation for binary operations."""ifcls.numpy_opisNoneorcls.relax_opisNone:raiseValueError("Numpy and Relax operators must be defined for BinaryBase.")ifall([isinstance(inp,relax.Constant)forinpininputs]):output=cls.numpy_op(# pylint: disable=not-callableinputs[0].data.numpy(),inputs[1].data.numpy())returnrelax.const(output,inputs[0].struct_info.dtype)ifany([isinstance(inp,relax.PrimValue)forinpininputs]):x=_to_numpy(inputs[0])y=_to_numpy(inputs[1])returnrelax.PrimValue(cls.numpy_op(x,y))# pylint: disable=not-callablereturncls.relax_op(inputs[0],inputs[1])# pylint: disable=not-callableclassAdd(BinaryBase):"""Converts an onnx Add node into an equivalent Relax expression."""numpy_op=_np.addrelax_op=relax.op.add@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classSub(BinaryBase):"""Converts an onnx Sub node into an equivalent Relax expression."""numpy_op=_np.subtractrelax_op=relax.op.subtract@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classMul(BinaryBase):"""Converts an onnx Mul node into an equivalent Relax expression."""numpy_op=_np.multiplyrelax_op=relax.op.multiply@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classDiv(BinaryBase):"""Converts an onnx Div node into an equivalent Relax expression."""numpy_op=_np.dividerelax_op=relax.op.divide@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classPow(BinaryBase):"""Converts an onnx Pow node into an equivalent Relax expression."""numpy_op=_np.powerrelax_op=relax.op.power@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classMod(BinaryBase):"""Converts an onnx Mod node into an equivalent Relax expression."""numpy_op=_np.modrelax_op=relax.op.mod@classmethoddef_impl_v10(cls,bb,inputs,attr,params):ifattr.get("fmod",0)==0:cls.numpy_op=_np.fmodcls.relax_op=relax.op.floor_modelse:cls.numpy_op=_np.modcls.relax_op=relax.op.modreturncls.base_impl(bb,inputs,attr,params)classAnd(BinaryBase):"""Converts an onnx And node into an equivalent Relax expression."""numpy_op=_np.logical_andrelax_op=relax.op.logical_and@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classOr(BinaryBase):"""Converts an onnx Or node into an equivalent Relax expression."""numpy_op=_np.logical_orrelax_op=relax.op.logical_or@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classXor(BinaryBase):"""Converts an onnx Xor node into an equivalent Relax expression."""numpy_op=_np.logical_xorrelax_op=relax.op.logical_xor@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classLess(BinaryBase):"""Converts an onnx Less node into an equivalent Relax expression."""numpy_op=_np.lessrelax_op=relax.op.less@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classLessOrEqual(BinaryBase):"""Converts an onnx LessEqual node into an equivalent Relax expression."""numpy_op=_np.less_equalrelax_op=relax.op.less_equal@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classGreater(BinaryBase):"""Converts an onnx Greater node into an equivalent Relax expression."""numpy_op=_np.greaterrelax_op=relax.op.greater@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classGreaterOrEqual(BinaryBase):"""Converts an onnx GreaterEqual node into an equivalent Relax expression."""numpy_op=_np.greater_equalrelax_op=relax.op.greater_equal@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classEqual(OnnxOpConverter):"""Converts an onnx Equal node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):ifall([isinstance(inp,relax.Constant)forinpininputs]):output=inputs[0].data.numpy()==inputs[1].data.numpy()returnrelax.const(output,output.dtype)elifall([isinstance(inp,(relax.Constant,relax.ShapeExpr))forinpininputs]):lhs=get_prim_expr_list(inputs[0])rhs=get_prim_expr_list(inputs[1])iflen(lhs)!=len(rhs):raiseValueError("Cannot compare two tensors with different shapes")output=[tvm.ir.structural_equal(l,r)forl,rinzip(lhs,rhs)]returnrelax.const(output,"bool")returnrelax.op.equal(inputs[0],inputs[1])classBitwiseBase(BinaryBase):"""Converts an onnx BitwiseBase node into an equivalent Relax expression."""@classmethoddefbase_impl(cls,bb,inputs,attr,params):"""Base implementation for bitwise operations."""valid_types=["int8","int16","int32","int64","uint8","uint16","uint32","uint64"]fornum,inpinenumerate(inputs):ifinp.struct_info.dtypenotinvalid_types:raiseValueError(f"Bitwise operations expect all inputs to have integer types, "f"got {inp.struct_info.dtype} for input {num}")returnsuper().base_impl(bb,inputs,attr,params)classBitwiseAnd(BitwiseBase):"""Converts an onnx BitwiseAnd node into an equivalent Relax expression."""numpy_op=_np.bitwise_andrelax_op=relax.op.bitwise_and@classmethoddef_impl_v18(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classBitwiseOr(BitwiseBase):"""Converts an onnx BitwiseOr node into an equivalent Relax expression."""numpy_op=_np.bitwise_orrelax_op=relax.op.bitwise_or@classmethoddef_impl_v18(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classBitwiseXor(BitwiseBase):"""Converts an onnx BitwiseXor node into an equivalent Relax expression."""numpy_op=_np.bitwise_xorrelax_op=relax.op.bitwise_xor@classmethoddef_impl_v18(cls,bb,inputs,attr,params):returncls.base_impl(bb,inputs,attr,params)classBitwiseNot(OnnxOpConverter):"""Converts an onnx BitwiseNot node into an equivalent Relax expression."""@classmethoddef_impl_v18(cls,bb,inputs,attr,params):ifisinstance(inputs[0],relax.Constant):returnrelax.const(_np.bitwise_not(inputs[0].data.numpy()),inputs[0].struct_info.dtype)returnrelax.op.bitwise_not(inputs[0])classBitShift(BitwiseBase):"""Converts an onnx BitShift node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):direction=attr.get("direction","LEFT").decode("ascii")ifdirection=="LEFT":cls.numpy_op=_np.left_shiftcls.relax_op=relax.op.left_shiftelifdirection=="RIGHT":cls.numpy_op=_np.right_shiftcls.relax_op=relax.op.right_shiftelse:raiseValueError("Unsupported Shift Direction: "+direction)returncls.base_impl(bb,inputs,attr,params)classSigmoid(OnnxOpConverter):"""Converts an onnx Sigmoid node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):returnrelax.op.sigmoid(inputs[0])classSoftmax(OnnxOpConverter):"""Converts an onnx Softmax node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):axis=attr.get("axis",-1)returnrelax.op.nn.softmax(inputs[0],axis=axis)classLogSoftmax(OnnxOpConverter):"""Converts an onnx LogSoftmax node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):axis=attr.get("axis",-1)returnrelax.op.nn.log_softmax(inputs[0],axis=axis)classHardmax(OnnxOpConverter):"""Converts an onnx Hardmax node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):axis=attr.get("axis",-1)indices=inputs[0]dtype=indices.struct_info.dtypeaxis_len=int(inputs[0].struct_info.shape[axis])argmax=relax.op.argmax(indices,axis=axis)on_value=relax.PrimValue(tvm.tir.const(1.0,dtype))off_value=relax.PrimValue(tvm.tir.const(0.0,dtype))one_hot=relax.op.one_hot(argmax,on_value,off_value,axis_len,axis)returnone_hotclassTranspose(OnnxOpConverter):"""Converts an onnx Transpose node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):axes=attr.get("perm",None)ifisinstance(inputs[0],relax.Constant):output=_np.transpose(inputs[0].data.numpy(),axes)returnrelax.const(output,output.dtype)returnrelax.op.permute_dims(inputs[0],axes)classUnsqueeze(OnnxOpConverter):"""Converts an onnx Unsqueeze node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):axes=list(attr.get("axes"))inputs=inputs+[relax.const(axes,"int64")]returncls._impl_v13(bb,inputs,attr,params)@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]axes=get_constant(inputs[1],params)# Handle ONNX shape inferenceifisinstance(data,relax.PrimValue)andisinstance(axes,relax.Constant):axes=axes.data.numpy().tolist()ifaxes==[0]:returnrelax.ShapeExpr([data.value])else:raiseNotImplementedError("Unsqueeze with symbolic axes and non-zero axes is not supported.")# If input is a constant, compute directlyifisinstance(data,relax.Constant)andisinstance(axes,relax.Constant):axes=axes.data.numpy().tolist()expanded=data.data.numpy()iflen(expanded.shape)==0:# Special case implying input is a scalar, wrap it as a list.if0inaxes:axes.remove(0)expanded=[expanded]foraxisinaxes:expanded=_np.expand_dims(expanded,axis=axis)returnrelax.const(expanded,data.struct_info.dtype)ifisinstance(axes,relax.Constant):constant_axes=list(axes.data.numpy())constant_axes=list(map(int,constant_axes))constant_axes=sorted(constant_axes)foraxisinconstant_axes:data=relax.op.expand_dims(data,axis=axis)returndataraiseNotImplementedError("Unsqueeze with dynamic axes is not supported.")classConcat(OnnxOpConverter):"""Convert an onnx Concat node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):axis=attr.get("axis",0)defis_shape_like(x:Any)->bool:ifisinstance(x,relax.ShapeExpr):returnTrueelifisinstance(x,relax.Constant):returnx.struct_info.ndim==1andx.struct_info.dtype=="int64"else:returnFalse# If all inputs are shape expr, perform computation directly.ifall([is_shape_like(inp)forinpininputs]):const_inputs=[]forinpininputs:ifisinstance(inp,relax.ShapeExpr):const_inputs.extend(inp.values)elifisinstance(inp,relax.Constant):const_inputs.extend(inp.data.numpy().tolist())else:raiseNotImplementedError("Unsupported input type: {}".format(type(inp)))returnrelax.ShapeExpr(const_inputs)# If all inputs are constant, perform computation directly.ifall([isinstance(inp,relax.Constant)forinpininputs]):const_inputs=[]forinpininputs:const_inputs.append(inp.data.numpy())out=_np.concatenate(const_inputs,axis=axis)dtype=inputs[0].struct_info.dtypereturnrelax.const(out,dtype)returnrelax.op.concat(inputs,axis=axis)classCast(OnnxOpConverter):"""Convert an onnx Cast node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):to_type=get_type(attr["to"])ifisinstance(inputs[0],relax.ShapeExpr):shape=inputs[0]ifall([isinstance(x,tir.IntImm)forxinshape]):shape=[int(x)forxinshape]returnrelax.const(shape,to_type)ifisinstance(inputs[0],relax.Constant):output=inputs[0].data.numpy().astype(to_type)returnrelax.const(output,to_type)ifisinstance(inputs[0],relax.PrimValue):returnrelax.PrimValue(inputs[0].value.astype(to_type))returnrelax.op.astype(inputs[0],to_type)classGather(OnnxOpConverter):"""Convert an onnx Gather node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):# Unpack inputsdata=inputs[0]indices=inputs[1]axis=attr.get("axis",0)# If all inputs are constant, we can compute directly.ifall([isinstance(inp,relax.Constant)forinpin[data,indices]]):output=_np.take(data.data.numpy(),indices.data.numpy(),axis=axis)returnrelax.const(output,output.dtype)# If input is a shape expression, take a value from that shape and return it as a constant.ifisinstance(data,relax.ShapeExpr):assertisinstance(indices,relax.Constant),"Only constant indices supported for shape gather."np_index=indices.data.numpy()iflen(np_index.shape)==1:np_index=np_index[0]np_index=int(np_index)shape_val=data[np_index]returnrelax.PrimValue(shape_val)returnrelax.op.take(data,indices,axis)classGatherElements(OnnxOpConverter):"""Convert an onnx GatherElements node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):axis=attr.get("axis",0)returnrelax.op.gather_elements(inputs[0],inputs[1],axis)classGatherND(OnnxOpConverter):"""Convert an onnx GatherND node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):batch_dims=attr.get("batch_dims",0)returnrelax.op.gather_nd(inputs[0],inputs[1],batch_dims)classScatter(OnnxOpConverter):"""Convert an onnx Scatter node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):axis=attr.get("axis",0)returnrelax.op.scatter_elements(inputs[0],inputs[1],inputs[2],axis=axis)@classmethoddef_impl_v11(cls,bb,inputs,attr,params):raiseValueError("Scatter is deprecated in ONNX 11")classScatterElements(OnnxOpConverter):"""Convert an onnx ScatterElements node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):axis=attr.get("axis",0)returnrelax.op.scatter_elements(inputs[0],inputs[1],inputs[2],axis=axis)classScatterND(OnnxOpConverter):"""Convert an onnx ScatterND node into an equivalent Relax expression."""@staticmethoddef_reduction_check(attr,valid_reductions:List[str]):reduction=attr.get("reduction",None)reduction=reductionorb"update"reduction=reduction.decode("utf-8")reduction="update"ifreduction=="none"elsereductionassert(reductioninvalid_reductions),f"Only {valid_reductions} reductions are supported, but {reduction} is gotten"returnreduction@classmethoddef_impl_v11(cls,bb,inputs,attr,params):returnrelax.op.scatter_nd(inputs[0],inputs[1],inputs[2])@classmethoddef_impl_v16(cls,bb,inputs,attr,params):reduction=cls._reduction_check(attr,["update","add","mul"])returnrelax.op.scatter_nd(inputs[0],inputs[1],inputs[2],reduction)@classmethoddef_impl_v18(cls,bb,inputs,attr,params):reduction=cls._reduction_check(attr,["update","add","mul","min","max"])returnrelax.op.scatter_nd(inputs[0],inputs[1],inputs[2],reduction)classCompress(OnnxOpConverter):"""Convert an onnx Compress node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):tensor,condition=inputsaxis=attr.get("axis",None)# Change one hot tensor to indices e.g. [0, 1, 1, 0, 1] -> [1, 2, 4]ifcondition.struct_info.dtype!="bool":raiseValueError("Condition tensor is expected to be a boolean tensor")ifcondition.struct_info.ndim!=1:raiseValueError("Condition tensor is expected to be a 1D boolean tensor")indices=relax.op.nonzero(condition)num_nonzero=tir.Var("num_nonzero","int64")indices=bb.match_cast(indices,relax.TensorStructInfo([1,num_nonzero],"int64"))indices=relax.op.reshape(indices,[-1])ifaxisisnotNone:returnrelax.op.take(tensor,indices,axis=axis)# if axis is None, flatten input tensor before selectiontensor=relax.op.reshape(tensor,(-1,))returnrelax.op.take(tensor,indices,axis=0)classSize(OnnxOpConverter):"""Convert an onnx Size node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):# TODO(tvm-team): add native support for size opreturnrelax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))classEyeLike(OnnxOpConverter):"""Convert an onnx EyeLike node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):k=attr.get("k",0)input_dtype=inputs[0].struct_info.dtypeif"dtype"inattrandget_type(attr["dtype"])!=input_dtype:raiseValueError(f"dtype mismatch between input ({input_dtype}) and attribute ({attr['dtype']})")returnrelax.op.eye_like(inputs[0],k,input_dtype)classGemm(OnnxOpConverter):"""Convert an onnx Gemm node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):alpha=attr.get("alpha",None)beta=attr.get("beta",None)transA=attr.get("transA",False)transB=attr.get("transB",False)A=inputs[0]B=inputs[1]C=inputs[2]dtype=A.struct_info.dtype# Compute Y = alpha * A X B + beta * CifalphaisnotNoneandalpha!=1.0:A=relax.op.multiply(A,relax.const(alpha,dtype=dtype))iftransA:A=relax.op.permute_dims(A,[1,0])iftransB:B=relax.op.permute_dims(B,[1,0])Y=relax.op.matmul(A,B)ifCisnotNone:ifbetaisnotNoneandbeta!=1.0:C=relax.op.multiply(C,relax.const(beta,dtype=dtype))Y=relax.op.add(Y,C)returnYclassReshape(OnnxOpConverter):"""Convert an onnx Reshape node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]new_shape=get_constant(inputs[1],params)ifisinstance(data,relax.ShapeExpr)andisinstance(new_shape,relax.Constant):new_shape=new_shape.data.numpy().tolist()ifnew_shape!=[-1]:raiseNotImplementedError("Need to fix this case")returndataifisinstance(data,relax.Constant)andisinstance(new_shape,relax.Constant):out=_np.reshape(data.data.numpy(),new_shape.data.numpy().tolist())returnrelax.const(out,out.dtype)ifisinstance(new_shape,relax.Constant):new_shape=new_shape.data.numpy().tolist()out=relax.op.reshape(data,new_shape)returnoutclassWhere(OnnxOpConverter):"""Convert an onnx Where node into an equivalent Relax expression."""@classmethoddef_impl_v16(cls,bb,inputs,attr,params):ifall([isinstance(inp,relax.Constant)forinpininputs]):np_inputs=[inp.data.numpy()forinpininputs]output=_np.where(*np_inputs)returnrelax.const(output,output.dtype)ifall([isinstance(inp,(relax.Constant,relax.ShapeExpr))forinpininputs]):condition,x,y=[get_prim_expr_list(inp)forinpininputs]iflen(condition)!=len(x)orlen(condition)!=len(y):raiseValueError("Cannot broadcast condition to x and y")output=[xifcelseyforc,x,yinzip(condition,x,y)]returnrelax.ShapeExpr(output)returnrelax.op.where(inputs[0],inputs[1],inputs[2])classClip(OnnxOpConverter):"""Converts an onnx Clip node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):min=float(attr.get("min",-_np.inf))max=float(attr.get("max",_np.inf))results=inputs[0]results=bb.emit_te(topi.maximum,results,min)results=bb.emit_te(topi.minimum,results,max)returnresults@classmethoddef_impl_v13(cls,bb,inputs,attr,params):results=inputs[0]ifinputs[1]isnotNone:results=bb.emit_te(topi.maximum,results,inputs[1])ifinputs[2]isnotNone:results=bb.emit_te(topi.minimum,results,inputs[2])returnresultsclassShape(OnnxOpConverter):"""Converts an onnx Equal node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data_info=inputs[0].struct_infoifisinstance(data_info,relax.ShapeStructInfo):ifdata_info.ndim==-1:raiseValueError("The ndim of ShapeExpr is expected to a real number, but got -1.")returnrelax.ShapeExpr([data_info.ndim])# If no shape is defined in the struct info, it must be computed at runtime.ifnotdata_info.shape:data_shape=bb.normalize(relax.op.shape_of(inputs[0]))returndata_shapereturndata_info.shapeclassTrilu(OnnxOpConverter):"""Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor(s) """@classmethoddef_impl_v14(cls,bb,inputs,attr,params):upper=attr.get("upper",True)x=inputs[0]k=inputs[1]iflen(inputs)>1else0iflen(inputs)>1:k=get_constant(inputs[1],params)ifisinstance(k,relax.Constant):k=int(k.data.numpy().item())else:raiseValueError("Currently only support constant k for Trilu op.")else:k=0ifupper:returnrelax.op.triu(x,k)else:returnrelax.op.tril(x,k)classRelu(OnnxOpConverter):"""Converts an onnx Relu node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):returnrelax.op.nn.relu(inputs[0])classElu(OnnxOpConverter):"""Converts an onnx Elu node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):alpha=float(attr.get("alpha",1.0))returnrelax.expr.const(-alpha)*relax.op.nn.relu(relax.expr.const(1.0)-relax.op.exp(inputs[0]))+relax.op.nn.relu(inputs[0])classSelu(OnnxOpConverter):"""Converts an onnx Selu node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):alpha=attr.get("alpha",1.67326319217681884765625)gamma=attr.get("gamma",1.05070102214813232421875)returnrelax.const(gamma)*(relax.const(-alpha)*relax.op.nn.relu(relax.const(1.0)-relax.op.exp(inputs[0]))+relax.op.nn.relu(inputs[0]))classMish(OnnxOpConverter):"""Converts an onnx Mish node into an equivalent Relax expression. mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) """@classmethoddef_impl_v18(cls,bb,inputs,attr,params):dtype=inputs[0].struct_info.dtypereturninputs[0]*relax.op.tanh(relax.op.log(relax.const(1.0,dtype)+relax.op.exp(inputs[0])))classPRelu(OnnxOpConverter):"""Converts an onnx PRelu node into an equivalent Relax expression. f(x) = slope * x for x < 0, x for x >= 0 """@classmethoddef_impl_v1(cls,bb,inputs,attr,params):x=inputs[0]slope=inputs[1]# TODO(tvm-team): Should add a new op for this.returnx*slope+relax.op.nn.relu(x)*(relax.const(1.0)-slope)classThresholdedRelu(OnnxOpConverter):"""Converts an onnx ThresholdedRelu node into an equivalent Relax expression. f(x) = x for x > alpha, 0 otherwise """@classmethoddef_impl_v1(cls,bb,inputs,attr,params):x=inputs[0]alpha=attr.get("alpha",1.0)returnrelax.op.greater(x,relax.const(alpha)).astype("float32")*xclassLeakyRelu(OnnxOpConverter):"""Converts an onnx LeakyRelu node into an equivalent Relax expression. f(x) = x for x > 0, alpha * x otherwise """@classmethoddef_impl_v1(cls,bb,inputs,attr,params):x=inputs[0]alpha=attr.get("alpha",0.01)returnrelax.op.nn.leakyrelu(x,alpha)classGelu(OnnxOpConverter):"""Operator converter for Gelu from Microsoft onnxruntime contrib opset. gelu(x) = 0.5x(1 + erf(x/sqrt(2))) """@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returnrelax.op.nn.gelu(inputs[0])classFastGelu(OnnxOpConverter):"""Operator converter for FastGelu from Microsoft onnxruntime contrib opset. fast_gelu(x) = 0.5x(1 + tanh(sqrt(2/pi)(x + 0.044715x^3))) = 0.5x(1 + tanh((sqrt(2/pi)x + 0.044715(sqrt(2/pi)x^3))) = 0.5x(1 + tanh(c1 * x + c2 * x^3))) , where c1 = sqrt(2/pi) c2 = 0.044715 * sqrt(2/pi) """@classmethoddef_impl_v1(cls,bb,inputs,attr,params):ifinputs[1]:bias=inputs[1]bias_shape=bias.struct_info.shapeassertlen(bias_shape)==1,"bias term must be a 1D tensor"x+=bias# Declare constsconst_dtype=x.struct_info.dtypehalf=relax.const(0.5,dtype=const_dtype)one=relax.const(1.0,dtype=const_dtype)const1=relax.const(math.sqrt(2/math.pi),dtype=const_dtype)const2=relax.const(0.044715*math.sqrt(2/math.pi),dtype=const_dtype)# Compute FastGeluterm1=relax.op.multiply(half,x)term2=relax.op.multiply(const1,x)term3=relax.op.multiply(const2,relax.op.power(x,relax.const(3,const_dtype)))tanh=relax.op.tanh(relax.op.add(term2,term3))returnrelax.op.multiply(term1,relax.op.add(one,tanh))classBiasGelu(OnnxOpConverter):"""Operator converter for BiasGelu from Microsoft onnxruntime contrib opset. bias_gelu(x, b) = 0.5(x + b)(1 + erf((x + b)/sqrt(2))) """@classmethoddef_impl_v1(cls,bb,inputs,attr,params):inp=relax.op.add(inputs[0],inputs[1])returnrelax.op.nn.gelu(inp)classShrink(OnnxOpConverter):"""Converts an onnx Shrink node into an equivalent Relax expression. f(x) = x + bias if x > lambd, x - bias if x < -lambd, 0 otherwise """@classmethoddef_impl_v9(cls,bb,inputs,attr,params):x=inputs[0]dtype=x.struct_info.dtypelambd=relax.const(attr.get("lambd",0.5),dtype)bias=relax.const(attr.get("bias",0.0),dtype)zeros=relax.op.zeros_like(x)returnrelax.op.where(x>lambd,x-bias,zeros)+relax.op.where(x<-lambd,x+bias,zeros)classConv(OnnxOpConverter):"""Convert an onnx Conv node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):data=inputs[0]ifhasattr(inputs[0].struct_info,"ndim"):ndim=inputs[0].struct_info.ndimelse:ndim=len(inputs[0].struct_info.shape)if"kernel_shape"notinattr:attr["kernel_shape"]=inputs[1].struct_info.shape.values[2:]ifndim==3:op=relax.op.nn.conv1ddata_layout="NCW"kernel_layout="OIW"elifndim==4:op=relax.op.nn.conv2ddata_layout="NCHW"kernel_layout="OIHW"elifndim==5:op=relax.op.nn.conv3ddata_layout="NCDHW"kernel_layout="OIDHW"else:raiseNotImplementedError("Ndim > 5 not supported for convolution.")if"auto_pad"inattr:attr["auto_pad"]=attr["auto_pad"].decode("utf-8")ifattr["auto_pad"]in("SAME_UPPER","SAME_LOWER"):data=autopad(bb,inputs[0],attr.get("strides",[1]*(ndim-2)),attr["kernel_shape"],attr.get("dilations",[1]*(ndim-2)),mode=attr["auto_pad"],deconv=False,)elifattr["auto_pad"]=="VALID":attr["pads"]=[0for_inrange(ndim-2)]elifattr["auto_pad"]=="NOTSET":passelse:msg=(f'Value {attr["auto_pad"]} in attribute "auto_pad" of operator Conv 'f"is invalid.")raisetvm.error.OpAttributeInvalid(msg)attr.pop("auto_pad")conv_out=bb.normalize(op(data=data,weight=inputs[1],strides=attr.get("strides",1),padding=attr.get("pads",0),dilation=attr.get("dilations",1),groups=attr.get("group",1),data_layout=data_layout,kernel_layout=kernel_layout,))ifinputs[2]isnotNone:bias=relax.op.reshape(inputs[2],[1,-1]+[1]*(ndim-2))conv_out=relax.op.add(conv_out,bias)returnconv_outclassConvTranspose(OnnxOpConverter):"""Converts an onnx ConvTranspose node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):ifhasattr(inputs[0].struct_info,"ndim"):ndim=inputs[0].struct_info.ndimelse:ndim=len(inputs[0].struct_info.shape)ifndim==3:op=relax.op.nn.conv1d_transposedata_layout="NCW"kernel_layout="IOW"elifndim==4:op=relax.op.nn.conv2d_transposedata_layout="NCHW"kernel_layout="IOHW"elifndim==5:raiseNotImplementedError("Relax ConvTranspose3d not supported yet")else:raiseNotImplementedError("Ndim > 5 not supported for convolution.")conv_out=op(data=inputs[0],weight=inputs[1],strides=attr.get("strides",1),padding=attr.get("pads",0),dilation=attr.get("dilations",1),groups=attr.get("group",1),data_layout=data_layout,kernel_layout=kernel_layout,)ifinputs[2]isnotNone:bias=relax.op.reshape(inputs[2],[1,-1]+[1]*(ndim-2))conv_out=relax.op.add(conv_out,bias)returnconv_outclassErf(OnnxOpConverter):"""Converts an onnx Erf node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):returnrelax.op.erf(inputs[0])classCumSum(OnnxOpConverter):"""Converts an onnx CumSum node into an equivalent Relax expression."""@classmethoddef_impl_v14(cls,bb,inputs,attr,params):data=inputs[0]axis=get_constant(inputs[1],params)assertnotattr.get("exclusive",False),"Exclusive option not yet supported."ifisinstance(axis,relax.Constant):axis=int(axis.data.numpy())elifisinstance(axis,relax.Var):axis=0ifattr.get("reverse",0)!=0:data=bb.emit_te(topi.flip,data,axis=axisifaxiselse0)data=relax.op.cumsum(data,axis)data=bb.normalize(data)ifattr.get("reverse",0)!=0:data=bb.emit_te(topi.flip,data,axis=axisifaxiselse0)returndataclassSqueeze(OnnxOpConverter):"""Converts an onnx Squeeze node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]axis=get_constant(inputs[1],params)ifisinstance(axis,relax.Constant):axis=tuple([int(x)forxinaxis.data.numpy()])# If data is constant, perform computation directly.ifisinstance(data,relax.Constant):ifisinstance(axis,(tuple,type(None))):out_data=_np.squeeze(data.data.numpy(),axis)else:raiseNotImplementedError("Squeeze with symbolic axes not supported")returnrelax.const(out_data,data.struct_info.dtype)ifisinstance(data,relax.ShapeExpr):ifaxis==(0,):returnrelax.PrimValue(data[0])else:raiseNotImplementedError("Squeeze with symbolic axes and non-zero axes is not supported.")returnrelax.op.squeeze(data,axis)classConstant(OnnxOpConverter):"""Converts an onnx Constant node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):if"value"notinattr:raiseValueError("no value in Constant")value=attr.pop("value")# Constants may rarely have string types. These are likely exported# from other frameworks and not actually used in TVM. We'll just use# a zero valued constant for compatibility.ifisinstance(value,bytes):np_value=_np.asarray([0]).astype("int64")else:np_value=get_numpy(value)dtype=np_value.dtype.namevalue=relax.const(np_value,dtype)returnvalueclassConstantOfShape(OnnxOpConverter):"""Converts an onnx ConstantOfShape node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):shape=inputs[0]value=get_numpy(attr.get("value",0))ifisinstance(value,_np.ndarray):dtype=str(value.dtype)else:dtype="float32"# If shape is a constant, treat it as a ShapeExpr.ifisinstance(shape,relax.Constant):shape=relax.ShapeExpr(list(shape.data.numpy()))# Special case where requested shape are constantiflen(shape)==1andall([isinstance(x,tir.IntImm)forxinshape]):shape=[int(x)forxinshape]returnrelax.const(_np.full(shape,value,dtype),dtype)# Convert to shape expression from tensor if needed.ifnotisinstance(shape,relax.ShapeExpr):shape=relax.op.tensor_to_shape(shape)returnrelax.op.broadcast_to(relax.const(value,dtype),shape)classSin(OnnxOpConverter):"""Converts an onnx Sin node into an equivalent Relax expression."""@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returnrelax.op.sin(inputs[0])classSinh(OnnxOpConverter):"""Converts an onnx Sinh node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):returnrelax.op.sinh(inputs[0])classCos(OnnxOpConverter):"""Converts an onnx Cos node into an equivalent Relax expression."""@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returnrelax.op.cos(inputs[0])classCosh(OnnxOpConverter):"""Converts an onnx Cosh node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):returnrelax.op.cosh(inputs[0])classTan(OnnxOpConverter):"""Converts an onnx Tan node into an equivalent Relax expression."""@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returnrelax.op.tan(inputs[0])classTanh(OnnxOpConverter):"""Converts an onnx Tanh node into an equivalent Relax expression."""@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returnrelax.op.tanh(inputs[0])classAcos(OnnxOpConverter):"""Converts an onnx Acos node into an equivalent Relax expression."""@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returnrelax.op.acos(inputs[0])classAcosh(OnnxOpConverter):"""Converts an onnx Acosh node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):returnrelax.op.acosh(inputs[0])classAsin(OnnxOpConverter):"""Converts an onnx Asin node into an equivalent Relax expression."""@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returnrelax.op.asin(inputs[0])classAsinh(OnnxOpConverter):"""Converts an onnx Asinh node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):returnrelax.op.asinh(inputs[0])classAtan(OnnxOpConverter):"""Converts an onnx Atan node into an equivalent Relax expression."""@classmethoddef_impl_v7(cls,bb,inputs,attr,params):returnrelax.op.atan(inputs[0])classAtanh(OnnxOpConverter):"""Converts an onnx Atanh node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):returnrelax.op.atanh(inputs[0])classNeg(OnnxOpConverter):"""Converts an onnx Neg node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):ifisinstance(inputs[0],relax.Constant):data_np=inputs[0].data.numpy()returnrelax.const(_np.negative(data_np),inputs[0].struct_info.dtype)ifisinstance(inputs[0],relax.PrimValue):returnrelax.PrimValue(-inputs[0].value)returnrelax.op.negative(inputs[0])classAbs(OnnxOpConverter):"""Converts an onnx Abs node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):ifisinstance(inputs[0],relax.Constant):output=_np.abs(inputs[0].data.numpy())returnrelax.const(output,output.dtype)returnrelax.op.abs(inputs[0])classReciprocal(OnnxOpConverter):"""Converts an onnx Reciprocal node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):input_dtype=inputs[0].struct_info.dtypereturnrelax.op.divide(relax.const(1,dtype=input_dtype),inputs[0])classFloor(OnnxOpConverter):"""Converts an onnx Floor node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returnrelax.op.floor(inputs[0])classCeil(OnnxOpConverter):"""Converts an onnx Ceil node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returnrelax.op.ceil(inputs[0])classRound(OnnxOpConverter):"""Converts an onnx Round node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returnrelax.op.round(inputs[0])classIsInf(OnnxOpConverter):"""Converts an onnx IsInf node into an equivalent Relax expression."""@classmethoddef_impl_v10(cls,bb,inputs,attr,params):returnrelax.op.isinf(inputs[0])classIsNaN(OnnxOpConverter):"""Converts an onnx IsNaN node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):returnrelax.op.isnan(inputs[0])classSqrt(OnnxOpConverter):"""Converts an onnx Sqrt node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returnrelax.op.sqrt(inputs[0])classMultiInputBase(OnnxOpConverter):"""Converts an onnx MultiInputBase node into an equivalent Relax expression."""numpy_op:Callable=Nonerelax_op:Callable=None@classmethoddef_impl_v1(cls,bb,inputs,attr,params):ifcls.numpy_opisNoneorcls.relax_opisNone:raiseNotImplementedError("numpy_op and relax_op must be defined for MultiInputBase")ifall([isinstance(inp,relax.Constant)forinpininputs]):np_inputs=[inp.data.numpy()forinpininputs]output=cls.numpy_op(*np_inputs)# pylint: disable=not-callablereturnrelax.const(output,output.dtype)# Expand inputs, stack them, then perform minimum over the new axis.inputs=[bb.normalize(relax.op.expand_dims(i,axis=0))foriininputs]stacked_tensor=relax.op.concat(inputs,axis=0)returncls.relax_op(stacked_tensor,axis=0)# pylint: disable=not-callableclassMin(MultiInputBase):"""Converts an onnx Min node into an equivalent Relax expression."""numpy_op=_np.minrelax_op=relax.op.minclassMax(MultiInputBase):"""Converts an onnx Max node into an equivalent Relax expression."""numpy_op=_np.maxrelax_op=relax.op.maxclassMean(MultiInputBase):"""Converts an onnx Mean node into an equivalent Relax expression."""numpy_op=_np.meanrelax_op=relax.op.meanclassSum(MultiInputBase):"""Converts an onnx Sum node into an equivalent Relax expression."""numpy_op=_np.sumrelax_op=relax.op.sumclassLog(OnnxOpConverter):"""Converts an onnx Log node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):ifisinstance(inputs[0],relax.Constant):returnrelax.const(_np.log(inputs[0].data.numpy()),inputs[0].struct_info.dtype)returnrelax.op.log(inputs[0])classExp(OnnxOpConverter):"""Converts an onnx Exp node into an equivalent Relax expression."""@classmethoddef_check_type(cls,dtype,valid_types):assertdtypeinvalid_types,"Types {} are supported only, but {} is given".format(valid_types,dtype)@classmethoddef_impl_v1(cls,bb,inputs,attr,params):data=inputs[0]valid_types=["float","float32","double","float64","float16"]cls._check_type(data.struct_info.dtype,valid_types)returnrelax.op.exp(data)@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]valid_types=["float","float32","double","float64","float16","bfloat16"]cls._check_type(data.struct_info.dtype,valid_types)returnrelax.op.exp(data)classSoftplus(OnnxOpConverter):"""Converts an onnx Softplus node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):dtype=inputs[0].struct_info.dtypereturnrelax.op.log(relax.op.exp(inputs[0])+relax.const(1,dtype=dtype))classSoftsign(OnnxOpConverter):"""Converts an onnx Softsign node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):dtype=inputs[0].struct_info.dtypereturninputs[0]/(relax.op.abs(inputs[0])+relax.const(1,dtype=dtype))classSplit(OnnxOpConverter):"""Converts an onnx Split node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):splits=attr.get("split",None)ifsplitsisnotNoneandlen(splits)>1:indices=[]index=0foriinsplits[:-1]:index+=iindices.append(index)# When splits isnt specified divide evenly over axis.else:indices=attr["tvm_custom"]["num_outputs"]returnrelax.op.split(inputs[0],indices,attr.get("axis",0))@classmethoddef_impl_v13(cls,bb,inputs,attr,params):splits=inputs[1]splits_rank=NoneifsplitsisnotNone:splits_rank=splits.struct_info.ndimifsplitsisnotNoneandsplits_rank>0:ifisinstance(splits,relax.Constant):splits=splits.data.numpy()indices=[]index=0foriinsplits[:-1]:index+=iindices.append(index.item())else:raiseValueError("Dynamic Split not yet supported")# When splits isnt specified divide evenly over axis.else:indices=attr["tvm_custom"]["num_outputs"]returnrelax.op.split(inputs[0],indices,attr.get("axis",0))defget_prim_value_list(values):new_values=[]forvinlist(values):ifisinstance(v,relax.expr.PrimExpr):new_values.append(relax.PrimValue(v))else:new_values.append(v)returnnew_valuesclassSlice(OnnxOpConverter):"""Converts an onnx Splice node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):# TODO (jwfromm) currently only supports constant parameters.data=inputs[0]starts=get_constant(inputs[1],params)ends=get_constant(inputs[2],params)axes=get_constant(inputs[3],params)steps=get_constant(inputs[4],params)ifnotall([(isinstance(param,(relax.Constant,relax.ShapeExpr,relax.PrimValue))orparamisNone)forparamin[starts,ends,axes,steps]]):raiseValueError("Only constant Slice parameters are currently supported.")# Convert parameters to constant lists.starts=get_prim_expr_list(starts)ends=get_prim_expr_list(ends)ifaxesisnotNone:axes=get_prim_expr_list(axes)else:axes=list(range(len(starts)))# Convert negative axis to positive if needed.fori,axisinenumerate(axes):ifaxis<0:axes[i]=axis+len(data.struct_info.shape)ifstepsisnotNone:steps=get_prim_expr_list(steps)else:steps=[1]*len(axes)# If input is a shape tensor, we can directly extract it.ifisinstance(data,relax.ShapeExpr):shape_data=list(data)# Starts, ends, and steps must be 1-d for shape operation.assertall(len(i)==1foriin[starts,ends,steps])sliced_values=shape_data[starts[0]:ends[0]:steps[0]]ifall([isinstance(val,(tir.IntImm,int))forvalinsliced_values]):returnrelax.const([x.valueforxinsliced_values],"int64")else:returnrelax.ShapeExpr(sliced_values)# If all `starts`, `ends`, and `steps` are constant, use strict mode# Otherwise, we assume the slice is inbound.assume_inbound=notall([isinstance(param,(tir.IntImm,int))forparamin[*starts,*ends,*steps]])# Converting PrimExpr to PrimValue since relax.op.strided_slice does not accept PrimExprstarts=get_prim_value_list(starts)ends=get_prim_value_list(ends)steps=get_prim_value_list(steps)returnrelax.op.strided_slice(data,axes,starts,ends,steps,assume_inbound=assume_inbound)classPad(OnnxOpConverter):"""Converts an onnx Pad node into an equivalent Relax expression."""@classmethoddef_impl_v2(cls,bb,inputs,attr,params):pads=attr.get("pads")pads=relax.const(_np.array(pads),inputs[0].struct_info.shape[0].dtype)constant_value=attr.get("value")ifconstant_valueisNone:constant_value=0.0ifisinstance(pads,relax.Constant):pad_before,pad_after=_np.split(pads.data.numpy(),2)pad_before=_np.ndarray.tolist(pad_before)pad_after=_np.ndarray.tolist(pad_after)else:raiseValueError("Dynamic pads are not supported yet.")pad_mode=attr.get("mode",b"constant").decode("utf-8")ifnotpad_modein["constant","edge","reflect"]:raisetvm.error.OpAttributeInvalid("Value "+pad_mode+' in attribute "mode" is invalid for operator Pad.')ifpad_mode=="constant":returnbb.emit_te(topi.nn.pad,inputs[0],pad_before,pad_after,constant_value)elifpad_mode=="reflect":returnbb.emit_te(topi.nn.mirror_pad,inputs[0],pad_before,pad_after,"REFLECT")else:# TODO(gigiblender) Support edge mode.raiseNotImplementedError("Pad mode {} not implemented".format(pad_mode))@classmethoddef_impl_v11(cls,bb,inputs,attr,params):pads=get_constant(inputs[1],params)constant_value=get_constant(inputs[2],params)ifconstant_valueisnotNone:constant_value=constant_value.data.numpy().item()else:constant_value=0.0ifisinstance(pads,relax.Constant):pad_before,pad_after=_np.split(pads.data.numpy(),2)pad_before=_np.ndarray.tolist(pad_before)pad_after=_np.ndarray.tolist(pad_after)else:raiseValueError("Dynamic pads are not supported yet.")pad_mode=attr.get("mode",b"constant").decode("utf-8")ifnotpad_modein["constant","edge","reflect"]:raisetvm.error.OpAttributeInvalid("Value "+pad_mode+' in attribute "mode" is invalid for operator Pad.')ifpad_mode=="constant":returnbb.emit_te(topi.nn.pad,inputs[0],pad_before,pad_after,constant_value)elifpad_mode=="reflect":returnbb.emit_te(topi.nn.mirror_pad,inputs[0],pad_before,pad_after,"REFLECT")else:# TODO(gigiblender) Support edge mode.raiseNotImplementedError("Pad mode {} not implemented".format(pad_mode))classTile(OnnxOpConverter):"""Converts an onnx Tile node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):reps=get_constant(inputs[1],params)ifisinstance(reps,relax.Constant):reps=reps.data.numpy().tolist()else:raiseValueError("Dynamic reps for Tile are supported yet.")returnbb.emit_te(topi.tile,inputs[0],reps)classExpand(OnnxOpConverter):"""Converts an onnx Expand node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]shape=inputs[1]ifisinstance(shape,relax.ShapeExpr):data_shape=list(data.struct_info.shape)target_shape=list(shape.values)data_shape=[1]*(len(target_shape)-len(data_shape))+data_shapeassertlen(data_shape)==len(target_shape)# Fix small target shapes or target shapes assigned to -1fori,sinenumerate(target_shape):ifisinstance(s,tvm.tir.IntImm)and((isinstance(data_shape[i],tvm.tir.IntImm)ands<data_shape[i])ors.value==-1):target_shape[i]=data_shape[i]iftarget_shape==data_shape:returndatareturnrelax.op.broadcast_to(data,relax.ShapeExpr(target_shape))# If possible, directly expand to constant shape.ifisinstance(shape,relax.Constant):new_shape=shape.data.numpy().tolist()# ONNX Expand operator requires preserving target rank and broadcasting# according to standard rules. Dimensions are right-aligned.data_shape=[dim.valuefordimindata.struct_info.shape]# Right-align the shapesiflen(new_shape)>len(data_shape):data_shape=[1]*(len(new_shape)-len(data_shape))+data_shapeelse:new_shape=[1]*(len(data_shape)-len(new_shape))+new_shape# Fix small target shapes - if target dim is smaller than input dim# use the input dim (ONNX-specific behavior).foriinrange(len(new_shape)):ifnew_shape[i]<data_shape[i]:new_shape[i]=data_shape[i]returnrelax.op.broadcast_to(data,relax.ShapeExpr(new_shape))# Otherwise handle dynamic shapes.shape_ndim=[dim.valuefordiminshape.struct_info.shape.values][0]shape_dataflow_var=bb.emit(relax.Call(relax.ExternFunc("vm.builtin.tensor_to_shape"),[shape],sinfo_args=[relax.ShapeStructInfo(ndim=shape_ndim)],))shape_vars=[]foriinrange(shape_ndim):shape_vars.append(tvm.tir.Var("x_%d"%i,"int64"))bb.match_cast(shape_dataflow_var,relax.ShapeStructInfo(shape_vars))returnbb.normalize(relax.op.broadcast_to(data,relax.ShapeExpr(shape_vars)))classAttention(OnnxOpConverter):"""Converts an onnx.microsoft Attention node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):num_heads=attr["num_heads"]assert"do_rotary"notinattr,"rotary position embedding is not currently supported"assert("past_present_share_buffer"notinattr),"past state for key and value is not currently supported"assert"scale"notinattr,"custom scale is not currently supported"assert"unidirectional"notinattr,"unidirectional attention is not currently supported"if"mask_filter_value"inattr:mask_filter_value=attr["mask_filter_value"]else:mask_filter_value=-10000.0# (batch_size, sequence_length, input_hidden_size)input_emb=bb.normalize(inputs[0])# (input_hidden_size, hidden_size + hidden_size + v_hidden_size)weight=bb.normalize(inputs[1])defoptional_input(k:int):ifinputs[k]isnotNone:returnbb.normalize(inputs[k])else:returnNone# (hidden_size + hidden_size + v_hidden_size)bias=optional_input(2)# 1. ( batch_size, 1, max_seq_len, max_seq_len,)# 2. ( batch_size, total_seq_len,)# 3. ( batch_size, seq_len, total_seq_len,)# 4. ( batch_size,)# 5. (2 * batch_size,)# For now, we only support case 2 & 3.mask_index=optional_input(3)# (2, batch_size, num_heads, past_sequence_length, head_size)assertinputs[4]isNone,"past state for key and value is not currently supported"# (batch_size, num_heads, sequence_length, total_sequence_length)qk_bias=optional_input(5)assertinputs[6]isNone,"past_sequence_length is not currently supported"(batch_size,seq_len,input_hidden_size)=[val.valueforvalininput_emb.struct_info.shape.values]weight_shape=[val.valueforvalinweight.struct_info.shape.values]assert(weight_shape[0]==input_hidden_size),"input and weight should share the same input hiden size"if"qkv_hidden_sizes"inattr:assert(attr["qkv_hidden_sizes"][0]==attr["qkv_hidden_sizes"][1]),"Q and K should share the same hidden sizes"hidden_size,_,hidden_size_v=attr["qkv_hidden_sizes"]else:hidden_size=hidden_size_v=weight_shape[1]//3assert(hidden_size%num_heads==0),"hidden size should be divisible by number of attention heads"head_size=hidden_size//num_headshead_size_v=hidden_size_v//num_headsifmask_indexisnotNone:mask_index_shape=[val.valueforvalinmask_index.struct_info.shape.values]assertmask_index_shapein([batch_size,seq_len],[batch_size,seq_len,seq_len,],),"""mask index should be in shape of (batch_size, seq_len), or (batch_size, seq_len, seq_len)"""mask_bias=relax.op.subtract(relax.const(1,dtype=mask_index.struct_info.dtype),mask_index)mask_bias=relax.op.astype(mask_bias,dtype=input_emb.struct_info.dtype)mask_bias=bb.normalize(relax.op.multiply(mask_bias,relax.const(mask_filter_value,dtype=input_emb.struct_info.dtype),))ifqk_biasisNone:qk_bias=mask_biaselse:iflen(mask_index_shape)==2:mask_bias=bb.normalize(relax.op.reshape(mask_bias,[batch_size,1,1,seq_len]))eliflen(mask_index_shape)==3:mask_bias=bb.normalize(relax.op.reshape(mask_bias,[batch_size,1,seq_len,seq_len]))qk_bias=bb.normalize(relax.op.add(qk_bias,mask_bias))QKV=relax.op.matmul(input_emb,weight)ifbias:bias_shape=[val.valueforvalinbias.struct_info.shape.values]assert(bias_shape[0]==weight_shape[1]),"bias and weight should share the same hidden size sum"QKV=relax.op.add(QKV,bias)QKV=relax.op.split(QKV,[hidden_size,hidden_size*2],2)Q,K,V=QKV[0],QKV[1],QKV[2]Q=bb.normalize(relax.op.reshape(Q,(batch_size,seq_len,num_heads,head_size)))K=bb.normalize(relax.op.reshape(K,(batch_size,seq_len,num_heads,head_size)))V=bb.normalize(relax.op.reshape(V,(batch_size,seq_len,num_heads,head_size_v)))output=relax.op.nn.attention(Q,K,V,qk_bias)output=bb.normalize(relax.op.reshape(output,(batch_size,seq_len,num_heads*head_size_v)))# add placeholder for optional present state supported in the futureplaceholder=relax.const(0,dtype="float32")returnrelax.Tuple([output,placeholder])classIdentity(OnnxOpConverter):"""Converts an onnx Identity node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returninputs[0]classResize(OnnxOpConverter):"""Converts an onnx Resize node into an equivalent Relax expression."""@classmethoddef_impl_v18(cls,bb,inputs,attr,params):# Extract the many attributes of resize.coord_mode=attr.get("coordinate_transformation_mode",b"half_pixel").decode("ascii")cubic_coeff_a=attr.get("cubic_coeff_a",-0.75)exclude_outside=attr.get("exclude_outside",0)extrapolation_value=attr.get("extrapolation_value",0.0)mode=attr.get("mode",b"nearest").decode("ascii")rounding_method=attr.get("nearest_mode",b"round_prefer_floor").decode("ascii")# Adapt attributes to fit TVM definition.ifmode=="nearest":mode="nearest_neighbor"# Unpack inputs.x=inputs[0]roi=get_constant(inputs[1],params)scales=get_constant(inputs[2],params)sizes=get_constant(inputs[3],params)ndims=len(x.struct_info.shape)assertndims==4,"Only resize2d is currently supported."assert(scalesisNoneorsizesisNone),"Only one of scales and sizes can be provided in Resize."# Define relax implementation.ifroiisnotNone:ifisinstance(roi,relax.Constant):roi=roi.data.numpy().tolist()else:roi=relax.op.concat([relax.op.strided_slice(roi,axes=[0],begin=[2],end=[ndims]),relax.op.strided_slice(roi,axes=[0],begin=[ndims+2],end=[2*ndims]),],axis=0,)# TODO The backend C++ func resize2d does not support dynamic ROI for now.raiseNotImplementedError("Dynamic ROI is not supported in resize2d for now.")else:roi=[0.0]*4# Convert scales to sizes if needed.ifscalesisnotNone:assertisinstance(scales,relax.Constant),"Only constant scales currently supported."scales=scales.data.numpy()sizes=[]fori,diminenumerate(x.struct_info.shape):sizes.append(cast(scales[i]*dim,"int64"))sizes=sizes[2:]else:assertisinstance(sizes,relax.Constant),"Only constant output size currently supported."sizes=sizes.data.numpy().astype("int64").tolist()[2:]returnrelax.op.image.resize2d(x,size=relax.ShapeExpr(sizes),roi=roi,layout="NCHW",method=mode,coordinate_transformation_mode=coord_mode,rounding_method=rounding_method,cubic_alpha=cubic_coeff_a,cubic_exclude=exclude_outside,extrapolation_value=extrapolation_value,)classEinsum(OnnxOpConverter):"""Converts an onnx Einsum node into an equivalent Relax expression."""@classmethoddef_impl_v12(cls,bb,inputs,attr,params):equation=attr["equation"].decode("utf-8")returnbb.emit_te(topi.einsum,equation,*inputs)classRange(OnnxOpConverter):"""Converts an onnx Range node into an equivalent Relax expression."""@classmethoddef_impl_v12(cls,bb,inputs,attr,params):start=get_constant(inputs[0],params)limit=get_constant(inputs[1],params)delta=get_constant(inputs[2],params)out_dtype=start.struct_info.dtypeifisinstance(start,relax.Constant):start=start.data.numpy().tolist()ifisinstance(limit,relax.Constant):limit=limit.data.numpy().tolist()assertisinstance(delta,relax.Constant),"Constant delta required for Range."step=delta.data.numpy().tolist()# If all inputs are constant, compute directly.ifisinstance(start,int)andisinstance(limit,int):out_range=_np.arange(start=start,stop=limit,step=step)returnrelax.const(out_range,out_dtype)# Otherwise compute in graph.returnrelax.op.arange(start,limit,step,out_dtype)classInstanceNormalization(OnnxOpConverter):"""Converts an onnx InstanceNormalization node into an equivalent Relax expression."""@classmethoddef_impl_v6(cls,bb,inputs,attr,params):data=inputs[0]scale=inputs[1]B=inputs[2]epsilon=attr.get("epsilon",1e-05)epsilon=relax.const(epsilon,dtype=data.struct_info.dtype)ndim=len(data.struct_info.shape)redux_axes=list(range(2,ndim))mean=relax.op.mean(data,axis=redux_axes,keepdims=True)var=relax.op.variance(data,axis=redux_axes,keepdims=True)sqrt=relax.op.sqrt(relax.op.add(var,epsilon))out=relax.op.divide(relax.op.subtract(data,mean),sqrt)broadcast_shape=[-1]+[1,]*(ndim-2)ifscaleisnotNone:scale=relax.op.reshape(scale,broadcast_shape)out=relax.op.multiply(out,scale)ifBisnotNone:B=relax.op.reshape(B,broadcast_shape)out=relax.op.add(out,B)returnoutclassBatchNormalization(OnnxOpConverter):"""Converts an onnx BatchNormalization node into an equivalent Relax expression."""@classmethoddef_impl_v15(cls,bb,inputs,attr,params):# Unpack inputsdata=inputs[0]scale=inputs[1]bias=inputs[2]mean=inputs[3]var=inputs[4]epsilon=attr.get("epsilon",1e-05)returnrelax.op.nn.batch_norm(data,gamma=scale,beta=bias,moving_mean=mean,moving_var=var,epsilon=epsilon,axis=1)classMeanVarianceNormalization(OnnxOpConverter):"""Converts an onnx MeanVarianceNormalization node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):data=inputs[0]axis=attr.get("axes",(0,2,3))data_mean=relax.op.mean(data,axis=axis,keepdims=True)data_mean_squared=relax.op.power(data_mean,relax.const(2,dtype="float32"))data_squared=relax.op.power(data,relax.const(2,dtype="float32"))data_squared_mean=relax.op.mean(data_squared,axis=axis,keepdims=True)return(data-data_mean)/relax.op.sqrt(data_squared_mean-data_mean_squared)classPool(OnnxOpConverter):"""A helper class for pool op converters."""name=""@classmethoddefget_pad_pair(cls,input1d,kernel1d,stride1d,mode):"""infer pad size"""ifinput1d%stride1d==0:pad=max(kernel1d-stride1d,0)else:pad=max(kernel1d-(input1d%stride1d),0)pad_before=pad//2pad_after=pad-pad_beforeif"LOWER"inmode:return[pad_after,pad_before]return[pad_before,pad_after]@classmethoddef_impl_v1(cls,bb,inputs,attr,params):# Unpack inputs and attributes.data=inputs[0]input_shape=data.struct_info.shapendim=len(input_shape)auto_pad=attr.get("auto_pad",b"NOTSET").decode("utf-8")ceil_mode=attr.get("ceil_mode",0)dilations=attr.get("dilations",[1]*(ndim-2))kernel_shape=attr.get("kernel_shape")pads=attr.get("pads",0)strides=attr.get("strides",[1]*(ndim-2))count_include_pad=attr.get("count_include_pad",False)assertlen(kernel_shape)in[1,2,3],"Currently only 1D/2D/3D/ pooling is supported."assertauto_padin["NOTSET","SAME_UPPER","SAME_LOWER","VALID",],f"Value {auto_pad} in attribute auto_pad is invalid."ifauto_padin("SAME_UPPER","SAME_LOWER"):pads=[]ifcls.name=="avg_pool":foraxisinrange(len(input_shape)-2):axis_shape=input_shape[2+axis]stride=strides[axis]kernel=kernel_shape[axis]pad=cls.get_pad_pair(axis_shape,kernel,stride,auto_pad)pads.append(pad)else:input_spatial_shape=cls._get_input_spatial_shape(data)output_spatial_shape=[0for_ininput_spatial_shape]fori,_inenumerate(input_spatial_shape):ifauto_pad=="SAME_UPPER":output_spatial_shape[i]=int(_np.ceil(input_spatial_shape[i]/strides[i]))else:output_spatial_shape[i]=int(_np.floor(input_spatial_shape[i]/strides[i]))pad_i=((output_spatial_shape[i]-1)*strides[i]+((kernel_shape[i]-1)*dilations[i]+1)-input_spatial_shape[i])ifauto_pad=="SAME_UPPER":pads.append([pad_i//2,pad_i-pad_i//2])else:pads.append([pad_i-pad_i//2,pad_i//2])pads=tuple([valforpairinzip(*pads)forvalinpair])op=getattr(relax.op.nn,cls.name+str(len(kernel_shape))+"d")returnop(data,kernel_shape,strides,pads,dilations,ceil_mode,count_include_pad)@classmethoddef_get_input_spatial_shape(cls,tensor):# shape is (N x C x D1 x D2 ... Dn)return_np.array([int(d)fordintensor.struct_info.shape],dtype="int64")[2:]classMaxPool(Pool):"""Converts an onnx MaxPool node into an equivalent Relax expression."""name="max_pool"classAveragePool(Pool):"""Converts an onnx MaxPool node into an equivalent Relax expression."""name="avg_pool"classLpPool(OnnxOpConverter):"""Converts an onnx LpPool node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):dtype=inputs[0].struct_info.dtypep=attr.get("p",2.0)reci_p=relax.const(1.0/p,dtype=dtype)# emit for get struct_infodata=bb.emit(relax.op.power(inputs[0],relax.const(p,dtype=dtype)))attr.update({"count_include_pad":True})avg_pool=AveragePool._impl_v1(bb,[data],attr,params)kernels=attr["kernel_shape"]out=avg_pool*relax.const(_np.prod(kernels).astype(dtype))returnrelax.op.power(out,reci_p)classGlobalAveragePool(OnnxOpConverter):"""Converts an onnx GlobalAveragePool node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):rank=len(inputs[0].struct_info.shape)axes=list(range(2,rank))returnrelax.op.mean(inputs[0],axis=axes,keepdims=True)classGlobalMaxPool(OnnxOpConverter):"""Converts an onnx GlobalMaxPool node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):rank=len(inputs[0].struct_info.shape)axes=list(range(2,rank))returnrelax.op.max(inputs[0],axis=axes,keepdims=True)classGlobalLpPool(OnnxOpConverter):"""Converts an onnx GlobalLpPool node into an equivalent Relax expression."""@classmethoddef_impl_v2(cls,bb,inputs,attr,params):p=attr.get("p",2.0)dtype=inputs[0].struct_info.dtyperank=len(inputs[0].struct_info.shape)axes=list(range(2,rank))x_abs=relax.op.abs(inputs[0])x_p=relax.op.power(x_abs,relax.const(p,dtype=dtype))x_sum=relax.op.sum(x_p,axes,keepdims=True)returnrelax.op.power(x_sum,relax.const(1.0/p,dtype=dtype))classMaxUnpool(OnnxOpConverter):"""Converts an onnx MaxUnpool node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):data=inputs[0]indices=inputs[1]output_shape=inputs[2]kernel_shape=attr.get("kernel_shape")pads=attr.get("pads",[0]*len(kernel_shape)*2)strides=attr.get("strides",[1]*len(kernel_shape))multiplier=_np.concatenate([[1,1],list(strides)])shape=[v.valueforvindata.struct_info.shape]total_output_shape=multiplier*shape# Add extra dimensions from kernel size and stride mismatchtotal_output_shape+=_np.concatenate([[0,0],list(kernel_shape)],axis=0)total_output_shape-=_np.concatenate([[0,0],list(strides)],axis=0)ifoutput_shapeisnotNone:total_output_shape=output_shapeelifpadsisnotNone:# Get pads in the proper formatpads=_np.concatenate([[0,0,0,0],list(pads)],axis=0)pads=_np.reshape(pads,[-1,2])# Compute the total padding per axis.total_pad=_np.sum(pads,axis=-1)# Reversing maxpool means that padding actually makes our output smaller.total_output_shape=total_output_shape-total_pad# Create a tensor of zeros then scatter our data through it.relax_shape=relax.ShapeExpr(total_output_shape.tolist())zeros_tensor=bb.emit(relax.op.zeros(relax_shape,data.struct_info.dtype))# We need to flatten all our tensors before scattering.flat_tensor=relax.op.scatter_elements(relax.op.reshape(zeros_tensor,[-1]),relax.op.reshape(indices,[-1]),relax.op.reshape(data,[-1]),axis=0,)# Reshape our flattened data back to normal.output=relax.op.reshape(flat_tensor,relax_shape)returnoutputclassFlatten(OnnxOpConverter):"""Converts an onnx Flatten node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):axis=attr.get("axis",1)data_shape=list(inputs[0].struct_info.shape)ifaxis==0:new_shape=(1,-1)else:shape_flags=[isinstance(x,tvm.script.tir.IntImm)forxindata_shape[0:axis]]ifall(shape_flags):data_shape=[x.valueforxindata_shape[0:axis]]new_shape=(_np.prod(data_shape).astype("int64"),-1)else:batch_size=1forelindata_shape[0:axis]:batch_size=batch_size*elnew_shape=(batch_size,-1)returnrelax.op.reshape(inputs[0],new_shape)classLayerNormalization(OnnxOpConverter):"""Converts an onnx LayerNormalization node into an equivalent Relax expression."""@classmethoddef_impl_v17(cls,bb,inputs,attr,params):data=inputs[0]scale=inputs[1]bias=inputs[2]axis=attr.get("axis",-1)epsilon=attr.get("epsilon",1e-05)gamma_shape=get_const_tuple(scale.struct_info.shape)ifbiasisNone:seq_len=data.struct_info.shape[1].valuebias=relax.const([0.0]*seq_len,dtype="float32")else:beta_shape=get_const_tuple(bias.struct_info.shape)ifgamma_shape!=beta_shape:raiseValueError("gamma and beta shapes do not match")axis=list(axis)ifisinstance(axis,(list,tuple))else[axis]iflen(axis)<len(gamma_shape):axis.extend(range(axis[-1]+1,axis[-1]+1+len(gamma_shape)-len(axis)))output=relax.op.nn.layer_norm(data,scale,bias,axis,epsilon)# Onnx layernorm has 3 outputs but only the first is used.# We construct two empty constants for this.placeholder=relax.const(0,dtype="float32")returnrelax.Tuple([output,placeholder,placeholder])classReduceMax(OnnxOpConverter):"""Converts an onnx ReduceMax node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):data=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)returnrelax.op.max(data,axes,keepdims)@classmethoddef_impl_v18(cls,bb,inputs,attr,params):data=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes inputaxes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# If axes is empty and noop_with_empty_axes is False, reduce all dimsifnotaxesandnotnoop_with_empty_axes:returnrelax.op.max(data,None,keepdims)# If axes is empty and noop_with_empty_axes is True, return input unchangedelifnotaxesandnoop_with_empty_axes:returndata# Otherwise reduce over specified axeselse:returnrelax.op.max(data,axes,keepdims)classReduceMin(OnnxOpConverter):"""Converts an onnx ReduceMin node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):data=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)returnrelax.op.min(data,axes,keepdims)@classmethoddef_impl_v18(cls,bb,inputs,attr,params):data=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes inputaxes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# If axes is empty and noop_with_empty_axes is False, reduce all dimsifnotaxesandnotnoop_with_empty_axes:returnrelax.op.min(data,None,keepdims)# If axes is empty and noop_with_empty_axes is True, return input unchangedelifnotaxesandnoop_with_empty_axes:returndata# Otherwise reduce over specified axeselse:returnrelax.op.min(data,axes,keepdims)classReduceSum(OnnxOpConverter):"""Converts an onnx ReduceSum node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):data=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)returnrelax.op.sum(data,axes,keepdims)@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes inputaxes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# If axes is empty and noop_with_empty_axes is 0, reduce all dimensionsifnotaxesandnotnoop_with_empty_axes:returnrelax.op.sum(data,None,keepdims)# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.elifnotaxesandnoop_with_empty_axes:returndata# If axes is provided, reduce over the specified axeselse:returnrelax.op.sum(data,axes,keepdims)classReduceMean(OnnxOpConverter):"""Converts an onnx ReduceMean node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)returnrelax.op.mean(data,axes,keepdims)@classmethoddef_impl_v18(cls,bb,inputs,attr,params):data=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes inputaxes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# If axes is empty and noop_with_empty_axes is 0, reduce all dimensionsifnotaxesandnotnoop_with_empty_axes:returnrelax.op.mean(data,None,keepdims)# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.elifnotaxesandnoop_with_empty_axes:returndata# If axes is provided, reduce over the specified axeselse:returnrelax.op.mean(data,axes,keepdims)classReduceProd(OnnxOpConverter):"""Converts an onnx ReduceProd node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)returnrelax.op.prod(data,axes,keepdims)@classmethoddef_impl_v18(cls,bb,inputs,attr,params):data=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes inputaxes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# If axes is empty and noop_with_empty_axes is 0, reduce all dimensionsifnotaxesandnotnoop_with_empty_axes:returnrelax.op.prod(data,None,keepdims)# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.elifnotaxesandnoop_with_empty_axes:returndata# If axes is provided, reduce over the specified axeselse:returnrelax.op.prod(data,axes,keepdims)classReduceLogSumExp(OnnxOpConverter):"""Converts an onnx ReduceLogSumExp node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):x=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)max_x=relax.op.max(x,axes,True)exp_x=relax.op.exp(relax.op.subtract(x,max_x))sum_x=relax.op.sum(exp_x,axes,True)out_x=relax.op.add(relax.op.log(sum_x),max_x)ifnotkeepdims:out_x=relax.op.squeeze(out_x,axes)returnout_x@classmethoddef_impl_v18(cls,bb,inputs,attr,params):x=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes input (second input)axes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# Calculate LogSumExplog_sum_exp=lambdaaxes:(max_x:=relax.op.max(x,axes,True),exp_x:=relax.op.exp(relax.op.subtract(x,max_x)),sum_x:=relax.op.sum(exp_x,axes,True),out_x:=relax.op.add(relax.op.log(sum_x),max_x),relax.op.squeeze(out_x,axes)ifnotkeepdimselseout_x,)[-1]# If axes is empty and noop_with_empty_axes is 0, reduce all dimensionsifnotaxesandnotnoop_with_empty_axes:returnlog_sum_exp(None)# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.elifnotaxesandnoop_with_empty_axes:returnx# If axes is provided, reduce over the specified axeselse:returnlog_sum_exp(axes)classReduceLogSum(OnnxOpConverter):"""Converts an onnx ReduceLogSum node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)returnrelax.op.log(relax.op.sum(data,axes,keepdims))@classmethoddef_impl_v18(cls,bb,inputs,attr,params):data=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes inputaxes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# If axes is empty and noop_with_empty_axes is 0, reduce all dimensionsifnotaxesandnotnoop_with_empty_axes:returnrelax.op.log(relax.op.sum(data,None,keepdims))# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.elifnotaxesandnoop_with_empty_axes:returndata# If axes is provided, reduce over the specified axeselse:returnrelax.op.log(relax.op.sum(data,axes,keepdims))classReduceSumSquare(OnnxOpConverter):"""Converts an onnx ReduceSumSquare node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)returnrelax.op.sum(relax.op.multiply(data,data),axes,keepdims)@classmethoddef_impl_v18(cls,bb,inputs,attr,params):data=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes inputaxes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# If axes is empty and noop_with_empty_axes is 0, reduce all dimensionsifnotaxesandnotnoop_with_empty_axes:returnrelax.op.sum(relax.op.multiply(data,data),None,keepdims)# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.elifnotaxesandnoop_with_empty_axes:returndata# If axes is provided, reduce over the specified axeselse:returnrelax.op.sum(relax.op.multiply(data,data),axes,keepdims)classReduceL1(OnnxOpConverter):"""Converts an onnx ReduceL1 node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)returnrelax.op.sum(relax.op.abs(data),axes,keepdims)@classmethoddef_impl_v18(cls,bb,inputs,attr,params):data=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes inputaxes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# If axes is empty and noop_with_empty_axes is 0, reduce all dimensionsifnotaxesandnotnoop_with_empty_axes:returnrelax.op.sum(relax.op.abs(data),None,keepdims)# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.elifnotaxesandnoop_with_empty_axes:returndata# If axes is provided, reduce over the specified axeselse:returnrelax.op.sum(relax.op.abs(data),axes,keepdims)classReduceL2(OnnxOpConverter):"""Converts an onnx ReduceL2 node into an equivalent Relax expression."""@classmethoddef_impl_v13(cls,bb,inputs,attr,params):data=inputs[0]axes=attr.get("axes",None)keepdims=attr.get("keepdims",1)returnrelax.op.sqrt(relax.op.sum(relax.op.multiply(data,data),axes,keepdims))@classmethoddef_impl_v18(cls,bb,inputs,attr,params):data=inputs[0]keepdims=attr.get("keepdims",1)noop_with_empty_axes=attr.get("noop_with_empty_axes",0)# Optional axes inputaxes=Noneiflen(inputs)>1andinputs[1]isnotNone:axes_const=get_constant(inputs[1],params)assertisinstance(axes_const,relax.Constant),"Only constant axes currently supported"axes=axes_const.data.numpy().tolist()# If axes is empty and noop_with_empty_axes is 0, reduce all dimensionsifnotaxesandnotnoop_with_empty_axes:returnrelax.op.sqrt(relax.op.sum(relax.op.multiply(data,data),None,keepdims))# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.elifnotaxesandnoop_with_empty_axes:returndata# If axes is provided, reduce over the specified axeselse:returnrelax.op.sqrt(relax.op.sum(relax.op.multiply(data,data),axes,keepdims))classArgMax(OnnxOpConverter):"""Converts an onnx ArgMax node into an equivalent Relax expression."""@classmethoddef_check_attrs(cls,data,attr,shift_axis=True):dims_num=len(data.struct_info.shape)axis=attr.get("axis",0)ifshift_axisandaxis<0:axis+=dims_numassert0<=axis<dims_num,"Axis is out of bounds"keepdims=attr.get("keepdims",True)returnaxis,keepdims@classmethoddef_impl_v1(cls,bb,inputs,attr,params):data=inputs[0]axis,keepdims=cls._check_attrs(data,attr,False)returnrelax.op.argmax(data,axis,keepdims)@classmethoddef_impl_v11(cls,bb,inputs,attr,params):data=inputs[0]axis,keepdims=cls._check_attrs(data,attr)returnrelax.op.argmax(data,axis,keepdims)@classmethoddef_impl_v12(cls,bb,inputs,attr,params):data=inputs[0]axis,keepdims=cls._check_attrs(data,attr)select_last_index=attr.get("select_last_index",False)ifselect_last_index:# TODO(vvchernov): support attrraisetvm.error.OpAttributeUnImplemented("'select_last_index' attribute has not been supported yet")returnrelax.op.argmax(data,axis,keepdims)classArgMin(OnnxOpConverter):"""Converts an onnx ArgMin node into an equivalent Relax expression."""@classmethoddef_check_attrs(cls,data,attr,shift_axis=True):dims_num=len(data.struct_info.shape)axis=attr.get("axis",0)ifshift_axisandaxis<0:axis+=dims_numassert0<=axis<dims_num,"Axis is out of bounds"keepdims=attr.get("keepdims",True)returnaxis,keepdims@classmethoddef_impl_v1(cls,bb,inputs,attr,params):data=inputs[0]axis,keepdims=cls._check_attrs(data,attr,False)returnrelax.op.argmin(data,axis,keepdims)@classmethoddef_impl_v11(cls,bb,inputs,attr,params):data=inputs[0]axis,keepdims=cls._check_attrs(data,attr)returnrelax.op.argmin(data,axis,keepdims)@classmethoddef_impl_v12(cls,bb,inputs,attr,params):data=inputs[0]axis,keepdims=cls._check_attrs(data,attr)select_last_index=attr.get("select_last_index",False)ifselect_last_index:# TODO(vvchernov): support attrraisetvm.error.OpAttributeUnImplemented("'select_last_index' attribute has not been supported yet")returnrelax.op.argmin(data,axis,keepdims)classTopK(OnnxOpConverter):"""Converts an onnx TopK node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):data=inputs[0]k=inputs[1]ifnotisinstance(k,relax.Constant):raiseValueError("TopK k must be a constant")k=int(k.data.numpy())axis=attr.get("axis",-1)largest=attr.get("largest",1)sorted=attr.get("sorted",1)ifsorted!=1:raiseValueError("TopK sorted must be 1 for Relax frontend")returnrelax.op.topk(data,k,axis,ret_type="both",largest=largest)@classmethoddef_impl_v1(cls,bb,inputs,attr,params):data=inputs[0]k=attr.get("k",1)axis=attr.get("axis",-1)returnrelax.op.topk(data,k,axis,ret_type="both")classSkipLayerNormalization(OnnxOpConverter):"""Converts a microsoft contrib SkipLayerNormalization node into a Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):data=inputs[0]skip=inputs[1]gamma=inputs[2]beta=inputs[3]bias=inputs[4]assert(betaisnotNoneandbiasisnotNone),"SkipLayerNormalization import currently only supports required beta and bias"epsilon=attr.get("epsilon",1e-12)data=relax.op.add(data,skip)ifbiasisnotNone:data=relax.op.add(data,bias)output=relax.op.nn.layer_norm(data,gamma,beta,axes=-1,epsilon=epsilon)# Expects three outputs though only the first is used. Construct a placeholder for others.placeholder=relax.const(0,dtype="float32")returnrelax.Tuple([output,placeholder,placeholder])classEmbedLayerNormalization(OnnxOpConverter):"""Converts a microsoft contrib EmbedLayerNormalization node into a Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):input_ids=inputs[0]segment_ids=inputs[1]word_emb=inputs[2]pos_emb=inputs[3]segment_emb=inputs[4]gamma=inputs[5]beta=inputs[6]mask=inputs[7]pos_ids=inputs[8]epsilon=attr.get("epsilon",1e-12)(batch_size,seq_len)=[dim.valuefordimininput_ids.struct_info.shape]ifsegment_ids:assertsegment_embifpos_idsisNone:pos_ids=relax.const([list(range(seq_len))]*batch_size,dtype="int64")# TODO(jwfromm) Replace with relax ops once take has better support.word_vec=bb.emit_te(topi.take,word_emb,input_ids,0)ifsegment_ids:segment_vec=bb.emit_te(topi.take,segment_emb,segment_ids,0)pos_vec=bb.emit_te(topi.take,pos_emb,pos_ids,0)vec_sum=relax.op.add(word_vec,pos_vec)ifsegment_ids:vec_sum=relax.op.add(vec_sum,segment_vec)ln=relax.op.nn.layer_norm(vec_sum,gamma,beta,axes=-1,epsilon=epsilon)mask_index=relax.const(_np.zeros((batch_size,),dtype="int64"))ifmask:# Caculate number of words per sentence.mask_index=relax.op.sum(mask,axis=1)returnrelax.Tuple([ln,mask_index])classOneHot(OnnxOpConverter):"""Converts an onnx OneHot node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):indices=inputs[0]depth=get_constant(inputs[1],params)values=get_constant(inputs[2],params)axis=attr.get("axis",-1)assertisinstance(depth,relax.Constant),"Only constant depth currently supported."depth=depth.data.numpy().tolist()assertisinstance(values,relax.Constant),"Only constant values currently supported."values=values.data.numpy().tolist()off_value,on_value=valuesoff_value,on_value=relax.PrimValue(off_value),relax.PrimValue(on_value)returnrelax.op.one_hot(indices,on_value,off_value,depth,axis)classUnique(OnnxOpConverter):"""Converts an onnx Unique node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):data=inputs[0]axis=attr.get("axis",None)sorted=bool(attr.get("sorted",1))# TODO(tvm-team): Add support for return_index, return_inverse, return_countsunique=relax.op.unique(data,sorted=sorted,axis=axis)unique_numbers=tir.Var("unique_numbers","int64")input_shape=data.struct_info.shapedtype=data.struct_info.dtypeifaxisisNone:# flatten the input tensorreturnbb.match_cast(unique,relax.TensorStructInfo((unique_numbers,),dtype))axis=axisifaxis>=0elselen(input_shape)+axisifaxis<0oraxis>=len(input_shape):raiseValueError(f"Axis {axis} is out of bounds")output_shape=[input_shape[i]ifi!=axiselseunique_numbersforiinrange(len(input_shape))]returnbb.match_cast(unique,relax.TensorStructInfo(output_shape,dtype))classNonZero(OnnxOpConverter):"""Converts an onnx NonZero node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):ndim=inputs[0].struct_info.ndimndim=1ifndim==0elsendimnonzero_numbers=tir.Var("nonzero_numbers","int64")returnbb.match_cast(relax.op.nonzero(inputs[0]),relax.TensorStructInfo((ndim,nonzero_numbers),"int64"))classHardSigmoid(OnnxOpConverter):"""Converts an onnx HardSigmoid node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):x=inputs[0]dtype=x.struct_info.dtypealpha=float(attr.get("alpha",0.2))alpha=relax.const(alpha,dtype=dtype)beta=float(attr.get("beta",0.5))beta=relax.const(beta,dtype=dtype)returnrelax.op.clip(relax.op.add(relax.op.multiply(alpha,x),beta),0,1)classHardSwish(OnnxOpConverter):"""Converts an onnx HardSwish node into an equivalent Relax expression."""@classmethoddef_impl_v14(cls,bb,inputs,attr,params):x=inputs[0]dtype=x.struct_info.dtypereturnrelax.op.multiply(x,relax.op.divide(relax.op.clip(relax.op.add(x,relax.const(3,dtype)),0,6),relax.expr.const(6,dtype),),)classSign(OnnxOpConverter):"""Converts an onnx Sign node into an equivalent Relax expression."""@classmethoddef_impl_v9(cls,bb,inputs,attr,params):returnrelax.op.sign(inputs[0])classNot(OnnxOpConverter):"""Converts an onnx Not node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):returnrelax.op.logical_not(inputs[0])classDepthToSpace(OnnxOpConverter):"""Converts an onnx DepthToSpace node into an equivalent Relax expression."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):block_size=int(attr["blocksize"])mode=attr.get("mode",b"DCR").decode("utf-8")b,c,h,w=inputs[0].struct_info.shapeifmode=="DCR":x=relax.op.reshape(inputs[0],(b,block_size,block_size,c//(block_size**2),h,w))x=relax.op.permute_dims(x,[0,3,4,1,5,2])returnrelax.op.reshape(x,(b,c//(block_size**2),h*block_size,w*block_size))elifmode=="CRD":x=relax.op.reshape(inputs[0],(b,c//(block_size**2),block_size,block_size,h,w))x=relax.op.permute_dims(x,[0,1,4,2,5,3])returnrelax.op.reshape(x,(b,c//(block_size**2),h*block_size,w*block_size))else:raiseValueError(f"Unsupported mode: {mode}, expected DCR or CRD")classSpaceToDepth(OnnxOpConverter):"""Converts an onnx SpaceToDepth node into an equivalent Relax expression."""@classmethoddef_impl_v1(cls,bb,inputs,attr,params):block_size=int(attr["blocksize"])b,c,h,w=inputs[0].struct_info.shapex=relax.op.reshape(inputs[0],(b,c,h//block_size,block_size,w//block_size,block_size))x=relax.op.permute_dims(x,[0,3,5,1,2,4])returnrelax.op.reshape(x,(b,c*block_size*block_size,h//block_size,w//block_size))classSequenceConstruct(OnnxOpConverter):"""Operator converter for sequence construction op."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):# Construct a tuple from input tensors.returnrelax.Tuple(inputs)classSequenceEmpty(OnnxOpConverter):"""Operator converter for sequence empty op."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):# Construct an empty tuple.returnrelax.Tuple([])classSequenceErase(OnnxOpConverter):"""Operator converter for sequence erase op."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):# Erase tensor from sequence on specified positioninput_sequence=inputs[0]iflen(inputs)==2:position=inputs[1]# Non constant position is not supported.ifisinstance(position,relax.Constant):position=int(position.data.numpy())else:raiseNotImplementedError("Position must be a constant.")else:position=-1seq_len=len(input_sequence)ifnot-seq_len<=position<seq_len:raiseValueError(f"Position is out of bounds, expected [-{seq_len}, {seq_len}), got {position}")ifposition<0:position=seq_len+position# Convert sequence to a list, insert tensors before erased, and repackage as Tuple.tensor_list=[input_sequence[i]foriinrange(seq_len)ifi!=position]# Create new tuple and return.returnrelax.Tuple(tensor_list)classSequenceInsert(OnnxOpConverter):"""Operator converter for sequence insert op."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):# Insert a new tensor into a tuple of tensors.input_sequence=inputs[0]tensor_to_insert=inputs[1]iflen(inputs)==3:position=inputs[2]# Non constant position is not supported.ifisinstance(position,relax.Constant):position=position.data.numpy()else:raiseNotImplementedError("Position must be a constant.")else:position=-1ifposition<0:position=len(input_sequence)+position+1# Convert sequence to a list, insert new tensor, and repackage as Tuple.tensor_list=[input_sequence[i]foriinrange(len(input_sequence))]# Insert new tensor.tensor_list.insert(position,tensor_to_insert)# Create new tuple and return.returnrelax.Tuple(tensor_list)classSequenceLength(OnnxOpConverter):"""Operator converter for sequence length op."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):# Get length of input sequencereturnrelax.const(len(inputs[0]),dtype="int64")classConcatFromSequence(OnnxOpConverter):"""Operator converter for sequence concatenation op."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):axis=attr.get("axis",0)new_axis=attr.get("new_axis",0)ifnew_axis==1:raiseNotImplementedError("Insert new axis is not supported yet.")returnrelax.op.concat(inputs[0],axis=axis)classSplitToSequence(OnnxOpConverter):"""Operator converter for split to sequence op."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):axis=attr.get("axis",0)keepdims=attr.get("keepdims",1)input_tensor=inputs[0]input_shape=input_tensor.struct_info.shape# If split is not provided, we split all values along axis.iflen(inputs)==1:split=_np.array(1)ifnotkeepdims:raiseNotImplementedError("Only keepdims=1 is supported for now")else:split=inputs[1]ifnotisinstance(split,relax.Constant):raiseValueError("Only constant split supported for SplitToSequence")split=split.data.numpy()iflen(split.shape)==1andsplit.shape[0]>1:split=_np.cumsum(split)split=list(split[:-1])else:chunk_size,dim_size=int(split),input_shape[axis]ifdim_size%chunk_size!=0:raiseValueError(f"Dimension of size {dim_size} along axis {axis} must be "f"evenly divisible by chunk size {chunk_size}")split=dim_size//chunk_sizeoutput=relax.op.split(input_tensor,split,axis=axis)returnoutputclassSequenceAt(OnnxOpConverter):"""Operator converter for sequence at op."""@classmethoddef_impl_v11(cls,bb,inputs,attr,params):input_sequence=inputs[0]position=inputs[1]assertisinstance(position,relax.Constant),"Only constant position supported for SequenceAt"position=int(position.data.numpy())returninput_sequence[position]def_get_convert_map():return{# defs/experimental# "Optional": Optional_,# "OptionalHasElement": OptionalHasElement,# "OptionalGetElement": OptionalGetElement,# Binary operators"Add":Add,"Sub":Sub,"Mul":Mul,"Div":Div,"Mod":Mod,"Less":Less,"LessOrEqual":LessOrEqual,"Greater":Greater,"GreaterOrEqual":GreaterOrEqual,"Equal":Equal,"BitwiseAnd":BitwiseAnd,"BitwiseOr":BitwiseOr,"BitwiseXor":BitwiseXor,"BitShift":BitShift,"And":And,"Or":Or,"Xor":Xor,"Not":Not,# Unary operators"BitwiseNot":BitwiseNot,"Log":Log,"Exp":Exp,"Acos":Acos,"Acosh":Acosh,"Asin":Asin,"Asinh":Asinh,"Atan":Atan,"Atanh":Atanh,"Cos":Cos,"Cosh":Cosh,"Sin":Sin,"Sinh":Sinh,"Tan":Tan,"Tanh":Tanh,"Neg":Neg,"Abs":Abs,"Reciprocal":Reciprocal,"Floor":Floor,"Ceil":Ceil,"Round":Round,"IsInf":IsInf,"IsNaN":IsNaN,"Sqrt":Sqrt,"Relu":Relu,"Selu":Selu,"Mish":Mish,"Trilu":Trilu,"PRelu":PRelu,"LeakyRelu":LeakyRelu,"ThresholdedRelu":ThresholdedRelu,"Elu":Elu,"Gelu":Gelu,"FastGelu":FastGelu,"BiasGelu":BiasGelu,"HardSigmoid":HardSigmoid,"HardSwish":HardSwish,"Sign":Sign,"Softplus":Softplus,"Softsign":Softsign,"Shrink":Shrink,"Erf":Erf,"Sum":Sum,"Min":Min,"Max":Max,"Mean":Mean,"Cast":Cast,"Gemm":Gemm,"MatMul":MatMul,# "MatMulInteger": MatMulInteger,# "MatMulInteger16": MatMulInteger16,"Reshape":Reshape,"Sigmoid":Sigmoid,"Softmax":Softmax,"LogSoftmax":LogSoftmax,"Hardmax":Hardmax,"Transpose":Transpose,"Unsqueeze":Unsqueeze,"Where":Where,"Concat":Concat,"Clip":Clip,"Shape":Shape,"Pow":Pow,"CumSum":CumSum,"Squeeze":Squeeze,"Constant":Constant,"Gather":Gather,"GatherElements":GatherElements,"GatherND":GatherND,"Scatter":Scatter,"ScatterElements":ScatterElements,"ScatterND":ScatterND,"Compress":Compress,"Size":Size,"EyeLike":EyeLike,# Normalization"BatchNormalization":BatchNormalization,"LayerNormalization":LayerNormalization,"SkipLayerNormalization":SkipLayerNormalization,"EmbedLayerNormalization":EmbedLayerNormalization,"InstanceNormalization":InstanceNormalization,"MeanVarianceNormalization":MeanVarianceNormalization,# defs/reduction"ReduceMax":ReduceMax,"ReduceMin":ReduceMin,"ReduceSum":ReduceSum,"ReduceMean":ReduceMean,"ReduceProd":ReduceProd,"ReduceLogSumExp":ReduceLogSumExp,"ReduceLogSum":ReduceLogSum,"ReduceSumSquare":ReduceSumSquare,"ReduceL1":ReduceL1,"ReduceL2":ReduceL2,"ArgMax":ArgMax,"ArgMin":ArgMin,"TopK":TopK,"Expand":Expand,"ConstantOfShape":ConstantOfShape,"Slice":Slice,"Attention":Attention,"Pad":Pad,"Split":Split,"Tile":Tile,"AveragePool":AveragePool,"MaxPool":MaxPool,"LpPool":LpPool,"GlobalAveragePool":GlobalAveragePool,"GlobalMaxPool":GlobalMaxPool,"GlobalLpPool":GlobalLpPool,"MaxUnpool":MaxUnpool,"Conv":Conv,"ConvTranspose":ConvTranspose,"Flatten":Flatten,"Identity":Identity,"Resize":Resize,"Einsum":Einsum,"Range":Range,"OneHot":OneHot,"Unique":Unique,"NonZero":NonZero,# "If": If,# "LRN": LRN,# "MaxRoiPool": MaxRoiPool,# "RoiAlign": RoiAlign,# "NonMaxSuppression": NonMaxSuppression,# "GridSample": GridSample,# "Upsample": Upsample,# others"DepthToSpace":DepthToSpace,"SpaceToDepth":SpaceToDepth,# Sequence operators"SequenceConstruct":SequenceConstruct,"SequenceEmpty":SequenceEmpty,"SequenceErase":SequenceErase,"SequenceInsert":SequenceInsert,"SequenceLength":SequenceLength,"ConcatFromSequence":ConcatFromSequence,"SplitToSequence":SplitToSequence,"SequenceAt":SequenceAt,}classONNXGraphImporter:"""A helper class for handling Relax expression copying from pb2.GraphProto. Definition: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto Parameters ---------- shape_dict : dict of str to tuple, optional The input shape to the graph dtype_dict : str or dict of str to str The input types to the graph keep_params_in_input : bool If True, parameters will be treated as input variables. If false, parameters are treated as constant and folded directly into the graph. sanitize : bool Whether to sanitize the input names to be valid Relax identifiers. """current=Nonedef__init__(self,shape_dict:Dict[str,List],dtype_dict:Union[str,Dict[str,str]],keep_params_in_input:bool=False,sanitize:bool=True,):self._nodes:Dict[str,relax.Expr]={}self._inputs:Dict[str,relax.Var]={}self._num_input:int=0self._shape=shape_dict.copy()ifshape_dictelse{}self._input_names:List[str]=[]self._dtype=dtype_dictself.opset:int=Noneself._name_supply=NameSupply()self._keep_params_in_input=keep_params_in_inputself._sanitize:bool=sanitizeself.bb:relax.BlockBuilder=relax.BlockBuilder()# pylint: disable=invalid-nameself._params={}deffrom_onnx(self,graph:onnx.onnx_ml_pb2.ModelProto,opset:int)->IRModule:"""Construct Relax expressions from the ONNX graph. Onnx graph is a python protobuf object. Parameters ---------- graph : onnx protobuf object The loaded onnx graph opset : opset version Returns ------- mod : tvm.IRModule The returned relax module """withself.bb.function("main"):withself.bb.dataflow()asdf:# pylint: disable=invalid-name, unused-variableself.opset=opsetself._parse_graph_initializers(graph)self._parse_graph_input(graph)self._check_for_unsupported_ops(graph)self._construct_nodes(graph)# now return the outputsoutputs=[self._nodes[self._parse_value_proto(i)]foriingraph.output]outputs=outputs[0]iflen(outputs)==1elserelax.Tuple(outputs)output_var=self.bb.emit_output(outputs)# Create function attributes for this modulefunc_attrs={"num_input":self._num_input}# Create a function from our output expression and all input variables.input_list=[valueforvalueinself._inputs.values()ifisinstance(value,relax.Var)]# Attach params if they are available.ifself._keep_params_in_inputandself._params:param_var_list,param_value_list=map(list,zip(*self._params.values()))input_list=input_list+param_var_listfunc_attrs["params"]=param_value_listself.bb.emit_func_output(output_var,params=input_list)relax_mod=self.bb.get()# Attach attributes.relax_mod["main"]=relax_mod["main"].with_attrs(func_attrs)returnrelax_moddef_parse_graph_initializers(self,graph:onnx.onnx_ml_pb2.GraphProto):"""Parse network inputs to relax, aka parameters."""forinit_tensoringraph.initializer:# There are two cases for handling parameters, they are either# treated as variables or constants.ifnotinit_tensor.name.strip():raiseValueError("Tensor's name is required.")array=self._parse_array(init_tensor)# Create variables for constants.ifself._keep_params_in_input:# Pytorch sometimes inserts silly weight prefix. Remove it.var_name=init_tensor.name.strip("onnx::")init_var=self._new_var(var_name,shape=array.shape,dtype=array.dtype)self._nodes[init_tensor.name]=init_var# We need to keep track of both the real value and variable for this variable.self._params[var_name]=(init_var,array)# Otherwise we can use the weight as a constant.else:self._nodes[init_tensor.name]=relax.const(array)def_sanitize_name(self,name:str)->str:"""Sanitize a name to make it a valid identifier. If the name is None, returns a string input_0, input_1, etc. If the input is an empty string, returns empty_0, empty_1, etc. If the input is a string that does not start with a letter or underscore, returns input_<name>. Otherwise, returns an unique input name. Parameters ---------- name : str The name to sanitize Returns ------- new_name : str """ifname=="":returnself._name_supply.fresh_name("empty_")new_name=name.replace(".","_")ifnotnew_name[0].isalpha()andnew_name[0]!="_":new_name=str(self._name_supply.fresh_name("input_"+new_name))else:new_name=str(self._name_supply.fresh_name(new_name))ifnew_name!=name:warnings.warn(("Renaming name %s to %s"%(name,new_name)))returnnew_namedef_new_var(self,var_name:str,shape:List,dtype:str="float32"):"""Creates a new Relax variable."""returnrelax.Var(name_hint=var_name,struct_info=relax.TensorStructInfo(shape=shape,dtype=dtype))def_parse_graph_input(self,graph:onnx.onnx_ml_pb2.GraphProto):"""Parse model inputs to Relax parameters."""value_dict={}foriingraph.input:# from onnx v0.2, GraphProto.input has type ValueInfoProto,# and the name is 'i.name'i_name,i_shape,d_type,i_shape_name,value_dict=get_info(i,value_dict)ifi_namenotinself._nodes:self._num_input+=1self._input_names.append(i_name)ifi_nameinself._shape:i_shape=self._shape[i_name]else:if"?"instr(i_shape):warning_msg=("Input %s has unknown dimension shapes: %s. ""Specifying static values may improve performance"%(i_name,str(i_shape_name)))warnings.warn(warning_msg)ifisinstance(self._dtype,dict):dtype=self._dtype[i_name]ifi_nameinself._dtypeelsed_typeelse:dtype=d_typevar_name=self._sanitize_name(i_name)ifself._sanitizeelsei_nameself._nodes[i_name]=self._new_var(var_name,shape=i_shape,dtype=dtype)self._inputs[i_name]=self._nodes[i_name]def_check_for_unsupported_ops(self,graph:onnx.onnx_ml_pb2.GraphProto):convert_map=_get_convert_map()unsupported_ops=set()fornodeingraph.node:op_name=node.op_typeif(op_namenotinconvert_mapandop_name!="Constant"# and op_name not in _identity_list):unsupported_ops.add(op_name)ifunsupported_ops:msg="The following operators are not supported for frontend ONNX: "msg+=", ".join(unsupported_ops)raisetvm.error.OpNotImplemented(msg)def_construct_nodes(self,graph:onnx.onnx_ml_pb2.GraphProto):"""Nodes are stored as directed acyclic graph."""fornodeingraph.node:op_name=node.op_typeattr=self._parse_attr(node.attribute)# Create and populate input list.inputs=onnx_input()foriinnode.input:ifi!="":inputs.append(self._nodes[i])else:inputs.append(None)i_name=self._parse_value_proto(node)outputs=node.outputattr["tvm_custom"]={}attr["tvm_custom"]["name"]=i_nameattr["tvm_custom"]["num_outputs"]=len(outputs)# Perform special handling for shape expressions. If an input is a# shape expr, make sure the current op can handle it, otherwise# convert it to a tensor.shape_compatible_ops=["Reshape","ConstantOfShape","Gather","Slice","Shape","Expand","Concat","Equal","Where","Cast","Squeeze",]return_tuple_ops=["SequenceConstruct","SequenceEmpty","SequenceErase","SequenceInsert","ConcatFromSequence","SplitToSequence",]fori,inpinenumerate(inputs):if(inpisnotNoneandisinstance(inp,relax.Expr)andisinstance(inp.struct_info,relax.ShapeStructInfo)andop_namenotinshape_compatible_ops):raiseValueError(f"Node {node.name} cannot handle ShapeExpr inputs.")try:op=self._convert_operator(op_name,inputs,attr,self.opset)# Create struct information for the new operator.op=self.bb.normalize(op)exceptTVMErroraserr:print(f"Error converting operator {op_name}, with inputs: {inputs}")raiseerrifop_nameinreturn_tuple_ops:outputs_num=1elifnotisinstance(op,relax.Tuple):ifisinstance(op.struct_info,relax.TupleStructInfo):# This is a var bound to a tuple. We need to unpack it and create# a new tuple.tuple_items=[]foriinrange(len(op.struct_info.fields)):tuple_items.append(self.bb.emit(relax.TupleGetItem(op,i)))op=relax.Tuple(tuple_items)outputs_num=len(tuple_items)else:outputs_num=1else:outputs_num=len(op)assert(len(outputs)<=outputs_num),"Missing outputs during conversion. Expected {} but Got {} in {}.".format(len(outputs),outputs_num,op_name)ifoutputs_num==1:self._nodes[outputs[0]]=opelse:fork,iinzip(list(outputs),range(len(outputs))):self._nodes[k]=op[i]def_parse_value_proto(self,value_proto:onnx.onnx_ml_pb2.GraphProto):"""Parse ValueProto or raw str."""try:name=value_proto.nameexceptAttributeError:name=value_protoreturnnamedef_parse_array(self,tensor_proto:onnx.onnx_ml_pb2.TensorProto)->tvm.nd.array:np_array=get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims))returntvm.nd.array(np_array)def_parse_attr(self,attr_proto:onnx.onnx_ml_pb2.AttributeProto)->Dict[str,Any]:"""Convert a list of AttributeProto to a dict, with names as keys."""attrs={}forainattr_proto:forfin["f","i","s","g"]:ifa.HasField(f):attrs[a.name]=getattr(a,f)forfin["floats","ints","strings"]:iflist(getattr(a,f)):asserta.namenotinattrs,"Only one type of attr is allowed"attrs[a.name]=tuple(getattr(a,f))forfin["t"]:ifa.HasField(f):attrs[a.name]=getattr(a,f)forfin["tensors"]:iflist(getattr(a,f)):asserta.namenotinattrs,"Only one type of attr is allowed"attrs[a.name]=tuple(getattr(a,f))forfin["graphs"]:iflist(getattr(a,f)):raiseNotImplementedError("Field {} is not supported in relax.".format(f))ifa.namenotinattrs:raiseValueError("Cannot parse attribute: \n{}\n.".format(a))returnattrsdef_convert_operator(self,op_name:str,inputs:List[relax.Expr],attrs:Dict,opset:int,)->relax.Expr:"""Convert ONNX operator into a Relax operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters ---------- op_name : str Operator name, such as Convolution, FullyConnected inputs : list of tvm.relax.function.Function List of inputs. attrs : dict Dict of operator attributes opset : int Opset version Returns ------- sym : tvm.relax.function.Function Converted relax function """convert_map=_get_convert_map()ifop_nameinconvert_map:convert_class=convert_map[op_name]op_function=convert_class.get_converter(opset)sym=op_function(self.bb,inputs,attrs,[self._nodes,self._params])else:raiseNotImplementedError("Operator {} not implemented.".format(op_name))returnsym
[文档]deffrom_onnx(model:onnx.onnx_ml_pb2.GraphProto,shape_dict:Optional[Dict[str,List]]=None,dtype_dict:Optional[Union[str,Dict[str,str]]]="float32",opset:int=None,keep_params_in_input:bool=False,sanitize_input_names:bool=True,)->IRModule:"""Convert a ONNX model into an equivalent Relax Function. ONNX graphs are represented as Python Protobuf objects. The current implementation assumes that the input model is after ONNX v1.1.0. Parameters ---------- model : protobuf object ONNX ModelProto after ONNX v1.1.0 shape_dict : dict of str to tuple, optional The input shape to the graph dtype_dict : str or dict of str to str, optional The input types to the graph opset : int, optional Override to autodetected opset. This can be helpful for some testing. keep_params_in_input : bool If True, parameters will be treated as input variables. If false, parameters are treated as constant and folded directly into the graph. sanitize_input_names : bool, optional Whether to sanitize the input names to ensure they are valid Relax identifiers. Returns ------- mod : tvm.IRModule The relax module for compilation """# Error if the model version is below 1.1.0ifmodel.ir_version<3:raiseValueError("Model IR version {} not supported. Must be at least after 1.1.0.".format(model.ir_version))try:importonnx# pylint: disable=import-outside-toplevel, redefined-outer-nameifhasattr(onnx.checker,"check_model"):# try use onnx's own model checker before converting any modeltry:onnx.checker.check_model(model)exceptExceptionasexception:# pylint: disable=c-extension-no-member, broad-except# the checker is a bit violent about errors, so simply print warnings herewarnings.warn(str(exception))exceptImportErroraserror:raiseImportError("Unable to import onnx which is required {}".format(error))g=ONNXGraphImporter(shape_dict,dtype_dict,keep_params_in_input=keep_params_in_input,sanitize=sanitize_input_names,)graph=model.graphtry:opset_in_model=1ifmodel.opset_import:# TODO: for now we only really support ai.onnx op set# TODO: handle other namespaces well see https://github.com/apache/tvm/issues/10950foropset_identifierinmodel.opset_import:# As per https://github.com/onnx/onnx/blob/main/docs/IR.md# All operator sets except the default one must specify the operator versionifstr(opset_identifier.domain)in["ai.onnx",""]:opset_in_model=opset_identifier.versionbreakexceptAttributeError:opset_in_model=1ifopsetisNone:opset=opset_in_modelelifopset<opset_in_model:warnings.warn(""f"You are overwritting original opset ver = {opset_in_model} by lower ver = {opset}. "f"That might cause model conversion errors.")# Use the graph proto as a scope so that ops can access other nodes if needed.returng.from_onnx(graph,opset)