tvm.relax.op.distributed.distributed 源代码

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
"""Operators for distributed Relax.
"""
from typing import Union, List, Tuple, Optional

from tvm.relax.distributed.struct_info import DeviceMesh, Placement
from tvm.ir import PrimExpr
from tvm.relax.utils import args_converter
from tvm.relax.distributed import DTensorStructInfo
from ...expr import Tuple as RxTuple
from . import _ffi_api
from ...expr import Expr, ShapeExpr, Call, GlobalVar


[文档] def annotate_sharding(input: Expr, device_mesh: DeviceMesh, placement: Placement) -> Expr: """Annotate sharding plan for tensor Parameters ---------- input : relax.Expr The input tensor. device_mesh: DeviceMesh The device mesh of the sharding plan placement: Placement The placement of the sharding plan Returns ------- result : relax.Expr The tensor unmodified. """ return _ffi_api.annotate_sharding(input, device_mesh, placement) # type: ignore
[文档] def redistribute(input: Expr, device_mesh: DeviceMesh, placement: Placement) -> Expr: """Redistribute tensor Parameters ---------- input : relax.Expr The input tensor. device_mesh: DeviceMesh The device mesh after redistribution placement: Placement The placement after redistribution Returns ------- result : relax.Expr The tensor after redistribution. """ return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore
[文档] @args_converter.auto def call_tir_local_view( gvar: GlobalVar, args: Expr, out_sinfo: Union[DTensorStructInfo, List[DTensorStructInfo]], tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None, ) -> Call: """ Call a tir.prim_func and return the output. The prim_func should be a worker-local function that is actually executed on each worker, instead of the unpartitioned function. The output of this operator is DTensor or a tuple of DTensors. Parameters ---------- gvar : GlobalVar The GlobalVar referring to a tir PrimFunc. args : Expr The input arguments. out_sinfo : Union[DTensorStructInfo, List[DTensorStructInfo]] The structure info of the call_tir output. It should be a single or a list of DTensorStructInfo. Each one denotes the structure info of a returned tensor. tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used Returns ------- ret: Call A call node for the call_tir_local_view operator. """ if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] if isinstance(tir_vars, (list, tuple)): tir_vars = ShapeExpr(tir_vars) return _ffi_api.call_tir_local_view(gvar, args, out_sinfo, tir_vars) # type: ignore
[文档] def redistribute_replica_to_shard(input: Expr, num_workers: int, axis: int) -> Expr: """Slice tensor into several parts along one axis, and each worker takes one part. input.struct_info.shape[axis] % num_workers == 0 is required. Each worker must have an identical copy of the input. This is a specialized version of redistribute op. Parameters ---------- input : relax.Expr The buffer to be sliced into equal parts. num_worker : int The number of workers, i.e. the number of parts the given buffer should be sliced into. axis : int The axis of the tensor to be sliced. Returns ------- result : relax.Expr Sliced Tensor kept by each device. """ return _ffi_api.redistribute_replica_to_shard(input, num_workers, axis)