Coding GPT2/3 (124M) From Scratch

NLP
Deep Learning
Author

Imad Dabbura

Published

April 10, 2024

Modified

June 1, 2025

Introduction

There’s an old saying in engineering: “You don’t really understand something until you can build it.” This has never been more true than in the era of LLMs. While we’ve previously explored the foundational concepts in my post on the Transformer architecture explained here, true understanding comes from implementation. That’s why today, we’re building a GPT-style model (the 124M variant) from scratch in PyTorch.

This project has a different focus than my last “from scratch” endeavor, where I built an entire deep learning framework to grasp the low-level mechanics of autograd and tensor ops. Here, we’ll leverage PyTorch’s battle-tested primitives to focus on what makes GPT special: multi-head attention, positional encodings, and the specific architectural decisions that enable language understanding.

This hands-on process reveals challenges you can’t appreciate from diagrams alone. You’ll watch your GPU memory overflow, see training grind to a halt from inefficient data loading, and learn firsthand why techniques like mixed-precision training, gradient accumulation, and activation checkpointing are necessities, not just optimizations. It’s in facing these hurdles that you truly appreciate the engineering craft required to build and scale transformers efficiently.

GPTs

GPT (Generative Pre-trained Transformer) models, developed by OpenAI, represent a breakthrough in natural language processing. GPT-2, released in 2019, demonstrated that a transformer-based model trained on vast amounts of text could generate remarkably coherent and contextually relevant content. GPT-3, its successor, scaled this approach to 175 billion parameters, showcasing emergent capabilities like few-shot learning and complex reasoning. Both models share the same fundamental architecture: stacked transformer decoder blocks that predict the next token in a sequence, trained on the simple objective of minimizing prediction error across massive text corpora. The 124M parameter version we’ll be building captures the essential architecture while remaining computationally tractable for individual developers—though even at this “small” scale, you’ll quickly discover why the ML community spends so much time optimizing both training efficiency and model performance.

By the end of this journey, you won’t just know how transformers work—you’ll have built the critical components with your own hands, optimized the training loop, and watched your model evolve from random noise to coherent text generation. Let’s begin.

Implementation

Throughout this implementation, every piece of code will be thoroughly annotated with explanations of not just what we’re doing, but why we’re doing it. More importantly, we’ll use few optimizations that make a big difference in terms of computational efficiency:

  • TensorFloat32 (TF32): NVIDIA’s precision format that uses 19 bits of precision instead of 23, providing up to 8x speedup on A100 GPUs while maintaining model quality. We’ll see how a single line of code can dramatically accelerate matrix multiplications.

  • BFloat16 with Autocast: Mixed precision training using brain floating-point format, which maintains the same exponent range as FP32 but reduces mantissa precision. Combined with automatic mixed precision (AMP), this cuts memory usage in half and speeds up training significantly.

  • torch.compile: PyTorch 2.0’s just-in-time compilation that fuses operations and generates optimized kernels. We’ll explore how graph compilation can provide 10-30% speedups with minimal code changes.

  • Flash Attention and Online Softmax: An algorithmic improvement that computes attention without materializing the full attention matrix, reducing memory complexity from O(n²) to O(n).

  • Fused AdamW: A single-kernel implementation of the AdamW optimizer that reduces memory reads/writes by computing all parameter updates in one pass, providing up to 2x optimizer step speedup.

  • Annealed Learning Rate: Starting with a warmup phase followed by cosine decay, we’ll implement the learning rate schedule that has become standard for training transformers, understanding why stable training requires careful lr management.

  • Weight Decay Only on Matrices: A subtle but crucial detail—applying weight decay only to weight matrices in Linear and Embedding layers while excluding biases and layer normalization parameters, which improves model performance.

  • Distributed Data Parallelism (DDP): Scaling training across multiple GPUs using PyTorch’s DDP, including gradient synchronization, proper data loading, and the intricacies of maintaining consistent model states across devices.

Finally, since the GPT-2 paper omits certain architectural details and hyperparameter specifications, we’ll refer to the GPT-3 paper to fill these gaps—fortunately, the core architecture remains consistent between the two models, making the GPT-3 paper a reliable source for these missing implementation details.

Code
import inspect
import math
import os
import time
from dataclasses import dataclass
from functools import partial, wraps
from typing import Callable, Iterable

import tiktoken
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
Code
def listify(obj):
    if obj is None:
        return []
    elif isinstance(obj, str):
        return [obj]
    elif isinstance(obj, list):
        return obj
    elif isinstance(obj, Iterable):
        return list(obj)
    else:
        return [obj]
