Source code for orion.core.cli.db.upgrade

#!/usr/bin/env python
"""
Module to upgrade DB schemes
============================

Upgrade the scheme of the databases

"""
import argparse
import logging
import sys

from orion.core.io import experiment_builder
from orion.core.io.config import ConfigurationError
from orion.core.io.database.ephemeraldb import EphemeralCollection
from orion.core.io.database.mongodb import MongoDB
from orion.core.io.database.pickleddb import PickledDB
from orion.storage.legacy import Legacy

log = logging.getLogger(__name__)
SHORT_DESCRIPTION = "Upgrade the database scheme"


# TODO: Move somewhere else to share with `db setup`.
[docs]def ask_question(question, default=None): """Ask a question to the user and receive an answer. Parameters ---------- question: str The question to be asked. default: str The default value to use if the user enters nothing. Returns ------- str The answer provided by the user. """ if default is not None: question = f"{question} (default: {default}) " answer = input(question) if answer.strip() == "": return default return answer
[docs]def add_subparser(parser): """Add the subparser that needs to be used for this command""" upgrade_db_parser = parser.add_parser( "upgrade", help=SHORT_DESCRIPTION, description=SHORT_DESCRIPTION ) upgrade_db_parser.add_argument( "-c", "--config", type=argparse.FileType("r"), metavar="path-to-config", help="user provided " "orion configuration file", ) upgrade_db_parser.add_argument( "-f", "--force", action="store_true", help="Don't prompt user" ) upgrade_db_parser.set_defaults(func=main) return upgrade_db_parser
[docs]def main(args): """Upgrade the databases for current version""" print( "Upgrading your database may damage your data. Make sure to make a backup before the " "upgrade and stop any other process that may read/write the database during the upgrade." ) if not args.get("force"): action = "" while action not in ["y", "yes", "no", "n"]: action = ask_question("Do you wish to proceed? (y/N)", "N").lower() if action in ["no", "n"]: sys.exit(0) config = experiment_builder.get_cmd_config(args) storage_config = config.get("storage") if storage_config is None: storage_config = {"type": "legacy"} storage_config["setup"] = False builder = experiment_builder.ExperimentBuilder(storage_config) upgrade_db_specifics(builder.storage) print("Updating documents...") upgrade_documents(builder.storage) print("Database upgrade completed successfully")
[docs]def upgrade_db_specifics(storage): """Make upgrades that are specific to some backends""" if isinstance(storage, Legacy): database = storage._db # pylint: disable=protected-access print("Updating indexes...") update_indexes(database) if isinstance(database, PickledDB): print("Updating pickledb scheme...") upgrade_pickledb(database) elif isinstance(database, MongoDB): print("Updating mongodb scheme...") upgrade_mongodb(database)
[docs]def upgrade_documents(storage): """Upgrade scheme of the documents""" for experiment in storage.fetch_experiments({}): add_version(experiment) uid = experiment.pop("_id") algorithm = None if "algorithms" in experiment: algorithm = experiment.pop("algorithms") if "algorithm" in experiment: algorithm = experiment.pop("algorithm") if algorithm is None: raise ConfigurationError( "The data was corrupted, as there was no algorithm in the experiment" ) storage.update_experiment(uid=experiment, **experiment) storage.initialize_algorithm_lock(uid, algorithm) for trial in storage.fetch_trials(uid=uid): # trial_config = trial.to_dict() assert trial.id_override is not None # This will overwrite the trial with correct dict structure # (ex: have _id set to id_override and id set to trial.hash_name == trial.id. # storage.update_trial(trial, uid=trial.id_override) # pylint: disable=protected-access storage._db.write( "trials", data=trial.to_dict(), query={"_id": trial.id_override} )
[docs]def add_version(experiment): """Add version 1 if not present""" experiment.setdefault("version", 1)
[docs]def update_indexes(database): """Remove user from unique indices. This is required for migration to v0.1.6+ """ # For backward compatibility index_info = database.index_information("experiments") deprecated_indices = [ ("name", "metadata.user"), ("name", "metadata.user", "version"), "name_1_metadata.user_1", "name_1_metadata.user_1_version_1", ] for deprecated_idx in deprecated_indices: if deprecated_idx in index_info: database.drop_index("experiments", deprecated_idx)
# pylint: disable=unused-argument
[docs]def upgrade_mongodb(database): """Update mongo specific db scheme."""
[docs]def upgrade_pickledb(database): """Update pickledb specific db scheme.""" # pylint: disable=protected-access def upgrade_state(self, state): """Set state while ensuring backward compatibility""" self._documents = state["_documents"] # if indexes are from <=v0.1.6 if state["_indexes"] and isinstance( next(iter(state["_indexes"].keys())), tuple ): self._indexes = {} for keys, values in state["_indexes"].items(): if isinstance(keys, str): self._indexes[keys] = values # Convert keys that were registered with old index signature else: keys = [(key, None) for key in keys] self.create_index(keys, unique=True) else: self._indexes = state["_indexes"] old_setstate = getattr(EphemeralCollection, "__setstate__", None) EphemeralCollection.__setstate__ = upgrade_state document = database.read("experiments", {})[0] # One document update is enough to fix all collections database.write("experiments", document, query={"_id": document["_id"]}) if old_setstate is not None: EphemeralCollection.__setstate__ = old_setstate else: del EphemeralCollection.__setstate__