There’s an old saying in engineering: “You don’t really understand something until you can build it.” This has never been more true than in the era of LLMs. While we’ve previously explored the foundational concepts in my post on the Transformer architecture explained here, true understanding comes from implementation. That’s why today, we’re building a GPT-style model (the 124M variant) from scratch in PyTorch.
This project has a different focus than my last “from scratch” endeavor, where I built an entire deep learning framework to grasp the low-level mechanics of autograd and tensor ops. Here, we’ll leverage PyTorch’s battle-tested primitives to focus on what makes GPT special: multi-head attention, positional encodings, and the specific architectural decisions that enable language understanding.
This hands-on process reveals challenges you can’t appreciate from diagrams alone. You’ll watch your GPU memory overflow, see training grind to a halt from inefficient data loading, and learn firsthand why techniques like mixed-precision training, gradient accumulation, and activation checkpointing are necessities, not just optimizations. It’s in facing these hurdles that you truly appreciate the engineering craft required to build and scale transformers efficiently.
GPTs
GPT (Generative Pre-trained Transformer) models, developed by OpenAI, represent a breakthrough in natural language processing. GPT-2, released in 2019, demonstrated that a transformer-based model trained on vast amounts of text could generate remarkably coherent and contextually relevant content. GPT-3, its successor, scaled this approach to 175 billion parameters, showcasing emergent capabilities like few-shot learning and complex reasoning. Both models share the same fundamental architecture: stacked transformer decoder blocks that predict the next token in a sequence, trained on the simple objective of minimizing prediction error across massive text corpora. The 124M parameter version we’ll be building captures the essential architecture while remaining computationally tractable for individual developers—though even at this “small” scale, you’ll quickly discover why the ML community spends so much time optimizing both training efficiency and model performance.
By the end of this journey, you won’t just know how transformers work—you’ll have built the critical components with your own hands, optimized the training loop, and watched your model evolve from random noise to coherent text generation. Let’s begin.
Implementation
Throughout this implementation, every piece of code will be thoroughly annotated with explanations of not just what we’re doing, but why we’re doing it. More importantly, we’ll use few optimizations that make a big difference in terms of computational efficiency:
TensorFloat32 (TF32): NVIDIA’s precision format that uses 19 bits of precision instead of 23, providing up to 8x speedup on A100 GPUs while maintaining model quality. We’ll see how a single line of code can dramatically accelerate matrix multiplications.
BFloat16 with Autocast: Mixed precision training using brain floating-point format, which maintains the same exponent range as FP32 but reduces mantissa precision. Combined with automatic mixed precision (AMP), this cuts memory usage in half and speeds up training significantly.
torch.compile: PyTorch 2.0’s just-in-time compilation that fuses operations and generates optimized kernels. We’ll explore how graph compilation can provide 10-30% speedups with minimal code changes.
Flash Attention and Online Softmax: An algorithmic improvement that computes attention without materializing the full attention matrix, reducing memory complexity from O(n²) to O(n).
Fused AdamW: A single-kernel implementation of the AdamW optimizer that reduces memory reads/writes by computing all parameter updates in one pass, providing up to 2x optimizer step speedup.
Annealed Learning Rate: Starting with a warmup phase followed by cosine decay, we’ll implement the learning rate schedule that has become standard for training transformers, understanding why stable training requires careful lr management.
Weight Decay Only on Matrices: A subtle but crucial detail—applying weight decay only to weight matrices in Linear and Embedding layers while excluding biases and layer normalization parameters, which improves model performance.
Distributed Data Parallelism (DDP): Scaling training across multiple GPUs using PyTorch’s DDP, including gradient synchronization, proper data loading, and the intricacies of maintaining consistent model states across devices.
Finally, since the GPT-2 paper omits certain architectural details and hyperparameter specifications, we’ll refer to the GPT-3 paper to fill these gaps—fortunately, the core architecture remains consistent between the two models, making the GPT-3 paper a reliable source for these missing implementation details.
Code
import inspectimport mathimport osimport timefrom dataclasses import dataclassfrom functools import partial, wrapsfrom typing import Callable, Iterableimport tiktokenimport torchimport torch.distributed as distimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optfrom torch.distributed import destroy_process_group, init_process_groupfrom torch.nn.parallel import DistributedDataParallel as DDP
def annealer(func: Callable): wraps(func)def annealer_wrapper(*args, **kwargs):return partial(func, *args, **kwargs)return annealer_wrapper@annealerdef lin_sched(start, end, pos):"""Linear scheduler."""return start + (end - start) * pos@annealerdef cos_sched(start, end, pos):"""Cosine scheduler."""return start + (1+ math.cos(math.pi * (1- pos))) * (end - start) /2def combine_scheds(pcts, scheds):""" Combine multiple schedulers, each run for a given percentage of the training process. """assertlen(pcts) ==len(scheds), "Each scheduler should have its `pct`."assertsum(pcts) ==1.0, "Sum of the `pcts` should be equal to 1." pcts = torch.tensor([0] + listify(pcts))assert (pcts >=0).all(), "All percentages should be non-negative." pcts = torch.cumsum(pcts, 0)def _inner(pos): idx = (pos >= pcts).nonzero().max() actual_pos = (pos - pcts[idx]) / (pcts[idx +1] - pcts[idx])return scheds[idx](actual_pos)return _inner
@dataclassclass GPTConfig: block_sz: int=1024# Sequence length vocab_sz: int= (50257# Originally 50000 BPE merges + 256 byte tokens + 1 for <|endoftext|> token# which will delimits different documents. This token's index is 50256# However, we found that using 50257 as the vocab size is not a multiple of 64 and we# could improve efficiency and performance (through better occupancy) if we round up# to the closest multiple of 64, which is 5304. ) n_layer: int=12# Number of layers n_embd: int=768# Embedding dimension n_head: int=12# Number of attention heads lr: float=3e-4# Good for big models batch_sz: int=4 dropout: float=0.0 bias: bool=True
class MLP(nn.Module):def__init__(self, config: GPTConfig):# Point-wise feed-forward network that applies non-linearity# on every token separately. THERE IS NO INTERACTION BETWEEN TOKENS# This is where almost all the capacity and non-linearities of the # model come from especially when we project it to 4 x n_embdsuper().__init__()self.c_fc = nn.Linear(config.n_embd, config.n_embd *4)# Found to be better than ReLU in terms of gradient saturationself.gelu = nn.GELU(approximate="tanh")self.c_proj = nn.Linear(config.n_embd *4, config.n_embd)self.dropout = nn.Dropout(config.dropout)self.c_proj.NANOGPT_SCALE_INIT =1def forward(self, x):returnself.dropout(self.c_proj(self.gelu(self.c_fc(x))))
class CausalSelfAttention(nn.Module):def__init__(self, config: GPTConfig):super().__init__()self.n_head = config.n_headself.n_embd = config.n_embdself.c_attn = nn.Linear(config.n_embd, config.n_embd *3, bias=config.bias)self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)self.attn_dropout = nn.Dropout(config.dropout)self.resid_dropout = nn.Dropout(config.dropout)self.dropout = config.dropoutself.c_proj.NANOGPT_SCALE_INIT =1# NOTE: Mask is not needed when we use Pytorch's Flash attention# self.register_buffer(# "mask",# torch.tril(torch.ones(config.block_sz, config.block_sz)).view(# config.block_sz, config.block_sz# ),# )def forward(self, x): B, T, C = x.shape qkv =self.c_attn(x)# q/k/v is B x T x n_embd each q, k, v = torch.split(qkv, self.n_embd, dim=-1)# Reshape q/k/v to B x n_head x T x (n_embd / n_head)# So each head would be learning different kind of# relationships q = q.view(B, T, self.n_head, C //self.n_head).transpose(1, 2) k = k.view(B, T, self.n_head, C //self.n_head).transpose(1, 2) v = v.view(B, T, self.n_head, C //self.n_head).transpose(1, 2)# attn is B x T x T# attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.shape[-1]))# # Mask out future tokens# attn = attn.masked_fill(self.mask[:T, :T] == 0, float("-inf"))# attn = self.attn_dropout(F.softmax(attn, dim=-1))# # y is B x T x n_embd# y = attn @ v# Uses Flash attention that never materialize attention matrices for# each head and is aware of the memory hierarchy and tries to reduce# read/writes with more FLOPs -> Speed up since we're memory bound y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = y.transpose(1, 2).contiguous().view(B, T, C)returnself.resid_dropout(self.c_proj(y))
class Block(nn.Module):def__init__(self, config: GPTConfig):super().__init__()self.ln_1 = nn.LayerNorm(config.n_embd)self.ln_2 = nn.LayerNorm(config.n_embd)self.mlp = MLP(config)self.attn = CausalSelfAttention(config)def forward(self, x):# Use Pre-layer normalization which deviates from the# transformer original paper that uses post-layer normalization.# This should help stabilize training x = x +self.attn(self.ln_1(x)) x = x +self.mlp(self.ln_2(x))return x
class GPT2(nn.Module):def__init__(self, config: GPTConfig):super().__init__()self.config = configself.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_sz, config.n_embd),# Attention operation is a permutation equivariant, this means that# if we permute the input then the corresponding output will be# permuted in exactly the same way. In other words, attention mechanism# is not aware of the relative ordering of the tokens. Therefore, we# need some way to encode the positions of the tokens in each sequence.# This is where positional encoding comes into play.# Here we use a simple positional encoding that is a simple# embedding of the position of the token in the sequence. wpe=nn.Embedding(config.block_sz, config.n_embd), h=nn.ModuleList( [Block(config) for _ inrange(config.n_layer)] ),# Final layer norm after all transformer layers ln_f=nn.LayerNorm(config.n_embd), ) )self.lm_head = nn.Linear(config.n_embd, config.vocab_sz, bias=False)# Weigth sharing between the token embedding layer and# last linear layer (LM head classifier). The rationale is# that tokens that are semantically similar to each other in# the embedding space should have similar probabilities in the# softmax of the LM head layer# Also, these matrices are one of the biggest matrices in the the model# This means, for model like GPT2, we save almost 30 % of the parameters# by sharing the weight matrices (50257 * 768) / 124M = ~31%self.transformer.wte.weight =self.lm_head.weightself.apply(self._init_weights)def _init_weights(self, module):# The following initialization comes from gpt2 src code# NOTE: Because token embedding and classifier weights are shared,# our initialization logic will initialize the weight matrix twice# but shouldn't be an issue since they're being initialized with the# same std and meanifisinstance(module, nn.Linear): std =0.02# We're changing std because residual path affect std# by increasing it on every layer so we need to adjust# it so we still have the same std = 0.02ifhasattr(module, "NANOGPT_SCALE_INIT"):# `2` here because every layer has two blocks:# - Attention block# - MLP block# `N` is the number of layers in the model (n_layer)# Since they are independent, variance of the sum of the two# blocks is the sum of the variances std *= (2*self.config.n_layer) **-0.5 nn.init.normal_(module.weight, std=std)if module.bias isnotNone: nn.init.zeros_(module.bias)elifisinstance(module, nn.Embedding):# We're initializing the token and positional embeddings# with the same std but the paper initialized the positional# embedding with std = 0.01 nn.init.normal_(module.weight, std=0.02)def forward(self, x, targets=None): T = x.shape[-1]assert ( T <=self.config.block_sz ), f"Sequence length must be <= {self.config.block_sz}, got {T}" pos_emb =self.transformer.wpe( torch.arange(0, T, dtype=torch.long, device=x.device) ) tok_emb =self.transformer.wte(x) x = pos_emb + tok_embfor block inself.transformer.h: x = block(x) x =self.transformer.ln_f(x)# logits is B x T x vocab_sz logits =self.lm_head(x) loss =Noneif targets isnotNone:# F.cross_entropy expects the 2nd dimension to be probabilities loss = F.cross_entropy( logits.view(-1, self.config.vocab_sz), targets.view(-1) )return logits, lossdef configure_optimizer(self, weight_decay, lr, device): params_dict = { pn: p for pn, p inself.named_parameters() if p.requires_grad }# We're not applying weight decay to bias and layer norm parameters# And any 1D parameters. Therefore, we are ONLY applying weight decay# to the weight matrices in Embedding and Linear layers decay_params = [p for p in params_dict.values() if p.ndim >=2] nondecay_params = [p for p in params_dict.values() if p.ndim <2] params_groups = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": nondecay_params, "weight_decay": 0.0}, ]# Fused AdamW is available for PyTorch 2.0+ fused_available ="fused"in inspect.signature(opt.AdamW).parameters use_fused = fused_available and"cuda"in devicereturn opt.AdamW( params_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused )@torch.no_graddef generate(self, idxs: torch.tensor, max_tokens: int=5):for i inrange(max_tokens):# x would be B x T x vocab_sz (At most we we would have# block_sz tokens since we're using fixed block_sz for the# positional embedding idxs = idxs[:, -self.config.block_sz :] logits, _ =self(idxs)# Get probs for last token to predict next token# This would be B x vocab_sz logits = logits[:, -1, :]# Apply softmax to get probabilities probs = F.softmax(logits, dim=-1)# Pick top 50 prob -> we would never pick tokens with# very smally probs (right tails) -> B x 50# probs/idxs are sorted in descending order topk_probs, topk_idxs = torch.topk(probs, 50, dim=-1)# Sample 1 token from the top 50 tokens -> idx is B x 1 idx = torch.multinomial(topk_probs, 1)# Get the vocab idx as `multinomial` returns only indices that# corresponds to the given array idx = torch.gather(topk_idxs, -1, idx)# TODO: We should check for end_of_text token and break out of# the loop (stop generation) even if we have not reached max_tokens idxs = torch.cat([idxs, idx], dim=1)return idxs
class DataLoaderLight:def__init__(self, file_path: str, batch_sz: int, block_sz: int, process_rank: int=0, number_processes: int=1, ) ->None:self.batch_sz = batch_szself.block_sz = block_szself.process_rank = process_rankself.number_processes = number_processeswithopen(file_path, "r") as f: text = f.read() encoder = tiktoken.get_encoding("gpt2")self.tokens = torch.tensor(encoder.encode(text), dtype=torch.long)# We can truncate the tokens to a be multiple of batch_sz x block_sz# x number_processes. This is useful for multi-node training and mimics# the behavior of DataLoader's `drop_last` parameter.self.tokens =self.tokens[: len(self.tokens) // (self.batch_sz *self.block_sz *self.number_processes )*self.batch_sz*self.block_sz*self.number_processes]print(f"Loaded {len(self.tokens)} tokens")print(f"1 epoch = {len(self)} batches")self.current_pos = batch_sz * block_sz * process_rankdef__len__(self):returnlen(self.tokens) // (self.batch_sz *self.block_sz)def next_batch(self): buf =self.tokens[self.current_pos : self.current_pos+self.batch_sz *self.block_sz+1 ] x = buf[:-1].view(self.batch_sz, self.block_sz) y = buf[1:].view(self.batch_sz, self.block_sz)# Each process will process batch_sz x block_sz tokens in each# iteration -> with number_processes processes, total tokens processed# in each iteration is batch_sz x block_sz x number_processes. In the# case of one process, total tokens would be batch_sz x block_szself.current_pos += (self.batch_sz *self.block_sz *self.number_processes )# Similar to DataLoader's `drop_last` parameter, we drop the last# batch if it's not a multiple of batch_sz x block_sz x number_processes# if self.current_pos + (# self.batch_sz * self.block_sz * self.number_processes# ) + self.number_processes > len(self):ifself.current_pos >=len(self.tokens):self.current_pos =0return x, y
############ Distributed Data Parallel############ Distributed Data Parallel let us run the same model (replica) on different GPUs,# where each GPU would work on a different slice of data. After we do the backward# pass, we average the gradients across all processes (GPUs) and synchronize all# parameters across all devices. We use allReduce op to do this and communicate the# updates with all processes.# Each process would go through the same code from top to bottom not aware there# are other processes running the same thing on other devices## torchrun command sets the following environment variables:# RANK: Id of the process in the process group. It is an int 0-WORLD_SIZE# LOCAL_RANK: In the case of multi-nodes, LOCAL_RANK is the id of# the process in the same node. Example: If we have a node# with 4 GPUs, the first process will have LOCAL_RANK=0# but RANK of this process mayn't be 0 if we are running# on multiple nodes.# This is useful when we have multiple nodes and we want to# run the processes on different GPUs in the same node.# In this case, we can set the LOCAL_RANK to the GPU id in the# node.# WORLD_SIZE: Total number of processesddp =int(os.getenv("RANK", -1)) !=-1# Check if it is a ddp runif ddp:# DDP requires CUDA so we need to set the device for each process# so only one process can run per deviceassert torch.cuda.is_available(), "DDP requires CUDA" init_process_group(backend="nccl") ddp_rank =int(os.getenv("RANK")) ddp_local_rank =int(os.getenv("LOCAL_RANK")) ddp_world_size =int(os.getenv("WORLD_SIZE")) device =f"cuda:{ddp_local_rank}" torch.cuda.set_device(device)# Master process will do more things such as checkpointing and logging# while other processes would assist in the computations.# It always has RANK=0 master_process = ddp_rank ==0else: ddp_rank =0 ddp_local_rank =0 ddp_world_size =1 master_process =Trueif torch.cuda.is_available(): device ="cuda"elif torch.backends.mps.is_built(): # Apple Silicon device ="mps" torch.mps.manual_seed(1337)else: device ="cpu"print(device)torch.manual_seed(1337)if torch.cuda.is_available(): torch.cuda.manual_seed(1337)########### Initialize model and optimizer########### Everything in GPUs is a power of 2 such as tiling ops# So try to always have matrices be power of 2 to improve use of:# • Tensor Cores# • Memory coalescing# • Shared memory bank alignment# • Warp scheduling# Here we change the vocab_sz by rounding it up to the closest# number that is power of. This will increase space overhead# but would speed up computationsmodel = GPT2(GPTConfig(vocab_sz=50304)).to(device)# Speed up model by building static graph that analyzes all ops# and optimizes them such as fusing some of them to avoid unnecessary# trips to memory# model = torch.compile(model)if ddp: model = DDP(model, device_ids=[ddp_local_rank])raw_model = model.module if ddp else modelmax_lr =3e-4min_lr = max_lr *0.1warmup_steps =10max_steps =50sched = combine_scheds( [warmup_steps / max_steps, 1- (warmup_steps / max_steps)], [lin_sched(min_lr, max_lr), cos_sched(max_lr, min_lr)],)optimizer = raw_model.configure_optimizer( weight_decay=0.1, lr=max_lr, device=device)########### Run training loop########## NOTE: In order to run 0.5M (from GPT3 paper) tokens per fwd/bwd iteration,# we need to # use gradient accumulation because we can't fit it in almost# any commodity # GPU -> We only do backward after we loop through ~0.5M tokens.total_batch_sz =2**19# closest number to 0.5Massert total_batch_sz % (GPTConfig.batch_sz * GPTConfig.block_sz * ddp_world_size) ==0, "total batch size must be divisible by micro batch_sz x block_sz x ddp_world_size"grad_accum_steps = total_batch_sz // ( GPTConfig.batch_sz * GPTConfig.block_sz * ddp_world_size)if master_process:print(f"Total desired batch size: {total_batch_sz}")print(f"Calculated gradient accumulation steps: {grad_accum_steps}")train_dl = DataLoaderLight("tinyshakespeare.txt", batch_sz=GPTConfig.batch_sz, block_sz=GPTConfig.block_sz, process_rank=ddp_rank, number_processes=ddp_world_size)# Pytorch will use TensorFloat32 if available, else use FP32# But the weights will still be stored using FP32 with less precision# (10 bits for mantissa instead of 23). It is just the# operations would be executed as TF32 if availabletorch.set_float32_matmul_precision("high")for step inrange(max_steps): start = time.time() x, y = train_dl.next_batch() x = x.to(device) y = y.to(device)# code.interact(local=locals()) optimizer.zero_grad() loss_accum =0.0for macro_step inrange(grad_accum_steps):if device =="cuda":# Tensors that will be greatly affected by less precission such# as loss, layernorm would still be in FP32 while others such# as attention weights would be in BF16with torch.autocast(device_type=device, dtype=torch.bfloat16): logits, loss = model(x, y)else: logits, loss = model(x, y)# Just accumulating gradients yield to summation of objective but# we want mean so we weight each loss by 1/grad_accum_steps loss /= grad_accum_steps loss_accum += loss.detach()# To avoid syncing the gradients between the processes after every# macro step, we disable it and only allows the sync up of# gradients after we finish all gradient accumulation in each# processif ddp: model.require_backward_grad_sync = ( macro_step == grad_accum_steps -1 ) loss.backward()# Each process would have its own loss_accum tensor, so to get the# average loss_accum across all processes, we want to compute the# average of all loss_accum in all processesif ddp: dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)# Clips gradient to global norm. It is very useful to avoid having a# very high loss for some (bad) batch(es) that would have very high# loss # which would lead to high gradients and huge updates# NOTE: In the beginning of training it is normal to have high norms# as the # model initialized randomly norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)# TODO: Use ParamScheduler from `cmn_ai` lr = sched(step / max_steps)for pg in optimizer.param_groups: pg["lr"] = lr optimizer.step() end = time.time() elapsed_time = end - start token_per_sec = ( GPTConfig.batch_sz* GPTConfig.block_sz* grad_accum_steps* ddp_world_size ) / (elapsed_time)print(f"step {step}, loss: {loss.item()}, lr {lr:.4e}, norm: {norm:.2f}, time: {elapsed_time:.2f}s, tok/sec: {token_per_sec:.2f}" )if ddp:# Kills all processes destroy_process_group()
Conclusion
We’ve come a long way in this journey—from implementing the core transformer architecture with multi-head attention and positional encodings, to building an efficient training pipeline complete with modern optimizations like flash attention, mixed precision training, and distributed parallelism. We’ve debugged exploding gradients, optimized memory usage, and watched our model evolve from producing random gibberish to generating coherent text. Along the way, we’ve gained deep insights into why each component exists and how they work together to create these remarkable language models.
I hope this deep dive has been as illuminating for you as it has been for me. Writing this implementation forced me to confront gaps in my own understanding and solidified concepts that previously felt abstract. There’s something uniquely satisfying about seeing your hand-built transformer successfully predict its first coherent sentence—a moment where theory truly becomes understanding.
If you’ve made it this far, thank you for joining me on this journey. I’d love to hear about your experiences implementing transformers, any bugs you’ve encountered, optimizations you’ve discovered, or questions this post might have raised. Feel free to reach out with feedback, corrections, or insights—the best part of sharing these implementations is learning from the community’s collective wisdom. Happy building!