Code
def annealer(func: Callable):
    wraps(func)

    def annealer_wrapper(*args, **kwargs):
        return partial(func, *args, **kwargs)

    return annealer_wrapper


@annealer
def lin_sched(start, end, pos):
    """Linear scheduler."""
    return start + (end - start) * pos


@annealer
def cos_sched(start, end, pos):
    """Cosine scheduler."""
    return start + (1 + math.cos(math.pi * (1 - pos))) * (end - start) / 2


def combine_scheds(pcts, scheds):
    """
    Combine multiple schedulers, each run for a given percentage of the
    training process.
    """
    assert len(pcts) == len(scheds), "Each scheduler should have its `pct`."
    assert sum(pcts) == 1.0, "Sum of the `pcts` should be equal to 1."
    pcts = torch.tensor([0] + listify(pcts))
    assert (pcts >= 0).all(), "All percentages should be non-negative."
    pcts = torch.cumsum(pcts, 0)

    def _inner(pos):
        idx = (pos >= pcts).nonzero().max()
        actual_pos = (pos - pcts[idx]) / (pcts[idx + 1] - pcts[idx])
        return scheds[idx](actual_pos)

    return _inner
@dataclass
class GPTConfig:
    block_sz: int = 1024 # Sequence length
    vocab_sz: int = (
        50257  # Originally 50000 BPE merges + 256 byte tokens + 1 for <|endoftext|> token
        # which will delimits different documents. This token's index is 50256
        # However, we found that using 50257 as the vocab size is not a multiple of 64 and we
        # could improve efficiency and performance (through better occupancy) if we round up
        # to the closest multiple of 64, which is 5304.
    )
    n_layer: int = 12 # Number of layers
    n_embd: int = 768 # Embedding dimension
    n_head: int = 12 # Number of attention heads
    lr: float = 3e-4  # Good for big models
    batch_sz: int = 4
    dropout: float = 0.0
    bias: bool = True
