# -*- coding: utf-8 -*-
"""
Legacy storage
==============
Old Storage implementation.
"""
import datetime
import json
import logging
import orion.core
import orion.core.utils.backward as backward
from orion.core.io.convert import JSONConverter
from orion.core.io.database import Database, OutdatedDatabaseError
from orion.core.utils.exceptions import MissingResultFile
from orion.core.worker.trial import Trial, validate_status
from orion.storage.base import (
BaseStorageProtocol,
FailedUpdate,
MissingArguments,
get_uid,
)
log = logging.getLogger(__name__)
[docs]def get_database():
"""Return current database
This is a wrapper around the Database Singleton object to provide
better error message when it is used without being initialized.
Raises
------
RuntimeError
If the underlying database was not initialized prior to calling this function
Notes
-----
To initialize the underlying database you must first call `Database(...)`
with the appropriate arguments for the chosen backend
"""
return Database()
[docs]def setup_database(config=None):
"""Create the Database instance from a configuration.
Parameters
----------
config: dict
Configuration for the database backend. If not defined, global configuration
is used.
"""
if config is None:
# TODO: How could we support orion.core.config.storage.database as well?
config = orion.core.config.database.to_dict()
db_opts = config
dbtype = db_opts.pop("type")
log.debug("Creating %s database client with args: %s", dbtype, db_opts)
try:
Database(of_type=dbtype, **db_opts)
except ValueError:
if Database().__class__.__name__.lower() != dbtype.lower():
raise
[docs]class Legacy(BaseStorageProtocol):
"""Legacy protocol, store all experiments and trials inside the Database()
Parameters
----------
config: Dict
configuration definition passed from experiment_builder
to storage factory to legacy constructor.
See `~orion.io.database.Database` for more details
setup: bool
Setup the database (create indexes)
"""
def __init__(self, database=None, setup=True):
if database is not None:
setup_database(database)
self._db = Database()
if setup:
self._setup_db()
def _setup_db(self):
"""Database index setup"""
if backward.db_is_outdated(self._db):
raise OutdatedDatabaseError(
"The database is outdated. You can upgrade it with the "
"command `orion db upgrade`."
)
self._db.index_information("experiment")
self._db.ensure_index(
"experiments",
[("name", Database.ASCENDING), ("version", Database.ASCENDING)],
unique=True,
)
self._db.ensure_index("experiments", "metadata.datetime")
self._db.ensure_index("benchmarks", "name", unique=True)
self._db.ensure_index("trials", "experiment")
self._db.ensure_index("trials", "status")
self._db.ensure_index("trials", "results")
self._db.ensure_index("trials", "start_time")
self._db.ensure_index("trials", [("end_time", Database.DESCENDING)])
[docs] def create_benchmark(self, config):
"""Insert a new benchmark inside the database"""
return self._db.write("benchmarks", data=config, query=None)
[docs] def fetch_benchmark(self, query, selection=None):
"""Fetch all benchmarks that match the query"""
return self._db.read("benchmarks", query, selection)
[docs] def create_experiment(self, config):
"""See :func:`orion.storage.base.BaseStorageProtocol.create_experiment`"""
return self._db.write("experiments", data=config, query=None)
[docs] def delete_experiment(self, experiment=None, uid=None):
"""See :func:`orion.storage.base.BaseStorageProtocol.delete_experiment`"""
uid = get_uid(experiment, uid)
return self._db.remove("experiments", query={"_id": uid})
[docs] def update_experiment(self, experiment=None, uid=None, where=None, **kwargs):
"""See :func:`orion.storage.base.BaseStorageProtocol.update_experiment`"""
uid = get_uid(experiment, uid)
if where is None:
where = dict()
if uid is not None:
where["_id"] = uid
return self._db.write("experiments", data=kwargs, query=where)
[docs] def fetch_experiments(self, query, selection=None):
"""See :func:`orion.storage.base.BaseStorageProtocol.fetch_experiments`"""
return self._db.read("experiments", query, selection)
[docs] def fetch_trials(self, experiment=None, uid=None):
"""See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`"""
uid = get_uid(experiment, uid)
return self._fetch_trials(dict(experiment=uid))
def _fetch_trials(self, query, selection=None):
"""See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`"""
def sort_key(item):
submit_time = item.submit_time
if submit_time is None:
return 0
return submit_time
trials = Trial.build(self._db.read("trials", query=query, selection=selection))
trials.sort(key=sort_key)
return trials
[docs] def register_trial(self, trial):
"""See :func:`orion.storage.base.BaseStorageProtocol.register_trial`"""
self._db.write("trials", trial.to_dict())
return trial
[docs] def delete_trials(self, experiment=None, uid=None, where=None):
"""See :func:`orion.storage.base.BaseStorageProtocol.delete_trials`"""
uid = get_uid(experiment, uid)
if where is None:
where = dict()
if uid is not None:
where["experiment"] = uid
return self._db.remove("trials", query=where)
[docs] def register_lie(self, trial):
"""See :func:`orion.storage.base.BaseStorageProtocol.register_lie`"""
return self._db.write("lying_trials", trial.to_dict())
[docs] def retrieve_result(self, trial, results_file=None, **kwargs):
"""Parse the results file that was generated by the trial process.
Parameters
----------
trial: Trial
The trial object to be updated
results_file: str
the file handle to read the result from
Returns
-------
returns the updated trial object
Notes
-----
This does not update the database!
"""
if results_file is None:
return trial
try:
results = JSONConverter().parse(results_file.name)
except json.decoder.JSONDecodeError:
raise MissingResultFile()
trial.results = [
Trial.Result(name=res["name"], type=res["type"], value=res["value"])
for res in results
]
return trial
[docs] def get_trial(self, trial=None, uid=None):
"""See :func:`orion.storage.base.BaseStorageProtocol.get_trial`"""
if trial is not None and uid is not None:
assert trial._id == uid
if uid is None:
if trial is None:
raise MissingArguments("Either `trial` or `uid` should be set")
uid = trial.id
result = self._db.read("trials", {"_id": uid})
if not result:
return None
return Trial(**result[0])
[docs] def update_trials(self, experiment=None, uid=None, where=None, **kwargs):
"""See :func:`orion.storage.base.BaseStorageProtocol.update_trials`"""
uid = get_uid(experiment, uid)
if where is None:
where = dict()
where["experiment"] = uid
return self._db.write("trials", data=kwargs, query=where)
[docs] def update_trial(self, trial=None, uid=None, where=None, **kwargs):
"""See :func:`orion.storage.base.BaseStorageProtocol.update_trial`"""
uid = get_uid(trial, uid)
if where is None:
where = dict()
where["_id"] = uid
return self._db.write("trials", data=kwargs, query=where)
[docs] def fetch_lost_trials(self, experiment):
"""See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`"""
heartbeat = orion.core.config.worker.heartbeat
threshold = datetime.datetime.utcnow() - datetime.timedelta(
seconds=heartbeat * 5
)
lte_comparison = {"$lte": threshold}
query = {
"experiment": experiment._id,
"status": "reserved",
"heartbeat": lte_comparison,
}
return self._fetch_trials(query)
[docs] def push_trial_results(self, trial):
"""See :func:`orion.storage.base.BaseStorageProtocol.push_trial_results`"""
rc = self.update_trial(
trial, **trial.to_dict(), where={"_id": trial.id, "status": "reserved"}
)
if not rc:
raise FailedUpdate()
return rc
[docs] def set_trial_status(self, trial, status, heartbeat=None):
"""See :func:`orion.storage.base.BaseStorageProtocol.set_trial_status`"""
if heartbeat is None:
heartbeat = datetime.datetime.utcnow()
update = dict(status=status, heartbeat=heartbeat, experiment=trial.experiment)
validate_status(status)
rc = self.update_trial(
trial, **update, where={"status": trial.status, "_id": trial.id}
)
if not rc:
raise FailedUpdate()
trial.status = status
[docs] def fetch_pending_trials(self, experiment):
"""See :func:`orion.storage.base.BaseStorageProtocol.fetch_pending_trials`"""
query = dict(
experiment=experiment._id,
status={"$in": ["new", "suspended", "interrupted"]},
)
return self._fetch_trials(query)
[docs] def reserve_trial(self, experiment):
"""See :func:`orion.storage.base.BaseStorageProtocol.reserve_trial`"""
query = dict(
experiment=experiment._id,
status={"$in": ["interrupted", "new", "suspended"]},
)
# read and write works on a single document
now = datetime.datetime.utcnow()
trial = self._db.read_and_write(
"trials",
query=query,
data=dict(status="reserved", start_time=now, heartbeat=now),
)
if trial is None:
return None
return Trial(**trial)
[docs] def fetch_noncompleted_trials(self, experiment):
"""See :func:`orion.storage.base.BaseStorageProtocol.fetch_noncompleted_trials`"""
query = dict(experiment=experiment._id, status={"$ne": "completed"})
return self._fetch_trials(query)
[docs] def count_completed_trials(self, experiment):
"""See :func:`orion.storage.base.BaseStorageProtocol.count_completed_trials`"""
query = dict(experiment=experiment._id, status="completed")
return self._db.count("trials", query)
[docs] def count_broken_trials(self, experiment):
"""See :func:`orion.storage.base.BaseStorageProtocol.count_broken_trials`"""
query = dict(experiment=experiment._id, status="broken")
return self._db.count("trials", query)
[docs] def update_heartbeat(self, trial):
"""Update trial's heartbeat"""
return self.update_trial(
trial, heartbeat=datetime.datetime.utcnow(), status="reserved"
)
[docs] def fetch_trials_by_status(self, experiment, status):
"""See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials_by_status`"""
query = dict(experiment=experiment._id, status=status)
return self._fetch_trials(query)