tvm.tir.build 源代码

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name
"""The build utils in python."""
from typing import Dict, Optional, Tuple, Union

import tvm
from tvm import ir
from tvm.ir.module import IRModule
from tvm.target import Target
from tvm.tir import PrimFunc


def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target, IRModule]]:
    """Split an IRModule into host and device modules.

    This function takes an IRModule containing functions with different target attributes
    and separates them into host (CPU) and device (GPU/accelerator) modules. Functions
    are categorized based on their target attribute in func_attr.

    Parameters
    ----------
    mod : tvm.IRModule
        The input module to split.
        The module should contain functions with target attributes in their func_attr.
        Functions with "cpu" in their target string are considered host functions,
        while others are considered device functions.

    Returns
    -------
    host_mod : tvm.IRModule
        The module containing host functions (CPU-targeted functions)
    device_mod_dict : Dict[Target, tvm.IRModule]
        A dict mapping targets to device modules. Each device module contains
        functions targeting the same device (e.g., CUDA GPU, OpenCL, etc.)

        Examples
    --------
    Given an IRModule with the following functions:

    .. code-block:: python

        @I.ir_module
        class Module:
            @T.prim_func(private=True)
            def add(a: T.int32, b: T.int32) -> T.int32:
                T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"],
                                                "kind": "cuda", "max_num_threads": 1024}))
                return a + b

            @T.prim_func(private=True)
            def add_host(a: T.int32, b: T.int32) -> T.int32:
                T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"}))
                return a + b

            @T.prim_func
            def main_kernel(A: T.handle, B: T.handle, C: T.handle, length: T.int32):
                T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"],
                                                "kind": "cuda"}),
                            "calling_conv": 2,  # kDeviceKernelLaunch for device kernels
                            "tir.is_global_func": True})
                # ... kernel implementation

            @T.prim_func
            def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.handle):
                T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"}),
                            "calling_conv": 1,  # kCPackedFunc for entry functions
                            "tir.is_entry_func": True})
                # ... main function implementation

    The function will return:
    - host_mod: Contains `add_host` and `main` functions (CPU targets)
    - device_mod_dict: Contains a CUDA module with `add` and `main_kernel` functions

    Notes
    -----
    - Functions are categorized based on string matching of their target attribute
    - Functions with "cpu" in the target string are considered host functions
    - Device functions are grouped by their target to create separate modules
    - The function uses string-based target matching due to target hash limitations
    - All functions must have a `calling_conv` attribute in their func_attr:
        - Private helper functions (private=True): use `calling_conv: 0` (kDefault, by default)
        - Public entry functions: use `calling_conv: 1` (kCPackedFunc)
        - Device kernel functions: use `calling_conv: 2` (kDeviceKernelLaunch)
    """

    def is_host_func(f):
        target = f.attrs.get("target", tvm.target.Target("llvm"))
        return str(target.kind) in ["llvm", "c"]

    host_mod = tvm.tir.transform.Filter(is_host_func)(mod)
    device_mod = tvm.tir.transform.Filter(lambda f: not is_host_func(f))(mod)
    # TODO(syfeng): Here we use str as key since target hash is not correct
    target_str2target = {}
    device_func_dict = {}
    device_mod_dict: Dict[Target, IRModule] = {}
    for gv, func in device_mod.functions.items():
        target = func.attrs.get("target", None)
        target_str = str(target) if target is not None else ""
        target_str2target[target_str] = target  # This might be overridden by the last one
        device_func_dict.setdefault(target_str, dict()).update({gv: func})
    for target_str in target_str2target.keys():
        target = target_str2target[target_str]
        device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], attrs=device_mod.attrs)
    return host_mod, device_mod_dict


def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module:
    """Build a runtime module from an IRModule and a Target."""
    if tvm.ir.transform.PassContext.current().config.get("tir.disable_assert", False):
        mod = tvm.tir.transform.SkipAssert()(mod)
    build_f_name = "target.build." + target.kind.name
    bf = tvm.get_global_func(build_f_name)
    if bf is None:
        raise ValueError(f"{build_f_name} is not enabled")
    return bf(mod, target)


def tir_to_runtime(
    host_mod: IRModule, device_mod_dict: Dict[Target, IRModule], target_host: Target
):
    """Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module."""

    # Get the first module to get the attributes
    # necessary for tests/python/codegen/test_target_codegen_blob.py::test_cuda_multi_lib
    mhost_all = ir.IRModule({}, attrs=host_mod.attrs)

    mhost_all.update(host_mod)
    device_modules = []
    for target, device_mod in device_mod_dict.items():
        if len(device_mod.functions) != 0:
            device_modules.append(codegen_build(device_mod, target))

    mhost = codegen_build(mhost_all, target_host)
    for dev_mod in device_modules:
        if dev_mod is not None:
            mhost.import_module(dev_mod)
    return mhost


[文档] def build( mod: Union[PrimFunc, IRModule], target: Optional[Union[str, Target]] = None, pipeline: Union[None, str, tvm.transform.Pass] = "default", ): """Build a function with a signature, generating code for devices coupled with target information. Parameters ---------- mod : Union[PrimFunc, IRModule] The input to be built. target : Optional[Union[str, Target]] The target for compilation. pipeline : Union[None, str, tvm.transform.Pass] The pipeline to use for compilation. Returns ------- tvm.runtime.Module A module combining both host and device code. """ # Convert PrimFunc to IRModule if isinstance(mod, PrimFunc): mod = tvm.IRModule.from_expr(mod) else: assert isinstance(mod, tvm.IRModule) # Step 0: Determine the target in environment # It's used to bind the PrimFunc without target attr to serve as a default target target_to_bind = Target.current() if target is None else target if target_to_bind is None: target_to_bind = "llvm" assert target_to_bind is not None target_to_bind = Target.canon_target(target_to_bind) # Step 1: Determine the target to search for tir pipeline target = Target.current() if target is None else target if target is None: for func in mod.functions.values(): f_target = func.attrs.get("target", None) if f_target is not None: target = f_target break if target is not None: target = Target.canon_target(target) # Step 2: Determine the host target target_host = "llvm" if tvm.runtime.enabled("llvm") else "c" if target is not None: if target.host is not None: target_host = target.host elif ( tvm.device(target.kind.name, 0).dlpack_device_type() == tvm.cpu(0).dlpack_device_type() ): target_host = target target_host = Target.canon_target(target_host) target_to_bind = target_to_bind.with_host(target_host) # Step 3: Bind the target to the input module mod = tvm.tir.transform.BindTarget(target_to_bind)(mod) # Step 4: Apply the tir pipeline if pipeline is not None: # custom pipeline if isinstance(pipeline, str): pipeline = tvm.tir.get_tir_pipeline(pipeline) else: # default pipeline depends on the target pipeline = tvm.tir.get_default_tir_pipeline(target) mod = pipeline(mod) # Step 5: Get host and device modules host_mod, device_mod_dict = split_host_device_mods(mod) # Step 6: Apply finalization passes host_mod = tvm.tir.pipeline.finalize_host_passes()(host_mod) device_mod_dict = { target: tvm.tir.pipeline.finalize_device_passes()(device_mod) for target, device_mod in device_mod_dict.items() } # Convert TIR IRModules to runtime Module by calling target.build return tir_to_runtime(host_mod, device_mod_dict, target_host)
tvm.register_global_func("tir.build", build)