class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        # Point-wise feed-forward network that applies non-linearity
        # on every token separately. THERE IS NO INTERACTION BETWEEN TOKENS
        # This is where almost all the capacity and non-linearities of the 
        # model come from especially when we project it to 4 x n_embd
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, config.n_embd * 4)
        # Found to be better than ReLU in terms of gradient saturation
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(config.n_embd * 4, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.c_attn = nn.Linear(config.n_embd, config.n_embd * 3, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.dropout = config.dropout
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # NOTE: Mask is not needed when we use Pytorch's Flash attention
        # self.register_buffer(
        #     "mask",
        #     torch.tril(torch.ones(config.block_sz, config.block_sz)).view(
        #         config.block_sz, config.block_sz
        #     ),
        # )

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.c_attn(x)
        # q/k/v is B x T x n_embd each
        q, k, v = torch.split(qkv, self.n_embd, dim=-1)
        # Reshape q/k/v to B x n_head x T x (n_embd / n_head)
        # So each head would be learning different kind of
        # relationships
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        # attn is B x T x T
        # attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.shape[-1]))
        # # Mask out future tokens
        # attn = attn.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        # attn = self.attn_dropout(F.softmax(attn, dim=-1))
        # # y is B x T x n_embd
        # y = attn @ v
        # Uses Flash attention that never materialize attention matrices for
        # each head and is aware of the memory hierarchy and tries to reduce
        # read/writes with more FLOPs -> Speed up since we're memory bound
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.resid_dropout(self.c_proj(y))
class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
        self.attn = CausalSelfAttention(config)

    def forward(self, x):
        # Use Pre-layer normalization which deviates from the
        # transformer original paper that uses post-layer normalization.
        # This should help stabilize training
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
class GPT2(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_sz, config.n_embd),
                # Attention operation is a permutation equivariant, this means that
                # if we permute the input then the corresponding output will be
                # permuted in exactly the same way. In other words, attention mechanism
                # is not aware of the relative ordering of the tokens. Therefore, we
                # need some way to encode the positions of the tokens in each sequence.
                # This is where positional encoding comes into play.
                # Here we use a simple positional encoding that is a simple
                # embedding of the position of the token in the sequence.
                wpe=nn.Embedding(config.block_sz, config.n_embd),
                h=nn.ModuleList(
                    [Block(config) for _ in range(config.n_layer)]
                ),
                # Final layer norm after all transformer layers
                ln_f=nn.LayerNorm(config.n_embd),
            )
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_sz, bias=False)

        # Weigth sharing between the token embedding layer and
        # last linear layer (LM head classifier). The rationale is
        # that tokens that are semantically similar to each other in
        # the embedding space should have similar probabilities in the
        # softmax of the LM head layer
        # Also, these matrices are one of the biggest matrices in the the model
        # This means, for model like GPT2, we save almost 30 % of the parameters
        # by sharing the weight matrices (50257 * 768) / 124M = ~31%
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        # The following initialization comes from gpt2 src code
        # NOTE: Because token embedding and classifier weights are shared,
        # our initialization logic will initialize the weight matrix twice
        # but shouldn't be an issue since they're being initialized with the
        # same std and mean
        if isinstance(module, nn.Linear):
            std = 0.02
            # We're changing std because residual path affect std
            # by increasing it on every layer so we need to adjust
            # it so we still have the same std = 0.02
            if hasattr(module, "NANOGPT_SCALE_INIT"):
                # `2` here because every layer has two blocks:
                #   - Attention block
                #   - MLP block
                # `N` is the number of layers in the model (n_layer)
                # Since they are independent, variance of the sum of the two
                # blocks is the sum of the variances
                std *= (2 * self.config.n_layer) ** -0.5
            nn.init.normal_(module.weight, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            # We're initializing the token and positional embeddings
            # with the same std but the paper initialized the positional
            # embedding with std = 0.01
            nn.init.normal_(module.weight, std=0.02)

    def forward(self, x, targets=None):
        T = x.shape[-1]
        assert (
            T <= self.config.block_sz
        ), f"Sequence length must be <= {self.config.block_sz}, got {T}"
        pos_emb = self.transformer.wpe(
            torch.arange(0, T, dtype=torch.long, device=x.device)
        )
        tok_emb = self.transformer.wte(x)
        x = pos_emb + tok_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        # logits is B x T x vocab_sz
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            # F.cross_entropy expects the 2nd dimension to be probabilities
            loss = F.cross_entropy(
                logits.view(-1, self.config.vocab_sz), targets.view(-1)
            )
        return logits, loss

    def configure_optimizer(self, weight_decay, lr, device):
        params_dict = {
            pn: p for pn, p in self.named_parameters() if p.requires_grad
        }
        # We're not applying weight decay to bias and layer norm parameters
        # And any 1D parameters. Therefore, we are ONLY applying weight decay
        # to the weight matrices in Embedding and Linear layers
        decay_params = [p for p in params_dict.values() if p.ndim >= 2]
        nondecay_params = [p for p in params_dict.values() if p.ndim < 2]
        params_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nondecay_params, "weight_decay": 0.0},
        ]
        # Fused AdamW is available for PyTorch 2.0+
        fused_available = "fused" in inspect.signature(opt.AdamW).parameters
        use_fused = fused_available and "cuda" in device
        return opt.AdamW(
            params_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused
        )

    @torch.no_grad
    def generate(self, idxs: torch.tensor, max_tokens: int = 5):
        for i in range(max_tokens):
            # x would be B x T x vocab_sz (At most we we would have
            # block_sz tokens since we're using fixed block_sz for the
            # positional embedding
            idxs = idxs[:, -self.config.block_sz :]
            logits, _ = self(idxs)
            # Get probs for last token to predict next token
            # This would be B x vocab_sz
            logits = logits[:, -1, :]
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)
            # Pick top 50 prob -> we would never pick tokens with
            # very smally probs (right tails) -> B x 50
            # probs/idxs are sorted in descending order
            topk_probs, topk_idxs = torch.topk(probs, 50, dim=-1)
            # Sample 1 token from the top 50 tokens -> idx is B x 1
            idx = torch.multinomial(topk_probs, 1)
            # Get the vocab idx as `multinomial` returns only indices that
            # corresponds to the given array
            idx = torch.gather(topk_idxs, -1, idx)
            # TODO: We should check for end_of_text token and break out of
            # the loop (stop generation) even if we have not reached max_tokens
            idxs = torch.cat([idxs, idx], dim=1)
        return idxs
