In the first iteration of the training, we want all tokens to have almost same probability and thus the loss on each one would be the same. The probability of each token would be \(1/vocab\_sz\) -> \(loss \approx -log(1/vocab\_sz)\) because the probability distribution would be diffused.
3e-4 learning rate is good for AdamW optimizer for debugging
We want the model to overfit 1 batch to make sure it is running correctly
Weight sharing between token embedding and the final linear layer (also called classifier or LM head) because we want the tokens that are semantically similar to have similar probability when predicting next token.
This also has huge advantage on computational efficiency as those matrices have a lot parameters. For GPT2, each one has \(50257 * 768 \approx 38.5M\) which is \(1/3\) of the GPT2 model.
As a result of Weight sharing, gradient update will be addition from the two branches: classifier and token embedding
For tokens that don’t appear in the training data, we want their probabilities to be very close to zero
CPU can continue running even if Cuda kernels are not done. This is because CPU is kinda scheduling the kernels on the GPU and doesn’t wait for them to finish -> Use torch.cuda.synchronize() so CPU only presumes when scheduled kernels finish execution to get better timings
Implementation
Code
import inspectimport mathimport osimport sysimport timefrom dataclasses import dataclassfrom functools import partial, wrapsfrom typing import Callableimport tiktokenimport torchimport torch.distributed as distimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optfrom torch.distributed import destroy_process_group, init_process_groupfrom torch.nn.parallel import DistributedDataParallel as DDP
Code
def annealer(func: Callable): wraps(func)def annealer_wrapper(*args, **kwargs):return partial(func, *args, **kwargs)return annealer_wrapper@annealerdef lin_sched(start, end, pos):"""Linear scheduler."""return start + (end - start) * pos@annealerdef cos_sched(start, end, pos):"""Cosine scheduler."""return start + (1+ math.cos(math.pi * (1- pos))) * (end - start) /2def combine_scheds(pcts, scheds):""" Combine multiple schedulers, each run for a given percentage of the training process. """assertlen(pcts) ==len(scheds), "Each scheduler should have its `pct`."assertsum(pcts) ==1.0, "Sum of the `pcts` should be equal to 1." pcts = torch.tensor([0] + listify(pcts))assert (pcts >=0).all(), "All percentages should be non-negative." pcts = torch.cumsum(pcts, 0)def _inner(pos): idx = (pos >= pcts).nonzero().max() actual_pos = (pos - pcts[idx]) / (pcts[idx +1] - pcts[idx])return scheds[idx](actual_pos)return _inner
@dataclassclass GPTConfig: block_sz: int=1024 vocab_sz: int= (50257# 50000 BPE merges + 256 byte tokens + 1 for <|endoftext|> token# which will delimits different documents. This token's index is 50256 ) n_layer: int=12 n_embd: int=768 n_head: int=12 lr: int=3e-4 batch_sz: int=4class MLP(nn.Module):def__init__(self, config: GPTConfig):# Point-wise feed-forward network that applies non-linearity# on every token sepearately. THERE IS NO INTERACTION BETWEEN TOKENSsuper().__init__()self.c_fc = nn.Linear(config.n_embd, config.n_embd *4)self.gelu = nn.GELU(approximate="tanh")self.c_proj = nn.Linear(config.n_embd *4, config.n_embd)self.c_proj.NANOGPT_SCALE_INIT =1def forward(self, x):returnself.c_proj(self.gelu(self.c_fc(x)))class CausalSelfAttention(nn.Module):def__init__(self, config: GPTConfig):super().__init__()self.n_head = config.n_headself.n_embd = config.n_embdself.c_attn = nn.Linear(config.n_embd, config.n_embd *3)self.c_proj = nn.Linear(config.n_embd, config.n_embd)self.c_proj.NANOGPT_SCALE_INIT =1# NOTE: Bias is not needed when we use Pytorch's Flash attention# self.register_buffer(# "bias",# torch.tril(torch.ones(config.block_sz, config.block_sz)).view(# 1, 1, config.block_sz, config.block_sz# ),# )def forward(self, x): B, T, C = x.shape qkv =self.c_attn(x)# q/k/v is B x T x n_embd each q, k, v = torch.split(qkv, self.n_embd, dim=-1)# Reshape q/k/v to B x n_head x T x (n_embd / n_head)# So each head would be learning different kind of# relationships q = q.view(B, T, self.n_head, C //self.n_head).transpose(1, 2) k = k.view(B, T, self.n_head, C //self.n_head).transpose(1, 2) v = v.view(B, T, self.n_head, C //self.n_head).transpose(1, 2)# attn is B x T x T# attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.shape[-1]))# # Mask out future tokens# attn = attn.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))# attn = F.softmax(attn, dim=-1)# # y is B x T x n_embd# y = attn @ v# Uses Flash attention that never materialize attention matrices for# each head and is aware of the memory hierarchy and tries to reduce# read/writes with more FLOPs -> Speed up since we're memory bound y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = y.transpose(1, 2).contiguous().view(B, T, C)returnself.c_proj(y)class Block(nn.Module):def__init__(self, config: GPTConfig):super().__init__()self.ln_1 = nn.LayerNorm(config.n_embd)self.ln_2 = nn.LayerNorm(config.n_embd)self.mlp = MLP(config)self.attn = CausalSelfAttention(config)def forward(self, x):# Use Pre-layer normalization which deviates from the# transformer original paper that uses post-layer normalization.# This should help stabilize training x = x +self.attn(self.ln_1(x)) x = x +self.mlp(self.ln_2(x))return xclass GPT2(nn.Module):def__init__(self, config: GPTConfig):super().__init__()self.config = configself.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_sz, config.n_embd), wpe=nn.Embedding(config.block_sz, config.n_embd), h=nn.ModuleList( [Block(config) for _ inrange(config.n_layer)] ),# Final layer norm after all transformer layers ln_f=nn.LayerNorm(config.n_embd), ) )self.lm_head = nn.Linear(config.n_embd, config.vocab_sz, bias=False)# Weigth sharing between the token embedding layer and# last linear layer (LM head classifier). The rationale is# that tokens that are semantically similar to each other in# the embedding space should have similar probabilities in the# softmax of the LM head layer# Also, these matrices are one of the biggest matrices in the the model# This means, for model like GPT2, we save almost 30 % of the parameters# by sharing the weight matrices (50257 * 768) / 124M = ~31%self.transformer.wte.weight =self.lm_head.weightself.apply(self._init_weights)def _init_weights(self, module):# The following initialization comes from gpt2 src code# NOTE: Becuase token embedding and classifier weights are shared,# out initialization logic will initialize the weight matrix twice# but shouldn't be an issue since they're being initialized with the# same std and meanifisinstance(module, nn.Linear): std =0.02# We're changing std because residual path affect std# by increasing it on every layer so we need to adjust# it so we still have the same std = 0.02ifhasattr(module, "NANOGPT_SCALE_INIT"):# `2` here becauase every layer has two blocks:# Attention block and MLP block std *= (2*self.config.n_layer) **-0.5 nn.init.normal_(module.weight, std=std)if module.bias isnotNone: nn.init.zeros_(module.bias)elifisinstance(module, nn.Embedding):# We're initializing the token and positional embeddings# with the same std but the paper initialized the positional# embedding with std = 0.01 nn.init.normal_(module.weight, std=0.02)def forward(self, x, targets=None): T = x.shape[-1]assert ( T <=self.config.block_sz ), f"Sequence length must be <= {self.config.block_sz}, got {T}" pos_emb =self.transformer.wpe( torch.arange(0, T, dtype=torch.long, device=x.device) ) tok_emb =self.transformer.wte(x) x = pos_emb + tok_embfor block inself.transformer.h: x = block(x) x =self.transformer.ln_f(x)# logits is B x T x vocab_sz logits =self.lm_head(x) loss =Noneif targets isnotNone:# F.cross_entropy expects the 2nd dimension to be probabilities loss = F.cross_entropy( logits.view(-1, self.config.vocab_sz), targets.view(-1) )return logits, lossdef configure_optimizer(self, weight_decay, lr, device): params_dict = { pn: p for pn, p inself.named_parameters() if p.requires_grad } decay_params = [p for p in params_dict.values() if p.ndim >=2] nondecay_params = [p for p in params_dict.values() if p.ndim <2] params_groups = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": nondecay_params, "weight_decay": 0.0}, ] fused_available ="fused"in inspect.signature(opt.AdamW).parameters use_fused = fused_available and"cuda"in devicereturn opt.AdamW( params_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused )@torch.no_graddef generate(self, idxs: torch.tensor, max_tokens: int=5):for i inrange(max_tokens):# x would be B x T x vocab_sz idxs = idxs[:, -self.config.block_sz :] logits, _ =self(idxs)# Get probs for last token to predict next token# This would be B x vocab_sz logits = logits[:, -1, :]# Apply softmax to get probabilities probs = F.softmax(logits, dim=-1)# Each would be B x 50 topk_probs, topk_idxs = torch.topk(probs, 50, dim=-1)# idx is B x 1 idx = torch.multinomial(topk_probs, 1) idx = torch.gather(topk_idxs, -1, idx) idxs = torch.cat([idxs, idx], dim=1)return idxsclass DataLoaderLight:def__init__(self, batch_sz: int, block_sz: int, process_rank: int=0, number_processes: int=1, ) ->None:self.batch_sz = batch_szself.block_sz = block_szself.process_rank = process_rankself.number_processes = number_processeswithopen("input.txt", "r") as f: text = f.read() encoder = tiktoken.get_encoding("gpt2")self.tokens = torch.tensor(encoder.encode(text), dtype=torch.long)self.current_pos = batch_sz * block_sz * process_rankdef__len__(self):returnlen(self.tokens) // (self.batch_sz *self.block_sz)def next_batch(self): buf =self.tokens[self.current_pos : self.current_pos+self.batch_sz *self.block_sz+1 ] x = buf[:-1].view(self.batch_sz, self.block_sz) y = buf[1:].view(self.batch_sz, self.block_sz)# Each process will process batch_sz x block_sz tokens in each# iteration -> with number_processes processes, total tokens processed# in each iteration is batch_sz x block_sz x number_processes. In the# case of one process, total tokens would be batch_sz x block_szself.current_pos += (self.batch_sz *self.block_sz *self.number_processes )ifself.current_pos + (self.batch_sz *self.block_sz *self.number_processes ) +self.number_processes >len(self):self.current_pos =0return x, yif__name__=="__main__":############ Distributed Data Parallel############ torchrun command sets the following environment variables:# RANK: Id of the process in the process group. It is an int 0-WORLD_SIZE# LOCAL_RANK: In the case of multi-nodes, LOCAL_RANK is the id of# the process in the same node# WORLD_SIZE: Total number of processes ddp =int(os.getenv("RANK", -1)) !=-1# Check if it is a ddp runif ddp:# DDP requires CUDA so we need to set the device for each process# so only one process can run per deviceassert torch.cuda.is_available(), "DDP requires CUDA" init_process_group(backend="nccl") ddp_rank =int(os.getenv("RANK")) ddp_local_rank =int(os.getenv("LOCAL_RANK")) ddp_world_size =int(os.getenv("WORLD_SIZE")) device =f"cuda:{ddp_local_rank}" torch.cuda.set_device(device)# master process will do more things such as checkpointing and logging# while other processes would assist in the computations master_process = ddp_rank ==0else: ddp_rank =0 ddp_local_rank =0 ddp_world_size =1 master_process =Trueif torch.cuda.is_available(): device ="cuda"elif torch.backends.mps.is_built(): device ="mps" torch.mps.manual_seed(1337)else: device ="cpu"print(device) torch.manual_seed(1337)if torch.cuda.is_available(): torch.cuda.manual_seed(1337)########### Initialize model and optimizer########### Everything in GPUs is a power of 2 such as tiling ops# So try to always have matrices be power of 2. Here we# change the vocab_sz by rounding it up to the closest# number that is power of. This will increase space overhead# but would speed up computations model = GPT2(GPTConfig(vocab_sz=50304)).to(device)# Speed up model by building statis graph that analyzes all ops# and optimizes them such as fusing some of them to avoid unnecessary# trips to memory# model = torch.compile(model)if ddp: model = DDP(model, device_ids=[ddp_local_rank]) raw_model = model.module if ddp else model max_lr =3e-4 min_lr = max_lr *0.1 warmup_steps =10 max_steps =50 sched = combine_scheds( [warmup_steps / max_steps, 1- (warmup_steps / max_steps)], [lin_sched(min_lr, max_lr), cos_sched(max_lr, min_lr)], )# TODO: Move building of optimizer inside GPT2# Don't decay biases and 1D tensors such as layer norm tensors (scales and# biases) linear layer's tensors# optimizer = opt.AdamW(# model.parameters(), lr=GPTConfig.lr, betas=(0.9, 0.95), eps=1e-8# ) optimizer = raw_model.configure_optimizer( weight_decay=0.1, lr=max_lr, device=device )########### Run training loop########## NOTE: In order to run 0.5M tokens per fwd/bwd iteration, we need to# use gradient accumulation because we can't fit it in almost any commodity# GPY -> We only do backward after we loop through ~0.5M tokens. total_batch_sz =2**19# closest number to 0.5Massert ( total_batch_sz% (GPTConfig.batch_sz * GPTConfig.block_sz * ddp_world_size)==0,"total batch size must be divisible by micro batch_sz x block_sz x ddp_world_size", ) grad_accum_steps = total_batch_sz // ( GPTConfig.batch_sz * GPTConfig.block_sz * ddp_world_size )if master_process:print(f"Total desired batch size: {total_batch_sz}")print(f"Calculated gradient accumulation steps: {grad_accum_steps}") train_dl = DataLoaderLight( batch_sz=GPTConfig.batch_sz, block_sz=GPTConfig.block_sz )# Pytorch will use TensorFloat32 if available, else use FP32# But the weights will still be stored as FP32. It is just the# operations would be executed as TF32 if available torch.set_float32_matmul_precision("high")for step inrange(max_steps): start = time.time() x, y = train_dl.next_batch() x = x.to(device) y = y.to(device)# code.interact(local=locals()) optimizer.zero_grad() loss_accum =0.0for macro_step inrange(grad_accum_steps):if device =="cuda":# Tensors that will be greatly affected by less precission such# loss, layernorm would still be in FP32with torch.autocast(device_type=device, dtype=torch.bfloat16): logits, loss = model(x, y)else: logits, loss = model(x, y)# Just accumulation gradients yield to summation of objective but# we want mean so we weight each loss by 1/grad_accum_steps loss /= grad_accum_steps loss_accum += loss.detach()# To avoid syncing the gradients between the processes after every# macro step, we disable it and only allows the sync up of# gradients after we finish all gradient accumulation in each# processif ddp: model.require_backward_grad_sync = ( macro_step == grad_accum_steps -1 ) loss.backward()# Each process would have its own loss_accum tensor, so to get the# average loss_accum across all processes, we to compute the average of# all loss_accum in all processesif ddp: dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)# Clips gradient to global norm. It is very useful to avoid having a# very high loss for some batch(es) that would have very high loss# which would learn to high gradients and huge updates# In the beginning of training it is normal to have high norms as the# model initialized randomly norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)# TODO: Use ParamScheduler from `cmn_ai` lr = sched(step / max_steps)for pg in optimizer.param_groups: pg["lr"] = lr optimizer.step() end = time.time() elapsed_time = end - start token_per_sec = ( GPTConfig.batch_sz* GPTConfig.block_sz* grad_accum_steps* ddp_world_size ) / (elapsed_time)print(f"step {step}, loss: {loss.item()}, lr {lr:.4e}, norm: {norm:.2f}, time: {elapsed_time:.2f}s, tok/sec: {token_per_sec:.2f}" )if ddp:# Kills all processes destroy_process_group()