The world’s leading publication for data science, AI, and ML professionals.

Enter the j(r)VAE: divide, (rotate), and order… the cards

Introduction to joint (rotationally-invariant) VAEs that can perform unsupervised classification and disentangle …

Thoughts and Theory

Introduction to joint (rotationally-invariant) VAEs that can perform unsupervised classification and disentangle relevant (continuous) factors of variation at the same time.

Maxim Ziatdinov¹ ² & Sergei V. Kalinin¹

¹ Center for Nanophase Materials Sciences and ² Computational Sciences and Engineering Division, Oak Ridge National Laboratory, Oak Ridge, TN 37831, United States

What are joint (rotationally-invariant) variational autoencoders, and why would we need them? The short answer to the first question is that j(r)VAE is a version of a variational autoencoder where one of the latent variables is discrete and the rest are continuous. Hence, _j_VAE seeks both to cluster the data by associating it with a specific discrete class and disentangle data representations within the class via latent variables. The answer to the second question is that in the experimental world, we often do not know what we are dealing with, but very much want to find out.

In their careers, both authors of this note focus on nanometer- and atomically-resolved imaging and spectroscopy. M.Z. started his career doing Scanning Tunneling Microscopy and first-principles calculations at the Tokyo Institute of Technology, Japan, exploring the electronic properties of graphene decorated with hydrogen and oxygen atoms. After moving to Oak Ridge, he gradually transitioned to the field of data analysis and applied deep/machine learning (ML/DL), with his first DL papers focusing on the analysis of electron microscopy datasets. S.V.K. has started his familiarity with Scanning Probe Microscopy as a visiting student in PosTech, South Korea in 1996, and has made microscopy a career since 1998 as a graduate student at U. Penn working with Dawn Bonnell. In 25 years, he has amassed expertise working with variants of scanning probe microscopy, electron microscopy, and chemical imaging. The commonality for all these techniques is their capability to yield massive amounts of data, from polarization distributions in grains of polycrystalline thin films to the electronic structure of high-temperature superconductors to structural distortions underpinning the physics of the class of Kitaev materials for quantum computing.

As microscopists by training and inclination, we seek to understand the physics and chemistry of materials from these data, either looking for qualitative understanding of the materials’ structure and functionalities ("Are there dislocations? Do we see evidence of point defects?"), or by refining quantitative physical models ("How does polarization behave at interfaces? What is the value of flexoelectric tensor?"), or by discovering novel phenomena ("Hmmm…. This is weird…"). However, for many years such analysis was based on human analytics and comprehension – spending innumerable hours with the visualization software of the time looking for specific objects while varying contrast, filtering, color schemes, in the analysis software. Clearly, this is not the best strategy even for the 2D images (even though it used to be even worse in the times of Polaroid cameras connected to the microscopes). For the multidimensional spectroscopic data sets now common in electron and scanning probe microscopy, or stacks of images such as movies of atomic evolution in electron microscopy, this is just impossible.

Hence, the question is whether machine learning can help us discover "interesting" objects in data. If we know exactly what we are looking for, i.e. have labels, this becomes a classical supervised learning problem. The authors have pioneered the use of the deep convolutional neural networks (DCNN) for the atom finding in electron microscopy images in 2017, enabling analysis of massive data sets containing hundreds of images and creating libraries of atomic defects, exploring point defect reactions, etc [1–3]. In this case, ML effectively performs a human-level task, only faster.

However, in many cases, we aim to look for new and unusual objects and behaviors in our data. For example, we may want to find what are the constituent building blocks forming the structure of amorphous solids. Alternatively, we seek to understand how the electronic density patterns on the surface of high-temperature superconductors are organized, and if there is any underlying order. As a third example, we aim to discover ferroelectric domain structures from images of the atomic lattice. How should we go about it?

One possible strategy for such analysis is descriptor engineering with subsequent unsupervised learning. We subdivide the image into multiple sub-images and Explore the latter. In the absence of any prior knowledge about the system, the sub-images can be picked on a rectangular grid, much like it is done in the input convolutional layer of DCNN. Alternatively, they can be centered on the objects of interest, providing inferential bias into the analysis. For atomically resolved data, it is natural to center the descriptors on atomic columns, whereas for ferroelectric domains we can position them on domain walls. Parenthetically, the choice of the position encodings can be a very interesting area going further.

