# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Tensor class for computation declaration."""# pylint: disable=invalid-nameimporttvm._ffifromtvm.runtimeimportObject,ObjectGeneric,convert_to_objectfromtvm.tirimportexpras_expr,DataProducerfrom.import_ffi_api
[文档]classTensorSlice(ObjectGeneric,_expr.ExprOp):"""Auxiliary data structure for enable slicing syntax from tensor."""def__init__(self,tensor,indices):ifnotisinstance(indices,tuple):indices=(indices,)self.tensor=tensorself.indices=indicesdef__getitem__(self,indices):ifnotisinstance(indices,tuple):indices=(indices,)returnTensorSlice(self.tensor,self.indices+indices)
[文档]defasobject(self):"""Convert slice to object."""returnself.tensor.__call__(*self.indices)
@propertydefdtype(self):"""Data content of the tensor."""returnself.tensor.dtype
[文档]@tvm._ffi.register_objectclassTensor(DataProducer,_expr.ExprOp):"""Tensor object, to construct, see function.Tensor"""def__call__(self,*indices):ndim=self.ndimiflen(indices)!=ndim:raiseValueError(f"Need to provide {ndim} index in tensor but {len(indices)} was provided")indices=convert_to_object(indices)return_expr.ProducerLoad(self,indices)def__getitem__(self,indices):returnTensorSlice(self,indices)def__hash__(self):return_ffi_api.TensorHash(self)def__eq__(self,other):ifnotisinstance(other,Tensor):ifisinstance(other,_expr.ExprOp):return_expr.EqualOp(self,other)returnFalseifself.ndim==0andother.ndim==0:raiseValueError("Equal == comparison among rank-0 tensor is ambiguous, ""use Tensor.equal for content expression equvalence, ""use Tensor.same_as for exact reference comparison")return_ffi_api.TensorEqual(self,other)@propertydefndim(self):"""Dimension of the tensor."""returnlen(self.shape)@propertydefaxis(self):"""Axis of the tensor."""returnself.__getattr__("axis")@propertydefop(self):"""The corressponding :py:class:`Operation`."""returnself.__getattr__("op")@propertydefvalue_index(self):"""The output value index the tensor corresponds to."""returnself.__getattr__("value_index")@propertydefshape(self):"""The output shape of the tensor."""returnself.__getattr__("shape")@propertydefname(self):op=self.opifop.num_outputs==1:returnop.namereturnf"{op.name}.v{self.value_index}"
classOperation(Object):"""Represent an operation that generates a tensor"""defoutput(self,index):"""Get the index-th output of the operation Parameters ---------- index : int The index size. Returns ------- out : Tensor The i-th output. """return_ffi_api.OpGetOutput(self,index)@propertydefnum_outputs(self):"""Number of outputs from this op."""return_ffi_api.OpNumOutputs(self)@propertydefinput_tensors(self):"""List of input tensors to this op."""return_ffi_api.OpInputTensors(self)
@tvm._ffi.register_objectclassBaseComputeOp(Operation):"""Compute operation."""@propertydefaxis(self):"""Represent the IterVar axis, defined when it is a ComputeOp"""returnself.__getattr__("axis")@propertydefreduce_axis(self):"""Represent axis of reductions, only defined when it is a ComputeOp"""returnself.__getattr__("reduce_axis")
[文档]@tvm._ffi.register_objectclassScanOp(Operation):"""Scan operation."""@propertydefscan_axis(self):"""Represent the scan axis, only defined when it is a ScanOp"""returnself.__getattr__("scan_axis")