Understanding Google’s Switch Transformer

How Google created the world’s largest language model without the cost

Jonathan Davis
Towards Data Science

--

Photo by Jonathan on Unsplash

When GPT-3 was introduced by OpenAI in May 2020 the news spread like wildfire. Not only amongst the AI community but even within the mainstream media there were headlines like “A robot wrote this article” and “Have you read something written by GPT-3?”. People were excited!

Before GPT-3, the largest language model was Turing-NLG with 17 billion parameters, released in February 2020. Later that year, OpenAI blew this out the park with 175 billion parameters. Suddenly, there was a language model that could produce content that was often indistinguishable from humans.

At the start of 2021, Google released a paper titled “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”. Although this did receive some minor press coverage outside of the AI community, it was not close to GPT-3, despite creating a model with almost ten times as many parameters!

This lack of coverage was not a surprise. Google’s 1.6 trillion parameter model was not state-of-the-art in terms of performance, nor did it have a neat API front end for demoing the future of AI. However, that does not mean that the motivation or results are not significant for future AI and machine learning research.

In this article we will try to understand the motivation behind the Switch Transformer, how it works, and why the results are so significant for the future of machine learning research and application.

Motivation

It has been shown empirically that the performance of language models increases as a power-law with the number of parameters (model size), dataset size and computational budget.

However, as these increase, so does the financial cost of training. This has led to the increased popularity of open-source, pre-trained language models such as Google’s BERT which can be finetuned on specific downstream language tasks, such as classification, question answering or summarisation, allowing data science practitioners to benefit from the vast resources at the disposal of companies like Google, Facebook and OpenAI.

However, GPT-3 proved that even large corporations and organizations may start to struggle if these models continue to grow. GPT-3 cost an alleged $12 million to train. This included the training of close to 5,000 versions of the model, using almost 10,000 days of GPU time.

It is also worth noting that the costs are not only economical. It is estimated that this training produced 78,000 pounds of CO2 emissions. Similar to what an American adult produces in two years!

Photo by İsmail Enes Ayhan on Unsplash

Research by OpenAI suggests that the amount of compute used for training AI models doubles every 3.4 months, which produces worrying forecasts in terms of the economy and the environment.

This is the motivation behind the Switch Transformer, to create larger machine learning models, without the increased computational cost.

Switch Transformer

The Switch Transformer is a switch feed-forward neural network (FFN) layer that replaces the standard FFN layer in the transformer architecture. The key difference is that instead of containing a single FFN, each switch layer contains multiple FFNs, known as experts.

When each token passes through this layer, it first passes through a router function, which then routes the token to a specific FFN expert.

As each token only passes through one expert FFN, the number of floating-point operations (FLOPS) stays equal, whilst the number of parameters increases with the number of experts.

This creates a sparse model, where not every parameter is utilised for every token. In this respect, the Switch Transformer addresses the above motivation, to increase the number of model parameters without increasing the amount of computation, measured in FLOPs.

In a Switch Transformer feed-forward neural network layer, each token passes through a router function which sends it to a single feed-forward neural network, known as an expert. As each token only passes through a single FFN, the computation does not increase, but the number of parameters increases with the number of experts. Image from the original Switch Transformer paper.

Mixture-of-Experts

The concept of using experts to increase the number of model parameters was not novel to the Switch Transformer. A paper describing the Mixture-of-Experts layer was released in 2017, with an almost identical architecture to the Switch Transformer.

The key difference was that the router function would send each token to more than one FFN expert. The authors hypothesised that the router function would not be able to learn how to route tokens unless it could compare at least two of them by routing to k>1 experts.

On the other hand, the Switch Transformer only uses k=1, providing three key benefits:

  1. Router computation is reduced as it is only routing to a single FFN expert.
  2. The batch size of each expert is at least halved (i.e. k=1 instead of k ≥2).
  3. The communication costs between devices (i.e. router to expert) are reduced.

Capacity Factor

The second of these benefits require further analysis. The batch size of each expert is also known as the expert capacity, the number of tokens that the expert has the capacity to process in a given pass.

Ideally, this would be equal to the number of tokens, divided by the number of experts, known as a capacity factor of 1. This way, no expert capacity is wasted in a given step.

However, this assumes that the router function allocates the tokens equally across the experts. In practice, some expert FFNs will overflow, resulting in certain tokens not being processed by an FFN expert in that step.

In order to avoid overflow, resulting in tokens not being processed by any FFN during that step, the capacity factor must be increased. This increases the computation and communication costs, so an auxiliary loss penalises unbalanced routing. Image from the original Switch Transformer paper.

