
If you want to write a Pytorch model in five minutes, there are four steps to go through:
- Import and preprocess (dataset) data, and batch it (dataloader)
- Build the model using
nn.Module
- Write a training loop and run it
- Validate on the validation set
Because MNIST has been done to death, we’ll cover how to import a torchvision dataset, and write some code in under five minutes. For this reason, it won’t be pretty, but it will work.
Downloading and importing the data
Because MNIST has been done to death, we’re going to search the standard torchvision datasets to see if there’s anything else that we’d like to try and predict. Let’s go with Kuzushiji-MNIST, which is a Hiragana (Japanase) replacement for the MNIST dataset, consisting of 70,000 images. You can read about the motivations behind creating the dataset in the paper Deep Learning for Classical Japanese Literature.
First, we find the mean and standard deviations per channel. The reasoning behind this is that we obviously want to normalise our training data, but Pytorch transforms require the normalisation mean and standard deviation given in advance. So we’ll use the dataset to find these values, and then reimport it and pass a normalise transform with our predefined values.
Note that kmnist
is a dataset, so looping over it will give us an image and a label at each instance. Thus, if we loop over each image in the dataset, and stack them all along an extra fourth dimension, we’ll have a tensor of all the images.
We now compute the mean per channel. Note that calling imgs.view(1,-1)
will squeeze all the tensors into the second dimension, leading us with a trailing first dimension. We thus take the mean of the pixel values across this second dimension (hence dim=1
). We also do the same thing for standard deviation.
We can now reimport our data, using a Normalize
transform as well as a transform taking our arrays to tensors. Note that the Normalize
transform takes the mean and standard deviation of the pixel values as arguments.
Now that we have our datasets, we need to feed these into a DataLoader
to be batched up. If you’re on CPU, be sure to set a smaller batch size, and set num_workers=1
(this is a GPU thing, don’t worry too much about it).
We can view a few samples from our dataset. I’m not going to walk through the code here, it should be fairly self explanatory.

Constructing the model
This is not a tutorial about how to theoretically construct deep learning models. As such, we’re going to present the model here, and only comment on three things. First, you need to instantiate your model as an instance of nn.Module
by passing it to the class. Second, you need to initialise the superclass by the usual Python method (tutorial here). Finally, you need a model initialisation, where we define all our model layers, and then a forward method, where we tell the model how to take an input and pass it through these layers. And that’s it.
At this stage, it’s always important to debug your model by giving it a single example from the dataloader. We then pass this image through the model, and check that it outputs something of the correct size.
Perfect. We’ve constructed a model that takes a K-MNIST image, and outputs 10 classes, representing our 10 different probabilities for each possible number 0 through 9.
Writing and running the training loop
As per usual, our training step is the old mantra. Forward pass. Calculate loss. Reset gradients (Pytorch specialty). Backpropagation to calculate gradients with respect to loss. Update our weights with these gradients. That’s all there is to it (remember to set the model to train
mode).
We then instantiate our model, and set the Adam optimiser, and use cross entropy loss (as this is a multiclass classification problem). If your problems requires you to change these, the optimisers and loss functions are in torch.optim
and torch.nn
respectively.
Then just pass these arguments to the training loop, and let it run.
Validating the model
We want to keep things as simple as possible, so we’re going to get the structure of our validation loop to mirror the training loop. Iterate over the images and labels in the validation dataloader. Do a forward pass, and get the prediction by finding the index in the output tensor where the value is highest (remember, we output a vector of 10 probabilities). Remember to use .data.squeeze()
to get the actual scalar itself. Finally, print out the accuracy by summing over all the times the prediction was equal to the label (using np.sum()
and .item()
to escape the gradients), dividing by the total number of labels.
Validation set accuracy of 95%. Not bad for five minutes of coding.
Conclusion and further steps
There’s a few things you could do for yourself to make your model a lot better:
- Print validation set metrics as the model trains: obviously, it’s nice to see the training loss decreasing with each epoch. But we don’t really have any idea of how the model is performing until we validate it after training. If you print validation accuracy as you go, you’ll have a better idea of the success of the model.
- Implement early stopping: once validation accuracy has failed to improve for a certain number of epochs (referred to as the patience), go back to the best performing epoch and use those weights.
- Look at other metrics: another really powerful metric is area under the curve (AUC), which can be adopted from binary to multiclass classification by using weighted average precision-recall in one-vs-rest over all classes (for an introduction to this see here).
- Implement a residual network architecture: Computer Vision has come a long way since CNNs were introduced. You could try other architectures to improve performance. You might even want to have a go at a computer vision transformer.
And that’s it. Hopefully this article cut through a lot of the unnecessary pomp and circumstance that most introductory Pytorch tutorials include. I found that once I knew the theoretical foundations of Deep Learning, I really wanted a resource to help me with the purely end-to-end process of constructing these models myself. I hope this will serve as such a resource for you.