The world’s leading publication for data science, AI, and ML professionals.

6 Different Ways of Implementing VAE with TensorFlow 2 and TensorFlow Probability

A deep dive into the model from a practical perspective

Since its introduction in 2013 through this paper, variational auto-encoder (VAE) as a type of generative model has stormed the world of Bayesian Deep Learning with its application in a wide range of domains. The original paper by Kingma and Welling has over 10k citations; meanwhile, as its construction might not appear to be straightforward to digest at the first glance, there have been numerous great articles explaining the intuition, architecture, and other different components of the model.

However, the implementation of VAE usually comes as a complement to those articles, and the code itself is less talked about, especially being contextualized under some specific deep learning library (TensorFlow, PyTorch, etc.) – meaning that the code is just put out there in a code block, without enough comments about how some arguments work, why choose this particular function over others, etc. Also, due to the flexibility of each of these popular libraries, you might find these demo VAE implementations all appear different from each other. Furthermore, some versions might be implemented incorrectly, even by one of TensorFlow’s own tutorials (I will address this post later, mainly under Version a. of my implementation), but the mistakes might not be caught without comparison to other versions. Lastly, it’s of course nice to follow through one version that works, but I barely came across a post that compares laterally different ways of implementation. All of these motivate me to write this post.


I’d assume the reader has already had some level of understanding of how VAE works, and wants to know more about the implementation side of things.

I plan to divide this post into two components, which together completes the title:

  1. Things to pay extra attention to when implementing VAE in general, using TF2 and TFP
  2. How to implement VAE in different ways, using TF2 and TFP

Before I start to dive into each part, I want to use the following two paragraphs to share my personal experience of studying this model:

I’d thought that I had a good understanding of the model after reading the paper, until I started to implement it. When reading the paper, there are usually so many details about the model that one might very likely lose track on the parts to pay extra attention upon. Sometimes, a component of the model that the paper spent one short sentence on might take me hours to fully grasp and to make it work during implementation. In my opinion, if you want to have a good understanding on how a model works, although it’s definitely nice to try to go through the paper as thoroughly as possible, it’s IMHO better to try to implement it yourself.

When you start to implement the model, I’d say pick a benchmark dataset that the community already know how the model would behave towards, instead of trying to come up with your own dataset, especially if you are implementing the model for the first time. This is particularly important, as you’d have a very clear goal about how the results would look like – a reference so to speak. Because simply making the code run without error, and seeing the cost drop is far from calling the implementation a working one yet; we need to at least see if it behaves in a certain way that others have already observed on a particular dataset.


Part 0: Clarifying the Implementation Goal

Per the discussion from the above section, I will introduce the dataset, as well as the particular tasks for the VAE model to accomplish. For this post, I applied MNIST dataset with handwritten digits, with images of shape (28,28,1). I preprocessed it by normalizing the dataset to be between 0 and 1, and discretized the values to be either 0 or 1, using 0.5 as threshold. The tasks for the VAE model to accomplish on this dataset are:

(a) Reconstruction of the input digit images as close as possible

(b) Generation of new digit images that look realistic, using random samples from the prior distribution (rather than samples from the posterior, conditional on data) as the input to the decoder

Task (a) is relatively easy to accomplish, however Task (b) is, in my opinion, the one that indicates if the implementation works: that it has learnt the multimodality of different digits from the training data, and is able to produce new digit images as a generative model.

Now we are ready to dive into each component of this post.


Part I: Main Focus when Implementing VAE

Two things, both w.r.t. the loss function for each instance:

  • The loss function consists of two parts: the reverse KL divergence (between prior and posterior distribution of the latent variable z), and the expected negative log-likelihood (of the decoder distribution on the data to be reconstructed; also called the expected reconstruction error). Adding these two parts together, we get negative ELBO, which is to be minimized. For the computation of each part, we need to pay extra attention to what operation (taking the sum vs. taking the mean) we need to do on each dimension of data.
  • The weight of KL divergence in the loss function is a hyperparameter we shouldn’t ignore at all: it adjusts the "distance" between prior and posterior distribution of z, and plays a decisive role in the performance of the model.

Here is the loss function:

The computation of this loss function can be done in various ways, and is often very prone to make mistakes during implementation, especially w.r.t. what’s mentioned in the first bullet point: the operations on each dimension of data. I will elaborate on it in Part II, where each way of the implementation is introduced separately.

