# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.# pylint: disable=invalid-name, pointless-exception-statement"""TVM operator for softmax and log_softmax compute."""from__future__importabsolute_importimporttvmfromtvmimportte,topi
[文档]@tvm.te.tag_scope(tag="softmax_output")defsoftmax(x,axis=-1):"""Perform softmax activation on the data. Parameters ---------- data : tvm.te.Tensor can be any dimension axis : int channel axis Returns ------- output : tvm.te.Tensor output shape is the same as input """returnsoftmax_common(x,axis,False)
[文档]@tvm.te.tag_scope(tag="fast_softmax_output")deffast_softmax(x,axis=-1):"""Perform softmax activation on the data. Use approximation to compute exponent for faster speed. Parameters ---------- data : tvm.te.Tensor can be any dimension axis : int channel axis Returns ------- output : tvm.te.Tensor output shape is the same as input """returnsoftmax_common(x,axis,True)
[文档]defsoftmax_common(x,axis,use_fast_exp):"""The common part of softmax and fast_softmax"""shape=x.shapeifaxis<0:axis=len(shape)+axisifaxis>=len(shape):ValueError("axis parameter should be less than input dim")k1=te.reduce_axis((0,shape[axis]),name="k")k2=te.reduce_axis((0,shape[axis]),name="k")definsert_reduce_index(indices,reduce_index):returnindices[:axis]+(reduce_index,)+indices[axis:]defget_non_reduce_indices(indices):returntuple([varfor(i,var)inenumerate(indices)ifi!=axis])def_compute_max(*indices):eval_range=insert_reduce_index(indices,k1)returntvm.te.max(x[eval_range],axis=k1)def_compute_delta(max_elem,*indices):non_reduce_indices=get_non_reduce_indices(indices)returnx[indices]-max_elem[non_reduce_indices]def_compute_exp(max_elem,*indices):non_reduce_indices=get_non_reduce_indices(indices)returnte.exp(x[indices]-max_elem[non_reduce_indices])def_compute_expsum(exp,*indices):eval_range=insert_reduce_index(indices,k2)returnte.sum(exp[eval_range],axis=k2)def_normalize(exp,expsum,*indices):non_reduce_indices=get_non_reduce_indices(indices)returnexp[indices]/expsum[non_reduce_indices]reduced_shape=tuple([dimfor(i,dim)inenumerate(shape)ifi!=axis])max_elem=te.compute(reduced_shape,_compute_max,name="T_softmax_maxelem")ifuse_fast_exp:delta=te.compute(shape,lambda*indices:_compute_delta(max_elem,*indices),name="T_softmax_delta")exp=topi.math.fast_exp(delta)else:exp=te.compute(shape,lambda*indices:_compute_exp(max_elem,*indices),name="T_softmax_exp")expsum=te.compute(reduced_shape,lambda*indices:_compute_expsum(exp,*indices),name="T_softmax_expsum")returnte.compute(shape,lambda*indices:_normalize(exp,expsum,*indices),name="T_softmax_norm",attrs={"axis":axis},)
[文档]@tvm.te.tag_scope(tag="log_softmax_output")deflog_softmax(x,axis=-1):"""Perform log softmax activation on the data Parameters ---------- data : tvm.te.Tensor N-D input data Returns ------- output : tvm.te.Tensor N-D output with same shape """shape=x.shapeifaxis<0:axis=len(shape)+axisifaxis>=len(shape):ValueError("axis parameter should be less than input dim")k1=te.reduce_axis((0,shape[axis]),name="k")k2=te.reduce_axis((0,shape[axis]),name="k")definsert_reduce_index(indices,reduce_index):returnindices[:axis]+(reduce_index,)+indices[axis:]defget_non_reduce_indices(indices):returntuple([varfor(i,var)inenumerate(indices)ifi!=axis])def_compute_max(*indices):eval_range=insert_reduce_index(indices,k1)returntvm.te.max(x[eval_range],axis=k1)def_compute_expsum(max_elem,*indices):eval_range=insert_reduce_index(indices,k2)returnte.sum(te.exp(x[eval_range]-max_elem[indices]),axis=k2)def_normalize(max_elem,expsum,*indices):non_reduce_indices=get_non_reduce_indices(indices)returnx[indices]-max_elem[non_reduce_indices]-te.log(expsum[non_reduce_indices])reduced_shape=tuple([dimfor(i,dim)inenumerate(shape)ifi!=axis])max_elem=te.compute(reduced_shape,_compute_max,name="T_softmax_maxelem")expsum=te.compute(reduced_shape,lambda*indices:_compute_expsum(max_elem,*indices))returnte.compute(shape,lambda*indices:_normalize(max_elem,expsum,*indices),attrs={"axis":axis},)