Transformer Architecture Explained

NLP
Deep Learning
Author

Imad Dabbura

Published

February 14, 2023

Modified

February 14, 2023


Figure 1: The architecture of the vanilla Transformer model (source)

Introduction

Transformer architecture was first introduced in Attention Is All You Need paper in 2017. It outperformed RNN-based models on all NLP related tasks. It has an encoder-decoder architecture that is used in tasks such as Neural Machine Translation. The most common examples of models that use transformer architecture is BERT, which uses encoder-only architecture, and GPT, that uses decoder-only architecture.

The main motivation behind creating the Transformer architecture is to overcome issues that RNN-based models have:

  1. Hard to learn long distance dependencies due to gradient problems (vanishing/exploding). For example, if the last word of the sequence depends on the early words in the sequence, the hidden state by the time it reaches the last word wouldn’t have much of the information of the early words especially as the length of the sequence gets longer. Such models assume linear order of words, which is not the right way to think about it.
  2. These are sequential models, which means we can only start processing \(w_t\) once we finish \(w_{t - 1}\) because it is dependent on the previous hidden state that computed from \(w_{t - 1}\) and \(h_{t - 1}\). Therefore, it is not parallelizable.

In this post, we will implement and explain the main building blocks of the transformer architecture (see figure 1). By the end of this post, we should be able to:

  1. Implement vanilla transformer from scratch, including full encoder-decoder architecture, encoder-only architecture, and decoder-only architecture.
  2. Understand the role of each block.

Building Blocks

Embedding Layer

After the input sequence is tokenized and numericalized, we need to project each token into lower dimension space. Such projection is called embedding and it captures the semantic representation of tokens based on the context the token mostly occurs in.

Attention operation is a permutation equivariant, this means that if we permute the input then the corresponding output will be permuted in exactly the same way. In other words, attention mechanism is not aware of the relative ordering of the tokens. Therefore, we need some way to encode the positions of the tokens in each sequence. This is where positional encoding comes into play. There are two types of encodings:

  • Absolute Positional Encoding: Use token absolute position. Can use either static patterns such as sign function, or learned parameters
  • Relative Positional Encoding: Encode the relative position of tokens. We need to adjust the attention mechanism itself by adding new terms to be used when dot-products are used to encode the relative position between tokens up to maximum relative position.
  • Rotary Encoding: Combine both absolute and relative position of tokens to achieve great results. This can be done by encoding the absolute positions with a rotation matrix that will be multiplied with key and value matrices of each attenetion layer to add the relative position information at every layer.
Code
import torch.nn as nn
import torch.nn.functional as F
import torch
Code
config = {
    "vocab_sz": 1000,
    "block_sz": 8,
    "intermediare_sz": 4 * 64,  # 4x hidden_dim
    "hidden_dropout_prob": 0.2,
    "num_attention_heads": 12,
    "hidden_sz": 64,           # embed_dim / num_attention_head = 768 / 12 = 64
    "num_hidden_layers": 6,
    "embed_dim": 768,
    "num_classes": 2,
    "layer_norm_eps": 1e-12,
}
class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embedding = nn.Embedding(config.vocab_sz, config.embed_dim)
        self.position_embedding = nn.Embedding(
            config.block_sz, config.embed_dim
        )
        self.layer_norm = nn.LayerNorm(
            config.embed_dim, eps=config.layer_norm_eps
        )
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        # X:                   B x T
        # token_embeddings:    B x T x embed_dim
        # position_embeddings: T x embed_dim
        embeddings = self.token_embedding(x) + self.position_embedding(
            torch.arange(x.shape[1])
        )
        embeddings = self.layer_norm(embeddings)
        return self.dropout(embeddings)

Attention

Attention is a communication mechanism that is used by NN model to learn to make predictions by attending to some tokens in the context window (only current and previous tokens for decoder-only architectute). The attention weights, which are learned, are used to construct the weighted average of all the tokens attended to by each token. This will help each token focus on what is important in the context. As a reminder, with attention, there is no notion of space. This means it operates on a set of vectors. This is why we need positional encoding for tokens.

The results of the attention layer would be contextualized embeddings, since the output of the embedding layer is contextless embeddings. This is very useful because we know that the meaning of a word changes according to the context, and embeddings from the embedding layer for a token is the same regardless of its context. For example, the word “bear” has the same embedding vector whether it comes in “teddy bear” or “to bear”.

