# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.# pylint: disable=invalid-name"""The build utils in python."""fromtypingimportUnion,Optional,Dictimportenumimporttvmfromtvmimportirfromtvm.runtimeimportndarrayfromtvm.tirimportPrimFuncfromtvm.ir.moduleimportIRModulefromtvm.targetimportTargetdefsplit_host_device_mods(mod):"""Split an IRModule into host and device modules. Parameters ---------- mod : tvm.IRModule The input module to split Returns ------- host_mod : tvm.IRModule The module containing host functions device_mod_dict : Dict[Target, tvm.IRModule] A dict mapping targets to device modules """classCallConv(enum.IntEnum):"""Enum representing different calling conventions. Corresponds to the C++ tvm::ir::CallingConv enum. """kDefault=0kCPackedFunc=1kDeviceKernelLaunch=2host_mod=tvm.tir.transform.Filter(lambdaf:int(f.attrs.get("calling_conv",CallConv.kDefault))!=int(CallConv.kDeviceKernelLaunch))(mod)device_mod=tvm.tir.transform.Filter(lambdaf:int(f.attrs.get("calling_conv",CallConv.kDefault))==int(CallConv.kDeviceKernelLaunch))(mod)device_mod_dict={}forgv,funcindevice_mod.functions.items():device_mod_dict.setdefault(func.attrs.get("target",None),dict()).update({gv:func})fortarget,funcsindevice_mod_dict.items():device_mod_dict[target]=tvm.IRModule(funcs,attrs=device_mod.attrs)returnhost_mod,device_mod_dictdefcodegen_build(mod:IRModule,target:Target)->tvm.runtime.Module:"""Build a runtime module from an IRModule and a Target."""iftvm.ir.transform.PassContext.current().config.get("tir.disable_assert",False):mod=tvm.tir.transform.SkipAssert()(mod)build_f_name="target.build."+target.kind.namebf=tvm.get_global_func(build_f_name)ifbfisNone:raiseValueError(f"{build_f_name} is not enabled")returnbf(mod,target)deftir_to_runtime(host_mod:IRModule,device_mod_dict:Dict[Target,IRModule],target_host:Target):"""Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module."""# Get the first module to get the attributes# necessary for tests/python/codegen/test_target_codegen_blob.py::test_cuda_multi_libmhost_all=ir.IRModule({},attrs=host_mod.attrs)mhost_all.update(host_mod)device_modules=[]fortarget,device_modindevice_mod_dict.items():iflen(device_mod.functions)!=0:device_modules.append(codegen_build(device_mod,target))mhost=codegen_build(mhost_all,target_host)fordev_modindevice_modules:ifdev_modisnotNone:mhost.import_module(dev_mod)returnmhost
[文档]defbuild(mod:Union[PrimFunc,IRModule],target:Optional[Union[str,Target]]=None,pipeline:Union[None,str,tvm.transform.Pass]="default",):"""Build a function with a signature, generating code for devices coupled with target information. Parameters ---------- mod : Union[PrimFunc, IRModule] The input to be built. target : Optional[Union[str, Target]] The target for compilation. pipeline : Union[None, str, tvm.transform.Pass] The pipeline to use for compilation. Returns ------- tvm.runtime.Module A module combining both host and device code. """# Convert PrimFunc to IRModuleifisinstance(mod,PrimFunc):mod=tvm.IRModule.from_expr(mod)else:assertisinstance(mod,tvm.IRModule)# Step 0: Determine the target in environment# It's used to bind the PrimFunc without target attr to serve as a default targettarget_to_bind=Target.current()iftargetisNoneelsetargetiftarget_to_bindisNone:target_to_bind="llvm"asserttarget_to_bindisnotNonetarget_to_bind=Target.canon_target(target_to_bind)# Step 1: Determine the target to search for tir pipelinetarget=Target.current()iftargetisNoneelsetargetiftargetisNone:forfuncinmod.functions.values():f_target=func.attrs.get("target",None)iff_targetisnotNone:target=f_targetbreakiftargetisnotNone:target=Target.canon_target(target)# Step 2: Determine the host targettarget_host="llvm"iftvm.runtime.enabled("llvm")else"stackvm"iftargetisnotNone:iftarget.hostisnotNone:target_host=target.hostelifndarray.device(target.kind.name,0).device_type==ndarray.cpu(0).device_type:target_host=targettarget_host=Target.canon_target(target_host)target_to_bind=target_to_bind.with_host(target_host)# Step 3: Bind the target to the input modulemod=tvm.tir.transform.BindTarget(target_to_bind)(mod)# Step 4: Apply the tir pipelineifpipelineisnotNone:# custom pipelineifisinstance(pipeline,str):pipeline=tvm.tir.get_tir_pipeline(pipeline)else:# default pipeline depends on the targetpipeline=tvm.tir.get_default_tir_pipeline(target)mod=pipeline(mod)# Step 5: Get host and device moduleshost_mod,device_mod_dict=split_host_device_mods(mod)# Step 6: Apply finalization passeshost_mod=tvm.tir.pipeline.finalize_host_passes()(host_mod)device_mod_dict={target:tvm.tir.pipeline.finalize_device_passes()(device_mod)fortarget,device_modindevice_mod_dict.items()}# Convert TIR IRModules to runtime Module by calling target.buildreturntir_to_runtime(host_mod,device_mod_dict,target_host)