The world’s leading publication for data science, AI, and ML professionals.

Hierarchical Transformers – part 1

More efficient language models

Image from unsplash.com
Image from unsplash.com

In this article, we look at hierarchical transformers: what they are, how they work, how they differ from standard Transformers and what are their benefits. Let’s get started.

What Are Hierarchical Transformers

The "hierarchical Transformer" refers to a transformer architecture that operates on multiple scales or resolutions of the input sequence.

Why do we need hierarchical transformers?

Standard Transformers as amazing that they are are very time consuming. The attention mechanism inside Transformers takes O(n²) to run on an input sequence of n tokens. This means Transformers are not practical for long sequences. One way to address this inefficiencies, is to have hierarchical transformers. Is it the only way? no! another way is to have improve efficiency of attention mechanism. But this is a topic for another day.

How does hierarchy in Transformers help?

Hierarchical Transformers enable the model to operate on different levels of the input e.g. words, sentences, paragraphs etc. This matches how humans process text too. This forces attention over different hierarchies to model relationships between entities at different granularities.

There are many methods for hierarchical transformers; in this article, we strive to intuitively explain one of these methods.

The Hourglass Transformer

The Hourglass [1] network is a joint work by OpenAI, Google Research and University of Warsaw. It is a hierarchical autoregressive Transformer that takes an input sequence, and forms a hierarchy of sequences from full resolution to smaller and smaller scales; at each scale it processes the sequence within that resolution, and finally it expands the sequence back to the full size. This makes the model more efficient as the shorter sequences are cheaper to process.

Note that in autoregressive Transformers the first and last layer at the very least have to operate on full-scale of the input. This is because the first layer is processing the input so it has to operate at full-scale, and the last layer (since model is autoregressive) is producing an output so has to operate on full-scale again.

Let’s take a look at this architecture. The image below, demonstrates Hourglass:

"Hourglass" architecture - image taken from [1]
"Hourglass" architecture – image taken from [1]

We describe it step by step. We start from the left of above image where the input tokens are depicted as a grey box.

  1. The model first processes the full input sequence using standard Transformer layers. So if the input sequence (shown as grey box titled "input tokens") has L tokens, then they are all passed through standard Transformer layers (depicted in blue and called as pre-vanilla layers). These layers output L embedding vectors one for each token.

The pre-vanilla layers are Transformer layers operating on the full token-level sequence before shortening.

So if the task is "language modeling on text", the input to pre-vanilla layers is a sequence of subword tokens representing the text. If the task is "image generation", the input would be pixel values flattened into a sequence.

Step 1 - hourglass architecture - Image from [1] modified by the author
Step 1 – hourglass architecture – Image from [1] modified by the author
  1. In the second step, the model shortens the sequence from L tokens into fewer tokens. This step is shown as an orange trapezius and the "shortening factor" is stated as sf = k₁. Note sf stands for "shortening factor", and if it is set to k₁ it means every k₁ tokens is merged into 1 token. The shortening happens via using some sort of pooling operations such as average pooling, linear pooling, or attention pooling. We will talk about them soon. The output of this step is L/k₁ tokens. This step is also known as down down-sampling step.
Step 2— hourglass architecture - Image from [1] modified by the author
Step 2— hourglass architecture – Image from [1] modified by the author
  1. The shortened sequence is processed by more Transformer layers called shortened layers. In the image they are shown by yellow boxes. These layers output updated embeddings for the tokens.
Step 3— hourglass architecture - Image from [1] modified by the author
Step 3— hourglass architecture – Image from [1] modified by the author
  1. If there is more shortening left to do, we simply repeat the process. In the image below, in the second orange trapezius, we are shortening the input sequence by a factor of sf = k₂. This will merge every k₂ tokens into 1 token, and so will output L/(k₁.k₂) tokens. The outputted tokens get passed through more shortened layers shown in light yellow.
Step 4— hourglass architecture - Image from [1] modified by the author
Step 4— hourglass architecture – Image from [1] modified by the author

By now, we are in the middle of the architecture…

From here, the up-sampling starts!

  1. Up-sampling layers are used to expand the shortest sequence back to the original full resolution. Since we had two down-samplings (one from L tokens to L/k₁ and the second from _L/k₁_tokens to L/(k₁.k₂) tokens), we will perform two up-sampling to bring back the number of tokens to L tokens. The first up-sampling is the following:
Step 5 (first upsampling) - hourglass architecture - Image from [1] modified by the author
Step 5 (first upsampling) – hourglass architecture – Image from [1] modified by the author

And the second upsampling is the following:

Step 5 (second upsampling)— hourglass architecture - Image from [1] modified by the author
Step 5 (second upsampling)— hourglass architecture – Image from [1] modified by the author

After every up-sampling operation, we pass the token embeddings through Transformer layers. In the image they are called either shortened layers or post-vanilla layers.

step 5 - upsampling involves Transformer layers either as shortened layers or as post vanilla. - Image from [1] modified by the author
step 5 – upsampling involves Transformer layers either as shortened layers or as post vanilla. – Image from [1] modified by the author

The last upsampling passes embeddings through post vanilla layers and that outputs the embedding of the next predicted token.

Down-sampling Step

The downsampling step (also known as shortening step in the paper) shortens the input sequence to fewer tokens. This step is done by merging tokens into groups using various pooling operations such as: 1) average pooling, 2) linear pooling and 3) attention pooling.