Self-attention is a type of attention mechanism where the keys and values come from the same source as the queries, which is the input \(x\). Whereas in cross-attention, the queries still get produced from the input \(x\), but the keys and values come from some other external source (encoder module in the case of encoder-decoder architecture).


Figure 2: Scaled Dot-Product Attention (source)

For self-attention, we have:

  • Query matrix \(Q\) (hidden_sz x head_dim): what each token is looking for
  • Key matrix \(K\) (hidden_sz x head_dim): what each token contains
  • Value matrix \(V\) (hidden_sz x head_dim): what each token communicate with

Then,

  • The dot-product of query with all the keys of the tokens give us the affinities. Dot-product is just used as a form of computing similarities. Other form of attention include additive attention.
    • If query and key vectors are aligned -> very high value -> get to know more about that token as opposed to other tokens
    • All the tokens in all positions in B x T matrix produce query/key/value vectors in parallel and independently from each other and no communication is happening
    • Then all queries will be dot-product with all the keys
    • We scale attention by dividing it with \(sqrt(head\_sz)\). This makes it so when input Q,K are unit variance, weights will be unit variance too and softmax will stay diffuse and not saturate too much
  • Finally, we multiply the attention weights with the value matrix \(V\) to get the contextualized embeddings

In equations: \[attn(Q,K,V) = softmax(\frac{QK^T}{\sqrt d_k})V\]

class AttentionHead(nn.Module):
    def __init__(self, config, head_dim, is_decoder=False) -> None:
        super().__init__()
        self.k = nn.Linear(config.embed_dim, head_dim, bias=False)
        self.q = nn.Linear(config.embed_dim, head_dim, bias=False)
        self.v = nn.Linear(config.embed_dim, head_dim, bias=False)
        self.is_decoder = is_decoder
        if self.is_decoder:
            self.register_buffer(
                "mask", torch.tril(torch.ones(config.block_sz, config.block_sz))
        )

    def forward(self, query, key, value):
        # query, key, value are each B x T x embed_dim
        q = self.q(query)
        k = self.k(key)
        v = self.v(value)
        # w is B x T x T
        w = q @ k.transpose(2, 1) / (k.shape[-1] ** 0.5)
        if self.is_decoder:
            w = w.masked_fill(self.mask == 0, -float("inf"))
        w = F.softmax(w, dim=-1)
        # output is B x T x head_dim
        return w @ v

Multi-Head Attention

What we described in the previous section was self-attention mechanism with one-head. Since each attention head focuses on one specific characteristic of the data in terms of similarity such as subject-verb interaction, other heads are needed to focus on other aspects such as adjectives. We can also think of having multiple heads as if each head focuses on one or few other tokens. Remember that all of this is done in parallel and there is no communication between heads. This means that each head has no idea what other heads are doing.


Figure 3: Mutli-Head Attention with several attention layers running in parallel (source)

In multi-head layer, we typically have the head_sz be the result of dividing the hidden_sz (or the embeddind_dim if it is the first layer) by the number of heads.

Once we get all contextualized embeddings from all heads, we concatenate them. Then we pass the output through a projection layer with the same dimension as the input.

class MultiHeadAttention(nn.Module):
    def __init__(self, config, is_decoder=False) -> None:
        super().__init__()
        head_dim = config.embed_dim // config.num_attention_heads
        self.heads = nn.ModuleList(
            [
                AttentionHead(head_dim, config, is_decoder)
                for _ in range(config.num_attention_heads)
            ]
        )
        self.output = nn.Linear(config.embed_dim, config.embed_dim)

    def forward(self, x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.output(x)

Feed-Forward Layer

Because there are no elementwise nonlinearities involved in the calculation of the attention, stacking multiple layers of attention wouldn’t help much because the output would still be linear transformation of the input. As a result, feed-forward NN is added to add such nonlinearities to post-process each output vector from the attention layer. Therefore, each embedding vector is processed independently in the batched sequence, which leads to the position-wise feed-forward layer.

We typically first project the output vector into new space 4x the hidden_sz. Therefore, most of the capacity and memorization is expected to happen in the first layer, which is what gets scaled when the model is scaled up. Then we project it back to the original dimension. We use GELU as the activation function, which is a Gaussian Error Linear Units.

class FeedForwardNN(nn.Module):
    def __init__(self, config):
        super().__init__()
        # intermediate_sz is typically 4 x embed_dim
        self.l1 = nn.Linear(config.embed_dim, config.intermediate_sz)
        self.l2 = nn.Linear(config.intermediate_sz, config.embed_dim)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        return self.dropout(self.l2(F.gelu(self.l1(x))))

Layer Normalization

Layer normalization was introduced in this paper to overcome the main challenges of Batch normalization, which are 1) how do we handles batches with 1 or few examples because we would have infinite variance or unstable training and 2) how do we handle RNNs. The main differences with batch normalization are 1) we don’t have moving averages/standard deviations and 2) we average over the hidden dimesnion(s), so it is indepenedent of the batch size. It has two learnable parameters (scalars): \(\beta\) and \(\gamma\) (see the equation below):

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

