# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Binary Neural Network (BNN) Operators"""from__future__importabsolute_importas_absimporttvmfromtvmimporttefrom..importtagfrom..utilsimportsimplify,get_const_int
[文档]defbinarize_pack(data,axis=None,name="PackedInput"):"""Binarization and bit-packing along a certain axis. Parameters ---------- data : tvm.te.Tensor n-D input, can be any layout. axis : None or int The axis along which to do binarization and bit-packing, default is the last axis. name : str, optional The name prefix operators generate. Returns ------- output : tvm.te.Tensor n-D, the same layout as input, dtype is uint32. """ishape=data.shapeifaxisisNone:axis=len(ishape)-1assertget_const_int(ishape[axis])%32==0n=len(ishape)oshape=tuple(simplify(ishape[i]//32)ifi==axiselseishape[i]foriinrange(n))def_binarize_pack(*indices):start_idx=[indices[i]*32ifi==axiselseindices[i]foriinrange(n)]packed=tvm.tir.const(0,"uint32")forjinrange(32):idx=[start_idx[i]+jifi==axiselsestart_idx[i]foriinrange(n)]sign=(data(*idx)>=0).astype("uint32")packed=packed|signifj==31:returnpackedpacked=packed<<1raiseRuntimeError("not resach")returnte.compute(oshape,_binarize_pack,name=name,tag="binarize_pack")
[文档]defbinary_dense(data,weight):"""Binary matrix multiplication using xor and bit-count. Parameters ---------- data : tvm.te.Tensor 2-D with shape [batch, in_dim], dtype is uint32. weight : tvm.te.Tensor 2-D with shape [out_dim, in_dim], dtype is uint32. Returns ------- output : tvm.te.Tensor 2-D with shape [batch, out_dim], dtype is float32. """assert(data.dtype=="uint32"andweight.dtype=="uint32"),"dtype of data and weight should be uint32"assertlen(data.shape)==2andlen(weight.shape)==2,"only support 2-dim binary dense"batch,in_dim=data.shapeout_dim,_=weight.shapek=te.reduce_axis((0,in_dim),name="k")matmul=te.compute((batch,out_dim),lambdai,j:te.sum(tvm.tir.popcount(data[i,k]^weight[j,k]),axis=k),tag="binary_dense",)returnte.compute((batch,out_dim),lambdai,j:32*in_dim-2.0*matmul(i,j),tag=tag.ELEMWISE)