Coding GPT2/3 (124M) From Scratch

NLP
Deep Learning
Author

Imad Dabbura

Published

April 10, 2024

Introduction

TODO

  • In the first iteration of the training, we want all tokens to have almost same probability and thus the loss on each one would be the same. The probability of each token would be \(1/vocab\_sz\) -> \(loss \approx -log(1/vocab\_sz)\) because the probability distribution would be diffused.
  • 3e-4 learning rate is good for AdamW optimizer for debugging
  • We want the model to overfit 1 batch to make sure it is running correctly
  • Weight sharing between token embedding and the final linear layer (also called classifier or LM head) because we want the tokens that are semantically similar to have similar probability when predicting next token.
    • This also has huge advantage on computational efficiency as those matrices have a lot parameters. For GPT2, each one has \(50257 * 768 \approx 38.5M\) which is \(1/3\) of the GPT2 model.
    • As a result of Weight sharing, gradient update will be addition from the two branches: classifier and token embedding
  • For tokens that don’t appear in the training data, we want their probabilities to be very close to zero
  • CPU can continue running even if Cuda kernels are not done. This is because CPU is kinda scheduling the kernels on the GPU and doesn’t wait for them to finish -> Use torch.cuda.synchronize() so CPU only presumes when scheduled kernels finish execution to get better timings

Implementation

Code
import inspect
import math
import os
import sys
import time
from dataclasses import dataclass
from functools import partial, wraps
from typing import Callable

import tiktoken
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
Code
def annealer(func: Callable):
    wraps(func)

    def annealer_wrapper(*args, **kwargs):
        return partial(func, *args, **kwargs)

    return annealer_wrapper


@annealer
def lin_sched(start, end, pos):
    """Linear scheduler."""
    return start + (end - start) * pos


@annealer
def cos_sched(start, end, pos):
    """Cosine scheduler."""
    return start + (1 + math.cos(math.pi * (1 - pos))) * (end - start) / 2


def combine_scheds(pcts, scheds):
    """
    Combine multiple schedulers, each run for a given percentage of the
    training process.
    """
    assert len(pcts) == len(scheds), "Each scheduler should have its `pct`."
    assert sum(pcts) == 1.0, "Sum of the `pcts` should be equal to 1."
    pcts = torch.tensor([0] + listify(pcts))
    assert (pcts >= 0).all(), "All percentages should be non-negative."
    pcts = torch.cumsum(pcts, 0)

    def _inner(pos):
        idx = (pos >= pcts).nonzero().max()
        actual_pos = (pos - pcts[idx]) / (pcts[idx + 1] - pcts[idx])
        return scheds[idx](actual_pos)

    return _inner
@dataclass
class GPTConfig:
    block_sz: int = 1024
    vocab_sz: int = (
        50257  # 50000 BPE merges + 256 byte tokens + 1 for <|endoftext|> token
        # which will delimits different documents. This token's index is 50256
    )
    n_layer: int = 12
    n_embd: int = 768
    n_head: int = 12
    lr: int = 3e-4
    batch_sz: int = 4


class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        # Point-wise feed-forward network that applies non-linearity
        # on every token sepearately. THERE IS NO INTERACTION BETWEEN TOKENS
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, config.n_embd * 4)
        self.gelu = nn.GELU(approximate="tanh")
        self.c_proj = nn.Linear(config.n_embd * 4, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        return self.c_proj(self.gelu(self.c_fc(x)))


class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.c_attn = nn.Linear(config.n_embd, config.n_embd * 3)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # NOTE: Bias is not needed when we use Pytorch's Flash attention
        # self.register_buffer(
        #     "bias",
        #     torch.tril(torch.ones(config.block_sz, config.block_sz)).view(
        #         1, 1, config.block_sz, config.block_sz
        #     ),
        # )

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.c_attn(x)
        # q/k/v is B x T x n_embd each
        q, k, v = torch.split(qkv, self.n_embd, dim=-1)
        # Reshape q/k/v to B x n_head x T x (n_embd / n_head)
        # So each head would be learning different kind of
        # relationships
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        # attn is B x T x T
        # attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.shape[-1]))
        # # Mask out future tokens
        # attn = attn.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        # attn = F.softmax(attn, dim=-1)
        # # y is B x T x n_embd
        # y = attn @ v
        # Uses Flash attention that never materialize attention matrices for
        # each head and is aware of the memory hierarchy and tries to reduce
        # read/writes with more FLOPs -> Speed up since we're memory bound
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)


class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
        self.attn = CausalSelfAttention(config)

    def forward(self, x):
        # Use Pre-layer normalization which deviates from the
        # transformer original paper that uses post-layer normalization.
        # This should help stabilize training
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT2(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_sz, config.n_embd),
                wpe=nn.Embedding(config.block_sz, config.n_embd),
                h=nn.ModuleList(
                    [Block(config) for _ in range(config.n_layer)]
                ),
                # Final layer norm after all transformer layers
                ln_f=nn.LayerNorm(config.n_embd),
            )
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_sz, bias=False)

        # Weigth sharing between the token embedding layer and
        # last linear layer (LM head classifier). The rationale is
        # that tokens that are semantically similar to each other in
        # the embedding space should have similar probabilities in the
        # softmax of the LM head layer
        # Also, these matrices are one of the biggest matrices in the the model
        # This means, for model like GPT2, we save almost 30 % of the parameters
        # by sharing the weight matrices (50257 * 768) / 124M = ~31%
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        # The following initialization comes from gpt2 src code
        # NOTE: Becuase token embedding and classifier weights are shared,
        # out initialization logic will initialize the weight matrix twice
        # but shouldn't be an issue since they're being initialized with the
        # same std and mean
        if isinstance(module, nn.Linear):
            std = 0.02
            # We're changing std because residual path affect std
            # by increasing it on every layer so we need to adjust
            # it so we still have the same std = 0.02
            if hasattr(module, "NANOGPT_SCALE_INIT"):
                # `2` here becauase every layer has two blocks:
                # Attention block and MLP block
                std *= (2 * self.config.n_layer) ** -0.5
            nn.init.normal_(module.weight, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            # We're initializing the token and positional embeddings
            # with the same std but the paper initialized the positional
            # embedding with std = 0.01
            nn.init.normal_(module.weight, std=0.02)

    def forward(self, x, targets=None):
        T = x.shape[-1]
        assert (
            T <= self.config.block_sz
        ), f"Sequence length must be <= {self.config.block_sz}, got {T}"
        pos_emb = self.transformer.wpe(
            torch.arange(0, T, dtype=torch.long, device=x.device)
        )
        tok_emb = self.transformer.wte(x)
        x = pos_emb + tok_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        # logits is B x T x vocab_sz
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            # F.cross_entropy expects the 2nd dimension to be probabilities
            loss = F.cross_entropy(
                logits.view(-1, self.config.vocab_sz), targets.view(-1)
            )
        return logits, loss

    def configure_optimizer(self, weight_decay, lr, device):
        params_dict = {
            pn: p for pn, p in self.named_parameters() if p.requires_grad
        }
        decay_params = [p for p in params_dict.values() if p.ndim >= 2]
        nondecay_params = [p for p in params_dict.values() if p.ndim < 2]
        params_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nondecay_params, "weight_decay": 0.0},
        ]
        fused_available = "fused" in inspect.signature(opt.AdamW).parameters
        use_fused = fused_available and "cuda" in device
        return opt.AdamW(
            params_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused
        )

    @torch.no_grad
    def generate(self, idxs: torch.tensor, max_tokens: int = 5):
        for i in range(max_tokens):
            # x would be B x T x vocab_sz
            idxs = idxs[:, -self.config.block_sz :]
            logits, _ = self(idxs)
            # Get probs for last token to predict next token
            # This would be B x vocab_sz
            logits = logits[:, -1, :]
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)
            # Each would be B x 50
            topk_probs, topk_idxs = torch.topk(probs, 50, dim=-1)
            # idx is B x 1
            idx = torch.multinomial(topk_probs, 1)
            idx = torch.gather(topk_idxs, -1, idx)
            idxs = torch.cat([idxs, idx], dim=1)
        return idxs


class DataLoaderLight:
    def __init__(
        self,
        batch_sz: int,
        block_sz: int,
        process_rank: int = 0,
        number_processes: int = 1,
    ) -> None:
        self.batch_sz = batch_sz
        self.block_sz = block_sz
        self.process_rank = process_rank
        self.number_processes = number_processes
        with open("input.txt", "r") as f:
            text = f.read()
        encoder = tiktoken.get_encoding("gpt2")
        self.tokens = torch.tensor(encoder.encode(text), dtype=torch.long)
        self.current_pos = batch_sz * block_sz * process_rank

    def __len__(self):
        return len(self.tokens) // (self.batch_sz * self.block_sz)

    def next_batch(self):
        buf = self.tokens[
            self.current_pos : self.current_pos
            + self.batch_sz * self.block_sz
            + 1
        ]
        x = buf[:-1].view(self.batch_sz, self.block_sz)
        y = buf[1:].view(self.batch_sz, self.block_sz)
        # Each process will process batch_sz x block_sz tokens in each
        # iteration -> with number_processes processes, total tokens processed
        # in each iteration is batch_sz x block_sz x number_processes. In the
        # case of one process, total tokens would be batch_sz x block_sz
        self.current_pos += (
            self.batch_sz * self.block_sz * self.number_processes
        )
        if self.current_pos + (
            self.batch_sz * self.block_sz * self.number_processes
        ) + self.number_processes > len(self):
            self.current_pos = 0
        return x, y


