解读 override_native_generic_func()

解读 override_native_generic_func()#

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/read
import tvm
@tvm.target.override_native_generic_func("test_target_temp_strategy")
def target_generic(data):
    # default generic function
    return data + 1
@target_generic.register(["cuda", "gpu"])
def target_cuda_func(data):
    return data + 2
def temp_target_cuda_func(data):
    return data + 3
class TempStrategy:
    def __init__(self, name, target, fstrategy):
        generic_fstrategy = tvm.target.get_native_generic_func(name)
        self.target = target
        self.name = name
        self.origin_func = {}
        with tvm.target.Target(target) as target_obj:
            for tgt_key in target_obj.keys:
                self.origin_func[tgt_key] = generic_fstrategy.get_packed_func()
                generic_fstrategy.register(fstrategy, tgt_key, allow_override=True)

    def __enter__(self):
        return self

    def __exit__(self, typ, value, traceback):
        generic_fstrategy = tvm.target.get_native_generic_func(self.name)
        with tvm.target.Target(self.target) as target_obj:
            for tgt_key in target_obj.keys:
                generic_fstrategy.register(
                    self.origin_func[tgt_key], tgt_key, allow_override=True
                )

with tvm.target.Target("cuda"):
    assert target_generic(1) == 3

# The strategy func change to temp_target_cuda_func.
with TempStrategy("test_target_temp_strategy", "cuda", temp_target_cuda_func):
    with tvm.target.Target("cuda"):
        assert target_generic(1) == 4

with tvm.target.Target("cuda"):
    assert target_generic(1) == 3
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead
[20:11:25] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to "-arch=sm_50" instead