Transfer Learning and Twin Network for Image Classification using Flux.jl

Solving challenges with DataLoaders, Metalhead.jl, and Twin Network design

Garrett Kinman
Towards Data Science

--

Photo by Luca Bravo on Unsplash

Earlier this year, I was working on a project in PyTorch to create a deep learning model that could detect disease in unseen species. Recently, I decided to rebuild the project in Julia, and use it as an exercise in learning Flux.jl [1], Julia’s most popular deep learning package (at least as rated by number of stars on GitHub). But in doing so, I encountered several challenges, challenges which I could not find good examples for online or in the documentation. Hence, I decided to write up this post as a resource for any others wanting to do similar things in Flux as me.

Who is this for?

Because Flux.jl (henceforth referred to as “Flux”) is a deep learning package, I wrote this primarily for an audience familiar with deep learning concepts such as transfer learning.

And while I also wrote this with a semi-newbie to Flux (like myself) in mind, others might find this valuable. Just know that I did not write this as a comprehensive introduction or tutorial to Julia or Flux. For that, I will defer to other resources such as the official Julia and Flux docs, respectively.

Lastly, I make several comparisons to PyTorch. Experience with PyTorch is not required to understand the points I make, but those with PyTorch experience may find it particularly interesting.

Why Julia? And why Flux.jl?

If you already use Julia and/or Flux, you can probably skip this section. And further, many others have written plenty of posts about this very question, so I’ll be brief.

Ultimately it’s that I like Julia. It’s great at numerical computing, it’s a genuine joy to program in, and it’s fast. Natively fast: no need for NumPy or other wrappers around underlying C++ code.

As for why Flux, it’s because it’s the most popular deep learning framework in Julia, written in pure Julia, and composable with the Julia ecosystem. And this project seemed as good as any to finally learn Flux.

The project itself

Okay, so now that I’m done shamelessly proselytizing Julia, time for information about the project itself. I used three datasets — PlantVillage [2], PlantLeaves [3], and PlantaeK [4] — covering a number of different species. I used PlantVillage for the training set, and the other two combined for the test set. This meant that the models would have to learn something generalizable to unseen species, as the test set would comprise species not trained on.

Knowing this, I created three models:

  1. A baseline using Transfer Learning from ResNet
  2. A Twin (aka, Siamese) neural network with custom CNN architecture
  3. A Twin neural network with Transfer Learning twin paths

Most of the rest of this post will be detailing some of the challenges and pain points of dealing with the data and creating and training the models.

Dealing with data

The first challenge was the datasets were in the wrong form. I won’t go into detail about how I preprocessed them here, but the short of it is I created two directories of images, train and test. Both were filled with a long list of images, named img0.jpg, img1.jpg, img2.jpg, and so on. I also created two CSVs — one for the training set and one for the testing set — containing one column with filenames, and one column with the binary label.

The above structure is key, because the total dataset is over 10 GB and certainly won’t fit in my PC’s memory, let alone my GPU’s memory. Because of this, we need to use a DataLoader. (If you’ve ever used PyTorch, you’ll be familiar; it’s basically the same concept here as in PyTorch.)

To do this in Flux, we need to create a custom struct to wrap around our dataset, to allow it to load in batches of data. And all we need to do for our custom struct to be able to construct a dataloader with it is to define two methods for the type: length and getindex. Below is the implementation we’ll use for our datasets:

Essentially, how it works is when Flux attempts to retrieve a batch of images, it will call getindex(dataloader, i:i+batchsize), which is equivalent in Julia to dataloader[i:i+batchsize]. Thus, our custom getindex function gets the list of filenames, gets the appropriate filenames, loads in those images, then processes and reshapes them into the proper HEIGHT × WIDTH × COLOR × NUMBER shape. Similar is done for the labels.

Then, our training, validation, and testing dataloaders can be made pretty easily:

Making some models

With the dataloaders ready to go, the next step is to create the models. And first amongst these is the ResNet-based model for Transfer Learning. This actually proved relatively challenging to get working, although hopefully it will be fixed soon.

In the Metalhead.jl package — which contains computer vision Flux models for Transfer Learning — it should be as simple as model = ResNet(18; pretrain = true) to create a ResNet18 model with pretrained weights. However, at least at the time of writing this, creating a pre-trained model results in an error. This is likely because Metalhead.jl is still in the process of adding the pretrained weights. I finally found the .tar.gz file on HuggingFace containing the weights here. After extracting the tarball to get a plain .bson file, we can use the following code to load the weights and create our own custom Flux model with a single binary output:

(Note: if there is a more elegant way than this to change the last layer of ResNet, please let me know.)

With the pretrained Transfer Learning model created, this just leaves the two Twin network models. Unlike the Transfer Learning, however, we do have to learn how to manually create a model. (If you’re used to PyTorch, this is where Flux diverges a lot from PyTorch.)

Creating a CNN is relatively easy using the Flux docs and other online resources. However, Flux does not have a built-in layer to represent the Twin network with parameter sharing. The closest it has is the Parallel layer, which does not use parameter sharing.

Flux does, however, have documentation here on how to create custom multiple input or output layers. In our case, the code we can use to create a custom Twin layer is as follows:

First note that it begins with a simple struct, Twin, with two fields, combine and path. path is the network our two image inputs will go through, and combine is the function that will combine the outputs from path together at the end.

Then, using Flux.@functor tells Flux to treat our struct like a regular Flux layer, and (m::Twin)(Xs::Tuple) = m.combine(map(X -> m.path(X), Xs)…) defines the forward pass, where all inputs, X, in the tuple Xs, are fed through path and then all the outputs are put through combine.

To create the Twin network with the custom CNN architecture, we can then do the following:

In this case, we actually use Flux.Bilinear layer as combine, which essentially creates an output layer that is fully connected to two separate inputs. Above, the two inputs are the outputs of path, i.e., the custom CNN architecture. Alternatively, we could use hcat or vcat in some fashion as combine, and then add a Dense layer at the end, but this solution seems more elegant for this problem.

Now, to create the Twin network with ResNet, we can do the following:

Note how we used the same trick as before and used a Flux.Bilinear layer as combine, and used a trick also like earlier to use pretrained ResNet as path.

Training time

Now having our dataloaders and models ready, all that’s left is to train. Ordinarily in Flux, one could use a simple one-liner, @epochs 2 Flux.train!(loss, ps, dataset, opt), for the training loop, but we do have some custom things we want to do with ours.

First, the training loop for the non-Twin network:

It’s a lot to unpack here, but essentially this does a few things:

  1. It creates a helper struct for keeping track of whatever validation metrics we want. In this case, loss and accuracy for each epoch.
  2. It selects only the last layers of parameters to train. If we wanted to, we could train the whole model, but that would be computationally more taxing. And it’s not necessary, because we’re using pretrained weights.
  3. For each epoch, it iterates through all the batches of the training set to train. Then, it calculates accuracy and loss across the whole validation set (batched, of course). If the validation accuracy for the epoch is improved, it saves out the model. If not, it continues on to the next epoch.

Note that we could do more here, e.g., early stopping, but the above is enough to get the general idea.

Next, the training loop for the Twin networks is very similar, but slightly different:

First note that we use a function of the same name, train!, but with a slightly different function signature. This allows Julia to dispatch the correct function depending on which type of network we’re training.

Also note that the Twin ResNet model freezes its pretrained parameters, whereas we train all the Twin custom CNN parameters.

Other than that, the rest of the training loop is largely the same, with the exception that we have to use two training dataloaders and two validation dataloaders. These give us two inputs and two sets of labels for each batch, which we feed into the Twin models appropriately. Finally, note that the Twin model is predicting whether the two input images have the same label, whereas the regular non-Twin network merely predicted the label directly.

And with that, it shouldn’t be too hard to build the testing loop for the test set for all three models. Because the purpose of this post was to go through the major pain points for which I couldn’t easily find examples online, I will leave the testing part as an exercise for the reader.

Final thoughts

There is an element of truth to what they say about the current state of the Julia ecosystem, that it can be immature and that one needs to be willing to deal with limited documentation and/or scarce examples.

The biggest challenge was bridging the gap from relatively easy, simple examples to more advanced techniques for which examples are scarce. But this also reveals a strength in Julia: because it is natively fast, it is often remarkably easy to search the source code of a package to find your answers. Several times I found myself looking through the Flux source code to find how something worked. And every time I found my answer remarkably easily and quickly. I’m not sure I’m brave enough to try something similar for PyTorch.

Another challenge was the immature state of Metalhead.jl, which is certainly not unique within the Julia ecosystem for being feature-incomplete.

My last challenge was actually a mindset challenge. One fundamental difference between the Python and Julia ecosystems is the scale of many of the important packages. Where, for instance, PyTorch needs to do everything under the sun because of its internal C++, pure Julia packages can interoperate and compose easily and painlessly. Thus, you won’t see Flux having nearly so many features as PyTorch. Rather, you’ll see much of that functionality in other packages, like MLUtils.jl, CUDA.jl, BSON.jl, Metalhead.jl, and many countless others. Remember this, and you may save yourself some time and effort whenever programming in Julia.

And as one last thought, I found Flux to be rather enjoyable and elegant… once I got the hang of it. I’ll definitely be doing more deep learning with Flux in the future.

References

[1] M. Innes, Flux: Elegant Machine Learning with Julia (2018), Journal of Open Source Software

[2] Arun Pandian J. and G. Gopal, Data for: Identification of Plant Leaf Diseases Using a 9-layer Deep Convolutional Neural Network (2019), Mendeley Data

[3] S. S. Chouhan, A. Kaul, and U. P. Singh, A Database of Leaf Images: Practice towards Plant Conservation with Plant Pathology (2019), Mendely Data

[4] V. P. Kour and S. Arora, PlantaeK: A leaf database of native plants of Jammu and Kashmir (2019), Mendeley Data

--

--

Embedded ML research engineer | MSc electric engineering | BEng computer engineering | Experimenting with new tools, languages, and ideas