if __name__ == "__main__":
    ###########
    # Distributed Data Parallel
    ###########
    # torchrun command sets the following environment variables:
    # RANK: Id of the process in the process group. It is an int 0-WORLD_SIZE
    # LOCAL_RANK: In the case of multi-nodes, LOCAL_RANK is the id of
    #             the process in the same node
    # WORLD_SIZE: Total number of processes
    ddp = int(os.getenv("RANK", -1)) != -1  # Check if it is a ddp run
    if ddp:
        # DDP requires CUDA so we need to set the device for each process
        # so only one process can run per device
        assert torch.cuda.is_available(), "DDP requires CUDA"
        init_process_group(backend="nccl")
        ddp_rank = int(os.getenv("RANK"))
        ddp_local_rank = int(os.getenv("LOCAL_RANK"))
        ddp_world_size = int(os.getenv("WORLD_SIZE"))
        device = f"cuda:{ddp_local_rank}"
        torch.cuda.set_device(device)
        # master process will do more things such as checkpointing and logging
        # while other processes would assist in the computations
        master_process = ddp_rank == 0
    else:
        ddp_rank = 0
        ddp_local_rank = 0
        ddp_world_size = 1
        master_process = True
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_built():
            device = "mps"
            torch.mps.manual_seed(1337)
        else:
            device = "cpu"
    print(device)
    torch.manual_seed(1337)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)
    ##########
    # Initialize model and optimizer
    ##########
    # Everything in GPUs is a power of 2 such as tiling ops
    # So try to always have matrices be power of 2. Here we
    # change the vocab_sz by rounding it up to the closest
    # number that is power of. This will increase space overhead
    # but would speed up computations
    model = GPT2(GPTConfig(vocab_sz=50304)).to(device)
    # Speed up model by building statis graph that analyzes all ops
    # and optimizes them such as fusing some of them to avoid unnecessary
    # trips to memory
    # model = torch.compile(model)
    if ddp:
        model = DDP(model, device_ids=[ddp_local_rank])
    raw_model = model.module if ddp else model
    max_lr = 3e-4
    min_lr = max_lr * 0.1
    warmup_steps = 10
    max_steps = 50
    sched = combine_scheds(
        [warmup_steps / max_steps, 1 - (warmup_steps / max_steps)],
        [lin_sched(min_lr, max_lr), cos_sched(max_lr, min_lr)],
    )
    # TODO: Move building of optimizer inside GPT2
    # Don't decay biases and 1D tensors such as layer norm tensors (scales and
    # biases) linear layer's tensors
    # optimizer = opt.AdamW(
    #     model.parameters(), lr=GPTConfig.lr, betas=(0.9, 0.95), eps=1e-8
    # )
    optimizer = raw_model.configure_optimizer(
        weight_decay=0.1, lr=max_lr, device=device
    )

    ##########
    # Run training loop
    #########
    # NOTE: In order to run 0.5M tokens per fwd/bwd iteration, we need to
    # use gradient accumulation because we can't fit it in almost any commodity
    # GPY -> We only do backward after we loop through ~0.5M tokens.
    total_batch_sz = 2**19  # closest number to 0.5M
    assert (
        total_batch_sz
        % (GPTConfig.batch_sz * GPTConfig.block_sz * ddp_world_size)
        == 0,
        "total batch size must be divisible by micro batch_sz x block_sz x ddp_world_size",
    )
    grad_accum_steps = total_batch_sz // (
        GPTConfig.batch_sz * GPTConfig.block_sz * ddp_world_size
    )
    if master_process:
        print(f"Total desired batch size: {total_batch_sz}")
        print(f"Calculated gradient accumulation steps: {grad_accum_steps}")

    train_dl = DataLoaderLight(
        batch_sz=GPTConfig.batch_sz, block_sz=GPTConfig.block_sz
    )
    # Pytorch will use TensorFloat32 if available, else use FP32
    # But the weights will still be stored as FP32. It is just the
    # operations would be executed as TF32 if available
    torch.set_float32_matmul_precision("high")

    for step in range(max_steps):
        start = time.time()
        x, y = train_dl.next_batch()
        x = x.to(device)
        y = y.to(device)
        # code.interact(local=locals())
        optimizer.zero_grad()
        loss_accum = 0.0
        for macro_step in range(grad_accum_steps):
            if device == "cuda":
                # Tensors that will be greatly affected by less precission such
                # loss, layernorm would still be in FP32
                with torch.autocast(device_type=device, dtype=torch.bfloat16):
                    logits, loss = model(x, y)
            else:
                logits, loss = model(x, y)
            # Just accumulation gradients yield to summation of objective but
            # we want mean so we weight each loss by 1/grad_accum_steps
            loss /= grad_accum_steps
            loss_accum += loss.detach()
            # To avoid syncing the gradients between the processes after every
            # macro step, we disable it and only allows the sync up of
            # gradients after we finish all gradient accumulation in each
            # process
            if ddp:
                model.require_backward_grad_sync = (
                    macro_step == grad_accum_steps - 1
                )
            loss.backward()
        # Each process would have its own loss_accum tensor, so to get the
        # average loss_accum across all processes, we to compute the average of
        # all loss_accum in all processes
        if ddp:
            dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
        # Clips gradient to global norm. It is very useful to avoid having a
        # very high loss for some batch(es) that would have very high loss
        # which would learn to high gradients and huge updates
        # In the beginning of training it is normal to have high norms as the
        # model initialized randomly
        norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # TODO: Use ParamScheduler from `cmn_ai`
        lr = sched(step / max_steps)
        for pg in optimizer.param_groups:
            pg["lr"] = lr
        optimizer.step()
        end = time.time()
        elapsed_time = end - start
        token_per_sec = (
            GPTConfig.batch_sz
            * GPTConfig.block_sz
            * grad_accum_steps
            * ddp_world_size
        ) / (elapsed_time)
        print(
            f"step {step}, loss: {loss.item()}, lr {lr:.4e}, norm: {norm:.2f}, time: {elapsed_time:.2f}s, tok/sec: {token_per_sec:.2f}"
        )

    if ddp:
        # Kills all processes
        destroy_process_group()

Conclusion

TODO

Resources