Pruning with Surrogate Models
Note
You can find the corresponding Python scripts here:
https://github.com/Helmholtz-AI-Energy/propulate/tree/master/tutorials/surrogate
We all know that hyperparameter optimization is a critical aspect in machine learning that tries to find the most
effective settings for a model’s non-trainable parameters to optimize the trained model’s predictive performance.
Automated approaches like random search, grid search, Bayesian optimization, or population-based algorithms as
implemented in Propulate 🧬 train the neural network over and over again, testing new hyperparameters every
time. As each evaluation typically corresponds to a full training of a neural network model, we aim at finding effective
hyperparameter settings with as little evaluations as possible. Even though Propulate already makes smart choices
about which hyperparameters to test next compared to, e.g., plain grid search, this is still very compute-intensive,
especially with newer models getting bigger and bigger.
Predicting the performance of hyperparameter configurations during the training process allows for early termination of
less promising configurations. To this end, Propulate features so-called surrogate models, which have access to
interim loss values from each evaluated neural network’s training during the hyperparameter optimization and decide
whether to stop it early.
Our evaluation of static and probabilistic surrogate models for hyperparameter optimization in Propulate with
different datasets and neural networks showed a significant decrease in total run time and energy consumption while
still finding a loss within small bounds of the best loss found without early stopping.
Below, we will guide you through a basic example of how to use surrogate models for pruning in Propulate.
As in the hyperparameter optimization tutorial, let us again consider the problem of MNIST
classification with a simple convolutional neural network. The neural network class Net and the get_dataloaders()
function are left unchanged and reused as before:
GPUS_PER_NODE: int = 1
NUM_WORKERS: int = (
2 # Set this to the recommended number of workers in the PyTorch dataloader.
)
log_path = "torch_ckpts"
log = logging.getLogger(__name__) # Get logger instance.
class Net(nn.Module):
"""Convolutional neural network class."""
... # Reused from hyperparameter optimization tutorial without changes!
def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader]:
"""
Get MNIST train and validation dataloaders.
Parameters
----------
batch_size : int
The batch size.
Returns
-------
torch.utils.data.DataLoader
The training dataloader.
torch.utils.data.DataLoader
The validation dataloader.
"""
... # Reused from hyperparameter optimization tutorial without changes!
return train_loader, val_loader
The only thing that is different when using surrogate models is the individual’s loss function ind_loss(). As
mentioned before, surrogate models predict the performance of hyperparameter configurations during the training process
to stop less promising individuals early. To decide whether to stop an individual early, we need access to interim loss
values from each evaluated neural network’s training. This is achieved by yielding the average validation loss of each
evaluated candidate in regular intervals (e.g., after each epoch) during training. These interim loss values are fed
into the surrogate model which decides whether to continue or cancel the training and updates itself accordingly based
on the provided value.
def ind_loss(
params: Dict[str, Union[int, float, str]],
) -> Generator[float, None, None]:
"""
Loss function for evolutionary optimization with Propulate. Minimize the model's negative validation accuracy.
Parameters
----------
params : Dict[str, int | float | str]
The parameters to be optimized.
Returns
-------
Generator[float, None, None]
Yields the negative validation accuracy in regular intervals during training of the model.
"""
# Extract hyperparameter combination to test from input dictionary.
conv_layers = int(params["conv_layers"]) # Number of convolutional layers
activation = str(params["activation"]) # Activation function
lr = float(params["lr"]) # Learning rate
epochs: int = 2 # Number of epochs to train
rank: int = MPI.COMM_WORLD.rank # Get rank of current worker.
num_gpus = torch.cuda.device_count() # Number of GPUs available
if num_gpus == 0:
device = torch.device("cpu")
else:
device_index = rank % num_gpus
device = torch.device(
f"cuda:{device_index}" if torch.cuda.is_available() else "cpu"
)
log.info(f"Rank: {rank}, Using device: {device}")
activations = {
"relu": nn.ReLU,
"sigmoid": nn.Sigmoid,
"tanh": nn.Tanh,
} # Define activation function mapping.
activation = activations[activation] # Get activation function.
loss_fn = (
torch.nn.CrossEntropyLoss()
) # Use cross-entropy loss for multi-class classification.
model = Net(conv_layers, activation, lr, loss_fn).to(
device
) # Set up neural network with specified hyperparameters.
model.best_accuracy = 0.0 # Initialize the model's best validation accuracy.
train_loader, val_loader = get_data_loaders(
batch_size=8
) # Get training and validation dataloaders.
# Configure optimizer.
optimizer = model.configure_optimizers()
for epoch in range(epochs):
model.train()
total_train_loss = 0
# Training loop
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# Zero out gradients.
optimizer.zero_grad()
# Forward + backward pass and optimizer step to update parameters.
loss = model.training_step((data, target))
loss.backward()
optimizer.step()
# Update loss.
total_train_loss += loss.item()
avg_train_loss = total_train_loss / len(train_loader)
log.info(f"Epoch {epoch+1}: Avg Training Loss: {avg_train_loss}")
# Validation loop
model.eval()
total_val_loss = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(val_loader):
data, target = data.to(device), target.to(device)
# Forward pass
loss = model.validation_step((data, target))
# Update loss.
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_loader)
log.info(f"Epoch {epoch+1}: Avg Validation Loss: {avg_val_loss}")
yield avg_val_loss
Now we have all the ingredients to perform a hyperparameter optimization with early stopping in Propulate 🧬:
if __name__ == "__main__":
comm = MPI.COMM_WORLD
if comm.rank == 0: # Download data at the top, then we don't need to later.
MNIST(download=True, root=".", transform=None, train=True)
MNIST(download=True, root=".", transform=None, train=False)
comm.Barrier()
num_generations = 3 # Number of generations
pop_size = 2 * comm.size # Breeding population size
limits = {
"conv_layers": (2, 10),
"activation": ("relu", "sigmoid", "tanh"),
"lr": (0.01, 0.0001),
} # Define search space.
rng = random.Random(
comm.rank
) # Set up separate random number generator for evolutionary optimizer.
set_seeds(42 * comm.rank) # Set seed for torch.
propagator = get_default_propagator( # Get default evolutionary operator.
pop_size=pop_size, # Breeding population size
limits=limits, # Search space
crossover_prob=0.7, # Crossover probability
mutation_prob=0.4, # Mutation probability
random_init_prob=0.1, # Random-initialization probability
rng=rng, # Random number generator for evolutionary optimizer
)
islands = Islands( # Set up island model.
loss_fn=ind_loss, # Loss function to optimize
propagator=propagator, # Evolutionary operator
rng=rng, # Random number generator
generations=num_generations, # Number of generations per worker
num_islands=1, # Number of islands
checkpoint_path=log_path,
surrogate_factory=lambda: surrogate.StaticSurrogate(),
# Alternatively, you can use a dynamic surrogate model here:
# surrogate_factory=lambda: surrogate.DynamicSurrogate(limits),
)
islands.evolve( # Run evolutionary optimization.
top_n=1, # Print top-n best individuals on each island in summary.
logging_interval=1, # Logging interval
debug=2, # Verbosity level
)