It is used as a trick to train complex models, such as Transformer, faster. In our case, we would normalize the hidden vectors to zero mean and unit standard deviation. This trick helps maintain consistent distribution of signals by cutting down uninformative variations in hidden vector values.

There are two arrangements for the layer normalization as illustrated in Figure-4:


Figure 4: Different LayerNorm arrangement (source)

  • Prelayer normalization: Places the layer normalization within the span of skip connections. This arrangement is easier to train.
  • Postlayer normalization: Places the layer normalization in between skip connections. This arrangement is used in the Transformer paper.

Skip Connections

Skip connections help train deeper and more complex models faster as well as avoid the issue of vanishing gradients that deeper networks face. It provides paths for the gradient to flow through back to the input. In our case, we are using skip connections with addition, which means we take a copy of the inputs and added it to the output of a block (involves some computations). If we assume \(y = x + F(x)\), then it is as if we are asking the block to predict \(y - x\). In other words, it means to backpropagate through the identity function, which leads to multiply the gradient of \(y\) by one and retain its value in the earlier layers.

Skip connections help also smooth out the loss landscape (see Figure-5), and make it easier for the gradients to flow back as addition operator split the gradients equally. This means that small changes in the input can still find their way to the output. Additionally, it preserves the original input sequence, which means there is no way for the current word to forget to attend to its position because we always add it back.


Figure 5: The loss surfaces of ResNet-56 with/without skip connections (source)

Dropout


Figure 6: Dropout Neural Net Model. Left: A standard neural net with 2 hidden layers. Right: An example of a thinned net produced by applying dropout to the network on the left. Crossed units have been dropped. (source)

Dropout is a regularization technique that was introduced by Geoffrey Hinton et al. in this paper. On each iteration, we randomly shut down some outputs from the previous layer and don’t use those outputs in both forward propagation and back-propagation. Since the outputs that will be dropped out on each iteration will be random, the learning algorithm will have no idea which neurons will be shut down on every iteration; therefore, force the learning algorithm to spread out the weights and not focus on some specific feattures. Moreover, dropout help improving generalization error by:

  • Since we drop some units on each iteration, this will lead to smaller network which in turns means simpler network (regularization).
  • Can be seen as an approximation to bagging techniques. Each iteration can be viewed as different model since we’re dropping randomly different units on each layer. This means that the error would be the average of errors from all different models (iterations). Therefore, averaging errors from different models especially if those errors are uncorrelated would reduce the overall errors. In the worst case where errors are perfectly correlated, averaging among all models won’t help at all; however, we know that in practice errors have some degree of uncorrelation. As result, it will always improve generalization error.

Dropout is used in the Transformer in embeddings layer after adding the token and positional embeddings as well as after each multi-head/feed-forward layers in both the encoder and decoder layers.

For more information on dropout, check out my previous post.

Transformer Components

Encoder-only Architecture

Encoder-only architecture are well suited for classification tasks. The most common model that uses encoder-only branch of the Transformer architecture is BERT and all its variants such as RoBERTa. In this architecture, we would have:

  • body: Stacked encoder layers. The output would be B x T x hidden_sz.
  • head: A classification head which consists of linear layer that project the hidden_sz into num_classes. We take the hidden vector of the first token, which is the special token [CLS] in the case of BERT (indicates the beginning of sequence), and pass it through the linear layer to get the logits.
class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = MultiHeadAttention(config)
        self.layer_norm_1 = nn.LayerNorm(config.embed_dim)
        self.layer_norm_2 = nn.LayerNorm(config.embed_dim)
        self.ff = FeedForwardNN(config)

    def forward(self, x):
        # There are two arrangements for layer_norm:
        # Prelayer normalization & Postlayer normalization
        # we are using postlayer normalization arrangement
        x = self.layer_norm_1(x + self.attn(x))
        x = self.layer_norm_2(x + self.ff(x))
        # Prelayer normalization
        # x = self.layer_norm_1(x)
        # x = x + self.attn(x)
        # x = x + self.ff(self.layer_norm_2(x))
        return x
class TransformerEncoder(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.embeddings = Embeddings(config)
        self.encoder_blocks = nn.Sequential(
            *[EncoderLayer(config) for _ in range(config.num_hidden_layers)]
        )

    def forward(self, x):
        x = self.embeddings(x)
        return self.encoder_blocks(x)
class TransformerForSequenceClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = TransformerEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.embed_dim, config.num_classes)

    def forward(self, x):
        # We take the hidden state of the [CLS] token as
        # input to the classifier
        x = self.encoder(x)[:, 0, :]
        x = self.dropout(x)
        return self.classifier(x)

Decoder-only Architecture

These models are typically used as language models such as GPT and all its variants. In this architecture, as opposed to the encoder-only architecture, the token can only see past tokens but not future tokens because this would be a kind of cheating since we are trying to predict the next token. Therefore, we need to mask all future tokens in the attention layer. In this architecture, we would have:

  • body: Stacked decoder layers. The output would be B x T x hidden_sz.
  • head: A classification head which consists of linear layer that project the hidden_sz into vocab_sz. The output would then be passed through softmax to get the probability distribution over all tokens in the vocabulary. The token with the highest probability would be chosen during training.

At inference, we can use many sampling algorithms such as the greedy algorithm or top-k algorithm using the probability distribution obtained from the classification head.

class DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = MultiHeadAttention(config, is_decoder=True)
        self.layer_norm_1 = nn.LayerNorm(config.head_dim)
        self.layer_norm_2 = nn.LayerNorm(config.head_dim)
        self.ff = FeedForwardNN(config)

    def forward(self, x):
        x = self.layer_norm_1(x + self.attn(x))
        x = self.layer_norm_2(x + self.ff(x))
        return x
class TransformerDecoder(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.embeddings = Embeddings(config)
        self.decoder_blocks = nn.Sequential(
            *[DecoderLayer(config) for _ in range(config.num_hidden_layers)]
        )

    def forward(self, x):
        x = self.embeddings(x)
        return self.decoder_blocks(x)
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.decoder = TransformerDecoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.lm_head = nn.Linear(config.head_dim, config.vacab_sz)

    def forward(self, x):
        x = self.decoder(x)
        x = self.dropout(x)
        return self.lm_head(x)

Encoder-Decoder Architecture

The encoder-decoder architecture is the first Transformer architecture used in the Attention Is All You Need paper for Neural Machine Translation Task. It is typically used for tasks that have both their input and output as text such as summarization. T5 and BART are the most common models that use such architecture.

For each decoder layer, we add masked multi-head attention layer in the middle that 1) takes the hidden state from the last encoder layer to compute the keys and values and 2) takes the hidden state from layer norm to compute the query. This means, this additional middle multi-head attention layer attends to the all tokens in the input sequence. This is a kind of cross-attention that we defined earlier where keys and values come from different source (input sequence) while the query comes from other source.

It is very easy to extend or modify our implementation of DecoderLayer to use it in this architecture, so I will leave it for you as an exercise!

Conclusion

In this post we started with a brief introduction of Transformer architecture and the motivation behind it such as overcoming RNN-based models. We then covered the main building blocks of the Transformer architecture including attention mechanism. We then briefly went over few tricks that are helpful to train complex models faster such as skip connections and layer normalization. Along the way, we implemented main sublayers used in the architecure. We concluded with different branches of the Transformer architecture that can be used separately such as encoder-only or decoder-only.

I hope that you found this post helpful and provided and a good background about the Transformer architecture.

Credits/Resources