tvm.contrib.utils 源代码
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Common system utilities"""
import atexit
import contextlib
import datetime
import os
import pathlib
import tempfile
import threading
import shutil
try:
import fcntl
except ImportError:
fcntl = None
[文档]
class DirectoryCreatedPastAtExit(Exception):
"""Raised when a TempDirectory is created after the atexit hook runs."""
[文档]
class TempDirectory(object):
"""Helper object to manage temp directory during testing.
Automatically removes the directory when it went out of scope.
"""
# When True, all TempDirectory are *NOT* deleted and instead live inside a predicable directory
# tree.
_KEEP_FOR_DEBUG = False
# In debug mode, each tempdir is named after the sequence
_NUM_TEMPDIR_CREATED = 0
_NUM_TEMPDIR_CREATED_LOCK = threading.Lock()
@classmethod
def _increment_num_tempdir_created(cls):
with cls._NUM_TEMPDIR_CREATED_LOCK:
to_return = cls._NUM_TEMPDIR_CREATED
cls._NUM_TEMPDIR_CREATED += 1
return to_return
_DEBUG_PARENT_DIR = None
@classmethod
def _get_debug_parent_dir(cls):
if cls._DEBUG_PARENT_DIR is None:
all_parents = f"{tempfile.gettempdir()}/tvm-debug-mode-tempdirs"
if not os.path.isdir(all_parents):
os.makedirs(all_parents)
cls._DEBUG_PARENT_DIR = tempfile.mkdtemp(
prefix=datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S___"), dir=all_parents
)
return cls._DEBUG_PARENT_DIR
TEMPDIRS = set()
@classmethod
def remove_tempdirs(cls):
temp_dirs = getattr(cls, "TEMPDIRS", None)
if temp_dirs is None:
return
for path in temp_dirs:
shutil.rmtree(path, ignore_errors=True)
cls.TEMPDIRS = None
[文档]
@classmethod
@contextlib.contextmanager
def set_keep_for_debug(cls, set_to=True):
"""Keep temporary directories past program exit for debugging."""
old_keep_for_debug = cls._KEEP_FOR_DEBUG
try:
cls._KEEP_FOR_DEBUG = set_to
yield
finally:
cls._KEEP_FOR_DEBUG = old_keep_for_debug
def __init__(self, custom_path=None, keep_for_debug=None):
if self.TEMPDIRS is None:
raise DirectoryCreatedPastAtExit()
if keep_for_debug is not None:
self._created_with_keep_for_debug = keep_for_debug
else:
self._created_with_keep_for_debug = self._KEEP_FOR_DEBUG
if custom_path:
os.mkdir(custom_path)
self.temp_dir = custom_path
else:
if self._created_with_keep_for_debug:
parent_dir = self._get_debug_parent_dir()
self.temp_dir = f"{parent_dir}/{self._increment_num_tempdir_created():05d}"
os.mkdir(self.temp_dir)
else:
self.temp_dir = tempfile.mkdtemp()
if not self._created_with_keep_for_debug:
self.TEMPDIRS.add(self.temp_dir)
[文档]
def remove(self):
"""Remove the tmp dir"""
if self.temp_dir:
if not self._created_with_keep_for_debug:
shutil.rmtree(self.temp_dir, ignore_errors=True)
self.TEMPDIRS.remove(self.temp_dir)
self.temp_dir = None
@property
def path(self):
return pathlib.Path(self.temp_dir)
def __truediv__(self, other):
if not isinstance(other, (str, pathlib.Path)):
raise TypeError(
f"TempDirectory / operator: must supply str or pathlib.Path; got {repr(other)}"
)
return self.path / other
def __del__(self):
temp_dirs = getattr(self, "TEMPDIRS", None)
if temp_dirs is None:
# Do nothing if the atexit hook has already run.
return
self.remove()
[文档]
def relpath(self, name):
"""Relative path in temp dir
Parameters
----------
name : str
The name of the file.
Returns
-------
path : str
The concatenated path.
"""
return os.path.join(self.temp_dir, name)
[文档]
def listdir(self):
"""List contents in the dir.
Returns
-------
names : list
The content of directory
"""
return os.listdir(self.temp_dir)
atexit.register(TempDirectory.remove_tempdirs)
[文档]
def tempdir(custom_path=None, keep_for_debug=None):
"""Create temp dir which deletes the contents when exit.
Parameters
----------
custom_path : str, optional
Manually specify the exact temp dir path
keep_for_debug : bool
Keep temp directory for debugging purposes
Returns
-------
temp : TempDirectory
The temp directory object
"""
return TempDirectory(custom_path=custom_path, keep_for_debug=keep_for_debug)
[文档]
class FileLock(object):
"""File lock object
Parameters
----------
path : str
The path to the lock
"""
def __init__(self, path):
self.lock_file = open(path, "w")
if fcntl:
fcntl.lockf(self.lock_file, fcntl.LOCK_EX)
[文档]
def release(self):
"""Release the lock"""
if self.lock_file:
if fcntl:
fcntl.lockf(self.lock_file, fcntl.LOCK_UN)
self.lock_file.close()
self.lock_file = None
[文档]
def filelock(path):
"""Create a file lock which locks on path
Parameters
----------
path : str
The path to the lock
Returns
-------
lock : File lock object
"""
return FileLock(path)
[文档]
def is_source_path(path):
"""Check if path is source code path.
Parameters
----------
path : str
A possible path
Returns
-------
valid : bool
Whether path is a possible source path
"""
if os.path.exists(path):
return True
if path.find("\n") != -1:
return False
spath = path.rsplit(".", 1)
return len(spath) == 2 and spath[1].strip() == spath[1]
[文档]
def which(exec_name):
"""Try to find full path of exec_name
Parameters
----------
exec_name : str
The executable name
Returns
-------
path : str
The full path of executable if found, otherwise returns None
"""
base_list = ["", "/bin"] + os.environ.get("PATH", "").split(os.pathsep)
for path in base_list:
full_path = os.path.join(path, exec_name)
if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
return full_path
return None