# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""Auto-tuning Task Scheduler"""fromtypingimportCallable,List,Optional,Union# isort: offfromtyping_extensionsimportLiteral# isort: onfromtvm._ffiimportregister_objectfromtvm.runtimeimportObjectfrom..import_ffi_apifrom..builderimportBuilder,BuilderResultfrom..cost_modelimportCostModelfrom..databaseimportDatabasefrom..loggingimportget_logger,get_logging_funcfrom..measure_callbackimportMeasureCallbackfrom..runnerimportRunner,RunnerResultfrom..search_strategyimportMeasureCandidatefrom..tune_contextimportTuneContextlogger=get_logger(__name__)# pylint: disable=invalid-name@register_object("meta_schedule.TaskRecord")classTaskRecord(Object):"""The running record of a task."""ctx:TuneContexttask_weight:floatflop:floatis_terminated:boolbuild_error_count:intrun_error_count:intmeasure_candidates:List[MeasureCandidate]builder_results:List[BuilderResult]runner_results:List[RunnerResult]
[文档]defnext_task_id(self)->int:"""Fetch the next task id. Returns ------- next_task_id : int The next task id. """return_ffi_api.TaskSchedulerNextTaskId(self)# type: ignore # pylint: disable=no-member
[文档]defjoin_running_task(self,task_id:int)->List[RunnerResult]:"""Wait until the task is finished. Parameters ---------- task_id : int The task id to be joined. Returns ------- results : List[RunnerResult] The list of results. """return_ffi_api.TaskSchedulerJoinRunningTask(self,task_id)# type: ignore # pylint: disable=no-member
[文档]deftune(self,tasks:List[TuneContext],task_weights:List[float],max_trials_global:int,max_trials_per_task:int,num_trials_per_iter:int,builder:Builder,runner:Runner,measure_callbacks:List[MeasureCallback],database:Optional[Database],cost_model:Optional[CostModel],)->None:"""Auto-tuning. Parameters ---------- tasks : List[TuneContext] The list of tuning contexts as tasks. task_weights : List[float] The list of task weights. max_trials_global : int The maximum number of trials globally. max_trials_per_task : int The maximum number of trials per task. num_trials_per_iter : int The number of trials per iteration. builder : Builder The builder. runner : Runner The runner. measure_callbacks : List[MeasureCallback] The list of measure callbacks. database : Optional[Database] The database. cost_model : Optional[CostModel] The cost model. """task_weights=[float(w)forwintask_weights]_ffi_api.TaskSchedulerTune(# type: ignore # pylint: disable=no-memberself,tasks,task_weights,max_trials_global,max_trials_per_task,num_trials_per_iter,builder,runner,measure_callbacks,database,cost_model,)
[文档]defterminate_task(self,task_id:int)->None:"""Terminate the task Parameters ---------- task_id : int The task id to be terminated. """_ffi_api.TaskSchedulerTerminateTask(self,task_id)# type: ignore # pylint: disable=no-member
[文档]deftouch_task(self,task_id:int)->None:"""Touch the task and update its status Parameters ---------- task_id : int The task id to be checked. """_ffi_api.TaskSchedulerTouchTask(self,task_id)# type: ignore # pylint: disable=no-member
[文档]defprint_tuning_statistics(self)->None:"""Print out a human-readable format of the tuning statistics."""return_ffi_api.TaskSchedulerPrintTuningStatistics(self)# type: ignore # pylint: disable=no-member
create=TaskScheduler.create# pylint: disable=invalid-name@register_object("meta_schedule.PyTaskScheduler")class_PyTaskScheduler(TaskScheduler):""" A TVM object task scheduler to support customization on the python side. This is NOT the user facing class for function overloading inheritance. See also: PyTaskScheduler """def__init__(self,f_next_task_id:Callable,f_join_running_task:Callable,f_tune:Callable,):"""Constructor."""self.__init_handle_by_constructor__(_ffi_api.TaskSchedulerPyTaskScheduler,# type: ignore # pylint: disable=no-memberget_logging_func(logger),f_next_task_id,f_join_running_task,f_tune,)classPyTaskScheduler:""" An abstract task scheduler with customized methods on the python-side. This is the user facing class for function overloading inheritance. Note: @derived_object is required for proper usage of any inherited class. """_tvm_metadata={"cls":_PyTaskScheduler,"fields":[],"methods":["next_task_id","join_running_task","tune"],}def__init__(self):...deftune(self,tasks:List[TuneContext],task_weights:List[float],max_trials_global:int,max_trials_per_task:int,builder:Builder,runner:Runner,measure_callbacks:List[MeasureCallback],database:Optional[Database],cost_model:Optional[CostModel],)->None:"""Auto-tuning."""# Using self._outer to replace the self pointer_ffi_api.TaskSchedulerTune(# type: ignore # pylint: disable=no-memberself._outer(),# type: ignore # pylint: disable=no-membertasks,task_weights,max_trials_global,max_trials_per_task,builder,runner,measure_callbacks,database,cost_model,)defnext_task_id(self)->int:"""Fetch the next task id. Returns ------- next_task_id : int The next task id. """raiseNotImplementedErrordefjoin_running_task(self,task_id:int)->List[RunnerResult]:"""Wait until the task is finished. Parameters ---------- task_id : int The task id to be joined. """# Using self._outer to replace the self pointerreturn_ffi_api.TaskSchedulerJoinRunningTask(self._outer(),task_id)# type: ignore # pylint: disable=no-member