The world’s leading publication for data science, AI, and ML professionals.

pix2pix GAN in TensorFlow 2.0

(Find the code to follow this post here.)

(Find the code to follow this post _here_.)

The original code can be found here on the Google co-lab, but I took this as a good opportunity to dig deeper into TensorFlow 2.0, and the updated scope of the high level API for fast prototyping.

For detailed discussion on the newer features, and how to subclass the Model module, look here.

In this article, I will assume a basic understanding of methods used to train neural networks (NNs), and how Generative Adversarial Networks (GANs) work in practice. For a more in-depth introduction to conditional-GANS have a look at my previous article [here](, or for super resolution GANS here.

There are three topics that need addressing here: firstly a general overview of what image to image translation is and how Pix2Pix fits into that landscape; the maths behind how the loss function is defined, optimised, and how that feeds into decisions made about the network architecture; and the results from training on two datasets – the segmented facades dataset, and the satellite / google maps dataset where I’ll also briefly discuss the code and training.

Image processing is hard. The data is complicated, highly structured, and as humans we intuitively know what a good image looks like, even if we can’t easily quantify what that means.Yet at the same time, we’re continuously pushing the boundaries of what we can automatically process with Machine Learning algorithms, from creating super-resolution images many times higher definition, to dynamic image processing to get performance from phone cameras that the hardware just couldn’t do alone. But generally all these problems are solved individually; they’re impressive, but we’re a long way from a general solution to image processing using machine learning.

Generally image to image translation has needed specific algorithms and structure for each case, despite the fact that fundamentally all these cases are basically the same. They all involve taking one array of pixels and updating the values to map onto another space. Rather than needing hand holding through the process, why the lack of generalisability?

The team at Berkley set out to design a general solution to the image to image translation problem, one where the underlying method could be used on almost arbitrary data to perform generic translation across a wide range of image domains with no change to the underlying algorithm. They published in the 2018 paper here, along with the pix2pix software.

But first, what are we talking about with image to image translation? Image processing is generally where we take an image, apply some processing, and output a new image. It’s trivial to say, but that second image is an updated version of the first, a deficit in vanilla GAN methods, or even basic conditional GANS where the input by definition is a random noise vector.

The idea is by learning a mapping between one set of images and another, the network can apply the same mapping onto new data. This fits more generally in the context of a style transfer. I’m going to make a fairly sweeping statement that images can be loosely categorised into two sections. Firstly, there’s the content of the image. The content means what’s in the image – the objects, the landscape, the spatial relation of all parts in the image. The other part can be thought of as the style of the image.The style includes things like the colour, the shadow, and the way each object is expressed.

We can liken this to the difference between a photograph and an artists impression, or even two artistic representations of the same scene. The underlying information is the same, the content, while their styles may be wildly different. The trick is to find a way to translate from one to the other.

The hint is in the terminology: rather than focusing on images, think about language, in the same way a translator learns a mapping between English and French. Given a sentence in one language they can translate it into the other, while retaining as much of the original information as possible (content), and still finding a natural way to express the sentiment in each language (style). That’s the aim with pix2pix.

The interplay of loss function and network architecture is subtle, and as we will see there are often multiple ways to solve a problem, but first we need to define a conditional loss for a generative adversarial network. The key principle here is the same as any GAN, the generative adversarial min max game between the two networks to optimise the loss function:

where G is the generator and D the discriminator, x the conditional input image, y the true image and z a random noise vector.

_(A more detailed discussion here, but the basic idea is G outputs an image such that D(G) is maximised, while D is optimised to output 1 for true images. Why this ever converges, and to what extent we’re barking up the wrong tree, is another interesting discussion here.)_

This already has one key difference to the vanilla GAN case: it’s conditional. Conditional here means that rather than receiving just a random noise vector, the generator takes in additional information. This could simply be class information as in the cDCGAN, or as in this case where the conditional information is the original image. Where the vannilla GAN depends as G:z -> y, the conditional GAN goes as G:{x, z} -> y.

Interestingly, this isn’t actually the full picture. When the network trains, it generally learns to ignore the random noise vector, so to keep the network non-deterministic dropout was used to reintroduce the stochastic behaviour.

In many GANs, an L2 regularisation loss is added to the optimisation on the basis that by minimising the Euclidean distance between the target and generated image (MSE), the generator would learn structure and colour of the image. However, it was found that this generally leads to blurred images, so to combat this the L1 regularisation loss is added with some pre-factor weight instead, as:

Which gives the total loss function to optimise as:

The next key question is about the structure of the networks.

In this analysis, the generator is based on a U-Net structure, a variation on an auto encoder, while the discriminator is called patch-GAN.

Generator Architecture
Generator Architecture

The point of using a U-Net structure is that the network forces all information to be passed through a tight bottle neck in the middle of the network. This forces a latent representation of the input image to be found that can be reconstructed to the original, the idea being that a finite amount of information can pass through – the network is forced to learn an optimal reduced mapping, and cannot simply memorise the training set.

This has one obvious limitation; there is a significant proportion of the output and input image that ought to share the same description. The content of the image is supposed to remain unchanged. The variation in U-Net over an auto-encoder is the addition of a skip connection between each symmetric layer in the U-net structure, as can be seen in figure above. These concatenated layers have the effect of passing higher level information directly at the appropriate scale in the network, and reduce the amount of information needed to pass through the latent bottleneck layer. The idea here is that the information passed through focuses on finer details rather than large scale structure.

The discriminator is more atypical and might need more context. In general terms, the L1 and L2 regularisation is a weak constraint on the network that doesn’t produce sharp details as there are many paths to get a small L value. However, this doesn’t write off this part of the loss function, as it encourages generating the high level structure, which is exploited in the choice of discriminator.

Introducing crisp details to the generated image can be done via a number of paths:

  1. Tuning the weight of the lambda pre-factor on the L1/L2 loss – as discussed above, this caps out with relatively blurred images, generally correct but without sharp details.
  2. Adding an additional loss that quantifies the performance of the output image – a number of methods using a pre-trained network to assess the quality of the output images have been tried. Specifically in the case of SRGAN, the distance between the latent space of the VGG network on the target and output images is minimised.
  3. Update the discriminator to promote the crisp details – patch-GAN – seems almost too obvious, doesn’t it?
Patch-GAN Discriminator [here]
Patch-GAN Discriminator [here]

The way patch GAN works is that rather than classifying the entire image, only a portion (NxN) of the image is classified as real or fake. The ‘patch’ scans over the image, makes a prediction for each, and the mean is optimised.

This assumes the image can be treated as a Markov random field, where pixels separated by N or more pixels are independent. This wouldn’t be true for the entire image when considering high level structure, but is a fair assumption when looking at low level detail.

The advantage here is that fine detail is preserved with the patch-GAN; to pass the discriminator test realistic low level details are encouraged, and the general structure is preserved using the regularisation loss. Another vital advantage of this patch method over a whole image classifier is that it’s fast, as small NxN patches are faster to evaluate and train than one single discriminator. It also scales efficiently to arbitrary sized images.

There are two technical points on the training of the network. Firstly, rather than minimising (1 – log D), the training maximises log(D). Secondly, the Discriminator objective function is halved during training to restrict the rate at which the discriminator trains compared to the generator.

Finally, we’ll look at some results from the experiments I’ve carried out replicating the results from the paper. As mentioned above, my code to reproduce the results can be found here, written using tensorflow 2.0 high level API. To reproduce the facades results took 200 epochs trained on an Amazon AWS EC2 unit, using a single GPU compute node – each epoch took ~ 4 minutes. During the training, snapshot images were taken every 10 epochs and the loss for both the generator and discriminator were recorded throughout.

As you can see in these results, the structure of the building has been found, and the style is consistent from each mapped segment.

Some detailing has obviously been lost in the segmentation, particularly around the edges of the facades and fine structure within each feature, and the architraves around the windows and column capitals, as would be expected.

Clearly the fine details are not recovered as well as in the paper. Additionally, the fuzziness around the edges of the building is not very visually appealing – this is the area showing the most variation in the training gifs, and would stabilise with more epochs.

We might expect slightly better results with a longer training duration, as clearly the loss function hasn’t fully saturated.

The maps data has a slightly different issue. While the general road layout and features are recovered fairly consistently in these examples, the generated results seem less stable than the facade outputs, and the entire image more regularly has an error. The images also come out fairly dark, perhaps a function of the lower contrast between road against buildings or grass than between elements on a facade.

The take away is that while the pix2pix method clearly has impressive generation power and generalises well, it’s still relatively expensive to train, and as far as recovering the finest details consistently, more work is needed over the coming years.

Related Articles