tvm.auto_scheduler.workload_registry 源代码

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name

"""
Workload registration and serialization.

We use a json string to represent a workload (a computation graph).
The format of the string is `[func_name, [args...]]`.
The dag should be the return value of this `func_name(*args)`.

Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags
and matching them efficiently is not easy. Therefore, we use the above string to encode a compute
dag.
These strings are efficient for serialization/matching and won't be too long.
When we need the dag, we decode the string and call the function, which will return the dag.
"""

import json
import logging
import pickle

import tvm._ffi
from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON

from .utils import deserialize_args, get_func_name, serialize_args

logger = logging.getLogger("auto_scheduler")

# Global workload function and hash key registry
# It stores two types of workload:
# 1. User registered tasks. This type of workload is registered
#    by the decorator "register_workload"
# 2. Extracted tasks from a relay program. This type of workload is
#    registered by function "register_workload_tensors".
#
# For 1, the dictionary maps a function name to its function pointer
# For 2, the dictionary maps a hash key to a list of input/output tensors
WORKLOAD_FUNC_REGISTRY = {}


[文档] def register_workload(func_name, f=None, override=False): """Register a function that generates a certain workload. The input function should take hashable and jsonable arguments (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. Parameters ---------- func_name : Union[Function, str] The generation function that returns the compute declaration Tensors or its function name. f : Optional[Function] The generation function to be registered. override : boolean = False Whether to override existing entry. Examples -------- .. code-block:: python @auto_scheduler.register_workload def matmul(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') k = te.reduce_axis((0, K), name='k') C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] """ global WORKLOAD_FUNC_REGISTRY if callable(func_name): f = func_name func_name = get_func_name(f) if not isinstance(func_name, str): raise ValueError("expect string function name") def register(myf): """internal register function""" if func_name in WORKLOAD_FUNC_REGISTRY and not override: raise RuntimeError(f"{func_name} has been registered already") WORKLOAD_FUNC_REGISTRY[func_name] = myf return myf if f: return register(f) return register
def register_workload_tensors(workload_key, tensors, override=True): """Register a workload by provding input/output tensors. Since this function is used when extracting/deserializing tasks, it expects duplicated registrations by default. Parameters ---------- workload_key: str The wokrload key of the compute DAG in JSON string. tensors: List[Tensor] The input/output tensors of a compute DAG override : boolean = True Whether to override existing entry. Returns ------- workload_key: str The wokrload key of the compute DAG in JSON string. """ register_workload(workload_key, override=override)(tensors) return workload_key
[文档] def make_workload_key(func, args): """Make a workload key by function and arguments. Parameters ---------- func : Union[Function, str] The function that returns the compute declaration Tensors. Can be the a function or the function name. args : Args The args of the function. Returns ------- workload_key : str The workload key of the function. """ global WORKLOAD_FUNC_REGISTRY if callable(func): func_name = get_func_name(func) elif isinstance(func, str): func_name = func else: raise ValueError( "Invalid function: " + str(func) + " . `make_workload_key` expects a callable function or its function name" ) if not func_name in WORKLOAD_FUNC_REGISTRY: raise ValueError( f"{func} is not registered. " f"Please register it with @auto_scheduler.register_workload" ) args = serialize_args(args) return json.dumps((func_name,) + args)
@tvm._ffi.register_func("auto_scheduler.workload_key_to_tensors") def workload_key_to_tensors(workload_key): """Get the input/output tensors from the workload key. This method is usually used to create a ComputeDAG by workload key. Parameters ---------- workload_key : str The input workload key in JSON string. The format is either (func_name, arguments...) for compute functions, or (hash, shapes...) for ComputeDAG. Returns ------- tensors : List[Tensor] The registered compute declaration Tensors. """ global WORKLOAD_FUNC_REGISTRY # We register ComputeDAG with both hash and argumetns, which are fixed in ComputeDAG, # so we use an entire workload key to query the ComputeDAG. if workload_key in WORKLOAD_FUNC_REGISTRY: return WORKLOAD_FUNC_REGISTRY[workload_key] # We register compute function with only the function name since # it does not bind to specific arguments, so we use the function name to query # the function and call the function with arguments to get the tensors. workload = json.loads(workload_key) name = workload[0] value = WORKLOAD_FUNC_REGISTRY[name] assert callable(value) args = deserialize_args(workload[1:]) result = value(*args) if isinstance(result, tuple): result = list(result) return result def serialize_workload_registry_entry(workload_key): """ Serialize a workload registry entry. This is used when the start method of multiprocessing is spawn. We need to serialize the entry and register it in the new processes. Parameters ---------- workload_key : str The workload key Returns ------- data: Tuple The serialized pickable data """ global WORKLOAD_FUNC_REGISTRY if workload_key in WORKLOAD_FUNC_REGISTRY: sname = workload_key else: workload = json.loads(workload_key) sname = workload[0] svalue = WORKLOAD_FUNC_REGISTRY[sname] if not callable(svalue): # pylint: disable=assignment-from-no-return svalue = SaveJSON(svalue) return sname, svalue def deserialize_workload_registry_entry(data): """ Deserialize a workload registry entry. This should be used along with :code:`serialize_workload_registry_entry` Parameters ---------- data: Tuple The return value of :code:`serialize_workload_registry_entry` """ global WORKLOAD_FUNC_REGISTRY name, value = data if name not in WORKLOAD_FUNC_REGISTRY: # pylint: disable=assignment-from-no-return if not callable(value): value = LoadJSON(value) WORKLOAD_FUNC_REGISTRY[name] = value def save_workload_func_registry(filename): """Dump workload function registry to a pickle binary file. Parameters ---------- filename : str The filename to dump workload function registry to. """ global WORKLOAD_FUNC_REGISTRY pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, "wb")) def load_workload_func_registry(filename): """Load workload function registry from a pickle binary file. Parameters ---------- filename : str The filename to load workload function registry from. """ global WORKLOAD_FUNC_REGISTRY WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, "rb"))