Transformer: Self-Attention
Recently, one of my friends used LSTM with PPO to train a robot in a simulation aimed at solving a collection task. With a basic understanding of RNNs and LSTMs—an optimized form of RNN—this made me consider the possibility of integrating Transformer architectures into reinforcement learning tasks.
In many cases, these tasks rely on storing sequences of observations to form trajectories, which are fundamental factors affecting the performance of the policy and the final model. Before diving into this idea, I would like to discuss the Transformer itself and break down its underlying mechanics.
Introduction to the Transformer
The Transformer [arXiv:1706.03762] , introduced in the paper "Attention Is All You Need," revolutionized how natural language processing (NLP) tasks—such as language translation—are handled. And, the most critical innovation is the self-attention mechanism, which allows the model to focus on the most relevant parts of the input sequence. Prior to this, models struggled with long sequences, such as those found in novels or long-duration games, because most were based on RNNs, which process data sequentially, one step at a time. This approach not only limited computational parallelism but also made it difficult to capture long-range dependencies within the sequence.
Now, the Transformer comes to the rescue by not only solving the problem of limited parallelization but also addressing the issue of "memorization" through the self-attention mechanism. While self-attention had been studied prior to the publication of this paper, it wasn’t widely adopted until the Transformer architecture popularized it and demonstrated its effectiveness, sparking a major shift in the field. Let's talk about self-attention in the next section.
Self-Attention
Imagine you are reading a sentence: "Erno Rubik invented the cube, he is intelligent and amazing."
To understand what "he" refers to, you look at other words, and identifies "Erno Rubik" as the most relevant word. This is how self-attention works. Below list the calculation process of self-attention:
- Get inputs (Query, Key, Value)
- Calculate attention scores
- Scaling scores
- Apply softmax function
- Calculate outpus
Step1: Get Inputs (Query, Key, Value)
Query (Q): represents the current element asking for information of its context. It's like saying "what other elements are relevant to me?"
Key (K): represents the label of an element that other elements can query. This is like telling "who I am"
Value (V): represents the actual content of an element. This is the information passed on if this element is deemed relevant.
Derived from the same sequence, the embeddings of each element are typically multiplied by learned weight matrices.
$Q = XW_q$
$K = XW_k$
$V = XW_v$
Where $X$ is the matrix of input embeddings (each row is an element's embedding against other elements)
Step2: Calculate Scores
This determine how much attention each query should pay to each key. The most common way to do this is using the dot product, and obviously, there are optimized calculation for self-attention in frameworks like PyTorch.
For a single query $q_i$ and a key $k_j$, the score is $score_{ij} = q_i \cdot k_j$
Written in matrix form: Scores = $QK^T$
Step3: Scale Scores
Dot products can result in large values which might cause vanishing gradient. To counteract this problem, a specified form to scale is proposed in the paper.
Scaled_Scores = $\dfrac{QK^T}{\sqrt{d_k}}$
The division over $d_k$ is to keep the variance of the dot product to stay around 1 regardless of $d_k$. We can observe variance is proportional to $d_k$.
$Var(Q \cdot K) = Var(\sum_{i=1}^{d_k}Q_iK_i) = \sum_{i=1}^{d_k}Var(Q_iK_i) = d_k$
Step4: Apply Softmax
This simply apply softmax function to have the summation of all elements in a row sum to 1. We write it as below.
Attetion Weights = softmax $(\dfrac{QK^T}{\sqrt{d_k}})$.
Step5: Calculate Output
Finally, the output for each query is calculated as a weighted sum of all the value vectors, using the attention weights.
Attention(Q,K,V) = softmax $(\dfrac{QK^T}{\sqrt{d_k}})V$
import numpy as np
def scaled_dot_product_attention(query, key, value):
"""
Scaled Dot-Product Attention
Args:
query: (seq_len, d_k)
key: (seq_len, d_k)
value: (seq_len, d_v)
Returns:
weights: (seq_len, seq_len)
outputs: (seq_len, d_v)
"""
d_k = query.shape[-1]
print(f"The dimension of query: {query.shape}")
scores = np.matmul(query, key.T) / np.sqrt(d_k)
print(f"Scores: \n {scores}")
weights =
np.exp(scores) / np.sum(np.exp(scores), axis=-1, keepdims=True)
print(f"Weights: \n {weights}")
outputs = np.matmul(weights, value)
print(f"Outputs: \n {outputs}")
return outputs, weightsSelf-Attention Example
If you execute the code it should show something looks like this:
The dimension of query: (3, 4)
Scores:
[[1. 0. 1.]
[0. 1. 1.]
[1. 1. 2.]]
Weights:
[[0.4223188 0.1553624 0.4223188 ]
[0.1553624 0.4223188 0.4223188 ]
[0.21194156 0.21194156 0.57611688]]
Outputs:
[[0.8446376 0.5776812 0.8446376 0.5776812 ]
[0.5776812 0.8446376 0.5776812 0.8446376 ]
[0.78805844 0.78805844 0.78805844 0.78805844]]Self-Attention Result
Conclusion
TL;DR: Self-attention calculates how relevant each input element is to every other element (including itself), and then creates a new representation for each element as a weighted combination of all elements—where the weights reflect these relevance scores. Its strength lies in breaking away from the sequential nature of classical RNNs, introducing a completely new approach—almost like discovering a new mathematical kernel in a higher-dimensional space. Every major breakthrough seems to require stepping outside of existing frameworks and embracing entirely new ways of thinking. That’s what makes research so exciting.
Let's think about how to apply this in our RL task in the following.