Notice that it shall actually be the loss function for Beta-VAE, in which ω can take values other than 1. This hyperparameter is crucial, especially when for Task (b) mentioned in Part 0: what this hyperparameter does is that it decides how hard we want to penalize the difference between the prior and posterior distribution of z. A lot of times the reconstruction of images would look perfect, while the generation from code z sampled from its prior would have all kinds of crazy looks. If ω is set to be too small, we are basically not regularizing the posterior at all, thus after training it might be drastically different from the prior – in the sense that z sampled from the prior would often fall into the area with very low density in the posterior distribution. As a result, the decoder would not know what to do with such samples of z, as it’s trained on z from the posterior distribution (note the distribution at the subscript of the expectation for the negative log-likelihood term from the loss function above). Here are the reconstructed and generated digits when ω=0.0001:

Notice that reconstructed digits can be clearly distinguished (and also match their labels), but the generated digits are basically dark and unrecognizable. This is an example of under-regularization.

On the other hand, if ω is too large, the posterior would be pulled too close to the prior, thus no matter what image is inputed to the encoder, we’d end up getting a z as if it’s randomly sampled from the prior. Here are the reconstructed and generated digits when ω=20:

Notice that all digits look the same, no matter if it’s reconstructed conditional on the input digit, or generated as a new digit. This is an example of over-regularization.

Thirdly, we have the scenario when ω is set to be "just the right amount" – in the sense that the posterior is distinct enough from the prior to be flexible conditioning on the input digit data, thus the reconstruction looks great; while the high-density areas of the two distributions have enough overlap, hence the z samples from the prior wouldn’t look too unfamiliar to the decoder as its input. Here are the reconstructed and generated digits when ω=3:

Notice that all reconstructed and most of generated digits appear to be recognizable, while some of the generated digits appear less realistic.

In summary, while the autoencoder setup places a information bottleneck upon the latent variable z, forcing it to keep only the most essential information needed to reconstruct x after dimension reduction, ω helps to place z at the space that can make the decoder generate new x from scratch that would look like real x, while maintaining the quality of reconstruction.


A little digression: TensorFlow 2 and TensorFlow Probability

You might notice that although I put these two libraries in my title, the post so far has focused purely on discussing the VAE model and its application on MNIST dataset. I originally wanted to write about TF2 and TFP as the third component when I structured the post, but then I decided to contextualize them, i.e. to talk about them while I go through the implementation details in Part I and Part II. However, it could appear too hasty if I jump right into the code without giving a little background on these libraries. So here they come.

TF2

TensorFlow 2.0.0 was released in late September, 2019 – so it’s not even a year since its initial release. The latest version at the moment is 2.3.0, released in July, 2020. I started to work on a research project that had me pick up TF2 in November last year, thus I’ve used all four versions so far, each for a little while. At the time I was fully aware that PyTorch had been gaining a lot of traction in the research community; I chose TF2 mainly because it was the library where the project’s codebase was built upon.

My first impression of TF2: it was indeed much more convenient to develop deep learning models than TF1. With eager execution being one of its most distinctive features that separate TF2 from TF1, I no longer need to build the entire computational graph without seeing any intermediate results, which makes debugging a disaster as there’s no easy way to decompose each step and test its output. While the eager execution might come with the compromise to model training speed, the decorator @tf.function helps to some extent restore the efficiency under graph mode. Furthermore, its integration with Keras has brought some great perks – the Sequential and Functional API have brought different levels of flexibility when stacking up the neural network (NN) layers, just like their equivalents under PyTorch, torch.nn.Sequential and torch.nn.functional; also as a high level API, its interface looks deceptively simple (to the point that it’d cause trouble during VAE implementation – check my discussion under the implementation Version a. in Part II), as if I’m training a scikit-learn model.

Meanwhile, since it has only been less than a year since its first release, it’s definitely not as stable as I’d hope from a popular library. I still remembered that I got stuck over some simple tensor indexing procedure: I implemented the same functionality in multiple versions, which all led me to the same error message; however, it was such a straightforward step that no one would expect it to be the place that causes error. It turned out that after updating TF version to 2.1.0 after its release in early January this year, the model worked without changing any code. The most recent example was when adding a Dense layer after flattening the tensor, I got an error message regarding dimension, which also went away when updating TF version from 2.2.0 to 2.3.0.

