Linear Regression to GPT in Seven Steps

How the humble prediction method shows us the way to Generative AI

Devesh Rajadhyax
Towards Data Science

--

There are numerous writings about Generative AI. There are essays dedicated to its applications, ethical and moral issues, and its risk to human society. If you want to understand the technology itself, there is a range of available material from the original research papers to introductory articles and videos. Depending on your current level and interest, you can find the right resources for study.

This article is written for a specific class of readers. These readers have studied machine learning, though not as a major subject. They are aware that prediction and classification are the two main use cases of ML that cover most of its applications. They have also studied common machine learning algorithms for prediction and classification such as linear regression, logistic regression, support vector machines, decision trees and a bit of neural networks. They might have coded a few small projects in python using libraries such as scikit-learn and even used some pre-trained TensorFlow models like ResNet. I think a lot of students and professionals will be able to relate to this description.

For these readers it is natural to wonder: is generative AI a new type of ML use case? It certainly seems different from both prediction and classification. There is enough jargon going around to discourage venturing into understanding generative AI. Terms such as transformers, multi-head attention, large language models, foundational models, sequence to sequence, and prompt engineering can easily persuade you that this is a very different world than the cozy prediction-classification one we used to know.

The message of this article is that generative AI is just a special case of prediction. If you fit the description of ML enthusiasts I gave earlier, then you can understand the basic working of generative AI in seven simple steps. I start with linear regression (LinReg), the ML technique that everyone knows. In this article I have treated a particular branch of generative AI called Large Language Models (LLM), largely because the wildly popular ChatGPT belongs to this branch.

Image by Rajashree Rajadhyax

Step 1: Prediction by Linear Regression

LinReg identifies the best line that represents the given data points. Once this line is found, it is used to predict the output for a new input.

Image by author

We can write the LinReg model as a mathematical function. Written in an easy to understand way, it looks like:

new output = Line Function (new input)

We can also draw a schematic for it:

This is prediction at the most basic level. A LinReg model ‘learns’ the best line and uses it for prediction.

Step 2: Prediction by Neural Networks

You can use LinReg only if you know that the data will fit a line. This is usually easy to do for single-input-single-output problems. We can simply draw a plot and inspect it visually. But in most real life problems, there are multiple inputs. We cannot visualize such a plot.

In addition, real world data does not always follow a linear path. Many times, the best fitting shape is non-linear. See the below plot:

Image by author

Learning such a function in multi-dimensional data is not possible by simple methods like LinReg. This is where neural networks (NN) come in. NNs do not require us to decide which function they should learn. They find it themselves and then go on to learn the function, however complex it may be. Once a NN learns the complex, multi-input function, they use this function for prediction.

We can again write a similar equation, but with a change. Our inputs are now many, so we have to represent them by a vector. In fact, the output can also be many and we will use a vector for them too.

output vector = NN Function (input vector)

We will draw the schematic of this new, more powerful prediction:

Image by author

Step 3: Prediction on a word

Now consider that we have a problem in which the input to the NN is a word from some language. Neural networks can only accept numbers and vectors. To suit this, words are converted into vectors. You can imagine them to be the residents of a many dimensional space, where related words are close to each other. For example, the vector for ‘Java’ will be close to other vectors for programming techniques; but it will also be close to the vectors for places in the Far-east, such as Sumatra.

A (very imaginary) word embedding, image by author

Such a set of vectors corresponding to words in a language is called an ‘Embedding’. There are many methods to create the embeddings; Word2Vec and GloVe being two popular examples. Typical sizes of such embeddings are 256, 512 or 1024.

Once we have vectors for words, we can use the NN for prediction on them. But what can we achieve by prediction on words? We can do a lot of things. We can translate a word to another language, get a synonym for the word or find its past tense. The equation and schematic for this prediction will look very similar to Step 3.

output word embedding = NN Function (input word embedding)
Image by author

Step 4: Prediction for Naive Translation

In a translation problem, the input is a sentence in one language and the output is a sentence in another language. How can we implement translation using what we already know about prediction on a word? Here we take a naive approach to translation. We convert each word of the input sentence to its equivalent in another language. Of course, real translation will not work like this; but for this step, pretend as if it will.

This time we will draw the schematic first:

Image by author

The equation for the first word will be:

NN Translation Function (Embeddings for word. no 1 in input sentence) 
= Embedding for word no.1 in output sentence

