Source code for orion.algo.pbt.exploit

"""
Exploit classes for Population Based Training
---------------------------------------------

Formulation of a general exploit function for population based training.
Implementations must inherit from ``orion.algo.pbt.BaseExploit``.

Exploit objects can be created using `exploit_factory.create()`.

Examples
--------
>>> exploit_factory.create('TruncateExploit')
>>> exploit_factory.create('TruncateExploit', min_forking_population=10)

"""

import logging

import numpy

from orion.core.utils import GenericFactory

logger = logging.getLogger(__name__)


[docs]class BaseExploit: """Abstract class for Exploit in :py:class:`orion.algo.pbt.pbt.PBT` The exploit class is responsible for deciding whether the Population Based Training algorithm should continue training a trial configuration at next fidelity level or whether it should fork from another trial configuration. This class is expected to be stateless and serve as a configurable callable object. """
[docs] def __call__(self, rng, trial, lineages): """Execute exploit The method receives the current trial under examination and all lineages of population based training. It must then decide whether the trial should be promoted (continue with a higher fidelity) or if another trial should be forked instead. Parameters ---------- rng: numpy.random.Generator A random number generator. It is not contained in ``BaseExploit`` because the exploit class must be stateless. trial: Trial The :py:class:`orion.core.worker.trial.Trial` that is currently under examination. lineages: Lineages All :py:class:`orion.algo.pbt.pbt.Lineages` created by the population based training algorithm that is using this exploit class. Returns ------- ``None`` The exploit class signals that there are not enough completed trials in lineages to make a decision for current trial. ``Trial`` If the returned trial is the same as the one received as argument, it means that population based training should continue with same parameters. If another trial from the lineages is returned, it means that population based training should try to explore new parameters. """
@property def configuration(self): """Configuration of the exploit object""" return dict(of_type=self.__class__.__name__.lower())
[docs]class PipelineExploit(BaseExploit): """ Pipeline of BaseExploit objects The pipeline executes the BaseExploit objects sequentially. If one object returns `None`, the pipeline is stopped and it returns `None`. Likewise, if one object returns a trial different than the one passed, the pipeline is stopped and this trial is returned. Otherwise, if all BaseExploit objects return the same trial as the one passed to the pipeline, then the pipeline returns it. Parameters ---------- exploit_configs: list of dict List of dictionary representing the configurations of BaseExploit children. Examples -------- >>> PipelineExploit( exploit_configs=[ {'of_type': 'BacktrackExploit'}, {'of_type': 'TruncateExploit'} ]) """ # pylint: disable=super-init-not-called def __init__(self, exploit_configs): self.pipeline = [] for exploit_config in exploit_configs: self.pipeline.append(exploit_factory.create(**exploit_config))
[docs] def __call__(self, rng, trial, lineages): """Execute exploit objects sequentially If one object returns `None`, the pipeline is stopped and it returns `None`. Likewise, if one object returns a trial different than the one passed, the pipeline is stopped and this trial is returned. Otherwise, if all BaseExploit objects return the same trial as the one passed to the pipeline, then the pipeline returns it. Parameters ---------- rng: numpy.random.Generator A random number generator. It is not contained in ``BaseExploit`` because the exploit class must be stateless. trial: Trial The :py:class:`orion.core.worker.trial.Trial` that is currently under examination. lineages: Lineages All :py:class:`orion.algo.pbt.pbt.Lineages` created by the population based training algorithm that is using this exploit class. Returns ------- ``None`` The exploit class signals that there are not enough completed trials in lineages to make a decision for current trial. ``Trial`` If the returned trial is the same as the one received as argument, it means that population based training should continue with same parameters. If another trial from the lineages is returned, it means that population based training should try to explore new parameters. """ for exploit in self.pipeline: logger.debug("Executing %s", exploit.__class__.__name__) selected_trial = exploit(rng, trial, lineages) if selected_trial is not trial: logger.debug( "Exploit %s selected trial %s over %s", exploit.__class__.__name__, selected_trial, trial, ) return selected_trial else: logger.debug( "Exploit %s is skipping for trial %s", exploit.__class__.__name__, trial, ) return trial
@property def configuration(self): """Configuration of the exploit object""" configuration = super().configuration configuration["exploit_configs"] = [ exploit.configuration for exploit in self.pipeline ] return configuration
[docs]class TruncateExploit(BaseExploit): """Truncate Exploit If the given trial is under a ``truncation_quantile`` compared to all other trials that has reached the same fidelity level, then a new candidate trial is selected for forking. The new candidate is selected from a pool of best ``candidate_pool_ratio``\\% of the available trials at the same fidelity level. If there are less than ``min_forking_population`` trials that have reached the fidelity level as the passed trial, then `None` is return to signal that we should reconsider this trial later on when more trials are completed at this fidelity level. Parameters ---------- min_forking_population: int, optional Minimum number of trials that should be completed up to the fidelity level of the current trial passed. TruncateExploit will return ``None`` when this requirement is not met. Default: 5 truncation_quantile: float, optional If the passed trial's objective is above quantile ``truncation_quantile``, then another candidate is considered for forking. Default: 0.8 candidate_pool_ratio: float, optional When choosing another candidate for forking, it will be randomly selected from the best ``candidate_pool_ratio``\\% of the available trials. Default: 0.2 """ # pylint: disable=super-init-not-called def __init__( self, min_forking_population=5, truncation_quantile=0.8, candidate_pool_ratio=0.2, ): self.min_forking_population = min_forking_population self.truncation_quantile = truncation_quantile self.candidate_pool_ratio = candidate_pool_ratio
[docs] def __call__(self, rng, trial, lineages): """Select other trial if current one not good enough If the given trial is under a ``self.truncation_quantile`` compared to all other trials that has reached the same fidelity level, then a new candidate trial is selected for forking. The new candidate is selected from a pool of best ``self.candidate_pool_ratio``\\% of the available trials at the same fidelity level. If there are less than ``self.min_forking_population`` trials that have reached the fidelity level as the passed trial, then `None` is return to signal that we should reconsider this trial later on when more trials are completed at this fidelity level. Parameters ---------- rng: numpy.random.Generator A random number generator. It is not contained in ``BaseExploit`` because the exploit class must be stateless. trial: Trial The :py:class:`orion.core.worker.trial.Trial` that is currently under examination. lineages: Lineages All :py:class:`orion.algo.pbt.pbt.Lineages` created by the population based training algorithm that is using this exploit class. Returns ------- ``None`` The exploit class signals that there are not enough completed trials in lineages to make a decision for current trial. ``Trial`` If the returned trial is the same as the one received as argument, it means that population based training should continue with same parameters. If another trial from the lineages is returned, it means that population based training should try to explore new parameters. """ trials = lineages.get_trials_at_depth(trial) return self._truncate(rng, trial, trials)
def _truncate( self, rng, trial, trials, ): completed_trials = [trial for trial in trials if trial.status == "completed"] if len(completed_trials) < self.min_forking_population: logger.debug( "Not enough trials completed to exploit: %s", len(completed_trials) ) return None if trial not in completed_trials: raise ValueError( f"Trial {trial.id} not included in list of completed trials." ) sorted_trials = sorted( completed_trials, key=lambda trial: trial.objective.value ) worse_trials = sorted_trials[ int(self.truncation_quantile * len(sorted_trials)) : ] if trial not in worse_trials: logger.debug("Trial %s is good enough, no need to exploit.", trial) return trial candidate_threshold_index = int(self.candidate_pool_ratio * len(sorted_trials)) if candidate_threshold_index == 0: logger.warning( "Not enough completed trials to have a candidate pool. " "You should consider increasing min_forking_population or candidate_pool_ratio" ) return None index = rng.choice(numpy.arange(0, candidate_threshold_index)) return sorted_trials[index] @property def configuration(self): """Configuration of the exploit object""" configuration = super().configuration configuration.update( dict( min_forking_population=self.min_forking_population, truncation_quantile=self.truncation_quantile, candidate_pool_ratio=self.candidate_pool_ratio, ) ) return configuration
[docs]class BacktrackExploit(TruncateExploit): """ Backtracking Exploit This exploit is inspired from PBT with backtracking proposed in [1]. Instead of using all trials at the same level of fidelity as in ``TruncateExploit``, it selects the best trials from each lineage (worker), one per lineage. The objective of the best trial is compared to the objective of the trial under analysis, and if the ratio is higher than some threshold the current trial is not promoted. A trial from the pool of best trials is selected randomly. The backtracking threshold defined by [1] is unstable however and cause division error by 0 when the best candidate trial has an objective of 0. Also, if we select trials at any fidelity levels, we would likely drop any trial at a low fidelity in favor of best trials at high fidelity. This class use a quantile threshold instead of the ratio in [1] to determine if a trial should be continued at next fidelity level. The candidates for forking are select from best trials from all running lineages (workers), like proposed in [1], but limited to trials up to the fidelity level of the current trial under analysis. [1] Zhang, Baohe, Raghu Rajan, Luis Pineda, Nathan Lambert, André Biedenkapp, Kurtland Chua, Frank Hutter, and Roberto Calandra. "On the importance of hyperparameter optimization for model-based reinforcement learning." In International Conference on Artificial Intelligence and Statistics, pp. 4015-4023. PMLR, 2021. """
[docs] def __call__(self, rng, trial, lineages): """Select other trial if current one not good enough If the given trial is under a ``self.truncation_quantile`` compared to all other best trials with lower or equal fidelity level, then a new candidate trial is selected for forking. The new candidate is selected from a pool of best ``self.candidate_pool_ratio``\\% of the best trials with lower or equal fidelity level. See class description for more explanation on the rationale. If there are less than ``self.min_forking_population`` trials that have reached the fidelity level as the passed trial, then `None` is return to signal that we should reconsider this trial later on when more trials are completed at this fidelity level. Parameters ---------- rng: numpy.random.Generator A random number generator. It is not contained in ``BaseExploit`` because the exploit class must be stateless. trial: Trial The :py:class:`orion.core.worker.trial.Trial` that is currently under examination. lineages: Lineages All :py:class:`orion.algo.pbt.pbt.Lineages` created by the population based training algorithm that is using this exploit class. Returns ------- ``None`` The exploit class signals that there are not enough completed trials in lineages to make a decision for current trial. ``Trial`` If the returned trial is the same as the one received as argument, it means that population based training should continue with same parameters. If another trial from the lineages is returned, it means that population based training should try to explore new parameters. """ elites = lineages.get_elites(max_depth=trial) return self._truncate(rng, trial, elites + [trial])
exploit_factory = GenericFactory(BaseExploit)