tvm.relax.struct_info 源代码

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""The struct info nodes of the Relax language."""
from typing import List, Optional, Union

import tvm._ffi
import tvm

from tvm.ir import Span, EnvFunc, Array, VDevice
from tvm.tir import PrimExpr
from tvm.runtime import DataType
from .expr import StructInfo, Expr, ShapeExpr

from . import _ffi_api, ty, expr


[文档] @tvm._ffi.register_object("relax.ObjectStructInfo") class ObjectStructInfo(StructInfo): """StructInfo of an Object.""" def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore
[文档] @tvm._ffi.register_object("relax.PrimStructInfo") class PrimStructInfo(StructInfo): """StructInfo of a primitive POD value. Parameters ---------- dtype_or_expr : Union[str, DataType, PrimExpr] The data type of the prim value, or a known expression for the prim value. """ value: Optional[PrimExpr] dtype: str def __init__( self, dtype: Optional[Union[str, DataType]] = None, value: Optional[Union[int, float, PrimExpr]] = None, span: Span = None, ) -> None: # Guard against incorrect usage. For backwards compatibility, # the dtype and value are in the opposite order from most # usages. While PrimStructInfo could take a single positional # argument and check the type, this would require an API # difference from TVMScript's PrimProxy, which cannot. # (PrimProxy uses string arguments for datatype, and also for # inline variable definitions when used in a function # signature, and requires separate arguments to distinguish # the two cases.) if isinstance(dtype, (PrimExpr, int, float)): raise TypeError( f"The first positional argument of PrimStructInfo must be the datatype, " f", but received {type(dtype)}. " f"The value can be specified as a keyword argument " f"without needing specifying the dtype: " f"PrimStructInfo(value=arg)." ) if dtype is None and value is None: raise TypeError( "PrimStructInfo.__init__ missing required argument. " "Must provide either 'dtype' or 'value'" ) if dtype is not None: if isinstance(value, PrimExpr): assert value.dtype == dtype, ( "When providing both 'value' and 'dtype' to PrimStructInfo.__init__, " "they must be consistent with each other. " "However, the value {value} has dtype {value.dtype}, " "but the specified dtype was {dtype}." ) elif isinstance(value, (int, float)): value = tvm.tir.const(value, dtype) # Use relax's default integer type if not otherwise specified. if isinstance(value, int): value = tvm.tir.IntImm("int64", value) if value is None: self.__init_handle_by_constructor__( _ffi_api.PrimStructInfoFromDtype, dtype, span ) # type: ignore else: self.__init_handle_by_constructor__( _ffi_api.PrimStructInfoFromValue, value, span ) # type: ignore
[文档] @tvm._ffi.register_object("relax.ShapeStructInfo") class ShapeStructInfo(StructInfo): """StructInfo of a shape value. Parameters ---------- values : Optional[List[PrimExpr]] The symbolic shape values if known. ndim : Optional[int] The size of the shape. Note ---- Do not specify values and ndim at the same time. """ values: Optional[List[PrimExpr]] ndim: int span: Span def __init__( self, values: Optional[List[PrimExpr]] = None, ndim: int = -1, span: Span = None ) -> None: self.__init_handle_by_constructor__( _ffi_api.ShapeStructInfo, values, ndim, span # type: ignore )
[文档] @tvm._ffi.register_object("relax.TensorStructInfo") class TensorStructInfo(StructInfo): """StructInfo of a Tensor value. Parameters ---------- shape : Optional[Expr] The shape expression. dtype : Optional[str] The content data type. vdevice : Optional[Vdevice] The virtual device. ndim : Optional[int] The number of dimensions of the tensor. Note ---- Do not specify shape and ndim at the same time. """ shape: Optional[Expr] dtype: str vdevice: Optional[VDevice] ndim: int span: Span def __init__( self, shape: Union[Optional[Expr], List[PrimExpr]] = None, dtype: str = "float32", vdevice: Union[Optional[VDevice], str] = None, ndim: int = -1, span: Span = None, ) -> None: if isinstance(shape, (list, tuple, Array)): shape = ShapeExpr(shape) self.__init_handle_by_constructor__( _ffi_api.TensorStructInfo, shape, dtype, ndim, vdevice, span # type: ignore )
[文档] @tvm._ffi.register_object("relax.TupleStructInfo") class TupleStructInfo(StructInfo): """StructInfo of a Tuple value. Parameters ---------- fields: List[StructInfo] The struct info of the fields. """ fields: List[StructInfo] span: Span def __init__(self, fields: List[StructInfo], span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.TupleStructInfo, fields, span) # type: ignore
[文档] @tvm._ffi.register_object("relax.FuncStructInfo") class FuncStructInfo(StructInfo): """StructInfo of a function value. Parameters ---------- params: List[StructInfo] The struct info of the fields. ret: StructInfo The struct info of return value purity: bool Whether the function is pure (has no visible side effects). Note: We consider a function to be pure only if it is pure on all inputs. If a function can have visible side effects only in some cases, we still consider it impure. """ params: Optional[List[StructInfo]] ret: StructInfo derive_func: Optional[EnvFunc] purity: bool span: Span def __init__( self, params: List[StructInfo], ret: StructInfo, purity: bool = True, span: Span = None ) -> None: self.__init_handle_by_constructor__( _ffi_api.FuncStructInfo, params, ret, purity, span # type: ignore )
[文档] @staticmethod def opaque_func( *, ret: Optional[StructInfo] = None, derive_func: Optional[Union[str, EnvFunc]] = None, purity: bool = False, span: Span = None, ) -> "FuncStructInfo": """ Create an opaque FuncStructInfo. The opaque function takes either a ret that specificies the struct info of the return value or a derive_func that provides a customized derivation rule. Parameters ---------- ret: Optional[StructInfo] The struct info of the function return value. derive_func: Optional[Union[str,EnvFunc]] The environment function used for derivation purity: bool Whether the function is pure (false by default, as most opaque functions are not pure) span: Optional[Span] Optional span information of the ast. Returns ------- info: FuncStructInfo Note ---- We cannot specify ret and derive_func simultaneously. """ if isinstance(derive_func, str): derive_func = tvm.ir.EnvFunc.get("tvm.relax.struct_info.infer_view_sinfo") return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, purity, span) # type: ignore