Source code for orion.core.cli.db.set

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Module running the set command
==============================

Update data of experiments and trials in the storage

"""
import argparse
import logging
import sys

import orion.core.io.experiment_builder as experiment_builder
from orion.core.utils.pptree import print_tree
from orion.core.utils.terminal import confirm_name
from orion.storage.base import get_storage

logger = logging.getLogger(__name__)


DESCRIPTION = """
Command to update trial attributes.

To change a trial status, simply give the experiment name,
trial id and status. (use `orion status --all` to get trial ids)
$ orion db set my-exp-name id=3cc91e851e13281ca2152c19d888e937 status=interrupted

To change all trials from a given status to another, simply give the two status
$ orion db set my-exp-name status=broken status=interrupted

Or `*` to apply the change to all trials
$ orion db set my-exp-name '*' status=interrupted

By default, trials of the last version of the experiment are selected.
Add --version to select a prior version. Note that the modification
is applied recursively to all child experiment, but not to the parents.
$ orion db set my-exp-name --version 1 status=broken status=interrupted
"""


CONFIRM_MESSAGE = """
Trials matching the query `{query}` for all experiments listed above
will be modified with `{update}`.
To select a specific version use --version <VERSION>.

Make sure to stop any worker currently executing one of these experiment.

To proceed, type again the name of the experiment: """


[docs]def add_subparser(parser): """Return the parser that needs to be used for this command""" set_parser = parser.add_parser( "set", description=DESCRIPTION, help="Update trials' attributes", formatter_class=argparse.RawTextHelpFormatter, ) set_parser.set_defaults(func=main) set_parser.add_argument("name", help="Name of the experiment to delete.") set_parser.add_argument( "-c", "--config", type=argparse.FileType("r"), metavar="path-to-config", help="user provided orion configuration file", ) set_parser.add_argument( "-v", "--version", type=int, default=None, help="specific version of experiment to fetch; " "(default: last version matching.)", ) set_parser.add_argument( "-f", "--force", action="store_true", help="Force modify without asking to enter experiment name twice.", ) set_parser.add_argument( "query", help=( f"Query for trials to update. Can be `*` to update all trials. " f"Otherwise format must be <attribute>=<value>. Ex: status=broken. " f"Supported attributes are {VALID_QUERY_ATTRS}" ), ) set_parser.add_argument( "update", help=( f"Update for trials. Format must be <attribute>=<value>. " f"Ex: status=interrupted. " f"Supported attributes are {VALID_UPDATE_ATTRS}" ), ) return set_parser
[docs]def process_updates(storage, root, query, update): """Update the matching trials of the given experiment and its children.""" trials_total = 0 for node in root: count = storage.update_trials(node.item, where=query, **update) logger.debug( "%d trials modified in experiment %s-v%d", count, node.item.name, node.item.version, ) trials_total += count print(f"{trials_total} trials modified")
VALID_QUERY_ATTRS = ["status", "id"] VALID_UPDATE_ATTRS = ["status"]
[docs]def build_query(query): """Convert query string to dict format String format must be <attr name>=<value> """ if query.strip() == "*": return {} attribute, value = query.split("=") if attribute not in VALID_QUERY_ATTRS: raise ValueError( f"Invalid query attribute `{attribute}`. Must be one of {VALID_QUERY_ATTRS}" ) if attribute == "id": attribute = "_id" return {attribute: value}
[docs]def build_update(update): """Convert update string to dict format String format must be <attr name>=<value> """ attribute, value = update.split("=") if attribute not in VALID_UPDATE_ATTRS: raise ValueError( f"Invalid update attribute `{attribute}`. Must be one of {VALID_UPDATE_ATTRS}" ) return {attribute: value}
[docs]def main(args): """Remove the experiment(s) or trial(s).""" config = experiment_builder.get_cmd_config(args) experiment_builder.setup_storage(config.get("storage")) # Find root experiment root = experiment_builder.load( name=args["name"], version=args.get("version", None) ).node try: query = build_query(args["query"]) update = build_update(args["update"]) except ValueError as e: print(f"Error: {e}", file=sys.stderr) return 1 # List all experiments with children print_tree(root, nameattr="tree_name") confirmed = confirm_name( CONFIRM_MESSAGE.format(query=args["query"], update=args["update"]), args["name"], args["force"], ) if not confirmed: print("Confirmation failed, aborting operation.") return 1 storage = get_storage() process_updates(storage, root, query, update) return 0