How Positional Embeddings work in Self-Attention (code in Pytorch)
How Positional Embeddings work in Self-Attention (code in Pytorch)
If you are reading transformer papers, you may have noticed Positional Embeddings (PE). They may seem reasonable. However, when you try to implement them, it becomes really confusing!
The answer is simple: if you want to implement transformer-related papers, it is very important to get a good grasp of positional embeddings.
It turns out that sinusoidal positional encodings are not enough for computer vision problems. Images are highly structured and we want to incorporate some strong sense of position (order) inside the multi-head self-attention ( MHSA ) block.
To this end, I will introduce some theory as well as my re-implementation of positional embeddings.
The code contains einsum operations. Read my past article if you are not comfortable with it.
Positional encodings vs positional embeddings
In the vanilla transformer , positional encodings are added before the first MHSA block model. Let’s start by clarifying this: positional embeddings are not related to the sinusoidal positional encodings. It’s highly similar to word or patch embeddings, but here we embed the position.
Each position of the sequence will be mapped to a trainable vector of size (dim)
Moreover, positional embeddings are trainable as opposed to encodings that are fixed.
Here is a rough illustration of how this works:
# initialization
pos_emb1D = torch . nn . Parameter ( torch . randn ( max_seq_tokens , dim ))
# during forward pass
input_to_transformer_mhsa = input_embedding + pos_emb1D [: current_seq_tokens , :]
out = transformer ( input_to_transformer_mhsa )
By now you are probably wondering what PE learn. Me too!
Here is a beautiful illustration of the positional embeddings from different NLP models from Wang et Chen 2020 [1]:
Position-wise similarity of multiple position embeddings. Image from Wang et Chen 2020
In short, they visualized the position-wise similarity of different position embeddings. Brighter in the figures denotes higher similarity. Note that larger models such as GPT2 process more tokens (horizontal and vertical axis).
However, we have many reasons to enforce this idea inside MHSA.
How Positional Embeddings emerged inside MHSA
If the PE are not inside the MHSA block, they have to be added to the input representation, as we saw. The main concern is that they will only be available once in the beginning.
The well-known MHSA mechanism encodes no positional information, which makes it permutation equivariant . The latter limits its representational power for computer vision tasks.
Why?
Because images are highly-structured data.
So it would make more sense to come up with...