class DataLoaderLight:
    def __init__(
        self,
        file_path: str,
        batch_sz: int,
        block_sz: int,
        process_rank: int = 0,
        number_processes: int = 1,
    ) -> None:
        self.batch_sz = batch_sz
        self.block_sz = block_sz
        self.process_rank = process_rank
        self.number_processes = number_processes
        with open(file_path, "r") as f:
            text = f.read()
        encoder = tiktoken.get_encoding("gpt2")
        self.tokens = torch.tensor(encoder.encode(text), dtype=torch.long)
        # We can truncate the tokens to a be multiple of batch_sz x block_sz
        # x number_processes. This is useful for multi-node training and mimics
        # the behavior of DataLoader's `drop_last` parameter.
        self.tokens = self.tokens[: len(self.tokens) // (
            self.batch_sz * self.block_sz * self.number_processes
        )
            * self.batch_sz
            * self.block_sz
            * self.number_processes]
        print(f"Loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self)} batches")
        self.current_pos = batch_sz * block_sz * process_rank

    def __len__(self):
        return len(self.tokens) // (self.batch_sz * self.block_sz)

    def next_batch(self):
        buf = self.tokens[
            self.current_pos : self.current_pos
            + self.batch_sz * self.block_sz
            + 1
        ]
        x = buf[:-1].view(self.batch_sz, self.block_sz)
        y = buf[1:].view(self.batch_sz, self.block_sz)
        # Each process will process batch_sz x block_sz tokens in each
        # iteration -> with number_processes processes, total tokens processed
        # in each iteration is batch_sz x block_sz x number_processes. In the
        # case of one process, total tokens would be batch_sz x block_sz
        self.current_pos += (
            self.batch_sz * self.block_sz * self.number_processes
        )
        # Similar to DataLoader's `drop_last` parameter, we drop the last
        # batch if it's not a multiple of batch_sz x block_sz x number_processes
        # if self.current_pos + (
        #     self.batch_sz * self.block_sz * self.number_processes
        # ) + self.number_processes > len(self):
        if self.current_pos >= len(self.tokens):
            self.current_pos = 0
        return x, y
###########
# Distributed Data Parallel
###########
# Distributed Data Parallel let us run the same model (replica) on different GPUs,
# where each GPU would work on a different slice of data. After we do the backward
# pass, we average the gradients across all processes (GPUs) and synchronize all
# parameters across all devices. We use allReduce op to do this and communicate the
# updates with all processes.
# Each process would go through the same code from top to bottom not aware there
# are other processes running the same thing on other devices
#
# torchrun command sets the following environment variables:
# RANK: Id of the process in the process group. It is an int 0-WORLD_SIZE
# LOCAL_RANK: In the case of multi-nodes, LOCAL_RANK is the id of
#             the process in the same node. Example: If we have a node
#             with 4 GPUs, the first process will have LOCAL_RANK=0
#             but RANK of this process mayn't be 0 if we are running
#             on multiple nodes.
#             This is useful when we have multiple nodes and we want to
#             run the processes on different GPUs in the same node.
#             In this case, we can set the LOCAL_RANK to the GPU id in the
#             node.
# WORLD_SIZE: Total number of processes
ddp = int(os.getenv("RANK", -1)) != -1  # Check if it is a ddp run
if ddp:
    # DDP requires CUDA so we need to set the device for each process
    # so only one process can run per device
    assert torch.cuda.is_available(), "DDP requires CUDA"
    init_process_group(backend="nccl")
    ddp_rank = int(os.getenv("RANK"))
    ddp_local_rank = int(os.getenv("LOCAL_RANK"))
    ddp_world_size = int(os.getenv("WORLD_SIZE"))
    device = f"cuda:{ddp_local_rank}"
    torch.cuda.set_device(device)
    # Master process will do more things such as checkpointing and logging
    # while other processes would assist in the computations.
    # It always has RANK=0
    master_process = ddp_rank == 0
else:
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_built():  # Apple Silicon
        device = "mps"
        torch.mps.manual_seed(1337)
    else:
        device = "cpu"
print(device)
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)
##########
# Initialize model and optimizer
##########
# Everything in GPUs is a power of 2 such as tiling ops
# So try to always have matrices be power of 2 to improve use of:
# • Tensor Cores
# • Memory coalescing
# • Shared memory bank alignment
# • Warp scheduling
# Here we change the vocab_sz by rounding it up to the closest
# number that is power of. This will increase space overhead
# but would speed up computations
model = GPT2(GPTConfig(vocab_sz=50304)).to(device)
# Speed up model by building static graph that analyzes all ops
# and optimizes them such as fusing some of them to avoid unnecessary
# trips to memory
# model = torch.compile(model)
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model
max_lr = 3e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50
sched = combine_scheds(
    [warmup_steps / max_steps, 1 - (warmup_steps / max_steps)],
    [lin_sched(min_lr, max_lr), cos_sched(max_lr, min_lr)],
)
optimizer = raw_model.configure_optimizer(
    weight_decay=0.1, lr=max_lr, device=device
)

