Checkpointing trials

Hint

In short, you should use “{experiment.working_dir}/{trial.hash_params}” to set the path of the checkpointing file.

When using multi-fidelity algorithms such as Hyperband it is preferable to checkpoint the trials to avoid starting training from scratch when resuming a trial. In this tutorial for instance, hyperband will train VGG11 for 1 epoch, pick the best candidates and train them for 7 more epochs, doing the same again for 30 epoch, and then 120 epochs. We want to resume training at last epoch instead of starting from scratch.

Oríon provides a unique hash for trials that can be used to define the unique checkpoint file path: trial.hash_params. This can be used with the Python API as demonstrated in this example or with Command-line templating.

With command line

The example below is based on the Python API solely. It is also possible to do checkpointing using the command line API. To this end, your script should accept an argument for the checkpoint file path. Suppose this argument is --checkpoint, you should call your script with the following template.

orion hunt -n <exp name>
    ./your_script.sh --checkpoint '{experiment.working_dir}/{trial.hash_params}'

Your script is responsible to take this checkpoint path, resume from checkpoints or same checkpoints. We will demonstrate below how this can be done with PyTorch, but using Oríon’s Python API.

Note for macos users : You will need to either run this page as a jupyter notebook in order for it to compile, or encapsulate the code in a main function and running it under if __name__ == '__main__'.

Training code

We will first go through the training code piece by piece before tackling the hyperparameter optimization.

First things first, the imports.

# noqa
import os

# flake8: noqa
import numpy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import SubsetRandomSampler

We will use the data SubsetRandomSampler data loader from PyTorch to split the training set into a training and validation sets. We include test set here for completeness but won’t use it in this example as we only need the training data and the validation data for the hyperparameter optimization. We use torchvision’s transformers to apply the standard transformations on CIFAR10 images, that is, random cropping, random horizontal flipping and normalization.

def build_data_loaders(batch_size, split_seed=1):
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]

    augment = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]

    train_set = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose(augment + normalize),
    )
    valid_set = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose(normalize),
    )
    test_set = torchvision.datasets.CIFAR10(
        root="./data",
        train=False,
        download=True,
        transform=transforms.Compose(normalize),
    )

    num_train = 45000
    # num_valid = 5000
    indices = numpy.arange(num_train)
    numpy.random.RandomState(split_seed).shuffle(indices)

    train_idx, valid_idx = indices[:num_train], indices[num_train:]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=batch_size, sampler=train_sampler, num_workers=5
    )
    valid_loader = torch.utils.data.DataLoader(
        train_set, batch_size=1000, sampler=train_sampler, num_workers=5
    )
    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=1000, shuffle=False, num_workers=5
    )

    return train_loader, valid_loader, test_loader

Next, we write the function to save checkpoints. It is important to include not only the model in the checkpoint, but also the optimizer and the learning rate schedule when using one. In this example we will use the exponential learning rate schedule, so we checkpoint it. We save the current epoch as well so that we know where we resume from.

def save_checkpoint(checkpoint, model, optimizer, lr_scheduler, epoch):
    if checkpoint is None:
        return

    state = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "lr_scheduler": lr_scheduler.state_dict(),
        "epoch": epoch,
    }
    torch.save(state, f"{checkpoint}/checkpoint.pth")

To resume from checkpoints, we simply restore the states of the model, optimizer and learning rate schedules based on the checkpoint file. If there is no checkpoint path or if the file does not exist, we return epoch 1 so that training starts from scratch. Otherwise we return the last trained epoch number found in checkpoint file.

def resume_from_checkpoint(checkpoint, model, optimizer, lr_scheduler):
    if checkpoint is None:
        return 1

    try:
        state_dict = torch.load(f"{checkpoint}/checkpoint.pth")
    except FileNotFoundError:
        return 1

    model.load_state_dict(state_dict["model"])
    optimizer.load_state_dict(state_dict["optimizer"])
    lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
    return state_dict["epoch"] + 1  # Start from next epoch

Then comes the training loop for one epoch.

def train(loader, device, model, optimizer, lr_scheduler, criterion):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

