We are going to split the training manager’s logic into several stages by following the classic hierarchy of a training setup:
- The most high-level stage is the training run, which covers the entire training for one set of hyper-parameters (number of epochs, model-config, learning rate, optimizer, etc.)
- Next, we have the epoch, this stage represents one pass of the entire dataset through our model
- For each epoch, we have two steps: the training step and the validation step. In the training step, we perform the actual learning with the forward and backward passes and in the validation, we compute the chosen metrics for the hold-out samples.
As a bonus, I will also show you a method to generate each run parameter set from a list of available values (hint: Cartesian product).
Training Process
First, let’s define a training process that already uses the Manager so that we’ll have an example of how it should be used, it will help later on when we’ll define the actual functionality.
RunManager
Next, let’s think about what information we want our Manager to collect and log so that we have an idea about how the RunManager would look like:
- We want to know what is the current epoch and run
- The average loss per epoch
- The accuracy per epoch (we can use other metrics: F1, precision, recall, MSE, MAE, etc.)
- How long the epoch and run took
- It would also be nice to log everything to tensorboard
For the loss and accuracy, we want to separate them into training and validation, so that we can have a better understanding of the model’s real performance.
Run and Epoch Begining
As I mentioned at the beginning of this article we have two main loops, the outer one that goes over all the combination of hyperparameters by creating sets that define a specific run (we will immediately see how to generate them) and the inner one which iterates over the number of epochs (also a potential hyperparameter), the portion where the learning takes place.
As such we need to signal the RunManager when a new run and a new epoch begins in order to group the data accordingly and reset the state-based variables.
Bonus
During the training process of a machine learning model, some parameters are not differentiable and define the learning process at a higher level. We call them, hyperparameters.
Usually, there is no way to determine the best value for them so we just have to try them all (disclaimer: there are methods to improve the search but this goes outside the scope of this article).
In short, our task is to generate all possible combinations from the values we propose for the hyperparameters, these combinations define each run. For this, we are going to use the product function from the built-in itertools package of python:
from itertools import product
h_params = {
"lr": [0.001, 0.0001],
"num_epochs": [10, 50]
}
runs = []
for run in product(*h_params.values()):
runs.append(run)
Epoch and Run Ending
Next, we are going to cover several stages in one step, all are related to each other and it will make more sense.
Previously we have seen the methods for signaling the beginning of a run and of an epoch. In every epoch, we have to iterate over the entire dataset, for various reasons related to both computational and convergence optimizations we do so in batches (mini-batch stochastic gradient descent). As such we have to log some information related to this step to our RunManager, which will help us to compute the overall loss and the accuracy for the entire epoch.
Once the epoch finishes we have to compute the aforementioned metrics, save them to the local collection and log them to Tensorboard. We will have two sub-steps, for training and validation:
Finally
Once everything finishes, you will have nice tensorboard plots to analyze the model’s performance and a logged history of every epoch that you can either wrap into a pandas dataframe or save to disk as CSV or JSON.
Conclusion
The training process presented here is pretty generic, however, it shows that by using a RunManager type module you can have a cleaner and easier-to-understand code.
As I mentioned at the beginning of the article you should check also check Pytorch Lightning.
Thank you for reading, I hope you will find this article helpful and if you want to stay up to date with the latest Programming and machine learning news and some good quality memes :), you can follow me on Twitter here or connect on LinkedIn here.