With the descriptors in hand, we may seek to find disparate categories of atomic structures (e.g. discover "molecules") or find the strain states of individual structural units. The first task is done by clustering. Using a suitable algorithm, whether k-means, Gaussian Mixture Model, or UMAP, we try to separate our descriptors into dissimilar groups. The second can be realized by physics-based analysis if we know exactly what we are looking for. Sometimes we do – however, there is always a risk to find exactly what you are looking for whether it is there or not, and miss novel discoveries. Hence, we would like to use unsupervised ML to find relevant traits within the data, and that’s exactly where variational autoencoders with their capability to disentangle data representations and find a priori unknown factors of variability within the data can come in handy.

However, often we need to do both at the same time! For example, in the observations of atomic motion in STEM, we can observe both structural units like 5-, 6-, and 7- member rings of ideal graphene and defects, but also their deformations under strain. Similarly, we can see beautiful domain structures in ferroelectric materials, and the human brain will naturally classify them into different types, and elucidate deformations within each type. Can VAEs do it?

Of course, one approach to this is to encode the full data set in one latent space. In this case, each encoded object corresponds to a single latent vector, and we can simply cluster the points in the latent space. However, in this case, the representation of the data (understood as a smooth variation of relevant traits within the data set) will not be disentangled well. In some sense, class to class variation dominates the process. Compare the disentangled data for the simple (r)VAE and conditional (r)VAE from our previous post – clearly, when classes are known, we can dig into the data structure much deeper and the traits become uniform across full latent space (rather than piecewise).

So, how can we proceed? A way to do it is to make one of the latent directions discrete, an approach we refer to as _j_VAE (‘j‘ for joint). In this case during training, the autoencoder both assigns the discrete class to the variables and finds the optimal continuous latent code, i.e. performs classification and disentangles relevant variables at the same time.

However, this is not simple, because discrete (categorical) latent variables cannot be backpropagated through samples. One solution is to substitute a non-differentiable sample from a categorical distribution with a differentiable sample from a Gumbel-Softmax distribution [4, 5]. Another solution is to marginalize out discrete latent variables via a full enumeration. This approach is generally associated with high computational costs, although the recently available probabilistic programming libraries such as Pyro make this task easier by enabling parallel enumeration. We have experimented with both approaches and found that they generally produce similar results on our datasets. For this article, we utilized the approach based on the Gumbel-Softmax trick.

Now, let’s look at examples. Rather than working with the standard MNIST data set, we are going to make our own data set of playing card suits, with monochrome clubs, spades, diamonds, and hearts. We speculate that galaxy images from GalaxyZoo, some biological objects (bacteria, diatoms, viruses) can all be a great subject to explore – both from aesthetic and scientific perspectives. However, cards offer an ideal toy model. Here, diamonds have a shape that is very distinct from the other three hands. Their rotation by 90 degrees is equivalent to uniaxial compression and resizing, bringing interesting degeneracy into outcomes of possible affine transforms. Similarly, hearts and spades differ by fairly small detail, whereas spades and clubs (without tail) have three-fold and mirror symmetry respectively. So, despite simplicity card suits form an interesting collection of symbols.

Typical objects of the cards set with different affine transforms. On the left (a = 12, t = 0.1, s = 10), on the right (a = 120, t = 0.1, s = 10). Images by the authors.
Typical objects of the cards set with different affine transforms. On the left (a = 12, t = 0.1, s = 10), on the right (a = 120, t = 0.1, s = 10). Images by the authors.

First, we create a generator of card symbols with different parameters. We start from the digitized symbols from Word as a starting point, transform them into images, and then use a set of affine transforms including rotation, shear, and translation to prepare our data set. Here, the angles have a uniform distribution from –a to a, translation has a uniform distribution from –t to t, and shear has a uniform distribution from –s to s. The examples from two realizations of this data set are shown above. In other words, here we have multiple well-defined classes of objects with several continuous traits (sizes and positions). Now, let’s assume that we, much like John Snow, do not know anything about these objects, and see if machine learning can help to sort out what goes on here!

