Source code for orion.storage.legacy

# -*- coding: utf-8 -*-
"""
Legacy storage
==============

Old Storage implementation.

"""
import datetime
import json
import logging

import orion.core
import orion.core.utils.backward as backward
from orion.core.io.convert import JSONConverter
from orion.core.io.database import Database, OutdatedDatabaseError
from orion.core.utils.exceptions import MissingResultFile
from orion.core.worker.trial import Trial, validate_status
from orion.storage.base import (
    BaseStorageProtocol,
    FailedUpdate,
    MissingArguments,
    get_uid,
)

log = logging.getLogger(__name__)


[docs]def get_database(): """Return current database This is a wrapper around the Database Singleton object to provide better error message when it is used without being initialized. Raises ------ RuntimeError If the underlying database was not initialized prior to calling this function Notes ----- To initialize the underlying database you must first call `Database(...)` with the appropriate arguments for the chosen backend """ return Database()
[docs]def setup_database(config=None): """Create the Database instance from a configuration. Parameters ---------- config: dict Configuration for the database backend. If not defined, global configuration is used. """ if config is None: # TODO: How could we support orion.core.config.storage.database as well? config = orion.core.config.database.to_dict() db_opts = config dbtype = db_opts.pop("type") log.debug("Creating %s database client with args: %s", dbtype, db_opts) try: Database(of_type=dbtype, **db_opts) except ValueError: if Database().__class__.__name__.lower() != dbtype.lower(): raise
[docs]class Legacy(BaseStorageProtocol): """Legacy protocol, store all experiments and trials inside the Database() Parameters ---------- config: Dict configuration definition passed from experiment_builder to storage factory to legacy constructor. See `~orion.io.database.Database` for more details setup: bool Setup the database (create indexes) """ def __init__(self, database=None, setup=True): if database is not None: setup_database(database) self._db = Database() if setup: self._setup_db() def _setup_db(self): """Database index setup""" if backward.db_is_outdated(self._db): raise OutdatedDatabaseError( "The database is outdated. You can upgrade it with the " "command `orion db upgrade`." ) self._db.index_information("experiment") self._db.ensure_index( "experiments", [("name", Database.ASCENDING), ("version", Database.ASCENDING)], unique=True, ) self._db.ensure_index("experiments", "metadata.datetime") self._db.ensure_index("benchmarks", "name", unique=True) self._db.ensure_index("trials", "experiment") self._db.ensure_index("trials", "status") self._db.ensure_index("trials", "results") self._db.ensure_index("trials", "start_time") self._db.ensure_index("trials", [("end_time", Database.DESCENDING)])
[docs] def create_benchmark(self, config): """Insert a new benchmark inside the database""" return self._db.write("benchmarks", data=config, query=None)
[docs] def fetch_benchmark(self, query, selection=None): """Fetch all benchmarks that match the query""" return self._db.read("benchmarks", query, selection)
[docs] def create_experiment(self, config): """See :func:`orion.storage.base.BaseStorageProtocol.create_experiment`""" return self._db.write("experiments", data=config, query=None)
[docs] def delete_experiment(self, experiment=None, uid=None): """See :func:`orion.storage.base.BaseStorageProtocol.delete_experiment`""" uid = get_uid(experiment, uid) return self._db.remove("experiments", query={"_id": uid})
[docs] def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): """See :func:`orion.storage.base.BaseStorageProtocol.update_experiment`""" uid = get_uid(experiment, uid) if where is None: where = dict() if uid is not None: where["_id"] = uid return self._db.write("experiments", data=kwargs, query=where)
[docs] def fetch_experiments(self, query, selection=None): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_experiments`""" return self._db.read("experiments", query, selection)
[docs] def fetch_trials(self, experiment=None, uid=None, where=None): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`""" uid = get_uid(experiment, uid) if where is None: where = dict() where["experiment"] = uid return self._fetch_trials(where)
def _fetch_trials(self, query, selection=None): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`""" def sort_key(item): submit_time = item.submit_time if submit_time is None: return 0 return submit_time trials = Trial.build(self._db.read("trials", query=query, selection=selection)) trials.sort(key=sort_key) return trials
[docs] def register_trial(self, trial): """See :func:`orion.storage.base.BaseStorageProtocol.register_trial`""" self._db.write("trials", trial.to_dict()) return trial
[docs] def delete_trials(self, experiment=None, uid=None, where=None): """See :func:`orion.storage.base.BaseStorageProtocol.delete_trials`""" uid = get_uid(experiment, uid) if where is None: where = dict() if uid is not None: where["experiment"] = uid return self._db.remove("trials", query=where)
[docs] def register_lie(self, trial): """See :func:`orion.storage.base.BaseStorageProtocol.register_lie`""" return self._db.write("lying_trials", trial.to_dict())
[docs] def retrieve_result(self, trial, results_file=None, **kwargs): """Parse the results file that was generated by the trial process. Parameters ---------- trial: Trial The trial object to be updated results_file: str the file handle to read the result from Returns ------- returns the updated trial object Notes ----- This does not update the database! """ if results_file is None: return trial try: results = JSONConverter().parse(results_file.name) except json.decoder.JSONDecodeError: raise MissingResultFile() trial.results = [ Trial.Result(name=res["name"], type=res["type"], value=res["value"]) for res in results ] return trial
[docs] def get_trial(self, trial=None, uid=None): """See :func:`orion.storage.base.BaseStorageProtocol.get_trial`""" if trial is not None and uid is not None: assert trial._id == uid if uid is None: if trial is None: raise MissingArguments("Either `trial` or `uid` should be set") uid = trial.id result = self._db.read("trials", {"_id": uid}) if not result: return None return Trial(**result[0])
[docs] def update_trials(self, experiment=None, uid=None, where=None, **kwargs): """See :func:`orion.storage.base.BaseStorageProtocol.update_trials`""" uid = get_uid(experiment, uid) if where is None: where = dict() where["experiment"] = uid return self._db.write("trials", data=kwargs, query=where)
[docs] def update_trial(self, trial=None, uid=None, where=None, **kwargs): """See :func:`orion.storage.base.BaseStorageProtocol.update_trial`""" uid = get_uid(trial, uid) if where is None: where = dict() where["_id"] = uid return self._db.write("trials", data=kwargs, query=where)
[docs] def fetch_lost_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`""" heartbeat = orion.core.config.worker.heartbeat threshold = datetime.datetime.utcnow() - datetime.timedelta( seconds=heartbeat * 5 ) lte_comparison = {"$lte": threshold} query = { "experiment": experiment._id, "status": "reserved", "heartbeat": lte_comparison, } return self._fetch_trials(query)
[docs] def push_trial_results(self, trial): """See :func:`orion.storage.base.BaseStorageProtocol.push_trial_results`""" rc = self.update_trial( trial, **trial.to_dict(), where={"_id": trial.id, "status": "reserved"} ) if not rc: raise FailedUpdate() return rc
[docs] def set_trial_status(self, trial, status, heartbeat=None, was=None): """See :func:`orion.storage.base.BaseStorageProtocol.set_trial_status`""" heartbeat = heartbeat or datetime.datetime.utcnow() was = was or trial.status validate_status(status) validate_status(was) update = dict(status=status, heartbeat=heartbeat, experiment=trial.experiment) rc = self.update_trial(trial, **update, where={"status": was, "_id": trial.id}) if not rc: raise FailedUpdate() trial.status = status
[docs] def fetch_pending_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_pending_trials`""" query = dict( experiment=experiment._id, status={"$in": ["new", "suspended", "interrupted"]}, ) return self._fetch_trials(query)
[docs] def reserve_trial(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.reserve_trial`""" query = dict( experiment=experiment._id, status={"$in": ["interrupted", "new", "suspended"]}, ) # read and write works on a single document now = datetime.datetime.utcnow() trial = self._db.read_and_write( "trials", query=query, data=dict(status="reserved", start_time=now, heartbeat=now), ) if trial is None: return None return Trial(**trial)
[docs] def fetch_noncompleted_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_noncompleted_trials`""" query = dict(experiment=experiment._id, status={"$ne": "completed"}) return self._fetch_trials(query)
[docs] def count_completed_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.count_completed_trials`""" query = dict(experiment=experiment._id, status="completed") return self._db.count("trials", query)
[docs] def count_broken_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.count_broken_trials`""" query = dict(experiment=experiment._id, status="broken") return self._db.count("trials", query)
[docs] def update_heartbeat(self, trial): """Update trial's heartbeat""" return self.update_trial( trial, heartbeat=datetime.datetime.utcnow(), where={"status": "reserved"} )
[docs] def fetch_trials_by_status(self, experiment, status): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials_by_status`""" query = dict(experiment=experiment._id, status=status) return self._fetch_trials(query)