##########
# Run training loop
#########
# NOTE: In order to run 0.5M (from GPT3 paper) tokens per fwd/bwd iteration,
# we need to # use gradient accumulation because we can't fit it in almost
# any commodity # GPU -> We only do backward after we loop through ~0.5M tokens.
total_batch_sz = 2**19  # closest number to 0.5M
assert total_batch_sz % (GPTConfig.batch_sz * GPTConfig.block_sz * ddp_world_size) == 0, "total batch size must be divisible by micro batch_sz x block_sz x ddp_world_size"
grad_accum_steps = total_batch_sz // (
    GPTConfig.batch_sz * GPTConfig.block_sz * ddp_world_size
)
if master_process:
    print(f"Total desired batch size: {total_batch_sz}")
    print(f"Calculated gradient accumulation steps: {grad_accum_steps}")

train_dl = DataLoaderLight(
    "tinyshakespeare.txt",
    batch_sz=GPTConfig.batch_sz,
    block_sz=GPTConfig.block_sz,
    process_rank=ddp_rank,
    number_processes=ddp_world_size
)
# Pytorch will use TensorFloat32 if available, else use FP32
# But the weights will still be stored using FP32 with less precision
# (10 bits for mantissa instead of 23). It is just the
# operations would be executed as TF32 if available
torch.set_float32_matmul_precision("high")

for step in range(max_steps):
    start = time.time()
    x, y = train_dl.next_batch()
    x = x.to(device)
    y = y.to(device)
    # code.interact(local=locals())
    optimizer.zero_grad()
    loss_accum = 0.0
    for macro_step in range(grad_accum_steps):
        if device == "cuda":
            # Tensors that will be greatly affected by less precission such
            # as loss, layernorm would still be in FP32 while others such
            # as attention weights would be in BF16
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                logits, loss = model(x, y)
        else:
            logits, loss = model(x, y)
        # Just accumulating gradients yield to summation of objective but
        # we want mean so we weight each loss by 1/grad_accum_steps
        loss /= grad_accum_steps
        loss_accum += loss.detach()
        # To avoid syncing the gradients between the processes after every
        # macro step, we disable it and only allows the sync up of
        # gradients after we finish all gradient accumulation in each
        # process
        if ddp:
            model.require_backward_grad_sync = (
                macro_step == grad_accum_steps - 1
            )
        loss.backward()
    # Each process would have its own loss_accum tensor, so to get the
    # average loss_accum across all processes, we want to compute the
    # average of all loss_accum in all processes
    if ddp:
        dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
    # Clips gradient to global norm. It is very useful to avoid having a
    # very high loss for some (bad) batch(es) that would have very high
    # loss # which would lead to high gradients and huge updates
    # NOTE: In the beginning of training it is normal to have high norms
    # as the # model initialized randomly
    norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # TODO: Use ParamScheduler from `cmn_ai`
    lr = sched(step / max_steps)
    for pg in optimizer.param_groups:
        pg["lr"] = lr
    optimizer.step()
    end = time.time()
    elapsed_time = end - start
    token_per_sec = (
        GPTConfig.batch_sz
        * GPTConfig.block_sz
        * grad_accum_steps
        * ddp_world_size
    ) / (elapsed_time)
    print(
        f"step {step}, loss: {loss.item()}, lr {lr:.4e}, norm: {norm:.2f}, time: {elapsed_time:.2f}s, tok/sec: {token_per_sec:.2f}"
    )

if ddp:
    # Kills all processes
    destroy_process_group()

Conclusion

We’ve come a long way in this journey—from implementing the core transformer architecture with multi-head attention and positional encodings, to building an efficient training pipeline complete with modern optimizations like flash attention, mixed precision training, and distributed parallelism. We’ve debugged exploding gradients, optimized memory usage, and watched our model evolve from producing random gibberish to generating coherent text. Along the way, we’ve gained deep insights into why each component exists and how they work together to create these remarkable language models.

I hope this deep dive has been as illuminating for you as it has been for me. Writing this implementation forced me to confront gaps in my own understanding and solidified concepts that previously felt abstract. There’s something uniquely satisfying about seeing your hand-built transformer successfully predict its first coherent sentence—a moment where theory truly becomes understanding.

If you’ve made it this far, thank you for joining me on this journey. I’d love to hear about your experiences implementing transformers, any bugs you’ve encountered, optimizations you’ve discovered, or questions this post might have raised. Feel free to reach out with feedback, corrections, or insights—the best part of sharing these implementations is learning from the community’s collective wisdom. Happy building!

Resources