Principal component analysis on the cards dataset. Images by the authors.
Principal component analysis on the cards dataset. Images by the authors.

As we pointed in the previous blogpost, recently we almost stopped doing principal component analysis. Running quick PCA on this data set illustrates why – the PCA components are very complex, and there is a lot of them. For sufficiently large data sets there will be as many significant components as the number of pixels in the image. So, not very useful.

We can run the code from the previous post and quickly look at the VAE analysis on this data set. Shown below are corresponding latent spaces and latent representations.

VAE analysis on the cards data set for t = 0.1 and varying orientation and shear strength. From left to right (a = 12, s = 1), (a = 12, s = 10), (a = 120, s = 1), and (a = 120, s = 10). The color scheme is chosen to represent the hand. Images by the authors.
VAE analysis on the cards data set for t = 0.1 and varying orientation and shear strength. From left to right (a = 12, s = 1), (a = 12, s = 10), (a = 120, s = 1), and (a = 120, s = 10). The color scheme is chosen to represent the hand. Images by the authors.

For small rotational disorder (a = 12), the latent space clearly forms well-separated clusters corresponding to each hand, with the rotation angles and shear components changing smoothly within the corresponding part of the latent space. At the same time, for large a = 120, the distribution in the latent space becomes very complex. We ascribe this behavior to the fact that the latent space representation should contain the "cowlick" defect containing unphysical shapes of objects, due to the fact that the unit circle (SO(2) Lie group) cannot be well represented on a surface. Given that we have several structures, the VAE is trying to orient the objects by angle and separate them by shape and size, giving rise to a very complex structure.

rVAE analysis on the cards data set for t = 0.1 and varying orientation and shear strength. From left to right (a = 12, s = 1), (a = 12, s = 10), (a = 120, s = 1), and (a = 120, s = 10). The color scheme is chosen to represent the hand. Images by the authors.
rVAE analysis on the cards data set for t = 0.1 and varying orientation and shear strength. From left to right (a = 12, s = 1), (a = 12, s = 10), (a = 120, s = 1), and (a = 120, s = 10). The color scheme is chosen to represent the hand. Images by the authors.

In comparison, shown above is the _r_VAE analysis of the same data. In this case, the rotational angle is separated as an independent variable by the construction of the latent vector. The emergent behaviors are truly remarkable. For small angle and shear disorder, a = 12, s = 1, the hands form four very distinct groups in the latent space of the system. If the shear disorder is increased, a = 12, s = 10, these clusters elongate and form clearly visible 1D manifolds in the latent space. When we increase the orientation disorder, a = 120, we start to see more than one cluster for each hand, e.g. there is a cluster corresponding to two orientations of clubs and spades rotated by 120 degrees. In other words, _r_VAE with the chosen priors can compensate for a small deviation of rotational angle but tends to discover the degenerate minima. As a simple exercise to understand this behavior, consider the structural similarity (or correlation) between club images as they are rotated around the central axis. The correlation will have pronounced minima for 120 and 240 degrees, where the main leaves overlap and only positions of the small tail differ. Here for small shear distortion, s = 1, we have well-defined clusters corresponding to the rotational variants of the hands, and for large shear distortion, s = 10, they again become elongated manifolds. Note that in some sense we have disentangled the representation of the data, which in this case is shear. Furthermore, it is curious to compare the shape of the manifold for diamonds and not diamonds – the former are more elongated and diffuse exactly due to the presence of affine transform that is equivalent to 90-degree rotation!

jVAE analysis for a = 12, t = 0.1, s = 10. Images by the authors.
jVAE analysis for a = 12, t = 0.1, s = 10. Images by the authors.

Now, enter the _j_VAE. You are welcome to play with the different model parameters in the accompanying notebook, and here we are only going to show values for two examples. For small disorder, a = 12 and s = 10, the simple _j_VAE (without rotational latent variable) does an excellent job separating the objects in an unsupervised manner. Notice how the first latent manifold traversal captures information about the orientational disorder (+ shear deformation), whereas the second latent manifold traversal captures off-the-center displacements (+ shear deformation).

jVAE analysis for a = 120, t = 0.1, s = 10. Images by the authors.
jVAE analysis for a = 120, t = 0.1, s = 10. Images by the authors.