1)Average pooling: On a high level, average pooling merges k adjacent token embeddings into a single embedding by taking the average. This method has two hyper- parameters: "pool size" and "stride".

"Pool size" is the size of the window, and "stride" is how many steps the window size would move forward each time. For example in a sequence of "ABCDEF" and with pool size = stride = 2, first two tokens make up the first window and the window moves forward 2 tokens at a time. So the windows will be: [AB], [CD], [EF].

The paper sets both "pool size" and "stride" to a same number and call it "shortening factor (sf)". Let’s see this in an example:

If the input sequence is [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10] and hyper-parameters pool size=stride=3, then average pooling divides the sequence into chunks of size 3 i.e. [x1, x2, x3], [x4, x5, x6], [x7, x8, x9], [x10] and averages token embeddings in each window to get a single embedding i.e.:

e1 = mean(x1, x2, x3)

e2 = mean(x4, x5, x6)

e3 = mean(x7, x8, x9)

e4 = x10

Therefore the shortened sequence will be [e1, e2, e3, e4] . Note the length of the shortened sequence is input length/sf = 10/3 = 3.

2) Linear pooling:

This method sets a stride = k and divides the input sequence of length L tokens to L/k windows. Each window has k tokens where each token has an embedding vector of dimension let’s say d. The method then flattens each window into k*d dimensional vector, and forms the following matrix:

linear pooling - part 1 - image by the author
linear pooling – part 1 – image by the author

Let’s say the input sequence is [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10] and stride=3, then we have the following windows: [x1, x2, x3] , [x4, x5, x6], [x7, x8, x9], [x10], and if every token has a 100-dimensional embedding vector, then the above matrix becomes:

linear pooling - part 2 - image by the author
linear pooling – part 2 – image by the author

Note that we have now reduced length of sequence from 10 to 4, but now each new token has dimensionality of 300 instead of 100! To bring it back to the original dimensionality, linear pooling projects them to a 100-dim space using a learned linear transformation. A linear transformation is a 300*100 matrix that is learned from the data.

3) Attention pooling:

This method starts similar to above two methods: the input sequence is divided into windows of size k, then an attention is applied within each window, this allows tokens inside each window to attend to each other. At the end the embeddings produced by attention of tokens in each window are summed together. After this step, a feedforward layer is applied on the chunk embeddings.

Attention pooling - image by the author
Attention pooling – image by the author

Up-sampling Step

The up-sampling step expands the shortened sequence back to the original full length. There are two simple ways to do upsampling:

  1. repeat expansion: The repeat expansion just simply copies each embedding multiple times. This is computationally very efficient.
  2. linear expansion. The linear expansion projects into the higher dimension then expands it. For example if the shortened sequence is [e1, e2, e3, e4] and the sf=k=3, then each embedding is linearly projected to a vector of size k * d, where d is the original embedding dimension. The projection weight matrix is learnable and is trained end-to-end along with the full model.

To maintain fidelity to the original input sequence, a residual connection (shown as red dotted line) adds the input sequence from before shortening to the upsampled sequence. Think of this as a way to acoustic context through the multiple shorten-expand cycles.

residual connections added for maintaining fidelity - image from [1] modified by the author
residual connections added for maintaining fidelity – image from [1] modified by the author

There is a more advanced upsampling method too that is called attention upsampling and it works as following:

If the shortened sequence is [e1, e2, e3, e4] and sf=k=3, then first a linear or repeat upsampling is applied to expand this to the original length. This gives [u1, u2, …, u12].

Let the embeddings before shortening be [x1, x2, …, x12]; these are added to the upsampled embeddings via residual connection (the red dotted lines) and form [u1+x1, …, u12+x12]. Now the self-attention mechanism is applied over this sequence, where:

  • Queries (Q) come from the summed embeddings [u1+x1, …, u12+x12].
  • Keys (K) and Values (V) come from the upsampled embeddings [u1, u2, …, u12].

This updates the summed embeddings and that would be the final output. The attention over the upsampled sequence helps amplify relevant parts and combine it with the pre-shortened context.

Experiments

They[1] evaluated their model on language modeling using enwik8, and image generation using ImageNet-32/64. They showed perplexity was improved on Enwik8 dataset by 10–15% over Transformer-XL baseline; and they achieved new State-of-the-art for autoregressive Transformer models on ImageNet-32 image generation.

parameters of the experiment - image by the author
parameters of the experiment – image by the author

This concludes the hourglass network. In next article, we will look into other hierarchical transformer models.

Summary

In this article we reviewed a hierarchical architecture for transformers which improves efficiency and reduces memory usage in processing long sequences. The architecture is called Hourglass [1] and it consists of two major components: 1) shortening or downsampling, and 2) upsampling. The shortening is done by merging tokens into groups using pooling operations like average pooling or linear pooling. The sequence length is reduced by a shorten factor k in the middle layers of the network. The upsampling component uses methods like linear upsampling or attention upsampling to expand the shortened sequences back to the original length. Hourglass models improve perplexity compared to baseline Transformers like Transformer-XL [1]. In fact, it achieves new SOTA for Transformer models on ImageNet32 image generation task.


If you have any questions or suggestions, feel free to reach out to me: Email: [email protected] LinkedIn: https://www.linkedin.com/in/minaghashami/

References

  1. Hierarchical Transformers Are More Efficient Language Models
  2. An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification

Related Articles