# -*- coding: utf-8 -*-
"""
A Novel Bandit-Based Approach to Hyperparameter Optimization
============================================================
Implement Hyperband to exploit configurations with fixed resource efficiently
"""
import copy
import hashlib
import logging
import numpy
from orion.algo.base import BaseAlgorithm
from orion.algo.space import Fidelity
logger = logging.getLogger(__name__)
REGISTRATION_ERROR = """
Bad fidelity level {fidelity}. Should be in {budgets}.
Params: {params}
"""
SPACE_ERROR = """
Hyperband cannot be used if space does not contain a fidelity dimension.
For more information on the configuration and usage of Hyperband, see
https://orion.readthedocs.io/en/develop/user/algorithms.html#hyperband
"""
BUDGET_ERROR = """
Cannot build budgets below max_resources;
(max: {}) - (min: {}) > (num_rungs: {})
"""
[docs]def compute_budgets(max_resources, reduction_factor):
"""Compute the budgets used for each execution of hyperband"""
num_brackets = int(numpy.log(max_resources) / numpy.log(reduction_factor))
budgets = []
budgets_tab = {} # just for display consideration
for bracket_id in range(0, num_brackets + 1):
bracket_budgets = []
num_trials = int(
numpy.ceil(
int((num_brackets + 1) / (num_brackets - bracket_id + 1))
* (reduction_factor ** (num_brackets - bracket_id))
)
)
min_resources = max_resources / reduction_factor ** (num_brackets - bracket_id)
for i in range(0, num_brackets - bracket_id + 1):
n_i = int(num_trials / reduction_factor ** i)
min_i = int(min_resources * reduction_factor ** i)
bracket_budgets.append((n_i, min_i))
if budgets_tab.get(i):
budgets_tab[i].append((n_i, min_i))
else:
budgets_tab[i] = [(n_i, min_i)]
budgets.append(bracket_budgets)
display_budgets(budgets_tab, max_resources, reduction_factor)
return budgets
[docs]def display_budgets(budgets_tab, max_resources, reduction_factor):
"""Display hyperband budget as a table in debug log"""
num_brackets = len(budgets_tab[0])
table_str = "Display Budgets:\n"
col_format_str = "{:<4}" + " {:<12}" * num_brackets + "\n"
col_title_list = ["i "] + ["n_i r_i"] * num_brackets
col_sub_list = ["---"] + ["---------"] * num_brackets
table_str += col_format_str.format(*col_sub_list)
table_str += col_format_str.format(*col_title_list)
table_str += col_format_str.format(*col_sub_list)
total_trials = 0
for key, values in budgets_tab.items():
table_row = "{:<4} ".format(key)
for value in values:
n_i, r_i = value
total_trials += n_i
st = "{:<5} {:<7}".format(n_i, r_i)
table_row += st
table_str += table_row + "\n"
table_str += col_format_str.format(*col_sub_list)
table_str += "max resource={}, eta={}, trials number of one execution={}\n".format(
max_resources, reduction_factor, total_trials
)
logger.debug(table_str)
[docs]class Hyperband(BaseAlgorithm):
"""Hyperband formulates hyperparameter optimization as a pure-exploration non-stochastic
infinite-armed bandit problem where a predefined resource like iterations, data samples,
or features is allocated to randomly sampled configurations.`
For more information on the algorithm,
see original paper at http://jmlr.org/papers/v18/16-558.html.
Li, Lisha et al. "Hyperband: A Novel Bandit-Based Approach to Hyperparameter Optimization"
Journal of Machine Learning Research, 18:1-52, 2018.
Parameters
----------
space: `orion.algo.space.Space`
Optimisation space with priors for each dimension.
seed: None, int or sequence of int
Seed for the random number generator used to sample new trials.
Default: ``None``
repetitions: int
Number of execution of Hyperband. Default is numpy.inf which means to
run Hyperband until no new trials can be suggested.
"""
def __init__(self, space, seed=None, repetitions=numpy.inf):
self.brackets = []
super(Hyperband, self).__init__(space, seed=seed, repetitions=repetitions)
self.trial_info_wo_fidelity = (
{}
) # Stores Point id (with no fidelity) -> Bracket
self.sampled = set()
self.points_in_suggest_call = {}
try:
fidelity_index = self.fidelity_index
except IndexError:
raise RuntimeError(SPACE_ERROR)
fidelity_dim = space.values()[fidelity_index]
self.min_resources = fidelity_dim.low
self.max_resources = fidelity_dim.high
self.reduction_factor = fidelity_dim.base
# if self.reduction_factor < 2:
# raise AttributeError("Reduction factor for Hyperband needs to be at least 2.")
self.repetitions = repetitions
# Counter for how many times Hyperband been executed
self.executed_times = 0
if self.reduction_factor >= 2:
self.budgets = compute_budgets(self.max_resources, self.reduction_factor)
self.brackets = [
Bracket(self, bracket_budgets, 1) for bracket_budgets in self.budgets
]
self.seed_rng(seed)
else:
self.budgets = None
self.brackets = None
logger.warning("Reduction factor for Hyperband needs to be at least 2")
[docs] def sample(self, num, bracket, buffer=10):
"""Sample new points from bracket"""
sample_buffer = bracket.rungs[0]["n_trials"] * buffer
samples = self.space.sample(sample_buffer, seed=bracket.seed)
i = 0
points = []
while len(points) < num and i < sample_buffer:
point = samples[i]
_id_wo_fidelity = self.get_id(point)
_bracket_observed = self.trial_info_wo_fidelity.get(_id_wo_fidelity)
if _id_wo_fidelity not in self.points_in_suggest_call and (
not _bracket_observed
or (
_bracket_observed.repetition_id < bracket.repetition_id
and _bracket_observed.get_point_max_resource(point)
< bracket.rungs[0]["resources"]
)
):
# if no duplicated found or the duplicated found existing in previous hyperband
# execution with less resource
point = list(point)
point[self.fidelity_index] = bracket.rungs[0]["resources"]
points.append(tuple(point))
self.points_in_suggest_call[_id_wo_fidelity] = bracket
i += 1
return points
[docs] def seed_rng(self, seed):
"""Seed the state of the random number generator.
:param seed: Integer seed for the random number generator.
"""
self.seed = seed
for i, bracket in enumerate(self.brackets):
bracket.seed_rng(
self.executed_times + seed + i if seed is not None else None
)
self.rng = numpy.random.RandomState(seed)
@property
def state_dict(self):
"""Return a state dict that can be used to reset the state of the algorithm."""
return {
"rng_state": self.rng.get_state(),
"seed": self.seed,
"executed_times": self.executed_times,
"sampled": self.sampled,
}
[docs] def set_state(self, state_dict):
"""Reset the state of the algorithm based on the given state_dict
:param state_dict: Dictionary representing state of an algorithm
"""
self.seed_rng(state_dict["seed"])
self.rng.set_state(state_dict["rng_state"])
self.executed_times = state_dict["executed_times"]
self.sampled = state_dict["sampled"]
[docs] def suggest(self, num=1):
"""Suggest a number of new sets of parameters.
Sample new points until first rung is filled. Afterwards
waits for all trials to be completed before promoting trials
to the next rung.
Parameters
----------
num: int, optional
Number of points to suggest. Defaults to 1.
Returns
-------
list of points or None
A list of lists representing points suggested by the algorithm. The algorithm may opt
out if it cannot make a good suggestion at the moment (it may be waiting for other
trials to complete), in which case it will return None.
"""
samples = []
# Temp persist of the points proposed inside this call
self.points_in_suggest_call = {}
for bracket in reversed(self.brackets):
if not bracket.is_filled:
samples += bracket.sample()
if samples:
for sample in samples:
self.sampled.add(self.get_id(sample, ignore_fidelity=False))
return samples
# All brackets are filled
for bracket in reversed(self.brackets):
if bracket.is_ready() and not bracket.is_done:
samples += bracket.promote()
if samples:
for sample in samples:
self.sampled.add(self.get_id(sample, ignore_fidelity=False))
return samples
# Either all brackets are done or none are ready and algo needs to wait for some trials to
# complete
if len(self.trial_info_wo_fidelity) >= self.space.cardinality:
logger.warning(
"The number of unique trials of bottom rungs exceeds the search space "
"cardinality %i, Hyperband algorithm exits.",
self.space.cardinality,
)
else:
logger.warning("Hyeprband can not suggest new samples, exit.")
return None
[docs] def get_id(self, point, ignore_fidelity=True):
"""Compute a unique hash for a point based on params, without fidelity level by default."""
_point = list(point)
if ignore_fidelity:
non_fidelity_dims = _point[0 : self.fidelity_index]
non_fidelity_dims.extend(_point[self.fidelity_index + 1 :])
_point = non_fidelity_dims
return hashlib.md5(str(_point).encode("utf-8")).hexdigest()
def _refresh_bracket(self):
"""Refresh bracket if one hyperband execution is done"""
if all(bracket.is_done for bracket in self.brackets):
self.executed_times = self.brackets[0].repetition_id
logger.debug(
"Hyperband execution %i is done, required to execute %s times",
self.executed_times,
str(self.repetitions),
)
# Continue to the next execution if need
if self.executed_times < self.repetitions:
self.brackets = [
Bracket(self, bracket_budgets, self.executed_times + 1)
for bracket_budgets in self.budgets
]
def _get_bracket(self, point):
"""Get the bracket of a point during observe"""
fidelity = point[self.fidelity_index]
_id_wo_fidelity = self.get_id(point)
brackets = []
for bracket in self.brackets:
# If find same point in first rung of a bracket,
# the point should register in this bracket
if _id_wo_fidelity in bracket.rungs[0]["results"]:
brackets = [bracket]
break
if not brackets:
# If the point show in current hyeprband execution the first time,
# the bracket with same fidelity in the first rung should be used,
# the assumption is that there is no duplicated points inside same hyperband execution.
brackets = [
bracket
for bracket in self.brackets
if bracket.rungs[0]["resources"] == fidelity
]
if not brackets:
raise ValueError(
"No bracket found for point {0} with fidelity {1}".format(
_id_wo_fidelity, fidelity
)
)
if len(brackets) > 1:
logger.warning(
"More than one bracket found for point %s, this should not happen",
str(point),
)
bracket = brackets[0]
return bracket
[docs] def observe(self, points, results):
"""Observe evaluation `results` corresponding to list of `points` in
space.
A simple random sampler though does not take anything into account.
"""
for point, result in zip(points, results):
full_id = self.get_id(point, ignore_fidelity=False)
if full_id not in self.sampled:
logger.info(
"Ignoring point %s because it was not sampled by current algo.",
full_id,
)
continue
bracket = self._get_bracket(point)
try:
bracket.register(point, result["objective"])
except IndexError:
logger.warning(
"Point registered to wrong bracket. This is likely due "
"to a corrupted database, where trials of different fidelity "
"have a wrong timestamps."
)
continue
_id_wo_fidelity = self.get_id(point)
if (
_id_wo_fidelity not in self.trial_info_wo_fidelity
or bracket.get_point_max_resource(point)
> self.trial_info_wo_fidelity[_id_wo_fidelity].get_point_max_resource(
point
)
):
# If the point show up the first time or it has the bigger resource value
# compared with the same point in previous hyperband execution
self.trial_info_wo_fidelity[_id_wo_fidelity] = bracket
self._refresh_bracket()
@property
def is_done(self):
"""Return True, if all required execution been done."""
if self.executed_times >= self.repetitions:
return True
return False
@property
def fidelity_index(self):
"""Compute the index of the point when fidelity is."""
def _is_fidelity(dim):
return isinstance(dim, Fidelity) or (
hasattr(dim, "original_dimension")
and isinstance(dim.original_dimension, Fidelity)
)
return [i for i, dim in enumerate(self.space.values()) if _is_fidelity(dim)][0]
[docs]class Bracket:
"""Bracket of rungs for the algorithm Hyperband.
Parameters
----------
hyperband: `Hyperband` algorithm
The hyperband algorithm object which this bracket will be part of.
budgets: list of tuple
Each tuple gives the (n_trials, resource_budget) for the respective rung.
repetition_id: int
The id of hyperband execution this bracket belongs to
"""
def __init__(self, hyperband, budgets, repetition_id):
self.hyperband = hyperband
self.rungs = [
dict(resources=budget, n_trials=n_trials, results=dict())
for n_trials, budget in budgets
]
self.seed = None
self.repetition_id = repetition_id
logger.debug("Bracket budgets: %s", str(budgets))
# points = hyperband.sample(compute_rung_sizes(reduction_factor, len(budgets))[0])
# for point in points:
# self.register(point, None)
@property
def is_filled(self):
"""Return True if first rung with trials is filled"""
return self.has_rung_filled(0)
[docs] def get_point_max_resource(self, point):
"""Return the max resource value that has been tried for a point"""
max_resource = 0
_id_wo_fidelity = self.hyperband.get_id(point)
for rung in self.rungs:
if _id_wo_fidelity in rung["results"]:
max_resource = rung["resources"]
return max_resource
[docs] def sample(self):
"""Sample a new trial with lowest fidelity"""
should_have_n_trials = self.rungs[0]["n_trials"]
n_trials = len(self.rungs[0]["results"])
return self.hyperband.sample(should_have_n_trials - n_trials, self)
[docs] def register(self, point, objective):
"""Register a point in the corresponding rung"""
fidelity = point[self.hyperband.fidelity_index]
rungs = [
rung["results"] for rung in self.rungs if rung["resources"] == fidelity
]
if not rungs:
budgets = [rung["resources"] for rung in self.rungs]
raise IndexError(
REGISTRATION_ERROR.format(
fidelity=fidelity, budgets=budgets, params=point
)
)
rungs[0][self.hyperband.get_id(point)] = (objective, point)
[docs] def get_candidates(self, rung_id):
"""Get a candidate for promotion"""
if self.has_rung_filled(rung_id + 1):
return []
rung = self.rungs[rung_id]["results"]
next_rung = self.rungs[rung_id + 1]["results"]
rung = list(sorted((objective, point) for objective, point in rung.values()))
should_have_n_trials = self.rungs[rung_id + 1]["n_trials"]
points = []
i = 0
while len(points) + len(next_rung) < should_have_n_trials:
objective, point = rung[i]
assert objective is not None
_id = self.hyperband.get_id(point)
if _id not in next_rung:
points.append(point)
i += 1
return points
@property
def is_done(self):
"""Return True, if the last rung is filled."""
return self.has_rung_filled(len(self.rungs) - 1)
[docs] def has_rung_filled(self, rung_id):
"""Return True, if the rung[rung_id] is filled."""
n_trials = len(self.rungs[rung_id]["results"])
return n_trials >= self.rungs[rung_id]["n_trials"]
[docs] def is_ready(self, rung_id=None):
"""Return True, if the bracket is ready for next promote"""
if rung_id is not None:
return self.has_rung_filled(rung_id) and all(
objective is not None
for objective, _ in self.rungs[rung_id]["results"].values()
)
is_ready = False
for _rung_id in range(len(self.rungs)):
if self.has_rung_filled(_rung_id):
is_ready = self.is_ready(_rung_id)
else:
break
return is_ready
[docs] def seed_rng(self, seed):
"""Seed the state of the random number generator.
:param seed: Integer seed for the random number generator.
"""
self.seed = seed
def __repr__(self):
"""Return representation of bracket with fidelity levels"""
return "Bracket(resource={}, repetition id={})".format(
[rung["resources"] for rung in self.rungs], self.repetition_id
)