Finally the validation loop to compute the validation error rate.

def valid(loader, device, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return 100.0 * (1 - correct / total)

We combine all these functions into a main function for the whole training pipeline.

Note

We set batch_size to 1024 by default, you may need to reduce it depending on your GPU.

def main(
    epochs=120,
    learning_rate=0.1,
    momentum=0.9,
    weight_decay=0,
    batch_size=1024,
    gamma=0.97,
    checkpoint=None,
):

    # We create the checkpointing folder if it does not exist.
    if checkpoint and not os.path.isdir(checkpoint):
        os.makedirs(checkpoint)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = models.vgg11()
    model = model.to(device)

    # We define the training criterion, optimizer and learning rate scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(),
        lr=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay,
    )
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)

    # We restore the states of model, optimizer and learning rate scheduler if a checkpoint file is
    # available. This will return the last epoch number of the checkpoint or 1 if no checkpoint.
    start_epoch = resume_from_checkpoint(checkpoint, model, optimizer, lr_scheduler)

    # We build the data loaders. test_loader is here for completeness but won't be used.
    train_loader, valid_loader, test_loader = build_data_loaders(batch_size=batch_size)

    # If no training needed, because the trial was resumed from an epoch equal or greater to number
    # of epochs requested here (``epochs``).
    if start_epoch >= epochs + 1:
        return valid(valid_loader, device, model)

    # Training from last epoch until ``epochs + 1``, checkpointing at end of each epoch.
    for epoch in range(start_epoch, epochs + 1):
        print("epoch", epoch)
        train(train_loader, device, model, optimizer, lr_scheduler, criterion)
        valid_error_rate = valid(valid_loader, device, model)
        save_checkpoint(checkpoint, model, optimizer, lr_scheduler, epoch)

    return valid_error_rate

You can test the training pipeline before working with the hyperparameter optimization.

main(epochs=4)

HPO code

We finally implement the hyperparameter optimization loop. We will use Hyperband with the number of epochs as the fidelity, using the prior fidelity(1, 120, base=4). Hyperband will thus train VGG11 for 1, 7, 30 and 120 epochs. To explore enough candidates at 120 epochs, we set Hyperband with 5 repetitions.

In the optimization loop (while not experiment.is_done), we ask Oríon to suggest a new trial and then pass the hyperparameter values **trial.params to main(), specifying the checkpoint file with f"{experiment.working_dir}/{trial.hash_params}".

from orion.client import build_experiment


def run_hpo():

    # Specify the database where the experiments are stored. We use a local PickleDB here.
    storage = {
        "type": "legacy",
        "database": {
            "type": "pickleddb",
            "host": "./db.pkl",
        },
    }

    # Load the data for the specified experiment
    experiment = build_experiment(
        "hyperband-cifar10",
        space={
            "epochs": "fidelity(1, 120, base=4)",
            "learning_rate": "loguniform(1e-5, 0.1)",
            "momentum": "uniform(0, 0.9)",
            "weight_decay": "loguniform(1e-10, 1e-2)",
            "gamma": "loguniform(0.97, 1)",
        },
        algorithm={
            "hyperband": {
                "seed": 1,
                "repetitions": 5,
            },
        },
        storage=storage,
    )

    trials = 1
    while not experiment.is_done:
        print("trial", trials)
        trial = experiment.suggest()
        if trial is None and experiment.is_done:
            break
        valid_error_rate = main(
            **trial.params, checkpoint=f"{experiment.working_dir}/{trial.hash_params}"
        )
        experiment.observe(trial, valid_error_rate, name="valid_error_rate")
        trials += 1

Let’s run the optimization now. You may want to reduce the maximum number of epochs in fidelity(1, 120, base=4) and set the number of repetitions to 1 to get results more quickly. With current configuration, this example takes 2 days to run on a Titan RTX.

experiment = run_hpo()

Analysis

That is all for the checkpointing example. We should nevertheless analyse the results before wrapping up this tutorial.

We should first look at the Regret curves to verify the optimization with Hyperband.

fig = experiment.plot.regret()
fig.show()