Furthermore, its user and developer community have not been as active as I expected. I’ve posted multiple questions under Keras Github Issues page, including the one mentioned above – all ended up with me answering my own question and closing the issue. Some issues obtained comments from users after several months, but none was addressed by the TensorFlow/Keras team.

Lastly, some of its documentation has not been straightforward or organized enough to follow. I’ve spent a lot of time trying to figure out how the decorator @tf.function actually works, and eventually I concluded that the best way is just trying to imitate the working examples without caring too much about the rationale at the moment. Some of its tutorials have also appear to be sloppy or even misleading (by giving incorrect implementation) – I will give examples later.

TFP

TensorFlow Probability was introduced in the first half of 2018, as a library developed specifically for probabilistic modeling. It implements the reparameterization trick under the hood, which enables backpropagation for training probabilistic models. You can find a good demonstration of the reparameterization trick in both the VAE paper and this paper that proposed Bayes by Backprop algorithm – the former work has the hidden nodes for the latent variable z and the output nodes of the decoder being probabilistic, while the latter one has the learnable parameters (weights and biases of each NN layer) being probabilistic.

With TFP, we no longer need to explicitly define the mean and variance parameter for the posterior distribution of z, nor the computation of KL divergence, which greatly simplifies the code. As a matter of fact, it might make the implementation too simple that one could do it without having a great understanding of VAE, because its major components that would require a good grasp to the model are basically all abstracted by TFP. As a result, mistakes can also be easily made when using TFP to implement VAE as well.

I’d suggest that you start by explicitly implementing the reparameterization trick and defining the KL term, if it’s your first time implementing VAE and that you want to know well about how the model works – I will start with this way of implementation as the first version in Part II.


Part II: Different Ways of Implementing VAE

Before diving into the details of each implementation version, I’d like to list all of them here first. This list is by no means exhaustive, but rather representative to the points I plan to make.

The various way of implementing VAE are due to the different options we have for each of the following 3 modules:

  • the output layer of encoder
  • the output layer of decoder
  • the loss function (mainly the part that computes the expected reconstruction error)

Each module has options as follows:

the output layer of encoder

  1. tfpl.IndependentNormal
  2. tfkl.Dense that outputs (concatenated) mean and (raw) standard deviation of the posterior distribution of z

the output layer of decoder

  1. tfpl.IndependentBernoulli
  2. tfpl.IndependentBernoulli.mean
  3. tfkl.Conv2DTranspose that outputs the logits

the loss function

  1. negative_log_likelihood = lambda x, rv_x: -rv_x.log_prob(x)
  2. tf.nn.sigmoid_cross_entropy_with_logits
  3. tfk.losses.BinaryCrossentropy
  4. tf.nn.sigmoid_cross_entropy_with_logits + tfkl.Layer.add_loss
  5. PyTorch-esque explicit computation by epoch through with tf.GradientTape() as tape + tf.nn.sigmoid_cross_entropy_with_logits + tfkl.Layer.add_loss
  6. PyTorch-esque explicit computation by epoch through with tf.GradientTape() as tape + tf.nn.sigmoid_cross_entropy_with_logits

Side note:

we have the following module abbreviations:

import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions tfpl = tfp.layers tfk = tf.keras tfkl = tf.keras.layers

Also note that the NN architecture of the encoder and the decoder are of less importance – I use the same architecture as this TF tutorial for all versions of my implementation.

One might think that since the computation of KL divergence can be done differently (analytical solution vs. MC approximation), it should also be a module; but since the way to implement this term is more or less determined by the choice of the decoder output layer, and adding it as the fourth module might over-complicate the presentation, I choose to talk about it under one of the implementation versions (Version a.) instead.

The implementations are simply the combinatorics of different options from each of the 3 modules:

Now let’s dive into each version of implementation.


Version a.

This version is perhaps the most widely adopted one – as far as I know, all of my fellow researchers have been writing in this way: the encoder outputs nodes that represent the mean and (some transformation of, with range in ℝ) the standard deviation of the posterior distribution of the latent variable z. We then use the reparameterization trick to sample z as follows (from Equation (10) of VAE paper):

