The aim of this article is to illustrate the structure and training method for Generative Adversarial Networks (GANs), highlighting the key ideas behind GANs and elucidating the topic with a working example, using Tensorflow and Keras to train a GAN on the MNIST dataset to produce handwritten digits. We will then explore methods of adding levels of control over the output of the generator.
When I began recently to look for tutorials on training GANs I found many overly complicated scripts that offered little intuition regarding what was going on behind the scenes during training. It is my hope that the reader will have a clear understanding of why the code provided here works, and appreciate the simplicity of it. Familiarity with the Keras functional API as well as some general knowledge of Deep Learning will be useful.
GAN Overview
First, let us discuss the what we would like to accomplish in this tutorial. A GAN comprises of two separate neural networks, a Generator and a Discriminator. Both of these networks could have any structure you care to imagine, it depends on the task at hand.
The objective of the generator is to produce objects that look as though they belong to any given dataset given some random input, canonically 100 numbers drawn from the standard normal distribution. For example, we will be considering the MNIST dataset in this tutorial. This dataset comprises of 28 x 28 arrays (that makes it a 784-dimensional dataset), and each array represents an image of a handwritten digit, 0–9. Each class of digit can be thought of as a manifold in 784-dimensional space, with an image lying on each manifold if and only if it is recognisable to a human as the corresponding digit. A good generator must therefore be good at mapping random inputs onto these manifolds, so that it will only generate images that look as if they belong to the true dataset.
The second network, the discriminator, has the opposite objective. It must learn to discriminate between real examples from the dataset and the ‘fakes’ created by the Generator. The combined structure is as follows:

The ‘Adversarial’ part of the name refers to the method of training GANs. During training they compete (as adversaries), each trying to beat the other. The networks are trained alternately on each batch of data. First the discriminator will be trained to classify real data as ‘real’, and the fake images created by the generator as ‘fake’. Next, the generator will be trained to produce output that the discriminator classifies as ‘real’. This process is repeated for each batch in the data. It really is a simple idea, but it is a powerful one.
It is interesting to note that the generator will never see any real data – it will simply learn how to fool the discriminator by learning from the gradients propagated through the discriminator via the backpropagation algorithm. For this reason GANs are particularly susceptible to the vanishing gradients problem. After all, if the gradients vanish before reaching the generator there is no way for it to learn! This is particularly important to consider when using very deep GANs, but it should not be a worry for us here as the networks we use are relatively small.
Another common issue is that of ‘Mode Collapse’. In a way, this is where the generator starts thinking outside the box in order to be lazy. The generator can simply learn to generate the exact same output regardless of the input! If this output is convincing to the discriminator, then the generator has completed its task despite being totally useless to us.
We will consider overcoming the issue of mode collapse in this tutorial, but first we shall consider a very simple model in order to focus on understanding the training algorithm. Without further ado, let’s get stuck in.
Building the Models with Keras
There are a few parameters that are used throughout the code, they are presented here for clarity.
We will be using the MNIST dataset provided by the tensorflow_datasets module. The data must be loaded, and the pixel values scaled to range between -1 and 1, the range of the tanh activation function which will be used in the last layer of the generator. The data is then shuffled and batched according to the BATCH_SIZE and BUFFER_SIZE parameters. We use a buffer size of 60,000 (length of the dataset) so that the data is fully shuffled, and a Batch size of 128. The prefetch(1) call means that while one batch is being used to train the network, the next batch is being loaded into memory, which can help to prevent bottle-necking. This may not be a problem for MNIST as each image has relatively little data, but can make a difference for high resolution images.
The simplicity of MNIST will allow us to get to grips with GANs with a simple dense structure. We will use a generator and discriminator with 3 dense hidden layers. Since we need to generate 28 x 28 images, the final layer will have 784 units, which can then be reshaped into the desired format. The other parameters can be played with. We will also use the standard tanh activation in the last layer, with the other layers having ReLu activation for simplicity.
The discriminator is the mirror image of the generator. This makes sense intuitively as it is trying to undo what the generator has done. It will predict 1 for ‘Real’, and 0 for ‘Fake’, so the sigmoid activation is used in the final layer. We use ReLu in the other layers again.
Now these two networks can be combined into a GAN as if they are just layers. It’s that simple!
Training the GAN with Keras
As outlined previously, GANs can’t be trained using the model.fit() method that is used for simpler deep learning models in Keras. This is because we have two different networks that must be trained concurrently, but with opposite objectives. Therefore we must create our own training loop to iterate over the batched data and train each model separately. There are a few subtleties in the following code, so I would encourage the reader to take in the comments carefully, but the key points are these:
- (10–12) The discriminator and GAN must be compiled separately, and we make the discriminator untrainable when we compile the GAN. This allows us to train only the generator by calling gan.train_on_batch(), and only the discriminator when we call discriminator.train_on_batch(). This is a handy feature of Keras that is usually overlooked, leading to a less elegant code.
- (25–35) We train the discriminator to map ‘real’ images to 1, and ‘fake’ images to 0. Softening these values, that is using random numbers close to 1 and 0 is a standard trick that helps GANs learn.
-
(40–47) We train the generator by training the GAN (with the discriminator’s weights untrainable) but with the labels reversed, that is mapping the ‘fake’ images to 1. In this way, we ask the generator to learn how to trick the discriminator.
We observe the following convergence through 100 epochs of training. Although we see different digits appearing throughout training, it seems to have mostly settled on drawing 1s which could be an example of mode collapse. This is where the generator learns to generate just one class from the dataset. If left to train longer, the discriminator would likely learn to classify that class as fake, at which point the generator would learn to generate another class, with this cycle continuing. Since we made no effort to prevent such behaviour, I believe we can count this result as a success, we clearly have a working training method for this class of neural network.

