Why “Attention” in Transformers are Special?
👷♂️ Software Architecture Series — Part 35.
We have seen a dramatic rise in the capabilities of machine learning language models since 2017 onwards, all thanks to the paper “Attention is all you need”. But what was special about this proposed “attention mechanism” which acted as a catalyst in the growth of language models ever since?
To understand the significance of “Attention Mechanism”, which are the core part of a LLM nowadays, we take a use case. Say, we want to develop a translation model which can translate text from one language to another. Language translation is not a straightforward play. One requires to understand the context and frame the words accordingly. A plain word by word translation may look a bit awkward. For example:
French: Le battage médiatique quotidien autour des modèles d’IA me semble assez écrasant.
Word by Word Translation:
Le battage — The Hype
Médiatique — media
Quotidien — daily
Autour — around
Des — some
Modèles — models
d’IA — AI
me — me
semble — seems
assez — enough
écrasant — crushing
Actual English Interpretation: The daily media hype around AI models seems pretty overwhelming to me.
To accomplish the task above, it is common to use a deep learning neural network with encode-decoder modules attached. Encoder will read and process the text, and decoder will output the translated text.
RNN: A brief Overview
Earlier, recurrent neural networks (RNNs) were the most preferrable choice for implementing language translation using encoder-decoder architecture. However, the RNN models had a shortcoming, which inspired the design of attention mechanisms in the now famous transformer architecture. Hence, it would be worth to a take sneak peek into RNN’s shortcoming first to build up come ‘context’.
RNN based encoder-decoder architecture models process text sequences one step (word) at a time, maintaining temporally linked hidden state that gets updated at each step (word). The hidden state is used to capture information from previous steps (words). Thus, the final hidden state serves as a compressed representation of the entire text sequence, often represented as a context vector, which serves as the initial hidden state for the decoder. The decoder then generates output sequence (prediction) autoregressively by conditioning each prediction on the previous hidden state. In simpler terms, each predicted word is fed back into the decoder to generate the next word.
And herein lies the shortcoming which we were talking about. The RNN model can’t directly access earlier hidden states from the encoder during the decoding phase. It has to solely rely on the current hidden state, which encapsulates all relevant information. Since, neural networks use gradient descent and gradients diminish during backpropagation through time, there is a strong possibility of a loss of context, especially in complex sentences, where the model has to learn long-range dependencies. As a result, important words appearing earlier in the sequence may be forgotten by the time the decoder needs them. Moreover, since this process is sequential which leads to O(n) time complexity for sequence length n, the advantage of leveraging parallel computation is out of the box.
So how did the Transformer architecture overcome this limitation?
The real pain-point was the single context vector or the final hidden state of the encoder phase, which acted as the sole entry point to the decoder phase. The transformers, through their attention mechanism, allow the decoder to access all encoder hidden states through a learned alignment model. At each step, attention weight relative to the encoder hidden state is computed which indicates the relevance of encoder state. This step is called “Self-attention” mechanism.
Self-attention mechanism allows each position in the text sequence to consider the relevancy of other positions. Let us consider the sentence:
“The cat chased the mouse because it was hungry.”
While processing the above text sentence, the model should be able to understand the word ‘it’ in the above sentence is referring to the cat and not the mouse. In a transformer-based architecture, it is the attention mechanism which makes that determination. It is quite unlike the earlier RNN based architecture, which was unidirectional in flow. The attention mechanism helps the decoder to retroactively retrieve encoder states from any position. But how does the decoder learn to pay attention to relevant states.
The attention mechanism involves two main steps:
1. Relevance scoring: Scoring relevancy of previous tokens with respect to the current token being processed. Higher scores indicate stronger semantic or syntactic relationships between tokens. For example, in the sentence “ The cat chased the mouse because it was hungry “ processing “ chased “ yields high scores with “cat” (subject) and “rat” (object), reflecting verb-argument dependencies.
One of the popular ways of computing the scores is by using the scaled dot-product attention mechanism. It computes contextual representations by dynamically weighting input tokens through three core components — queries (Q), keys (K), and values (V) — and a critical scaling operation.
A Query (Q) represents the token seeking contextual information, Key (K) represents all tokens in the sequence that the Query token can attend to and Value (V) Contains the actual content of the tokens, which will be weighted and combined based on attention scores.
Once the attention score is calculated using a dot-product similarity, the score needs to be scaled so that the variance of the dot-product values remains controlled and do not unnecessarily produce extremely large values. The scaled attention scores are passed through a softmax function to convert them into probabilities. This function ensures that all attention scores sum to 1, making it easier for the model to distribute focus across multiple tokens.
Let us simplify the above sentence and see this process in action: Input Sentence: The cat chased the mouse.
Let us assume the attention score between “chased” (ith Query) and other words (Keys, ranging from j to T) is as follows: [0.58, 0.94, 1.85, 0.89, 1.45]. Now, we scale the score around the given dimensionality of embedding space.
If we approximate above attention scores for a 3-dimensional embedding space, the revised values which we get are [0.33, 0.54, 1.07, 0.51, 0.84]. To convert these scores into probabilities, we use a softmax function. These probabilities determine how much weight each Value vector (V) should have in the final output. This tells us the likelihood of each word being the next word in the sentence. Among all these probabilities, one word(token) will have the highest probability (greedy decoding).
Say, once we apply softmax function, the calculations w.r.t each word comes as follows: [0.11, 0.14, 0.27,0.13, 0.21]. This means that the word “chased” attends 27% to itself, 21% to “mouse”, and so on.
This process can be independently and parallelly repeated for all words in the sentence to form their new contextual representations.
2. Contextual Integration: The above calculated scores are then used to compute a weighted sum of the value (V) vectors, which helps to form the output representation for the token, incorporating relevant information from all input tokens.
The final resulting vector corresponding to the word “chased” will be a blend of information from the other words in the sentence, weighted by their importance. Instead of a static meaning, relationships are now being captured in the output vector.
No wonder, the attention calculation is the most computationally expensive part of the LLM process. And to make this part more efficient, numerous improvement algorithms are being incorporated to smoothen out the process and make the output more relevant. The approach discussed so far computes only one set of attention weights, meaning it captures just one type of relationship in a sentence. Hence, to add diversification to the model, multi-head attention mechanisms are used.
In multi-head mechanism, several attention mechanisms are applied in parallel to capture different relationships such as subject-verb, object-action, long-distance dependencies, etc. This makes the final output vector richer and more informative.
The original Transformer model had only 6 encoder and decoder layers. Each encoder and decoder layer contained a multi-head self-attention mechanism (8 attention heads per layer). Recent models have hundreds of layers and hundreds of attention heads stacked on top of each other.