Transfer learning using the pre-trained deep learning networks from MATLAB can be easily implemented to achieve fast and impressive results
I obtained the image data from Unsplash. I downloaded 42 cat images, 46 dog images, and 35 horse images for the input into the pre-trained Alexnet model in MATLAB. For details about the Alexnet network in MATLAB, see its documentation.

AlexNet is a convolutional neural network that is 8 layers deep. The MATLAB has a pre-trained version of the network trained on more than a million images from the ImageNet database. The pre-trained network can classify images into 1000 predefined object categories.
The training for the 1000 object categories on a million images has made the network learn rich feature representations for a wide range of images.
Transfer learning applied on the Unsplash data using alexnet pre-trained network
You can check the installation of alexnet by typing
at the command line. If it’s not installed then it will prompt a link to the required support package in the Add-On Explorer. Simply, follow the link.
Why Transfer Learning?
Transfer learning has become popular in deep learning applications because of its speed and easy implementation. One can take a pre-trained network and use it as a starting point to learn a new task. This quickly transfers learned features to a new task using a smaller number of training images.
I obtained the three categories of images from the Unsplash website. Unsplash provides millions of high-quality free images. I obtained the images for cats, dogs, and horses. There is no specific reason for using these categories but using cats and dogs has become standard for testing any model as it provides sufficient robustness.
I saved the data into three subfolders folders labeled with cats
, dogs
, and horses
under the folder unsplashData
. I kept the zipped data in the folder containing the script.
Prepare data
The first thing we do is to unzip the data using the unzip
command. Then we automatically label the images based on folder names and store the data as an ImageDatastore
clear; close; clc;
%% Unzip and load the new images as an image datastore
filename = 'unsplashData';
% imageDatastore automatically labels the images based on folder names and stores the data as an ImageDatastore object
imds = imageDatastore(filename, ...
'IncludeSubfolders',true, ...
Divide the data into training and validation data sets
We use 70% of the randomly selected images for training and 30% for validation.
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
Visualize the loaded images
We plot the 16 randomly selected data.
visualize = 1;
if visualize==1
numTrainImages = numel(imdsTrain.Labels);
idx = randperm(numTrainImages,16);
fig1 = figure;
for i = 1:16
I = readimage(imdsTrain,idx(i));

Load Pretrained Network
net = alexnet;
inputSize = net.Layers(1).InputSize
To analyze the layers of the alexnet network.
if inspect_network==1

Replace the final three layers
We extract all layers, except the last three, from the pre-trained network. One can increase the WeightLearnRateFactor and BiasLearnRateFactor values of the fully connected layer to learn faster in the new layers than in the transferred layers
layersTransfer = net.Layers(1:end-3);
numClasses = numel(categories(imdsTrain.Labels)) %the number of classes in the new data
layers = [
Use an augmented image datastore to automatically resize the training images
The alexnet network has been designed to work on the images of fixed dimensions (227-by-227-by-3). We follow the standard operations to augment the training images – randomly flip the training images along the vertical axis, randomly translate them up to 30 pixels horizontally and vertically. Data augmentation is necessary to prevent the network from overfitting and memorizing the exact details of the training images.
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ... %randomly flip the training images along the vertical axis
'RandXTranslation',pixelRange, ... %randomly translate them up to 30 pixels horizontally and vertically
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
% automatically resize the validation images without performing further data augmentation
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Fine-tune the training options
The goal of applying transfer learning is to keep the features from the early layers of the pre-trained network (the transferred layer weights).
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ... %when performing transfer learning, you do not need to train for as many epochs
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ... %slow down learning in the transferred layers ( fast learning only in the new layers and slower learning in the other layers)
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress', ...
'ExecutionEnvironment','auto'); %Hardware resource for training network - auto, cpu, gpu, multi-gpu, parallel
Train the network
We train the network consisting of the transferred and new layers.
netTransfer = trainNetwork(augimdsTrain,layers,options); %By default, trainNetwork uses a GPU if one is available
Classify Validation Images
[YPred,scores] = classify(netTransfer,augimdsValidation); %classify using the fine-tuned network
We display four sample validation images with their predicted labels.
classify_visualize = 1;
if classify_visualize==1
idx = randperm(numel(imdsValidation.Files),4);
fig = figure;
for i = 1:4
I = readimage(imdsValidation,idx(i));
label = YPred(idx(i));

Classification accuracy
Classification accuracy gives the fraction of labels that the network predicts correctly.
YValidation = imdsValidation.Labels;
accuracy = mean(YPred == YValidation)
% accuracy = 0.9189
I used the pre-trained Alexnet network from Matlab to fine-tune it with the Unsplash data. However, it has the potential for any classification problems with the image data. We should also explore several other pre-trained networks provided by MATLAB such as squeezenet, resnet18, googlenet, etc to achieve better accuracy. The accuracy depends highly on the data quantity, quality, and model parameters such as the number of layers.