We can similarly write the equations for the other words.

The neural network used here has learned the translation function by looking at many examples of word pairs, one from each language. We are using one such NN for each word.

We thus have a translation system using prediction. I have already admitted that this is a naive approach to translation. What are the additions that will make it work in the real world? We will see that in the next two steps.

Step 5: Prediction with Context

The first problem with the naive approach is that the translation of one word depends on other words in the sentences. As an example, consider the following English to Hindi translation:

Input (English): ‘Ashok sent a letter to Sagar’

Output (Hindi): ’Ashok ne Sagar ko khat bheja’.

The word ‘sent’ is translated as ‘bheja’ in the output. However, if the input sentence is:

Input (English): ‘Ashok sent sweets to Sagar’

Then the same word is translated as ‘bheji’.

Output (Hindi): ’Ashok ne Sagar ko mithai bheji’.

Thus it is necessary to add the context from other other words in the sentence while predicting the output. We will draw the schematic only for one word:

Image by author

There are many methods to generate the context. The most powerful and state-of-the-art is called ‘attention’. The neural networks that use attention for context generation are called ‘transformers’. Bert and GPT are examples of transformers.

We now have a kind of prediction that uses context. We can write the equation as:

NN Translation Function (Embeddings for word. no 1 in input sentence 
+ context from other words in input sentence)
= Embedding for word no.1 in output sentence

Step 6: Prediction of Next Word

We will now handle the second problem in the naive approach to translation. Translation is not a one-to-one mapping of words. See the example from the previous step:

Input (English): ‘Ashok sent a letter to Sagar’

Output (Hindi): ’Ashok ne Sagar ko khat bheja’.

You will notice that the order of the words is different, and there is no equivalent of the word ‘a’ in input, or the word ‘ne’ in the output. Our one-NN-per-word approach will not work in this case. In fact it will not work in most cases.

Fortunately there is a better method available. After giving the input sentence to an NN, we ask it to predict just one word, the word that will be the first word of the output sentence. We can represent this as:

Image by author

In our letter sending example, we can write this as:

NN Translation Function (Embeddings for 'Ashok sent a letter to Sagar' 
+ context from input sentence)
= Embedding for 'Ashok'

To get the second word in output, we change our input to:

Input = Embeddings for input sentence + Embedding for first word in output

We have to also include this new input in the context:

Context = Context from input sentence + context from first word in output

The NN will then predict the next (second) word in the sentence:

This can be written as:

NN Translation Function 
(Embeddings for 'Ashok send a letter to Sagar + Embedding for 'Ashok',
Context for input sentence + context for 'Ashok')
= Embedding for 'ne'

We can continue this process till the NN predicts the embedding for ‘.’, in other words, it signals that the output has ended.

We have thus reduced the translation problem to the ‘predict the next word’ problem. In the next step we will see how this approach to translation leads us to a more generic Generative AI capability.

Step 7: Prediction for Generative AI

The ‘prediction of next word’ method is not limited to translation. The NN can be trained to predict the next word in such a way that the output is the answer to a question, or an essay on the topic you specify. Imagine that the input to such Generative NN is the sentence:

‘Write an article on the effect of global warming on the Himalayan glaciers’.

The input to the generative model is called a ‘prompt’. The GNN predicts the first word of the article, and then goes on predicting further words till it generates a whole nice essay. This is what the Large Language Models such as ChatGPT do. As you can imagine, the internals of such models are much more complex than what I have described here. But they contain the same essential components that we have seen: embeddings, attention and a next-word-prediction NN.

Image by author

Apart from translation and content generation, LLMs can also answer questions, plan travel and do many other wonderful things. The basic method used in all these tasks remains prediction of the next word.

Summary

We started from the basic prediction technique using LinReg. We made the prediction problem more complex by adding vectors and then word embeddings. We learned how to apply prediction to language by solving the naive translation problem. While enhancing the naive method to handle real translation we got oriented with the essential elements of LLMs: context and next word prediction. We realized that Large Language Models are all about prediction on text sequences. We got familiar with important terms such as prompt, embeddings, attention and transformers.

The sequences don’t have to be text really. We can have any sequence such as images or sounds. The message of this article, which thus covers the whole gamut of generative AI, is:

Generative AI is prediction on sequences using neural networks.

--

--

Author of 'Decoding GPT', AI startup founder, self-taught in machine learning