For the large disorder, the simple _j_VAE fails to separate classes. As one can see from the figure above, each inferred "class" contains a mixture of diamonds, spades, clubs, and hearts. We note that unlike (r)VAE, j(r)VAE tends to be quite sensitive to the initialization of model parameters (i.e., initialization of weights in decoder and encoder neural networks and shuffling of the training set), although on the qualitative level the results remain largely unchanged. Feel free to experiment with different seeds in the notebook and see how they affect the results.

jrVAE analysis for a = 120, t = 0.1, s = 10. Images by the authors.
jrVAE analysis for a = 120, t = 0.1, s = 10. Images by the authors.

In comparison, _jr_VAE that has three "special" continuous latent variables designated to absorb rotational and translational disorder shows much better separation for the same disorder parameters. Here, both clubs and diamonds got their own latent space, whereas hearts and spades share another (here we would like to add that simply adding three (or more) more standard continuous latent variables to _j_VAE won’t resolve its inability to separate classes for the large disorder).

Encoded latent space for the jVAE (left) and jrVAE (right) for a = 120, t = 0.1, s = 10. Images by the authors.
Encoded latent space for the jVAE (left) and jrVAE (right) for a = 120, t = 0.1, s = 10. Images by the authors.

Finally, let’s see how well our _jr_VAE classes correspond to the original classes. Note that the order of inferred classes is random (depends on the initialization) and doesn’t correspond to the order of the ground truth labels. We can see that overall, it did quite a remarkable job separating different classes without any supervision.

Correlation matrix (inferred vs. ground truth) for jrVAE for a = 120, t = 0.1, s = 10. Image by the authors.
Correlation matrix (inferred vs. ground truth) for jrVAE for a = 120, t = 0.1, s = 10. Image by the authors.

Note that the performance can be fine-tuned further; however, in this case, we would be tuning it towards the "known" answer for this particular system, whereas unknown systems are, well, unknown.

This summarizes the introduction of the _j_VAE and _jr_VAE. Feel free to play with the notebooks and apply them to your data sets. The authors use VAE and its extensions in their research on atomically resolved and mesoscopic imaging in scanning probe and electron microscopies, but these methods can be applied to a much broader variety of optical, chemical, and other imaging, as well as across other computer science domains. Please also check out our AtomAI software package for applying this and other deep/machine learning tools to scientific imaging.

Finally, in the scientific world, we acknowledge the sponsor that funded this research. This effort was performed and supported at Oak Ridge National Laboratory’s Center for Nanophase Materials Sciences (CNMS), a U.S. Department of Energy, Office of Science User Facility. You can take a virtual walk through it using this link and tell us if you want to know more.

The executable Google Colab notebook is available here

References

  1. Ziatdinov, M.; Dyck, O.; Maksov, A.; Li, X. F.; San, X. H.; Xiao, K.; Unocic, R. R.; Vasudevan, R.; Jesse, S.; Kalinin, S. V., Deep Learning of Atomically Resolved Scanning Transmission Electron Microscopy Images: Chemical Identification and Tracking Local Transformations. ACS Nano 2017, 11 (12), 12742–12752.
  2. Maksov, A.; Dyck, O.; Wang, K.; Xiao, K.; Geohegan, D. B.; Sumpter, B. G.; Vasudevan, R. K.; Jesse, S.; Kalinin, S. V.; Ziatdinov, M., Deep learning analysis of defect and phase evolution during electron beam-induced transformations in WS2. npj Comput. Mater. 2019, 5, 12.
  3. Ziatdinov, M.; Dyck, O.; Li, X.; Sumpter, B. G.; Jesse, S.; Vasudevan, R. K.; Kalinin, S. V., Building and exploring libraries of atomic defects in graphene: Scanning transmission electron and scanning tunneling microscopy study. Sci. Adv. 2019, 5 (9), eaaw8989.
  4. Jang, E.; Gu, S.; Poole, B., Categorical reparameterization with Gumbel-Softmax. arXiv preprint arXiv:1611.01144 2016.
  5. Dupont, E., Learning Disentangled Continuous and Discrete Representations. arXiv preprint arXiv:1804.00104.

Related Articles