generic_func()
#
%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/read
import tvm
# wrap function as target generic
@tvm.target.generic_func
def my_func(a):
return a + 1
# register specialization of my_func under target cuda
@my_func.register("cuda")
def my_func_cuda(a):
return a + 2
# displays 3, because my_func is called
print(my_func(2))
# displays 4, because my_func_cuda is called
with tvm.target.cuda():
print(my_func(2))
3
4
/media/pc/data/lxw/ai/tvm/python/tvm/target/target.py:446: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.
warnings.warn("Try specifying cuda arch by adding 'arch=sm_xx' to your target.")
[20:02:00] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
import json
import pytest
import tvm
import tvm.testing
from tvm.target import Target, arm_cpu, bifrost, cuda, intel_graphics, mali, rocm, vta
@tvm.target.generic_func
def mygeneric(data):
# default generic function
return data + 1
@mygeneric.register(["cuda", "gpu"])
def cuda_func(data):
return data + 2
@mygeneric.register("rocm")
def rocm_func(data):
return data + 3
@mygeneric.register("cpu")
def rocm_func(data):
return data + 10
所有目标设备类型的一致性验证:
all_targets = [tvm.target.Target(t) for t in tvm.target.Target.list_kinds()]
for tgt in all_targets:
# skip targets with hooks or otherwise intended to be used with external codegen
relay_to_tir = tgt.get_kind_attr("RelayToTIR")
tir_to_runtime = tgt.get_kind_attr("TIRToRuntime")
is_external_codegen = tgt.get_kind_attr("is_external_codegen")
if relay_to_tir is not None or tir_to_runtime is not None or is_external_codegen:
continue
if tgt.kind.name not in tvm._ffi.runtime_ctypes.Device.STR2MASK:
raise KeyError("Cannot find target kind: %s in Device.STR2MASK" % tgt.kind.name)
assert (
tgt.get_target_device_type() == tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]
)
ROCm not detected, using default gfx900
[20:04:10] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:04:10] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:186: Warning: Unable to detect CUDA version, default to "-mcpu=sm_50" instead
[20:04:10] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:224: Warning: Unable to detect ROCm version, assuming >= 3.5
def test_target_dispatch():
with tvm.target.cuda():
assert mygeneric(1) == 3
assert mygeneric.get_packed_func()(1) == 3
with tvm.target.rocm():
assert mygeneric(1) == 4
assert mygeneric.get_packed_func()(1) == 4
with tvm.target.Target("cuda"):
assert mygeneric(1) == 3
assert mygeneric.get_packed_func()(1) == 3
with tvm.target.arm_cpu():
assert mygeneric(1) == 11
assert mygeneric.get_packed_func()(1) == 11
with tvm.target.Target("metal"):
assert mygeneric(1) == 3
assert mygeneric.get_packed_func()(1) == 3
assert tvm.target.Target.current() is None
test_target_dispatch()
ROCm not detected, using default gfx900
/media/pc/data/lxw/ai/tvm/python/tvm/target/target.py:446: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.
warnings.warn("Try specifying cuda arch by adding 'arch=sm_xx' to your target.")
[20:07:57] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:07:57] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:224: Warning: Unable to detect ROCm version, assuming >= 3.5
[20:07:57] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead