Saving Multiple Images in Tensorboard with tf.summary.image

How to store multiple images within Tensorboard to see the evolution of your Neural Network. Case application to the GAN model

Gabriel Furnieles
Towards Data Science

--

Photo by Sanni Sahil on Unsplash. Edited by author.

Recently I’ve been developing an image-generating model and have had the need to save several images in Tensorboard to see how the model evolved in each epoch.

After a long search and several documentation files, I finally came up with the result that you can see in the image:

Screenshot taken from Tensorboard project. Top image is the input given to the Neural Network, middle image is the output of the network (the translated image) and bottom is the ground truth (what the model was supposed to output). The progress bar that is on top of the images represents the epoch when the image was generated. Image by the author

In the example shown above the model attempts to “translate” a given image into an art painting by Monet, using the Pix2Pix model [link to paper]. You can find the dataset you’re looking at in the Kaggle repository [link]

Note1. The number of epochs with which the model has been trained is very low since I just wanted to show the dashboard, so the results look really poor.

Why use Tensorboard

Tensorboard is a visualization tool that provides the framework to analyze the performance of a neural network. It can be integrated using the Tensorflow and Pytorch libraries, generating an interface on a local port of your machine where to project metrics such as the evolution of the loss function or to track customized ones, such as the generated images as in this case.

These visualizations are essential for the development stage of a neural network model, allowing us to detect key problems such as overfitting/underfitting, measure and compare multiple models, or provide insights on what changes might improve the overall performance (hyperparameter tunning)
For instance, analyzing the generated images on each epoch can help us to detect weaknesses in the generator such as the generation of grainy patterns or stagnation by mode collapse, but those are topics for another article.

For more information about Tensorboard, you can check the documentation here.

Code

I don’t want to waste your time so here is the code to save the images:

keras.Callbacks subclass. Code by the author

Next, to launch a training using the custom callback:

Training Pix2Pix model using custom Callback. Code by the author

Code explanation

For those who just don’t copy and paste the code, let me explain what happens when you run the above code.

1st We need to create a keras-callbacks subclass where we can define our custom callback.
Callbacks are functions that are executed at a specified frequency during training [see doc]. To tell Tensorflow to use Callbacks when training we just have to pass them as arguments in the form of a list object (see step 5th)

Screenshot taken from the code. Image by the author

In our case, the callback class receives as arguments a batch of images from both training and validation sets, the model’s generator (so we can generate images), and the path (logdir) where to store the images to display after in Tensorboard.

We also define the instance variable self.writer, which creates a summary file writer for the given log directory (where we are going to store the information for later display on Tensorboard)

2nd We define de class method on_epoch_end which, as its name indicates, will be executed after every epoch.

Note2. The arguments that receive this method are prefixed by Tensorflow, so don’t try to change them.

Within the on_epoch_end method, we must also define a function to generate the images (this way the code looks cleaner and better organized).

Screenshot taken from the code. Image by the author

The function generate_imgs takes a set of images (a TakeDataset element from Tensorflow.data) and the generator (g), and returns a list of 3-display images concatenating vertically the input image x , the translated image by the model out , and the ground truth y. If we don’t concatenate the images these will be displayed on different cards on Tensorboard.

Note3. Before concatenating we have to remove the batch dimensions using the function tf.squeeze() to prevent exceptions.

3rd Next, we save the images using tf.summary.image()

Screenshot taken from the code. Image by the author

The first line self.writer.as_default() tells Tensorflow to store the next operations in the same graph (the self.writer graph) so all the images generated after each epoch by the callback will be recorded in the same file [check doc, link]
Next, the tf.name_scope() function adds to the name of every image the prefix “train/” or “val/”, so the train and validation generated images are saved in different folders in your working directory (within Tensorboard this is reflected as different sections, but in reality both files are named the same and belong to the same summary file)

Note4. In this case, I’ve defined the name of the images as the scope again, so they are going to be named “train/train/” or “val/val”, but for further projects, I recommend changing it.

4th Initialize the class.
Back to our train.py file where we’ve created our model, we initialize the class by calling it (line 9). For the project, I decided to take 4 images from the training set and other 4 images from the validation, each image from a different batch. Next, I specify the generator (pix2pix is a keras.model so I can call the generator as a method) and the logdir where to save the summary.

Screenshot taken from the code. Image by the author

Note5. Someone might think that if we are passing the generator as an argument, then during training their weights are not going to be updated and the results wouldn’t be correct. However, Tensorflow manages to do exactly that, since we are passing a class instance pix2pix.g that points to the created class pix2pix i.e. when pix2pix updates its weights, it is applied to all its instances.
I’ve personally checked this (and you can do it too if you don’t believe me) by adding the line within the custom callback:

print(self.g.get_weights()[0][0][0][0])

The result is a print of the first neuron weights every epoch, so you can notice the modifications.

5th Train the model using the custom callback

Screenshot taken from the code. Image by the author

Conclusions

Being able to visualize the evolution of your model during the training phase is crucial for good development and provides key insights on what direction to take to increase the accuracy and performance of the model.

However, it’s also important to take into account the costs of evaluating our custom metrics, since time is an important fact when it comes to training models. In this project example, unpacking the images and testing the generator can take a few extra minutes depending on your hardware and the image size, considerably slowing the training.

A good solution to this problem is establishing a frequency variable that will be compared with the epoch number within the custom callback class so that if epoch % frequency == 0 then we test the generator and save the results.

I hope this post has been useful and helpful, for more articles on Artificial Intelligence and Data Science you can check my other posts.

--

--