The world’s leading publication for data science, AI, and ML professionals.

VAE: Variational Autoencoders – How to Employ Neural Networks to Generate New Images

An overview of VAEs with a complete Python example that teaches you how to build one yourself

Neural Networks

Variational Autoencoders (VAE). Image by author.
Variational Autoencoders (VAE). Image by author.

Intro

This article will take you through Variational Autoencoders (VAE), which fall into a broader group of Deep Generative Models alongside the famous GANs (Generative Adversarial Networks).

Unlike GAN, VAE uses an Autoencoder architecture instead of a pair of Generator-Discriminator networks. So, the ideas used in VAEs should be relatively straightforward to understand, especially if you have used Autoencoders in the past.

Feel free to subscribe to email notifications if you would like to be informed about my future articles on Neural Networks such as GANs.

Contents

  • VAE’s place in the universe of Machine Learning algorithms
  • The structure of VAEs and an explanation of how they work
  • A complete Python example showing you how to build a VAE with Keras/Tensorflow

VAE’s place in the universe of Machine Learning algorithms

The below chart is my attempt to organise the most common Machine Learning algorithms. Although, it is not an easy task since we can categorise them across multiple dimensions based on the algorithm’s underlying structure, or the problems they are designed to solve.

I have tried to take both dimensions into account, which led me to placing Neural Networks into their own category. While we typically use Neural Networks in a Supervised manner, it is essential to acknowledge that some examples, such as Autoencoders, are more like Unsupervised/Self-Supervised algorithms.

Despite Variational Autoencoders (VAE) having similar objectives as GANs, their architecture is closer to other types of Autoencoders such as Undercomplete Autoencoders. Hence, you will find VAEs by clicking on the Autoencoders group in the interactive chart below👇 .

If you enjoy Data Science and Machine Learning, please subscribe to get an email with my new articles.

The structure of VAEs and an explanation of how they work

Let’s start by analysing the architecture of a standard Undercomplete Autoencoder (AE) before diving into the elements than make VAEs different.

Undercomplete AE

Below is an illustration of a typical AE.

Undercomplete Autoencoder architecture. Image by author, created using AlexNail's NN-SVG tool.
Undercomplete Autoencoder architecture. Image by author, created using AlexNail’s NN-SVG tool.

The goal of an Undercomplete AE is to efficiently encode information from input data into a lower-dimensional latent space (bottleneck). We achieve this objective by ensuring that the inputs can be recreated with minimal loss using a decoder.

Note that during training, we pass the same set of data into input and output layers as we attempt to discover the parameter values for an "optimal" latent space.

Variational AE

Now let’s look at how VAE differs from an Undercomplete AE by analysing its architecture:

VAE architecture. Image by author.
VAE architecture. Image by author.

We notice that VAE’s latent space is not made up of point vectors (individual nodes). Instead, the inputs are mapped onto a Normal distribution, where Zμ and Zσ are the mean and variance, the parameters learned during model training.

Meanwhile, the latent vector Z is sampled from a distribution with mean Zμ and variance Zσ and passed to the decoder to obtain the predicted outputs.

It is crucial to understand that by design, the latent space of a VAE is continuous, which enables us to sample from any part of it to generate new outputs (e.g. new images), making VAE a generative model.

The need for regularisation

Encoding inputs into a distribution takes us only halfway to creating a latent space that is suitable for generating "meaningful" outputs.

However, we can achieve the desired regularity by adding a regularisation term expressed as the Kulback-Leibler divergence (KL divergence). We will talk more about it in the Python section later on.

Intuition about the latent space

We can use the following illustration to visualise how the information is spread within the latent space.

An intuitive way to think about regularised continuous latent space. Image by author.
An intuitive way to think about regularised continuous latent space. Image by author.

As you can see, mapping data as individual points do not train the model to understand the similarities/differences between those points. Hence, we cannot use such space to generate new "meaningful" data.

In the case of Variational Autoencoders, we have mapped data as distributions and regularised the latent space, which gives us the "gradient" or "smooth transition" between distributions. Hence, when we sample a point from such latent space, we generate new data closely resembling the training data.

A complete Python example showing you how to build a VAE with Keras/Tensorflow

Finally, it’s time to build our own VAE!

Setup

We’ll need the following data and libraries:

