"""
Asynchronous Successive Halving Algorithm
=========================================
"""
from __future__ import annotations
import copy
import logging
from collections import defaultdict
from typing import Sequence
import numpy as np
from orion.algo.hyperband import (
BudgetTuple,
Hyperband,
HyperbandBracket,
display_budgets,
)
from orion.algo.space import Space
from orion.core.utils.flatten import flatten
from orion.core.worker.trial import Trial
logger = logging.getLogger(__name__)
REGISTRATION_ERROR = """
Bad fidelity level {fidelity}. Should be in {budgets}.
Params: {params}
"""
SPACE_ERROR = """
ASHA can only be used if there is one fidelity dimension.
For more information on the configuration and usage of ASHA, see
https://orion.readthedocs.io/en/develop/user/algorithms.html#asha
"""
BUDGET_ERROR = """
Cannot build budgets below max_resources;
(max: {}) - (min: {}) > (num_rungs: {})
"""
[docs]def compute_budgets(
min_resources: float,
max_resources: float,
reduction_factor: int,
num_rungs: int,
num_brackets: int,
):
"""Compute the budgets used for ASHA"""
# Tracks state for new trial add
if num_brackets > num_rungs:
logger.warning(
"The input num_brackets %i is larger than the number of rungs %i, "
"set num_brackets as %i",
num_brackets,
num_rungs,
num_rungs,
)
num_brackets = num_rungs
budgets = np.logspace(
np.log(min_resources) / np.log(reduction_factor),
np.log(max_resources) / np.log(reduction_factor),
num_rungs,
base=reduction_factor,
)
budgets = (budgets + 0.5).astype(int)
for i in range(num_rungs - 1):
if budgets[i] >= budgets[i + 1]:
budgets[i + 1] = budgets[i] + 1
if budgets[-1] > max_resources:
raise ValueError(BUDGET_ERROR.format(min_resources, max_resources, num_rungs))
resources = [budgets[bracket_index:] for bracket_index in range(num_brackets)]
budgets_lists: list[list[BudgetTuple]] = []
budgets_tab: dict[int, list[BudgetTuple]] = defaultdict(list)
for bracket_ressources in resources:
bracket_budgets: list[BudgetTuple] = []
for i, min_ressources in enumerate(bracket_ressources[::-1]):
budget = BudgetTuple(reduction_factor**i, min_ressources)
bracket_budgets.append(budget)
budgets_tab[len(bracket_ressources) - i - 1].append(budget)
budgets_lists.append(bracket_budgets[::-1])
display_budgets(budgets_tab, max_resources, reduction_factor)
return list(budgets_lists)
[docs]class ASHA(Hyperband):
"""Asynchronous Successive Halving Algorithm
`A simple and robust hyperparameter tuning algorithm with solid theoretical underpinnings
that exploits parallelism and aggressive early-stopping.`
For more information on the algorithm, see original paper at https://arxiv.org/abs/1810.05934.
Li, Liam, et al. "Massively parallel hyperparameter tuning."
arXiv preprint arXiv:1810.05934 (2018)
Parameters
----------
space: `orion.algo.space.Space`
Optimisation space with priors for each dimension.
seed: None, int or sequence of int
Seed for the random number generator used to sample new trials.
Default: ``None``
num_rungs: int, optional
Number of rungs for the largest bracket. If not defined, it will be equal to ``(base + 1)``
of the fidelity dimension. In the original paper,
``num_rungs == log(fidelity.high/fidelity.low) / log(fidelity.base) + 1``.
Default: ``log(fidelity.high/fidelity.low) / log(fidelity.base) + 1``
num_brackets: int
Using a grace period that is too small may bias ASHA too strongly towards
fast converging trials that do not lead to best results at convergence (stagglers). To
overcome this, you can increase the number of brackets, which increases the amount of
resource required for optimisation but decreases the bias towards stragglers.
Default: 1
repetitions: int
Number of execution of ASHA. Default is np.inf which means to
run ASHA until no new trials can be suggested.
"""
def __init__(
self,
space: Space,
seed: int | Sequence[int] | None = None,
num_rungs: int | None = None,
num_brackets: int = 1,
repetitions: int | float = np.inf,
):
super().__init__(space, seed=seed, repetitions=repetitions)
self.num_rungs = num_rungs
self.num_brackets = num_brackets
self._param_names += ["num_rungs", "num_brackets"]
if self.reduction_factor < 2:
raise AttributeError("Reduction factor for ASHA needs to be at least 2.")
if num_rungs is None:
num_rungs = int(
np.log(self.max_resources / self.min_resources)
/ np.log(self.reduction_factor)
+ 1
)
self.num_rungs = num_rungs
self.budgets = compute_budgets(
self.min_resources,
self.max_resources,
self.reduction_factor,
self.num_rungs,
self.num_brackets,
)
self.brackets = self.create_brackets()
self.seed_rng(seed)
# pylint: disable=missing-function-docstring
def compute_bracket_idx(self, num: int) -> np.ndarray:
# pylint: disable=invalid-name
def assign_resources(
n: int, remainings: np.ndarray, totals: np.ndarray
) -> np.ndarray:
if n == 0 or remainings.sum() == 0:
return remainings
ratios = remainings / totals
fractions = np.nan_to_num(ratios / ratios.sum(), nan=0)
index = np.argmax(fractions)
remainings[index] -= 1
n -= 1
if n > 0:
return assign_resources(n, remainings, totals)
return remainings
assert self.brackets is not None
remainings = np.asarray(
[bracket.remainings for bracket in self.brackets], dtype=np.int64
)
totals = np.asarray(
[bracket.rungs[0]["n_trials"] for bracket in self.brackets], dtype=np.int64
)
remainings_after = copy.deepcopy(remainings)
assign_resources(num, remainings_after, totals)
allocations = remainings - remainings_after
return allocations
def sample(self, num: int) -> list[Trial]:
samples = []
bracket_nums = self.compute_bracket_idx(num)
for idx, bracket_num in enumerate(bracket_nums):
assert self.brackets is not None
bracket = self.brackets[idx]
bracket_samples = self.sample_from_bracket(bracket, bracket_num)
self.register_samples(bracket, bracket_samples)
samples += bracket_samples
return samples
[docs] def suggest(self, num: int) -> list[Trial]:
return super().suggest(num)
def create_bracket(self, budgets: list[BudgetTuple], iteration: int) -> ASHABracket:
return ASHABracket(self, budgets, iteration)
[docs]class ASHABracket(HyperbandBracket[ASHA]):
"""Bracket of rungs for the algorithm ASHA.
Parameters
----------
asha: `ASHA` algorithm
The asha algorithm object which this bracket will be part of.
budgets: list of tuple
Each tuple gives the (n_trials, resource_budget) for the respective rung.
repetition_id: int
The id of hyperband execution this bracket belongs to
"""
def __init__(
self,
owner: ASHA,
budgets: list[BudgetTuple],
repetition_id: int,
):
super().__init__(owner=owner, budgets=budgets, repetition_id=repetition_id)
[docs] def get_candidates(self, rung_id: int) -> list[Trial]:
"""Get a candidate for promotion
Raises
------
TypeError
If get_candidates is called before the entire rung is completed.
"""
rung_results = self.rungs[rung_id]["results"]
next_rung = self.rungs[rung_id + 1]["results"]
rung = list(
sorted(
(
(objective, trial)
for objective, trial in rung_results.values()
if objective is not None
),
key=lambda item: item[0],
)
)
k = len(rung) // self.owner.reduction_factor
k = min(k, len(rung))
candidates: list[Trial] = []
for i in range(k):
trial = rung[i][1]
_id = self.owner.get_id(trial, ignore_fidelity=True)
if _id not in next_rung:
candidates.append(trial)
return candidates
@property
def is_filled(self) -> bool:
"""ASHA's first rung can always sample new trials"""
return False
[docs] def is_ready(self, rung_id: int | None = None) -> bool:
"""ASHA's always ready for promotions"""
return True