in which ϵ is sampled from a standard multivariate Gaussian distribution, and the mean and the standard deviation of the posterior distribution q are outputted deterministically from the encoder. The obtained z is then used as the input to the decoder. Since after preprocessing our image data x is binary, it’s natural to assume a multivariate Bernoulli distribution (where all pixels are independent with each other) whose parameters are the output of the decoder – this is the scenario described under Appendix C.1 Bernoulli MLP as decoder in the VAE paper. Therefore, the log-likelihood of the decoder distribution has the following form (Equation (11) of the VAE paper):

in which y is the decoder output as the Bernoulli parameters for each pixel, and D is the number of pixels for each instance. z here represents a single sample from the encoder; since the expected negative log-likelihood term in the loss function cannot be computed analytically, we use MC approximation as follows (from Equation (10) of the VAE paper):

in which the z term (with the exact same form of superscript as the one in the reparameterization trick equation above) represents the l-th MC sample for the i-th digit image instance. In practice, it’s usually enough to set L=1, i.e. only one MC sample is needed.

The other term in the loss function, the reverse KL divergence, can also be approximated through MC samples. Since we assume the encoder distribution q to be multivariate Gaussian, we can directly plug the mean and variance outputted by the encoder into its density function:

in which f represents the multivariate Gaussian density. Similarly, we can set L=1 as well in practice.

Furthermore, if we let q to be multivariate Gaussian with a diagonal covariance matrix, the KL divergence term can be computed analytically as shown in Appendix B of the VAE paper:

in which J represents the number of dimensions of z.

In this version of implementation, I put the code that computes the cost of a mini-batch image data (note that the loss function is computed on a single instance; the cost function is the average of losses for all instances) into the function vae_cost, and define the optimization step at each epoch through function train_step. Here’s how they are implemented in TF2:

Several things need to be addressed here:

  • The computation of the expected reconstruction error

Since the negative log-likelihood of Bernoulli distribution is essentially where the cross-entropy loss comes from (if you can’t see it right away, this post gives a good review), we can use existing functions – in this implementation version I chose tf.nn.sigmoid_cross_entropy_with_logits: this function takes the binary data and logits of the Bernoulli parameters as arguments, so I didn’t apply Sigmoid activation to the output of the last decoder layer tfkl.Conv2DTranspose.

Note that this function keeps the original dimension of it inputs: since each instance of both the digit images and decoder outputs are of shape (28,28,1), tf.nn.sigmoid_cross_entropy_with_logit would output a tensor of shape (batch_size, 28,28,1), with each element being the negative log-likelihood of the Bernoulli distribution for that specific pixel. Since we assume each digit image instance to have a independent Bernoulli distribution, the negative log-likelihood of each instance is the sum of the negative log-likelihood of all pixels – hence the function tf.math.reduce_sum(..., axis=[1, 2, 3]) in above code block.

This is where extra precaution needs to be taken if you plan to use Keras API: one of the biggest perks of Keras API is that it greatly simplifies the code for training a neural network, to as few as three lines: (1) build a tfk.Model object by defining the input and output of the network; (2) compile the model by specifying the loss function and optimizer; (3) train the model by calling fit method, whose arguments include input and output data, mini-batch size, number of epochs, etc. For Step (2), the argument loss takes a function with exactly two arguments, and compute the cost for the batch of data during training by taking the average across all dimensions of the output of that function, no matter what shape it is. So in our case, you might be very tempted to write loss=tf.nn.sigmoid_cross_entropy_with_logits based on all of the Keras tutorial you’ve seen, but it’s incorrect since it takes average of the cross-entropy losses of all pixels for each instance, instead of summing them up – the resulting cost would no longer possess any statistical interpretation. But fear not, you can still combine tf.nn.sigmoid_cross_entropy_with_logits with Keras API – under Version c. I will elaborate on how to do it.