The capacity factor can be increased, but this will result in some experts having unused capacity, increasing computation and communication costs.

As a compromise, the authors add an auxiliary loss to the overall loss functions that penalises the unbalanced routing of tokens by the router function. They found empirically that a capacity factor of 1.25 resulted in the best performance.

Results

In order to measure the performance of the Switch Transformer, they trained several models on the Colossal Clean Crawled Corpus (C4), used the T5 language model as a benchmark, and compared the negative log perplexity.

Considering the motivation, to increase the number of parameters (and hence performance) without increasing the computation required, the models they trained were FLOP-matched to a T5 equivalent, i.e. the amount of computation per token was kept equal.

They found the following:

  • After 100,000 steps, the Switch Transformer model has a greater negative loss perplexity than the FLOP-matched T5 equivalent.
  • The Switch Transformer reached a quality threshold (Neg. log perp.=-1.495) quicker than the T5 equivalent. In the case of T5-Base, this threshold was never reached!
  • The Switch Transformer was able to process more examples per second.

Scaling Properties

As well as analysing overall performance, the authors also looked at scaling properties during pre-training. That is, given an infinite computational budget, how is it best to scale the model?

Three different dimensions were analysed to understand the model scaling properties: steps, time and against a larger dense model.

Step scaling: For a fixed number of training steps, the Switch Transformer will outperform the FLOP-matched T5 model. Increasing the number of experts will further improve performance, without increasing the number of FLOPS. The Switch Transformer model with 64 experts achieves in 60k steps the same performance as the T5-Base model in 450k steps, equivalent to a 7.5x speedup.

Step scaling of T5-base compared to FLOP-matched equivalent Switch Transformer models, with varying numbers of experts. Image from the original Switch Transformer paper.

Time Scaling: Intuitively, the time scaling should be equivalent to the step scaling. However, additional communication costs across devices and the computation of the router function mean this needs to be explicitly analysed. Results showed that for a fixed training time, the Switch Transformer outperformed the FLOP-matched equivalent T5-Base model. A 7x speedup was observed between the T5-Base model and the 64 expert Switch Transformer model.

Larger Dense Model: The authors finally considered the scenario where, to match the performance of the Switch Transformer, the T5-Large model is used instead of T5-Base. However, they showed that this resulted in a 3.5x increase in FLOPS per token, and a FLOP-matched Switch Transformer model would outperform this.

These scaling results show that, for any available increase in computation, a larger Switch Transformer model will outperform a larger dense model. Considering the motivation to reduce the computational footprint of language models, this is an extremely significant result.

Downstream Results

To measure downstream performance T5–Base (223M parameters and 124B FLOPS) and T5-Large (739M parameters and 425B FLOPS) were compared to FLOP matched Switch-Base (7.4B parameter) and Switch-Large (26.3B parameters) models.

These models were compared for 11 different language tasks, spanning classification, questions answering and summarisation, amongst other tasks. The Switch Transformer models outperform the FLOP-matched equivalent T5 models in all tasks, with the exception of ARC.

Multilingual Learning

As a final test of performance, the authors measure the quality of the model whilst pre-training on 101 different languages. The Switch-Base model has a greater negative log perplexity than T5-Base in all languages and an average training speedup of 5x was observed.

A Trillion Parameter Model

Towards the end of the paper, the authors address the design and training of two large Switch Transformer models, Switch-XXL and Switch-C, with 395 billion and 1.571 trillion parameters respectively.

The Switch-XXL model was FLOP matched to the T5-XXL model. However, due to the size of the Switch-C model, the architecture size (layer size, depth, number of attention heads, etc.), and therefore FLOPS, was decreased.

Photo by Kaffeebart on Unsplash

This might be part of the reason why the negative log perplexity of the Switch-XXL model was greater than Switch-C, indicating that there are diminishing returns in increasing the number of experts, especially at the expense of other model dimensions. However, both models outperformed T5-XXL after pre-training.

These models were also measured on downstream tasks. However, none of these results were state-of-the-art.

This is perhaps the reason why the Switch Transformer did not get the same press exposure as GPT-3. Although a trillion parameter model is impressive, it also needs to be backed by an impressive performance.

However, this does not take away from the significance of the research. To date, increasing the size of dense language models has been the path of least resistance to producing state-of-the-art models; but it is clear that this is not economically or environmentally sustainable.

Google has shown that it is possible to create innovative model architectures that can improve model performance without increasing computational costs, and this is something that the data science and AI community will surely be seeing more of in the near future.

--

--