{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Checkpointing trials\n\n.. hint::\n\n    In short, you should use \"{experiment.working_dir}/{trial.hash_params}\" to set the path of\n    the checkpointing file.\n\nWhen using multi-fidelity algorithms such as Hyperband it is preferable to checkpoint the trials\nto avoid starting training from scratch when resuming a trial. In this tutorial for instance,\nhyperband will train VGG11 for 1 epoch, pick the best candidates and train them for 7 more epochs,\ndoing the same again for 30 epoch, and then 120 epochs. We want to resume training at last epoch\ninstead of starting from scratch.\n\nOr\u00edon provides a unique hash for trials that can be used to define the unique checkpoint file\npath: ``trial.hash_params``. This can be used with the Python API as demonstrated in this example\nor with `commandline_templates`.\n\n## With command line\n\nThe example below is based on the Python API solely. It is also possible to do checkpointing\nusing the command line API. To this end, your script should accept an argument for the checkpoint\nfile path. Suppose this argument is ``--checkpoint``, you should call your script with the\nfollowing template.\n\n::\n\n    orion hunt -n <exp name>\n        ./your_script.sh --checkpoint '{experiment.working_dir}/{trial.hash_params}'\n\nYour script is reponsible to take this checkpoint path, resume from checkpoints or same\ncheckpoints.\nWe will demonstrate below how this can be done with PyTorch, but using Or\u00edon's Python API.\n\n## Training code\n\nWe will first go through the training code piece by piece before tackling the hyperparameter\noptimization.\n\nFirst things first, the imports.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import SubsetRandomSampler\n\nimport torchvision\nimport torchvision.models as models\nimport torchvision.transforms as transforms\n\nimport os\nimport argparse"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We will use the data SubsetRandomSampler data loader from PyTorch to split\nthe training set into a training and validation sets. We include test set\nhere for completeness but won't use it in this example as we only need the training\ndata and the validation data for the hyperparameter optimization.\nWe use torchvision's transformers to apply the standard transformations on CIFAR10\nimages, that is, random cropping, random horizontal flipping and normalization.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def build_data_loaders(batch_size, split_seed=1):\n    normalize = [\n        transforms.ToTensor(),\n        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n    ]\n\n    augment = [\n        transforms.RandomCrop(32, padding=4),\n        transforms.RandomHorizontalFlip(),\n    ]\n\n    train_set = torchvision.datasets.CIFAR10(\n        root=\"./data\",\n        train=True,\n        download=True,\n        transform=transforms.Compose(augment + normalize),\n    )\n    valid_set = torchvision.datasets.CIFAR10(\n        root=\"./data\",\n        train=True,\n        download=True,\n        transform=transforms.Compose(normalize),\n    )\n    test_set = torchvision.datasets.CIFAR10(\n        root=\"./data\",\n        train=False,\n        download=True,\n        transform=transforms.Compose(normalize),\n    )\n\n    num_train = 45000\n    # num_valid = 5000\n    indices = numpy.arange(num_train)\n    numpy.random.RandomState(split_seed).shuffle(indices)\n\n    train_idx, valid_idx = indices[:num_train], indices[num_train:]\n    train_sampler = SubsetRandomSampler(train_idx)\n    valid_sampler = SubsetRandomSampler(valid_idx)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_set, batch_size=batch_size, sampler=train_sampler, num_workers=5\n    )\n    valid_loader = torch.utils.data.DataLoader(\n        train_set, batch_size=1000, sampler=train_sampler, num_workers=5\n    )\n    test_loader = torch.utils.data.DataLoader(\n        test_set, batch_size=1000, shuffle=False, num_workers=5\n    )\n\n    return train_loader, valid_loader, test_loader"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next, we write the function to save checkpoints. It is important to include\nnot only the model in the checkpoint, but also the optimizer and the learning rate\nschedule when using one. In this example we will use the exponential learning rate schedule,\nso we checkpoint it. We save the current epoch as well so that we now where we resume from.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def save_checkpoint(checkpoint, model, optimizer, lr_scheduler, epoch):\n    if checkpoint is None:\n        return\n\n    state = {\n        \"model\": model.state_dict(),\n        \"optimizer\": optimizer.state_dict(),\n        \"lr_scheduler\": lr_scheduler.state_dict(),\n        \"epoch\": epoch,\n    }\n    torch.save(state, f\"{checkpoint}/checkpoint.pth\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To resume from checkpoints, we simply restore the states of the model, optimizer and learning rate\nschedules based on the checkpoint file. If there is no checkpoint path or if the file does not\nexist, we return epoch 1 so that training starts from scratch. Otherwise we return the last\ntrained epoch number found in checkpoint file.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def resume_from_checkpoint(checkpoint, model, optimizer, lr_scheduler):\n    if checkpoint is None:\n        return 1\n\n    try:\n        state_dict = torch.load(f\"{checkpoint}/checkpoint.pth\")\n    except FileNotFoundError:\n        return 1\n\n    model.load_state_dict(state_dict[\"model\"])\n    optimizer.load_state_dict(state_dict[\"optimizer\"])\n    lr_scheduler.load_state_dict(state_dict[\"lr_scheduler\"])\n    return state_dict[\"epoch\"] + 1  # Start from next epoch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then comes the training loop for one epoch.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def train(loader, device, model, optimizer, lr_scheduler, criterion):\n    model.train()\n    for batch_idx, (inputs, targets) in enumerate(loader):\n        inputs, targets = inputs.to(device), targets.to(device)\n        optimizer.zero_grad()\n        outputs = model(inputs)\n        loss = criterion(outputs, targets)\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Finally the validation loop to compute the validation error rate.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def valid(loader, device, model):\n    model.eval()\n    correct = 0\n    total = 0\n    with torch.no_grad():\n        for batch_idx, (inputs, targets) in enumerate(loader):\n            inputs, targets = inputs.to(device), targets.to(device)\n            outputs = model(inputs)\n\n            _, predicted = outputs.max(1)\n            total += targets.size(0)\n            correct += predicted.eq(targets).sum().item()\n\n    return 100.0 * (1 - correct / total)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We combine all these functions into a main function for the whole training pipeline.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>We set ``batch_size`` to 1024 by default, you may need to reduce it depending on your GPU.</p></div>\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def main(\n    epochs=120,\n    learning_rate=0.1,\n    momentum=0.9,\n    weight_decay=0,\n    batch_size=1024,\n    gamma=0.97,\n    checkpoint=None,\n):\n\n    # We create the checkpointing folder if it does not exist.\n    if checkpoint and not os.path.isdir(checkpoint):\n        os.makedirs(checkpoint)\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    model = models.vgg11()\n    model = model.to(device)\n\n    # We define the training criterion, optimizer and learning rate scheduler\n    criterion = nn.CrossEntropyLoss()\n    optimizer = optim.SGD(\n        model.parameters(),\n        lr=learning_rate,\n        momentum=momentum,\n        weight_decay=weight_decay,\n    )\n    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)\n\n    # We restore the states of model, optimizer and learning rate scheduler if a checkpoint file is\n    # available. This will return the last epoch number of the checkpoint or 1 if no checkpoint.\n    start_epoch = resume_from_checkpoint(checkpoint, model, optimizer, lr_scheduler)\n\n    # We build the data loaders. test_loader is here for completeness but won't be used.\n    train_loader, valid_loader, test_loader = build_data_loaders(batch_size=batch_size)\n\n    # If no training needed, because the trial was resumed from an epoch equal or greater to number\n    # of epochs requested here (``epochs``).\n    if start_epoch >= epochs + 1:\n        return valid(valid_loader, device, model)\n\n    # Training from last epoch until ``epochs + 1``, checkpointing at end of each epoch.\n    for epoch in range(start_epoch, epochs + 1):\n        print(\"epoch\", epoch)\n        train(train_loader, device, model, optimizer, lr_scheduler, criterion)\n        valid_error_rate = valid(valid_loader, device, model)\n        save_checkpoint(checkpoint, model, optimizer, lr_scheduler, epoch)\n\n    return valid_error_rate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "You can test the training pipeline before working with the hyperparameter optimization.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "main(epochs=4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## HPO code\n\nWe finally implement the hyperparameter optimization loop. We will use Hyperband\nwith the number of epochs as the fidelity, using the prior ``fidelity(1, 120, base=4)``.\nHyperband will thus train VGG11 for 1, 7, 30 and 120 epochs. To explore enough candidates\nat 120 epochs, we set Hyperband with 5 repetitions.\n\nIn the optimization loop (``while not experiment.is_done``), we ask Or\u00edon to suggest a new trial\nand then pass the hyperparameter values ``**trial.params`` to ``main()``, specifying the\ncheckpoint file with ``f\"{experiment.working_dir}/{trial.hash_params}\"``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from orion.client import build_experiment\n\n\ndef run_hpo():\n\n    # Specify the database where the experiments are stored. We use a local PickleDB here.\n    storage = {\n        \"type\": \"legacy\",\n        \"database\": {\n            \"type\": \"pickleddb\",\n            \"host\": \"./db.pkl\",\n        },\n    }\n\n    # Load the data for the specified experiment\n    experiment = build_experiment(\n        \"hyperband-cifar10\",\n        space={\n            \"epochs\": \"fidelity(1, 120, base=4)\",\n            \"learning_rate\": \"loguniform(1e-5, 0.1)\",\n            \"momentum\": \"uniform(0, 0.9)\",\n            \"weight_decay\": \"loguniform(1e-10, 1e-2)\",\n            \"gamma\": \"loguniform(0.97, 1)\",\n        },\n        algorithms={\n            \"hyperband\": {\n                \"seed\": 1,\n                \"repetitions\": 5,\n            },\n        },\n        storage=storage,\n    )\n\n    trials = 1\n    while not experiment.is_done:\n        print(\"trial\", trials)\n        trial = experiment.suggest()\n        if trial is None and experiment.is_done:\n            break\n        valid_error_rate = main(\n            **trial.params, checkpoint=f\"{experiment.working_dir}/{trial.hash_params}\"\n        )\n        experiment.observe(trial, valid_error_rate, name=\"valid_error_rate\")\n        trials += 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's run the optimization now. You may want to reduce the maximum number of epochs in\n``fidelity(1, 120, base=4)`` and set the number of ``repetitions`` to 1 to get results more\nquickly. With current configuration, this example takes 2 days to run on a Titan RTX.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "experiment = run_hpo()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Analysis\n\nThat is all for the checkpointing example. We should nevertheless analyse the results\nbefore wrapping up this tutorial.\n\nWe should first look at the `sphx_glr_auto_examples_plot_1_regret.py`\nto verify the optimization with Hyperband.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig = experiment.plot.regret()\nfig.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        ".. This file is produced by docs/scripts/build_database_and_plots.py\n\n.. raw:: html\n    :file: ../_static/hyperband-cifar10_regret.html\n\n\nMoving the cursor over the points, we see that only a handful of trials\nlead to better results with 1 epoch. Otherwise, all other trials with validation error rate\nbelow 80% were trained for more than 1 epoch.\nThe best found result is high, a validation accuracy 23.6%. With VGG11 we could expect to achieve\nlower than 10%. To see if the search space may be the issue, we first look at the\n`sphx_glr_auto_examples_plot_3_lpi.py`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig = experiment.plot.lpi()\nfig.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        ".. raw:: html\n    :file: ../_static/hyperband-cifar10_lpi.html\n\nThe momentum and weight decay had very large priors, yet the different values\nhad no important effect on the validation accuracy. We can ignore them.\nFor the learning rate and for gamma\nit is worth looking at the `sphx_glr_auto_examples_plot_4_partial_dependencies.py`\nto see if the search space was perhaps too narrow or too large.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig = experiment.plot.partial_dependencies(params=[\"gamma\", \"learning_rate\"])\nfig.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        ".. This file is produced by docs/scripts/build_database_and_plots.py\n\n.. raw:: html\n    :file: ../_static/hyperband-cifar10_partial_dependencies_params.html\n\nThe main culprit for the high validation error rate seems to be the wide prior for ``gamma``.\nBecause of this Hyperband spent most of the computation time on bad ``gamma``s. This prior\nshould be narrowed to ``uniform(0.995, 1)``.\nThe prior for the learning rate could also be narrowed to ``loguniform(0.001, 0.1)`` to\nhelp the optimization.\n\nNote that Hyperband could also find better results without adjusting the search space, but\nit would required significantly more repetitions.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}