Still remember at the very beginning of the post, I mentioned that one of TensorFlow’s own tutorials had incorrect implementation of VAE? Now it’s time to take a closer look: the first mistake it made was this computation of the expected reconstruction error. It applied mse_loss_fn = tfk.losses.MeanSquaredError() as the loss function: first of all, the choice of mean squared error is already questionable to me – not only that it implicitly assumes the reconstruction distribution to be independent Gaussian with real-valued data, while our image data after preprocessing being binary makes the assumption of independent Bernoulli distribution a more natural choice, but also that MSE computes a scaled and shifted negative log-likelihood of the Gaussian distribution, meaning that you’d be using a Beta-VAE loss without realizing it; and (much) more importantly, tfk.losses.MeanSquaredError() without explicitly defining the argument reduction would also compute the mean of MSE across all dimensions. Again, the only dimension we need to take average on is the instance dimension, and for each instance we need to take the sum over all pixels. If one wants to apply tfk.losses module, under implementation Version d. I will demonstrate how to use tfk.losses.BinaryCrossentropy, a more suitable choice for our case to compute the expected reconstruction error.

  • The computation of KL divergence

Extra attention needs to be paid for similar reason as above: the operation we need to do for each dimension of the data, specifically taking the mean or the sum.

Note that for the analytical solution of KL divergence, we are taking the sum of the parameters for the posterior distribution of z over their elements across all J dimensions. However, in that same TensorFlow tutorial, it’s computed in the following way:

kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )

which is the second mistake they made in that guide, as they are taking the mean across all dimensions.

One can also use MC approximation with one sample of z for the KL divergence computation – check the code when analytic_kl=False. If you want to verify if your implementation is correct, you might use the following way to approximate the KL divergence, with L set as some large integer, say 10,000:

in which each z is sampled through

and see if the result is similar to the one you obtain with the analytical solution.

Lastly, if you don’t want to manually compute the KL divergence, you may use the function from TensorFlow Probability library, which directly computes the KL divergence between two distributions:

prior_dist = tfd.MultivariateNormalDiag(loc=tf.zeros((batch_size, latent_dim)), scale_diag=tf.ones((batch_size, latent_dim)))
var_post_dist = tfd.MultivariateNormalDiag(loc=mu, scale_diag=sd)
kl_divergence = tfd.kl_divergence(distribution_a=var_post_dist, distribution_b=prior_dist)

Now that we just mentioned TFP, it’s a good time to jump into the next version of implementation, which is all about this library.


Version b.

By combining TFP with Keras API of TF2, the code looks much simpler than the one in Version a. It is actually my favorite version, and will be the one I use in the future due to its simplicity.

In this version, the output of both the encoder and the decoder are objects from tensorflow_probability.distributions module, which have many methods you’d expect from a probabilistic distribution: mean, mode, sample, prob, log_prob, etc. To obtain such output from encoder and decoder, you only need to replace their output layer with one of the tensorflow_probability.layers objects.

The implementation is as follows:

That’s it! 62 lines of code after code reformatting, including comments. One of the main reasons for such simplification is the sampling of z, as all of the steps of the reparameterization trick have been abstracted through the encoder output TFP layer tfpl.IndependentNormal. Furthermore, the KL divergence computation is done through the activity_regularizer argument in that probabilistic encoder output layer, where we specify the prior distribution to be the standard multivariate Gaussian distribution, as well as the KL divergence weight ω, to create a tfpl.KLDivergenceRegularizer object. In addition, the expected reconstruction error can be computed by simply calling the log_prob method of the decoder output, which is a tfp.distributions.Independent object – this is as neat as it can get.

One caveat is that one might think since the input to the decoder needs to be a tensor, but z is a tfp.distributions.Independent object (see Line 53), we need to instead write z = encoder(x_input).sample() to explicitly sample z. Doing so is not only unnecessary, but also incorrect:

  • unnecessary because we have convert_to_tensor_fn set to tfd.Distribution.sample (which is actually the default value but I wrote it out explicitly so you can see): what this argument does is that whenever the output of this layer is treated as a tf.Tensor object, like in our case when we need a sample from this distribution – outputs=decoder(z) in Line 56, it will call the method specified by convert_to_tensor_fn, so it’s already doing outputs=decoder(z.sample()).
  • incorrect because with .sample() being explicitly called, the KL divergence, which is supposed to be computed by tfpl.KLDivergenceRegularizer, would not get picked up as part of the cost. It might be that after .sample() is called, we no longer have z as a tfp.distributions.Independent object in the computational graph of the neural network, which is the object type that contains tfpl.KLDivergenceRegularizer as its activity_regularizer. Therefore, doing .sample() for the encoder output in this version of implementing VAE would give us a loss function that only contains the expected reconstruction error – the trained VAE would no longer serve as a generative model, but one that only reconstructs its input.