Improving the Performance of our Model – Adding Functionality
While this first attempt has been successful to a certain degree, we can certainly do much better. After all, we have only used half of the dataset! We have neglected to make use of the labels that are available to us. With just a few tweaks to the above code, we can create a GAN which uses the labelled data, so that the generator will produce a digit from a given class chosen by us. As fun as generating random digits is, having control over the generated images would add a beautiful layer of sophistication to out model. This will also help the GAN to learn faster, and help to prevent mode collapse. The new structure of the model will be as follows:

The generator must take random noise as input as before, but also a randomly selected class that it should generate. This means that, once trained, the generator should produce a digit from any class we choose. The discriminator will also take two inputs, the first being images, the second being corresponding class labels. This will allow the discriminator to make its decision of whether the image is real or fake not just based on how convincing the image is as a digit, but based on whether the digit belongs to the given class, which will force the generator to draw images from the correct class in order to perform well. We need only make some very simple tweaks to our code.
First, we will need a variable telling us how many features to expect for our data. Since we will be one-hot encoding the labels, each label will be a vector of length 10. We must also create the labelled dataset.
The following is where we see the advantage of using the Keras functional API compared to the sequential API. It is a triviality to add multiple inputs and concatenate them in this framework as follows:
The discriminator is tweaked similarly, with another hidden layer added in after inputting the class labels into the model. The extra layer is necessary as if the concatenated layer connected directly to the output neuron the discriminator would lack the flexibility to process the information given by the class label.
Creating the GAN follows intuitively as before.
Finally, training is precisely the same idea as before, but with the additional inputs generated and added in. Because of the way we generated the dataset, the code looks much the same.
This model was then trained for 100 epochs as before, with the progress shown here. The generator should be producing a full sequence of digits, 0–9. It seems we have yet more success! With the exception of the digit 9, we have a convincing sequence of digits by the end of the training.

Exploring Feature Mapping
We are also able to observe a fascinating characteristic of GANs. They actually learn to encode certain random inputs as features of the data they generate. This is best seen through an example:

The images in each column are generated from the same random input, but with different class labels. It is clear that encoded within the random input are features of the resulting digits. For example, in the fifth column all strokes are quite thin, whereas all digits in column 8 have thick line strokes. Similarly all digits in column 2 appear to slant to the right, while those in column 3 slant slightly to the left. While these encodings would not be clear to any human trying to interpret them, it is possible to control the style of the generated digits to a certain extent by studying the behaviour of the generator on a sample of random inputs.
Averaging Features
In order to demonstrate this, I have inspected the output of the generator on a large number of inputs, and sorted the inputs based on the type of output the generator produced. The categories I considered were ‘thick stroke, straight’, ‘thick stroke, slanted right’, ‘thin stroke, straight’ and ‘thin stroke, slanted right’. Each category was then averaged, so that I had four input vectors, one for each category. When these inputs are passed through the generator, they produce new images from that category, as shown below. Here I’m showing 1s, 4s, and 8s as these digits show the most exaggerated slanting.

Scaling Features
We can actually do even better than that! If we multiply these inputs by a number larger than 1, we can exaggerate the features of that input, and if we us a number less than one, we can reduce that feature. In fact, it seems that if we multiply by negative number we can generate digits with the reversed features! See the below gifs – they represent the output of our generator when the inputs for each category are multiplied by scalars ranging between 2 and -2. See how they start with exaggerated features of the above, and end as the reverse!

What we are seeing here is a common feature in GANs. While the input may be random, the generator’s interpretation of them is anything but. It has learned to use certain vectors in the latent space to represent the features it draws.
Adding and Subtracting Features
You can even add and subtract meaningfully in this latent space. What might you expect to get if you added the input for thick, straight digits to the input for thin, slanted digits?

You get an input for thick, slanted digits! And what about subtracting the input for thick, straight digits from the input for thick, slanted digits?

You get an input for thin, slanted digits! How cool is that? If one had the time and resources, the output of the GAN could be categorised more fully and one could gain absolute control of the output being generated using this method.
In future, I aim to apply the principles laid out in this article to train GANs on more complex datasets, and explore the feature mappings I obtain. The full code that I have used is available in this Github Repository: