generic_func()

generic_func()#

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/read
import tvm
# wrap function as target generic
@tvm.target.generic_func
def my_func(a):
    return a + 1
# register specialization of my_func under target cuda
@my_func.register("cuda")
def my_func_cuda(a):
    return a + 2
# displays 3, because my_func is called
print(my_func(2))
# displays 4, because my_func_cuda is called
with tvm.target.cuda():
    print(my_func(2))
3
4
/media/pc/data/lxw/ai/tvm/python/tvm/target/target.py:446: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.
  warnings.warn("Try specifying cuda arch by adding 'arch=sm_xx' to your target.")
[20:02:00] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
import json

import pytest
import tvm
import tvm.testing
from tvm.target import Target, arm_cpu, bifrost, cuda, intel_graphics, mali, rocm, vta


@tvm.target.generic_func
def mygeneric(data):
    # default generic function
    return data + 1
@mygeneric.register(["cuda", "gpu"])
def cuda_func(data):
    return data + 2


@mygeneric.register("rocm")
def rocm_func(data):
    return data + 3


@mygeneric.register("cpu")
def rocm_func(data):
    return data + 10

所有目标设备类型的一致性验证:

all_targets = [tvm.target.Target(t) for t in tvm.target.Target.list_kinds()]

for tgt in all_targets:
    # skip targets with hooks or otherwise intended to be used with external codegen
    relay_to_tir = tgt.get_kind_attr("RelayToTIR")
    tir_to_runtime = tgt.get_kind_attr("TIRToRuntime")
    is_external_codegen = tgt.get_kind_attr("is_external_codegen")
    if relay_to_tir is not None or tir_to_runtime is not None or is_external_codegen:
        continue

    if tgt.kind.name not in tvm._ffi.runtime_ctypes.Device.STR2MASK:
        raise KeyError("Cannot find target kind: %s in Device.STR2MASK" % tgt.kind.name)

    assert (
        tgt.get_target_device_type() == tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]
    )
ROCm not detected, using default gfx900
[20:04:10] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:04:10] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:186: Warning: Unable to detect CUDA version, default to "-mcpu=sm_50" instead
[20:04:10] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:224: Warning: Unable to detect ROCm version, assuming >= 3.5
def test_target_dispatch():
    with tvm.target.cuda():
        assert mygeneric(1) == 3
        assert mygeneric.get_packed_func()(1) == 3

    with tvm.target.rocm():
        assert mygeneric(1) == 4
        assert mygeneric.get_packed_func()(1) == 4

    with tvm.target.Target("cuda"):
        assert mygeneric(1) == 3
        assert mygeneric.get_packed_func()(1) == 3

    with tvm.target.arm_cpu():
        assert mygeneric(1) == 11
        assert mygeneric.get_packed_func()(1) == 11

    with tvm.target.Target("metal"):
        assert mygeneric(1) == 3
        assert mygeneric.get_packed_func()(1) == 3

    assert tvm.target.Target.current() is None
test_target_dispatch()
ROCm not detected, using default gfx900
/media/pc/data/lxw/ai/tvm/python/tvm/target/target.py:446: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.
  warnings.warn("Try specifying cuda arch by adding 'arch=sm_xx' to your target.")
[20:07:57] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:07:57] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:224: Warning: Unable to detect ROCm version, assuming >= 3.5
[20:07:57] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead