Source code for orion.storage.base

# -*- coding: utf-8 -*-
"""
Generic Storage Protocol
========================

Implement a generic protocol to allow Orion to communicate using different storage backend.

Storage protocol is a generic way of allowing Orion to interface with different storage.
MongoDB, track, cometML, MLFLow, etc...

Examples
--------
>>> storage_factory.create('track', uri='file://orion_test.json')
>>> storage_factory.create('legacy', experiment=...)

Notes
-----
When retrieving an already initialized Storage object you should use `get_storage`.
`storage_factory.create()` should only be used for initialization purposes as `get_storage`
raises more granular error messages.

"""
import copy
import logging

import orion.core
from orion.core.io import resolve_config
from orion.core.utils.singleton import GenericSingletonFactory

log = logging.getLogger(__name__)


[docs]def get_uid(item=None, uid=None, force_uid=True): """Return uid either from `item` or directly uid. Parameters ---------- item: Experiment or Trial, optional Object with .id attribute uid: str, optional str id representation force_uid: bool, optional If True, at least one of item or uid must be passed. Raises ------ UndefinedCall if both item and uid are not set and force_uid is True AssertionError if both item and uid are provided and they do not match """ if item is not None and uid is not None: assert item.id == uid if uid is None: if item is None and force_uid: raise MissingArguments("Either `item` or `uid` should be set") elif item is not None: uid = item.id return uid
[docs]class FailedUpdate(Exception): """Exception raised when we are unable to update a trial' status""" pass
[docs]class MissingArguments(Exception): """Raised when calling a function without the minimal set of parameters""" pass
[docs]class BaseStorageProtocol: """Implement a generic protocol to allow Orion to communicate using different storage backend """
[docs] def create_benchmark(self, config): """Insert a new benchmark inside the database""" raise NotImplementedError()
[docs] def fetch_benchmark(self, query, selection=None): """Fetch all benchmarks that match the query""" raise NotImplementedError()
[docs] def create_experiment(self, config): """Insert a new experiment inside the database""" raise NotImplementedError()
[docs] def delete_experiment(self, experiment=None, uid=None): """Delete matching experiments from the database Parameters ---------- experiment: Experiment, optional experiment object to retrieve from the database uid: str, optional experiment id used to retrieve the trial object Returns ------- Number of experiments deleted. Raises ------ UndefinedCall if both experiment and uid are not set AssertionError if both experiment and uid are provided and they do not match """ raise NotImplementedError()
[docs] def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): """Update the fields of a given experiment Parameters ---------- experiment: Experiment, optional experiment object to retrieve from the database uid: str, optional experiment id used to retrieve the trial object where: Optional[dict] constraint experiment must respect **kwargs: dict a dictionary of fields to update Returns ------- returns true if the underlying storage was updated Raises ------ UndefinedCall if both experiment and uid are not set AssertionError if both experiment and uid are provided and they do not match """ raise NotImplementedError()
[docs] def fetch_experiments(self, query, selection=None): """Fetch all experiments that match the query""" raise NotImplementedError()
[docs] def register_trial(self, trial): """Create a new trial to be executed""" raise NotImplementedError()
[docs] def delete_trials(self, experiment=None, uid=None, where=None): """Delete matching trials from the database Parameters ---------- experiment: Experiment, optional experiment object to retrieve from the database uid: str, optional experiment id used to retrieve the trial object where: Optional[dict] constraint trials must respect Returns ------- Number of trials deleted. Raises ------ UndefinedCall if both experiment and uid are not set AssertionError if both experiment and uid are provided and they do not match """ raise NotImplementedError()
[docs] def reserve_trial(self, experiment): """Select a pending trial and reserve it for the worker Returns ------- Returns the reserved trial or None if no trials were found """ raise NotImplementedError()
[docs] def fetch_trials(self, experiment=None, uid=None, where=None): """Fetch all the trials of an experiment in the database Parameters ---------- experiment: Experiment, optional experiment object to retrieve from the database uid: str, optional experiment id used to retrieve the trial object where: Optional[dict] constraint trials must respect Returns ------- return none if the experiment is not found, Raises ------ UndefinedCall if both experiment and uid are not set AssertionError if both experiment and uid are provided and they do not match """ raise NotImplementedError()
[docs] def update_trials(self, experiment=None, uid=None, where=None, **kwargs): """Update trials of a given experiment matching a query Parameters ---------- experiment: Experiment, optional experiment object to retrieve from the database uid: str, optional experiment id used to retrieve the trial object where: Optional[dict] constraint trials must respect **kwargs: dict a dictionary of fields to update Raises ------ UndefinedCall if both experiment and uid are not set AssertionError if both experiment and uid are provided and they do not match """ raise NotImplementedError()
[docs] def update_trial(self, trial=None, uid=None, where=None, **kwargs): """Update fields of a given trial Parameters ---------- trial: Trial, optional trial object to update in the database uid: str, optional id of the trial to update in the database where: Optional[dict] constraint trials must respect. Note: useful to handle race conditions. **kwargs: dict a dictionary of fields to update Raises ------ UndefinedCall if both trial and uid are not set AssertionError if both trial and uid are provided and they do not match """ raise NotImplementedError()
[docs] def get_trial(self, trial=None, uid=None): """Fetch a single trial Parameters ---------- trial: Trial, optional trial object to retrieve from the database uid: str, optional trial id used to retrieve the trial object Returns ------- return none if the trial is not found, Raises ------ UndefinedCall if both trial and uid are not set AssertionError if both trial and uid are provided and they do not match """ raise NotImplementedError()
[docs] def fetch_lost_trials(self, experiment): """Fetch all trials that have a heartbeat older than some given time delta (2 minutes by default) """ raise NotImplementedError()
[docs] def retrieve_result(self, trial, *args, **kwargs): """Fetch the result from a given medium (file, db, socket, etc..) for a given trial and insert it into the trial object """ raise NotImplementedError()
[docs] def push_trial_results(self, trial): """Push the trial's results to the database""" raise NotImplementedError()
[docs] def set_trial_status(self, trial, status, heartbeat=None, was=None): """Update the trial status and the heartbeat Parameters ---------- trial: `Trial` object Trial object to update in the database. status: str Status to be set to the trial heartbeat: datetime, optional New heartbeat to update simultaneously with status was: str, optional The status the trial should be set to in the database. If None, current ``trial.status`` will be used. This is used to ensure coherence in the database, protecting against race conditions for instance. Raises ------ FailedUpdate The exception is raised if the status of the trial object does not match the status in the database """ raise NotImplementedError()
[docs] def fetch_pending_trials(self, experiment): """Fetch all trials that are available to be executed by a worker, this includes new, suspended and interrupted trials """ raise NotImplementedError()
[docs] def fetch_noncompleted_trials(self, experiment): """Fetch all non completed trials""" raise NotImplementedError()
[docs] def fetch_trials_by_status(self, experiment, status): """Fetch all trials with the given status""" raise NotImplementedError()
[docs] def count_completed_trials(self, experiment): """Count the number of completed trials""" raise NotImplementedError()
[docs] def count_broken_trials(self, experiment): """Count the number of broken trials""" raise NotImplementedError()
[docs] def update_heartbeat(self, trial): """Update trial's heartbeat""" raise NotImplementedError()
storage_factory = GenericSingletonFactory(BaseStorageProtocol)
[docs]def get_storage(): """Return current storage This is a wrapper around the Storage Singleton object to provide better error message when it is used without being initialized. Raises ------ RuntimeError If the underlying storage was not initialized prior to calling this function Notes ----- To initialize the underlying storage you must first call `Storage(...)` with the appropriate arguments for the chosen backend """ return storage_factory.create()
[docs]def setup_storage(storage=None, debug=False): """Create the storage instance from a configuration. Parameters ---------- config: dict, optional Configuration for the storage backend. If not defined, global configuration is used. debug: bool, optional If using in debug mode, the storage config is overrided with legacy:EphemeralDB. Defaults to False. """ if storage is None: storage = orion.core.config.storage.to_dict() storage = copy.deepcopy(storage) if storage.get("type") == "legacy" and "database" not in storage: storage["database"] = orion.core.config.storage.database.to_dict() elif storage.get("type") is None and "database" in storage: storage["type"] = "legacy" # If using same storage type if storage["type"] == orion.core.config.storage.type: storage = resolve_config.merge_configs( orion.core.config.storage.to_dict(), storage ) if debug: storage = {"type": "legacy", "database": {"type": "EphemeralDB"}} storage_type = storage.pop("type") log.debug("Creating %s storage client with args: %s", storage_type, storage) try: storage_factory.create(of_type=storage_type, **storage) except ValueError: if storage_factory.create().__class__.__name__.lower() != storage_type.lower(): raise
# pylint: disable=too-few-public-methods
[docs]class ReadOnlyStorageProtocol(object): """Read-only interface from a storage protocol. .. seealso:: :py:class:`BaseStorageProtocol` """ __slots__ = ("_storage",) valid_attributes = { "get_trial", "fetch_trials", "fetch_experiments", "count_broken_trials", "count_completed_trials", "fetch_noncompleted_trials", "fetch_pending_trials", "fetch_lost_trials", "fetch_trials_by_status", } def __init__(self, protocol): """Init method, see attributes of :class:`BaseStorageProtocol`.""" self._storage = protocol def __getattr__(self, attr): """Get attribute only if valid""" if attr not in self.valid_attributes: raise AttributeError( "Cannot access attribute %s on ReadOnlyStorageProtocol." % attr ) return getattr(self._storage, attr)