tvm.contrib.pickle_memoize 源代码
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Memoize result of function via pickle, used for cache testcases."""
# pylint: disable=broad-except,superfluous-parens
import atexit
import os
import pathlib
import sys
from decorator import decorate
from .._ffi.base import string_types
try:
import cPickle as pickle
except ImportError:
import pickle
def _get_global_cache_dir() -> pathlib.Path:
if "XDG_CACHE_HOME" in os.environ:
cache_home = pathlib.Path(os.environ.get("XDG_CACHE_HOME"))
else:
cache_home = pathlib.Path.home().joinpath(".cache")
return cache_home.joinpath("tvm", f"pkl_memoize_py{sys.version_info[0]}")
GLOBAL_CACHE_DIR = _get_global_cache_dir()
[文档]
class Cache(object):
"""A cache object for result cache.
Parameters
----------
key: str
The file key to the function
save_at_exit: bool
Whether save the cache to file when the program exits
"""
cache_by_key = {}
def __init__(self, key, save_at_exit):
self._cache = None
self.path = GLOBAL_CACHE_DIR.joinpath(key)
self.dirty = False
self.save_at_exit = save_at_exit
@property
def cache(self):
"""Return the cache, initializing on first use."""
if self._cache is not None:
return self._cache
if self.path.exists():
with self.path.open("rb") as cache_file:
try:
cache = pickle.load(cache_file)
except pickle.UnpicklingError:
cache = {}
else:
cache = {}
self._cache = cache
return self._cache
def save(self):
if self.dirty:
self.path.parent.mkdir(parents=True, exist_ok=True)
with self.path.open("wb") as out_file:
pickle.dump(self.cache, out_file, pickle.HIGHEST_PROTOCOL)
[文档]
@atexit.register
def _atexit():
"""Save handler."""
for value in Cache.cache_by_key.values():
if value.save_at_exit:
value.save()
[文档]
def memoize(key, save_at_exit=False):
"""Memoize the result of function and reuse multiple times.
Parameters
----------
key: str
The unique key to the file
save_at_exit: bool
Whether save the cache to file when the program exits
Returns
-------
fmemoize : function
The decorator function to perform memoization.
"""
def _register(f):
"""Registration function"""
allow_types = (string_types, int, float, tuple)
fkey = key + "." + f.__name__ + ".pkl"
if fkey not in Cache.cache_by_key:
Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit)
cache = Cache.cache_by_key[fkey]
cargs = tuple(x.cell_contents for x in f.__closure__) if f.__closure__ else ()
cargs = (len(cargs),) + cargs
def _memoized_f(func, *args, **kwargs):
assert not kwargs, "Only allow positional call"
key = cargs + args
for arg in key:
if isinstance(arg, tuple):
for x in arg:
assert isinstance(x, allow_types)
else:
assert isinstance(arg, allow_types)
if key in cache.cache:
return cache.cache[key]
res = func(*args)
cache.cache[key] = res
cache.dirty = True
return res
return decorate(f, _memoized_f)
return _register