Build a Transformer in JAX from scratch

How to write and train your own models

Sergios Karagiannakos
Towards Data Science
7 min readMar 19, 2021

--

Image by Author

In this tutorial, we will explore how to develop a Neural Network (NN) with JAX. And what better model to choose than the Transformer. As JAX is growing in popularity, more and more developer teams are starting to experiment with it and incorporating it into their projects. Despite the fact that it lacks the maturity of Tensorflow…

--

--