This version of implementation is very similar to this tutorial by TensorFlow, which does a much better job than the one showed under Version a. – both reconstruction of existing digits and generation of new digits would work. Only a couple of things I want to address for this post:

  • They applied tfpl.MultivariateNormalTriL instead of tfpl.IndependentNormal as the encoder output probabilistic layer, which essentially trains the non-zero elements of a lower triangular matrix that is conceptually derived from Cholesky decomposition of a positive definite matrix. Such positive definite matrix is essentially the covariance matrix of the posterior distribution of z, and can be any positive definite matrix, instead of just a diagonal matrix assumed in the VAE paper. This would give us a more flexible posterior distribution, but also contains more parameters to train, and the KL divergence is more complicated to compute.
  • They set the KL divergence weight as the default 1.0 in tfpl.KLDivergenceRegularizer, but like I discussed under Part I, this hyperparameter is crucial to the success of a VAE implementation, and usually needs to be explicitly tuned to optimize the model performance.

Lastly for this implementation version, I want to show one perk of applying TFP layer as the decoder output: that we are able to obtain more flexible predictions. For the original VAE, the decoder output is deterministic, therefore after sampling z, the decoder output is set. However, with the output being a distribution, we can call mean, mode, or sample method to output a tf.Tensor object as the digit image prediction. Here are the results when different methods are called, for both reconstruction and generation:

Note that for both reconstruction and generation, the mean appears to be more blurry than (or not as sharp as) the mode, since the mean of a Bernoulli distribution is a value between 0 and 1, while the mode is either 0 (when parameter is less than 0.5) or 1 (otherwise); the sample shows more granularity than the other two while also looks sharp, since each pixel is a random sample (unlike the mode, which after training basically knows what values all pixels would take as a whole to form a digit) that also takes either 0 or 1 as its value. This is the flexibility that a probabilistic distribution can provide.

The following two versions of implementation focus on how to customize the loss function to compute the expected reconstruction error correctly while applying Keras API to simplify the code; since the discussion in Version a. already set the foundation for these two versions, I can focus mainly on demonstrate the code.


Version c.

In Version a. we talked about how directly using tf.nn.sigmoid_cross_entropy_with_logits for the loss argument when compile the tfk.Model object would result in a incorrect computation of the expected reconstruction error; a quick fix is to implement a custom loss function based on tf.nn.sigmoid_cross_entropy_with_logits as follows:

in which we take the sum of the cross-entropy losses across all pixel dimensions for each instance. And when compiling the model, we write

model.compile(loss=custom_sigmoid_cross_entropy_loss_with_logits, optimizer=tfk.optimizers.Adam(learning_rate))

Note that the decoder is the same one as in Version a., which deterministically outputs the logits of parameters of the independent Bernoulli distribution. Meanwhile, we use tfpl.IndependentNormal as the encoder output layer just like in Version b., thus the KL divergence computation is taken care of by its argument activity_regularizer.

Similarly, we have the following implementation version:


Version d.

We build another custom loss function upon tfk.losses.BinaryCrossentropy as follows:

Note that unlike Version c. which uses a deterministic decoder output layer, this version apples a probabilistic layer just like in Version b.; however, we need to take the mean of this distribution first (Line 2 of above code block), because one of the arguments for tfk.losses.BinaryCrossentropy object is the parameter of the Bernoulli distribution, which is the same as its mean. Also note the argument reduction when initializing the tfk.losses.BinaryCrossentropy object, which is set to tfk.losses.Reduction.NONE: this prevents the program from making further operations on the resulting tensor that is of the same shape as the mini-batch digit image tensor, each element of whose contains the cross-entropy loss for one specific pixel. We then take the sum over the dimensions at the instance level, just like what we did in the custom loss function in Version c.

It’s worth pointing out that any function we define for the loss argument when compile the model must take exactly two arguments, with one being the data that the model tries to predict, and the other being the model output. Therefore, we’d have a problem whenever we want to apply Keras API while having a more flexible loss function. We are lucky to have the activity_regularizer argument in the encoder output TFP layer that helps incorporate the KL divergence into the computation of cost, which gives us Version b.; but what if we don’t? The remaining two versions introduces a way to implement a more flexible loss function in general, not just for the specific case of VAE – thanks to add_loss method of tfkl.Layer class.


Version e.

I will start by directly demonstrate the code for this version:

Note that I used the exact same loss function as the one in Version c. when compile the tfk.Model (Line 70). The weighted KL divergence is incorporated into the computation of the cost through calling the add_loss class method (Line 49). The following is a direct quote from its documentation:

This method can be used inside a subclassed layer or model’s call function, in which case losses should be a Tensor or list of Tensors.

Without any requirements to the format besides being "a Tensor or list of Tensors", we can build a much more flexible loss function. The tricky part comes from the first half of that quote, which specifies where this method shall be called: note that in my implementation, I called it inside the call function of VAE_MNIST class, which inherits from tfk.Model class. This is different from implementation Version b., which doesn’t have class inheritance. For this version, I originally also started by writing VAE_MNIST class with no class inheritance, and call compile method for the tfk.Model object in the class function named build_vae_keras_model, just like what I did in Version b. After the tfk.Model object model gets compiled, I directly called model.add_loss(self.kl_weight * kl_divergence) – my rationale is that since the output layer of our model, an object of class tfkl.Conv2DTranspose, inherits from tfkl.Layer, we shall be able to conduct such operation. However, the training result is that I’d get "brighter" generated image if I turn the KL weight down, to around 0.01; and if I start to increase KL weight, I’d get almost completely dark generated images. Overall, the generated images would have poor quality, while the reconstructed images would look great but wouldn’t change at all with different values of KL weight. Only after I inherited VAE_MNIST from tfk.Model class, and changed the name of the class method where add_loss gets called to call, would the model train properly.


Version f.

I implemented this version as the last one for practice, imagining that such framework might come in handy if one day I need to implement some model that’s too complicated to use Keras API. Here’s the code:

I want to emphasize on the importance of naming the class method that calls the add_loss method as call: notice that this method was named encode_and_decode originally (Line 70) – if I add @tf.function decorator for train_step method (Line 94) before renaming the class method, I’d get

TypeError: An op outside of the function building code is being passed a "Graph" tensor. It is possible to have Graph tensors leak out of the function building context by including a tf.init_scope in your function building code.

After I modified the class method name accordingly, the error disappeared. It’s important to have the @tf.function decorator because it significantly improves the training speed: before renaming, the last epoch ran for 1041 seconds (and the time lapse for each epoch increased from epoch to epoch; the first epoch "only" ran for 139 seconds, which is another strange phenomenon); after renaming, every epoch has running time between 34.7 seconds and 39.4 seconds.

Moreover, before renaming, I’d need to add the lambda argument for add_loss method (Line 78), since without lambda: I’d get the following error:

ValueError: Expected a symbolic Tensors or a callable for the loss value. Please wrap your loss computation in a zero argument lambda.

In short, when you plan to apply add_loss method to implement a more flexible loss function, a good practice is to subclass the tfk.Model class, and call the add_loss method inside a class method named call.


Conclusion

In this post, I presented 6 different ways of implementing the VAE model that trains on the MNIST dataset, and gave a detailed review on VAE from a practical standpoint. In general, I would recommend a first-time VAE implementer to start with Version a., and then try to apply Keras API to simplify the code, with custom loss function to compute the expected reconstruction error just like Version c. and Version d.; after gaining a decent understanding of the model, it would be nice to adopt Version b., which is the most concise and IMHO the most elegant version. If one wants to prepare for implementing a very flexible loss function in general in the future, Version e. or Version f. could be a good place to start.

References

[1] Diederik P. Kingma and Max Welling, Auto-Encoding Variational Bayes (2013), Proceedings of the _2_nd International Conference on Learning Representations (ICLR)

[2] Charles Blundell, Julien Cornebise, Koray Kavukcuoglu and Daan Wierstra, Weight Uncertainty in Neural Networks (2015), Proceedings of the _32_nd International Conference on Machine Learning (ICML)

[3] Making new Layers & Models via subclassing (2020), TensorFlow Guide

[4] Convolutional Variational Autoencoder (2020), TensorFlow Tutorial

[5] Ian Fischer, Alex Alemi, Joshua V. Dillon, and the TFP Team, Variational Autoencoders with Tensorflow Probability Layers (2019), TensorFlow on Medium


Related Articles