tvm.meta_schedule.task_scheduler.task_scheduler 源代码

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Auto-tuning Task Scheduler"""
from typing import Callable, List, Optional, Union

# isort: off
from typing_extensions import Literal

# isort: on

from tvm._ffi import register_object
from tvm.runtime import Object

from .. import _ffi_api
from ..builder import Builder, BuilderResult
from ..cost_model import CostModel
from ..database import Database
from ..logging import get_logger, get_logging_func
from ..measure_callback import MeasureCallback
from ..runner import Runner, RunnerResult
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext

logger = get_logger(__name__)  # pylint: disable=invalid-name


@register_object("meta_schedule.TaskRecord")
class TaskRecord(Object):
    """The running record of a task."""

    ctx: TuneContext
    task_weight: float
    flop: float
    is_terminated: bool
    build_error_count: int
    run_error_count: int
    measure_candidates: List[MeasureCandidate]
    builder_results: List[BuilderResult]
    runner_results: List[RunnerResult]


[文档] @register_object("meta_schedule.TaskScheduler") class TaskScheduler(Object): """The abstract task scheduler interface.""" tasks_: List[TaskRecord] measure_callbacks_: List[MeasureCallback] database_: Optional[Database] cost_model_: Optional[CostModel] remaining_tasks_: int TaskSchedulerType = Union["TaskScheduler", Literal["gradient", "round-robin"]]
[文档] def next_task_id(self) -> int: """Fetch the next task id. Returns ------- next_task_id : int The next task id. """ return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member
[文档] def join_running_task(self, task_id: int) -> List[RunnerResult]: """Wait until the task is finished. Parameters ---------- task_id : int The task id to be joined. Returns ------- results : List[RunnerResult] The list of results. """ return _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member
[文档] def tune( self, tasks: List[TuneContext], task_weights: List[float], max_trials_global: int, max_trials_per_task: int, num_trials_per_iter: int, builder: Builder, runner: Runner, measure_callbacks: List[MeasureCallback], database: Optional[Database], cost_model: Optional[CostModel], ) -> None: """Auto-tuning. Parameters ---------- tasks : List[TuneContext] The list of tuning contexts as tasks. task_weights : List[float] The list of task weights. max_trials_global : int The maximum number of trials globally. max_trials_per_task : int The maximum number of trials per task. num_trials_per_iter : int The number of trials per iteration. builder : Builder The builder. runner : Runner The runner. measure_callbacks : List[MeasureCallback] The list of measure callbacks. database : Optional[Database] The database. cost_model : Optional[CostModel] The cost model. """ task_weights = [float(w) for w in task_weights] _ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member self, tasks, task_weights, max_trials_global, max_trials_per_task, num_trials_per_iter, builder, runner, measure_callbacks, database, cost_model, )
[文档] def terminate_task(self, task_id: int) -> None: """Terminate the task Parameters ---------- task_id : int The task id to be terminated. """ _ffi_api.TaskSchedulerTerminateTask(self, task_id) # type: ignore # pylint: disable=no-member
[文档] def touch_task(self, task_id: int) -> None: """Touch the task and update its status Parameters ---------- task_id : int The task id to be checked. """ _ffi_api.TaskSchedulerTouchTask(self, task_id) # type: ignore # pylint: disable=no-member
[文档] def print_tuning_statistics(self) -> None: """Print out a human-readable format of the tuning statistics.""" return _ffi_api.TaskSchedulerPrintTuningStatistics(self) # type: ignore # pylint: disable=no-member
[文档] @staticmethod def create( # pylint: disable=keyword-arg-before-vararg kind: Literal["round-robin", "gradient"] = "gradient", *args, **kwargs, ) -> "TaskScheduler": """Create a task scheduler.""" from . import ( # pylint: disable=import-outside-toplevel GradientBased, RoundRobin, ) if kind == "round-robin": return RoundRobin(*args, **kwargs) # type: ignore if kind == "gradient": return GradientBased(*args, **kwargs) raise ValueError(f"Unknown TaskScheduler name: {kind}")
create = TaskScheduler.create # pylint: disable=invalid-name @register_object("meta_schedule.PyTaskScheduler") class _PyTaskScheduler(TaskScheduler): """ A TVM object task scheduler to support customization on the python side. This is NOT the user facing class for function overloading inheritance. See also: PyTaskScheduler """ def __init__( self, f_next_task_id: Callable, f_join_running_task: Callable, f_tune: Callable, ): """Constructor.""" self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerPyTaskScheduler, # type: ignore # pylint: disable=no-member get_logging_func(logger), f_next_task_id, f_join_running_task, f_tune, ) class PyTaskScheduler: """ An abstract task scheduler with customized methods on the python-side. This is the user facing class for function overloading inheritance. Note: @derived_object is required for proper usage of any inherited class. """ _tvm_metadata = { "cls": _PyTaskScheduler, "fields": [], "methods": ["next_task_id", "join_running_task", "tune"], } def __init__(self): ... def tune( self, tasks: List[TuneContext], task_weights: List[float], max_trials_global: int, max_trials_per_task: int, builder: Builder, runner: Runner, measure_callbacks: List[MeasureCallback], database: Optional[Database], cost_model: Optional[CostModel], ) -> None: """Auto-tuning.""" # Using self._outer to replace the self pointer _ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member self._outer(), # type: ignore # pylint: disable=no-member tasks, task_weights, max_trials_global, max_trials_per_task, builder, runner, measure_callbacks, database, cost_model, ) def next_task_id(self) -> int: """Fetch the next task id. Returns ------- next_task_id : int The next task id. """ raise NotImplementedError def join_running_task(self, task_id: int) -> List[RunnerResult]: """Wait until the task is finished. Parameters ---------- task_id : int The task id to be joined. """ # Using self._outer to replace the self pointer return _ffi_api.TaskSchedulerJoinRunningTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member