Let’s import all the libraries:

# Tensorflow / Keras
from tensorflow import keras # for building Neural Networks
print('Tensorflow/Keras: %s' % keras.__version__) # print version
from keras.models import Model # for creating assembling a Neural Network model
from keras import Input # for instantiating a keras tensor and specifying input dimensions
from keras.layers import Dense, Lambda # adding layers to the Neural Network model
from tensorflow.keras.utils import plot_model # for plotting model diagram
from keras import backend as K # for access to Keras backend for reparameterization and creating custom loss function

# Data manipulation
import numpy as np # for data manipulation
print('numpy: %s' % np.__version__) # print version

# Visualization
import matplotlib 
import matplotlib.pyplot as plt # for plotting model loss
print('matplotlib: %s' % matplotlib.__version__) # print version
import graphviz # for showing model diagram
print('graphviz: %s' % graphviz.__version__) # print version
import plotly
import plotly.express as px # for data visualization
print('plotly: %s' % plotly.__version__) # print version

# Other utilities
import sys
import os

# Assign main directory to a variable
main_dir=os.path.dirname(sys.path[0])

The above code prints package versions used in this example:

Tensorflow/Keras: 2.7.0
numpy: 1.21.4
matplotlib: 3.5.1
graphviz: 0.19.1
plotly: 5.4.0

Next, we load MNIST handwritten digit data and display the first ten digits. Note that we will only use digit labels (y_train, y_test) in visualisations and not for model training.

# Load digits data 
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

# Print shapes
print("Shape of X_train: ", X_train.shape)
print("Shape of y_train: ", y_train.shape)
print("Shape of X_test: ", X_test.shape)
print("Shape of y_test: ", y_test.shape)

# Normalize input data (divide by 255) 
X_train = X_train.astype("float32") / 255
X_test = X_test.astype("float32") / 255

# Display images of the first 10 digits in the training set and their true lables
fig, axs = plt.subplots(2, 5, sharey=False, tight_layout=True, figsize=(12,6), facecolor='white')
n=0
for i in range(0,2):
    for j in range(0,5):
        axs[i,j].matshow(X_train[n])
        axs[i,j].set(title=y_train[n])
        n=n+1
plt.show()
The first ten digits of the MNIST dataset. Image by author.
The first ten digits of the MNIST dataset. Image by author.

As you can see, we have 60,000 images in the training set and 10,000 in the test set. Note that their dimensions are 28 x 28 pixels.

The final step in the setup is to flatten the images by reshaping them from 28×28 to 784.

Typically, we would use Convolutional layers instead of flattening images, especially when working with larger pictures. However, I wanted to keep this example simple, hence using Dense layers with flat data instead of Convolutional ones.

# Reshape input data
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)

# Print shapes
print("New shape of X_train: ", X_train.shape)
print("New shape of X_test: ", X_test.shape)
New shape of X_train:  (60000, 784)
New shape of X_test:  (10000, 784)

Building a Variational Autoencoder model

We will start by defining a function that will help us to sample from a latent space distribution Z.

Here we employ a reparameterisation trick that allows the loss to backpropagate through the mean (z-mean) and variance (z-log-sigma) nodes since they are deterministic.

At the same time, we separate the sampling node by adding a non-deterministic parameter, epsilon, which is sampled from a standard Normal distribution.

#--- Create a function, which we will use to randomly sample from latent space distribution
# Note, epsilon is sampled from a standard normal distribution and is used to maintain the required stochasticity of Z
# Meanwhile, z-mean and z-sigma remain deterministic allowing the loss to backpropagate through the layers.
def sampling(args):
    z_mean, z_log_sigma = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=1.)
    return z_mean + K.exp(z_log_sigma) * epsilon

Now, we can define the structure of the Encoder model.

# Specify dimensions for input/output and latent space layers
original_dim = 784 # number of neurons at the input layer (28 * 28 = 784)
latent_dim = 2 # latent space dimension

# ********** Create Encoder **********

#--- Input Layer 
visible = keras.Input(shape=(original_dim,), name='Encoder-Input-Layer')

#--- Hidden Layer
h_enc1 = Dense(units=64, activation='relu', name='Encoder-Hidden-Layer-1')(visible)
h_enc2 = Dense(units=16, activation='relu', name='Encoder-Hidden-Layer-2')(h_enc1)
h_enc3 = Dense(units=8, activation='relu', name='Encoder-Hidden-Layer-3')(h_enc2)

#--- Custom Latent Space Layer
z_mean = Dense(units=latent_dim, name='Z-Mean')(h_enc3) # Mean component
z_log_sigma = Dense(units=latent_dim, name='Z-Log-Sigma')(h_enc3) # Standard deviation component
z = Lambda(sampling, name='Z-Sampling-Layer')([z_mean, z_log_sigma]) # Z sampling layer

#--- Create Encoder model
encoder = Model(visible, [z_mean, z_log_sigma, z], name='Encoder-Model')

# Display model diagram
plot_model(encoder, show_shapes=True, dpi=300)

The above code creates an encoder model and prints its structural diagram.

Diagram of the Encoder part of the VAE model. Image by author.
Diagram of the Encoder part of the VAE model. Image by author.

Note how we send the same outputs from the Encoder-Hidden-Layer-3 into Z-Mean and Z-Log-Sigma before recombining them inside a custom Lambda layer (Z-Sampling-Layer), which is used for sampling from the latent space.

Next, we create the Decoder model:

# ********** Create Decoder **********

#--- Input Layer 
latent_inputs = Input(shape=(latent_dim,), name='Input-Z-Sampling')

#--- Hidden Layer
h_dec = Dense(units=8, activation='relu', name='Decoder-Hidden-Layer-1')(latent_inputs)
h_dec2 = Dense(units=16, activation='relu', name='Decoder-Hidden-Layer-2')(h_dec)
h_dec3 = Dense(units=64, activation='relu', name='Decoder-Hidden-Layer-3')(h_dec2)

#--- Output Layer
outputs = Dense(original_dim, activation='sigmoid', name='Decoder-Output-Layer')(h_dec3)

#--- Create Decoder model
decoder = Model(latent_inputs, outputs, name='Decoder-Model')

# Display model diagram
plot_model(decoder, show_shapes=True, dpi=300)

The above code creates a decoder model and prints its structural diagram.

Diagram of the Decoder part of the VAE model. Image by author.
Diagram of the Decoder part of the VAE model. Image by author.

As you can see, the decoder is a pretty straightforward model that takes inputs from the latest space and passes them through a few hidden layers before generating values for the 784 output nodes.

Next, we combine the Encoder and Decoder models to form a Variational Autoencoder model (VAE).

# Define outputs from a VAE model by specifying how the encoder-decoder models are linked
outpt = decoder(encoder(visible)[2]) # note, outputs available from encoder model are z_mean, z_log_sigma and z. We take z by specifying [2]
# Instantiate a VAE model
vae = Model(inputs=visible, outputs=outpt, name='VAE-Model')

If you paid close attention to the latent space layers in the Encoder model, you would have noticed that the encoder generates three sets of outputs: Z-mean [0], Z-log-sigma [1] and Z [2].

The above code links the models by specifying that the Encoder takes original inputs named "visible". Then out of the three outputs generated by the Encoder [0], [1], [2], we take the third one (Z [2]) and pass it into a Decoder, which generates the outputs that we named "outpt".

Connecting Encoder and Decoder to construct a VAE model. Image by author.
Connecting Encoder and Decoder to construct a VAE model. Image by author.

Custom Loss function

Before training the VAE model, the final step is to create a custom loss function and compile the model.

As mentioned earlier in the article, we will use KL divergence to measure the loss between the latent space distribution and a reference standard Normal distribution. The "KL loss" is in addition to the standard reconstruction loss (in this case, MSE) used to ensure that input and output images remain close.

# Reconstruction loss compares inputs and outputs and tries to minimise the difference
r_loss = original_dim * keras.losses.mse(visible, outpt)  # use MSE

# KL divergence loss compares the encoded latent distribution Z with standard Normal distribution and penalizes if it's too different
kl_loss =  -0.5 * K.sum(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis = 1)

# The VAE loss is a combination of reconstruction loss and KL loss
vae_loss = K.mean(r_loss + kl_loss)

# Add loss to the model and compile it
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')

VAE model training

With the Variational Autoencoder model assembled, let’s train it over 25 epochs and plot the loss chart.

# Train VAE model
history = vae.fit(X_train, X_train, epochs=25, batch_size=16, validation_data=(X_test, X_test))

# Plot a loss chart
fig, ax = plt.subplots(figsize=(16,9), dpi=300)
plt.title(label='Model Loss by Epoch', loc='center')

ax.plot(history.history['loss'], label='Training Data', color='black')
ax.plot(history.history['val_loss'], label='Test Data', color='red')
ax.set(xlabel='Epoch', ylabel='Loss')
plt.xticks(ticks=np.arange(len(history.history['loss']), step=1), labels=np.arange(1, len(history.history['loss'])+1, step=1))
plt.legend()
plt.show()
Variational Autoencoder model loss by epoch. Image by author.
Variational Autoencoder model loss by epoch. Image by author.

Visualising latent space and generating new digits

Since our latent space is two-dimensional, we can visualise the neighbourhoods of different digits on the latent 2D plane:

# Use encoder model to encode inputs into a latent space
X_test_encoded = encoder.predict(X_test)

# Recall that our encoder returns 3 arrays: z-mean, z-log-sigma and z. We plot the values for z
# Create a scatter plot
fig = px.scatter(None, x=X_test_encoded[2][:,0], y=X_test_encoded[2][:,1], 
                 opacity=1, color=y_test.astype(str))

# Change chart background color
fig.update_layout(dict(plot_bgcolor = 'white'))

# Update axes lines
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='white', 
                 zeroline=True, zerolinewidth=1, zerolinecolor='white', 
                 showline=True, linewidth=1, linecolor='white',
                 title_font=dict(size=10), tickfont=dict(size=10))

fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='white', 
                 zeroline=True, zerolinewidth=1, zerolinecolor='white', 
                 showline=True, linewidth=1, linecolor='white',
                 title_font=dict(size=10), tickfont=dict(size=10))

# Set figure title
fig.update_layout(title_text="MNIST digit representation in the 2D Latent Space")

# Update marker size
fig.update_traces(marker=dict(size=2))

fig.show()

Plotting the digit distribution in the latent space gives us the benefit of visually associating different regions with different digits.

Say we want to generate a new image of a digit 3. We know that 3’s are located in the top middle of the latent space. So let’s pick the coordinates of [0, 2.5] and generate an image associated with those inputs.

# Input latent space coordinates
z_sample_digit=[[0,2.5]]

# Decode latent inputs (i.e., generate new outputs)
digit_decoded = decoder.predict(z_sample_digit)

# Reshape and display the image
plt.matshow(digit_decoded.reshape(28,28))
plt.show()
New digit generated by the VAE model. Image by author.
New digit generated by the VAE model. Image by author.

As expected, we got an image of a shape closely resembling a digit 3 because we sampled a vector from a region in the latent space occupied by 3’s.

Let’s now generate 900 new digits from across the whole latent space.

# Display a 2D manifold of the digits
n = 30  # figure with 30x30 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))

# We will sample n points within [-1.5, 1.5] standard deviations
grid_x = np.linspace(1.5, -1.5, n)
grid_y = np.linspace(-1.5, 1.5, n)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        # Generate an image using a decoder model
        x_decoded = decoder.predict(z_sample)
        #x_decoded = np.clip(x_decoded, 0.25, 0.75) # we could use cliping to make digit edges thicker

        # Reshape from 784 to original digit size (28x28)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

# Plot figure
plt.figure(figsize=(18, 16))
plt.imshow(figure)
plt.show()
900 new digits generate using our VAE model. Image by author.
900 new digits generate using our VAE model. Image by author.

The cool thing about generating many images from the entire latent space is that it lets us see the gradual transition between different shapes. This confirms that we were able to regularise our latent space successfully.

Final remarks

It is important to note that we can use Variational Autoencoders to encode and generate much more complex data than MNIST digits.

Hence, I would like to encourage you to take my simple tutorial to the next level by applying it to real-world data relevant to your area.

For your convenience, I have saved a Jupyter Notebook in my GitHub repository containing all of the above code.

If you would like to be informed the moment I publish a new article on Machine Learning / Neural Networks (e.g., Generative Adversarial Networks (GAN)), please subscribe to receive an email.

Please do not hesitate to get in touch if you have any questions or suggestions!

Cheers! 🤓 Saul Dobilas


Related Articles