解读 override_native_generic_func()
#
%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/read
import tvm
@tvm.target.override_native_generic_func("test_target_temp_strategy")
def target_generic(data):
# default generic function
return data + 1
@target_generic.register(["cuda", "gpu"])
def target_cuda_func(data):
return data + 2
def temp_target_cuda_func(data):
return data + 3
class TempStrategy:
def __init__(self, name, target, fstrategy):
generic_fstrategy = tvm.target.get_native_generic_func(name)
self.target = target
self.name = name
self.origin_func = {}
with tvm.target.Target(target) as target_obj:
for tgt_key in target_obj.keys:
self.origin_func[tgt_key] = generic_fstrategy.get_packed_func()
generic_fstrategy.register(fstrategy, tgt_key, allow_override=True)
def __enter__(self):
return self
def __exit__(self, typ, value, traceback):
generic_fstrategy = tvm.target.get_native_generic_func(self.name)
with tvm.target.Target(self.target) as target_obj:
for tgt_key in target_obj.keys:
generic_fstrategy.register(
self.origin_func[tgt_key], tgt_key, allow_override=True
)
with tvm.target.Target("cuda"):
assert target_generic(1) == 3
# The strategy func change to temp_target_cuda_func.
with TempStrategy("test_target_temp_strategy", "cuda", temp_target_cuda_func):
with tvm.target.Target("cuda"):
assert target_generic(1) == 4
with tvm.target.Target("cuda"):
assert target_generic(1) == 3
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead