<?xml version="1.0" encoding="UTF-8"?>
<rss  xmlns:atom="http://www.w3.org/2005/Atom" 
      xmlns:media="http://search.yahoo.com/mrss/" 
      xmlns:content="http://purl.org/rss/1.0/modules/content/" 
      xmlns:dc="http://purl.org/dc/elements/1.1/" 
      version="2.0">
<channel>
<title>Imad Dabbura</title>
<link>https://imaddabbura.github.io/posts.html</link>
<atom:link href="https://imaddabbura.github.io/posts.xml" rel="self" type="application/rss+xml"/>
<description>Deep science. Built from scratch. Shared openly.</description>
<image>
<url>https://imaddabbura.github.io/images/profile-pic.png</url>
<title>Imad Dabbura</title>
<link>https://imaddabbura.github.io/posts.html</link>
<height>152</height>
<width>144</width>
</image>
<generator>quarto-1.4.553</generator>
<lastBuildDate>Sun, 21 Sep 2025 05:00:00 GMT</lastBuildDate>
<item>
  <title>Make ML Systems Ship Again</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/mlsys/improving-mlsys-theory-of-constraint.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<p><a href="images/network-anomaly-detection-toc.jpg" class="lightbox" data-gallery="quarto-lightbox-gallery-1"><img src="https://imaddabbura.github.io/posts/mlsys/images/network-anomaly-detection-toc.jpg" class="img-fluid"></a></p>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p>You burn six months “optimizing.” Swap in transformers. Squeeze another +<img src="https://latex.codecogs.com/png.latex?0.5%5C%25"> accuracy. Rewrite the feature pipeline. Add a shiny GPU cluster. And still: alert fatigue, missed incidents, and latency that kills real-time response.</p>
<p>That’s optimization theater.</p>
<p>This pattern shows up everywhere in production ML. Fraud teams add transaction features that never reduce false positives. Recommendation engines get fancier models that don’t move click-through rates. Forecasting pipelines gain complexity without improving planning accuracy. Parts get optimized. Systems don’t.</p>
<p>This post gives you a systematic method to break out of the cycle. It’s based on the Theory of Constraints — originally developed for manufacturing, but a natural fit for ML systems. We’ll use a network anomaly detection system as our running example, but the playbook works for any ML system in production.</p>
<section id="roadmap" class="level3">
<h3 class="anchored" data-anchor-id="roadmap">Roadmap</h3>
<table class="table">
<colgroup>
<col style="width: 19%">
<col style="width: 46%">
<col style="width: 34%">
</colgroup>
<thead>
<tr class="header">
<th>Section</th>
<th>What You’ll Learn / Do</th>
<th>Why It Matters</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>The Theory of Constraints</strong></td>
<td>The core idea and why single-bottleneck focus works</td>
<td>Gives you the mental model that makes the steps principled, not arbitrary</td>
</tr>
<tr class="even">
<td><strong>1. Goal &amp; Constraint</strong></td>
<td>Set SLOs, then build a constraint ledger to find the bottleneck</td>
<td>Defines success and focuses effort on the one thing that governs throughput</td>
</tr>
<tr class="odd">
<td><strong>2. Understand Why It’s Stuck</strong></td>
<td>Root-cause analysis</td>
<td>Prevents solving symptoms</td>
</tr>
<tr class="even">
<td><strong>3. See the Hidden Tradeoff</strong></td>
<td>Map the conflict</td>
<td>Reveals why simple fixes haven’t worked</td>
</tr>
<tr class="odd">
<td><strong>4. Break the Tradeoff</strong></td>
<td>Challenge assumptions, then innovate</td>
<td>Achieves step-function improvement</td>
</tr>
<tr class="even">
<td><strong>5. Prove It Works</strong></td>
<td>Minimum Viable Experiment</td>
<td>Validates before full investment</td>
</tr>
</tbody>
</table>
</section>
</section>
<section id="the-theory-of-constraints-in-5-minutes" class="level2">
<h2 class="anchored" data-anchor-id="the-theory-of-constraints-in-5-minutes">The Theory of Constraints in 5 Minutes</h2>
<p>The traditional approach to improving ML systems is based on a seemingly logical but flawed assumption: if you improve each component, the whole system improves. It doesn’t. <em>The sum of all local improvements doesn’t give you a system improvement.</em></p>
<p>The breakthrough insight, from Eli Goldratt’s <em>The Goal</em> (1984) and made operational by Alan Barnard’s pairing method, is simple: every system has exactly one constraint at any given moment — the single resource or stage you don’t have enough of. That constraint sets the ceiling for the entire system. Improving anything else delivers diminishing-to-zero returns.</p>
<p>A factory line can only produce as fast as its slowest machine. If the paint booth takes 10 minutes per car while everything else takes 2 minutes, buying faster welding robots changes nothing. You have to speed up the paint booth — or the line will forever produce one car every 10 minutes.</p>
<p>In a serial pipeline — which is what most ML systems are — this is even starker than it sounds. Throughput equals the throughput of the slowest stage. If your feature extraction handles 30K records/sec and everything else handles 100K, the system does 30K. Making inference 10x faster? Still 30K. Doubling ingest capacity? Still 30K. Only improving the bottleneck stage moves the number. Every other “optimization” is buying faster welding robots while the paint booth sets the pace.</p>
<p><a href="images/amdahl-law-optimization.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-2"><img src="https://imaddabbura.github.io/posts/mlsys/images/amdahl-law-optimization.svg" class="img-fluid"></a></p>
<p>Barnard turns this into a strict chain of focused pairings. Each pairing links a WHAT (what you need) to a HOW (how to get it), maintaining a one-to-one relationship that keeps focus razor-sharp:</p>
<ol type="1">
<li><strong>Goal → Constraint</strong>: WHAT do I want? More of the Goal. HOW? By getting more of the Constraint — the single resource I don’t have enough of.</li>
<li><strong>Constraint → Problem</strong>: WHAT limits the Constraint? The one Problem causing at least 50% of the gap.</li>
<li><strong>Problem → Conflict</strong>: WHY hasn’t the Problem been solved? Because it’s an unresolved Conflict between two necessary-but-competing approaches.</li>
<li><strong>Conflict → Innovation</strong>: HOW do I resolve it? With an Innovation that captures the Pros of <em>both</em> the current approach and the new idea. The aim is all the Pros — but some tradeoffs may remain. The key is they’re deliberate and tolerable, not the paralyzing either/or you started with.</li>
<li><strong>Innovation → Experiment</strong>: HOW do I know it works? With a Minimally Viable Experiment — before building anything.</li>
</ol>
<p>The five how-to steps below translate these pairings into ML-systems language. Step 1 defines the Goal (SLOs) and finds the Constraint (bottleneck). Step 2 uncovers the Problem (root cause). Step 3 maps the Conflict (hidden tradeoff). Step 4 designs the Innovation. Step 5 runs the Experiment.</p>
<p>So why does this matter for ML specifically? Because ML pipelines are textbook flow systems: ingest → features → inference → action. They have measurable stages with capacity limits. And they accumulate complexity over time — teams add features, models, and infrastructure without ever removing anything. This makes them natural candidates for constraint-based thinking. But ML teams rarely think this way, because they’re trained to optimize <em>models</em>, not <em>systems</em>.</p>
<p><a href="images/ml-pipeline-weak-constraint.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-3"><img src="https://imaddabbura.github.io/posts/mlsys/images/ml-pipeline-weak-constraint.svg" class="img-fluid"></a></p>
</section>
<section id="step-1-define-your-goal-and-find-the-bottleneck" class="level2">
<h2 class="anchored" data-anchor-id="step-1-define-your-goal-and-find-the-bottleneck">Step 1: Define Your Goal and Find the Bottleneck</h2>
<p>Before you can find the bottleneck, you need to define what success actually means — in numbers, not aspirations. And before you can fix the bottleneck, you need to know which stage is actually holding the system back. This step covers both.</p>
<section id="set-your-slos" class="level3">
<h3 class="anchored" data-anchor-id="set-your-slos">Set Your SLOs</h3>
<p>“Detect anomalies,” “reduce fraud,” and “improve recommendations” aren’t goals. They’re wishes. Without measurable targets, every team member optimizes for a different thing, and you can’t tell whether you’re constrained by latency, precision, coverage, or something else entirely.</p>
<p>The fix is Service Level Objectives — specific, measurable thresholds tied to business outcomes:</p>
<table class="table">
<colgroup>
<col style="width: 34%">
<col style="width: 43%">
<col style="width: 21%">
</colgroup>
<thead>
<tr class="header">
<th>SLO Dimension</th>
<th>What It Measures</th>
<th>Fill In</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Time-to-Decision (TTD)</strong></td>
<td>How fast the system produces an actionable output</td>
<td>p95 ≤ ___</td>
</tr>
<tr class="even">
<td><strong>Decision Budget</strong></td>
<td>How many outputs a human can realistically handle</td>
<td>≤ ___ per day</td>
</tr>
<tr class="odd">
<td><strong>Outcome-Weighted Performance</strong></td>
<td>Accuracy weighted by business impact, not volume</td>
<td>≥ ___%</td>
</tr>
<tr class="even">
<td><strong>Coverage</strong></td>
<td>Fraction of relevant events actually processed</td>
<td>≥ ___%</td>
</tr>
<tr class="odd">
<td><strong>Data Loss</strong></td>
<td>Events dropped or degraded in transit</td>
<td>≤ ___%</td>
</tr>
</tbody>
</table>
<p>These five dimensions force hard conversations. A model with 99% accuracy but 30-minute detection latency fails the TTD target. A model with perfect precision but 500 daily alerts fails the decision budget. The SLOs define the feasible region — and crucially, reveal <em>what’s blocking you</em> from reaching it.</p>
<p>Here’s how our network anomaly detection system instantiated these:</p>
<ul>
<li><strong>TTD</strong>: p95 ≤ 5 minutes from event to alert</li>
<li><strong>Alert Budget</strong>: ≤ 10 analyst-actionable alerts/day</li>
<li><strong>Incident-Weighted Recall</strong>: ≥ 90%</li>
</ul>
<p>The same template applies to other domains. A fraud detection team might set TTD ≤ 200ms with ≤ 50 manual reviews/day. A recommendation system might target TTD ≤ 100ms with CTR-weighted precision ≥ X%.</p>
<p><a href="images/slo-hierarchy.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-4"><img src="https://imaddabbura.github.io/posts/mlsys/images/slo-hierarchy.svg" class="img-fluid"></a></p>
<p>When defining SLOs, involve the people who <em>use</em> the system’s outputs — not just the team that builds it. Security analysts, operations teams, business stakeholders. When they disagree (and they will — security wants recall, ops wants fewer alerts), the SLOs make the tradeoff explicit rather than hiding it inside model thresholds.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Pitfall: Vanity Metrics Over Business Outcomes
</div>
</div>
<div class="callout-body-container callout-body">
<p>Teams optimize metrics that sound impressive but don’t connect to business value. “99.9% precision” means nothing if you’re missing 90% of incidents. “Processing 1M events/second” is irrelevant if decisions take 30 minutes. In our case, we celebrated achieving 99% detection rate on port scans — which the SOC ignored anyway — while missing lateral movement using legitimate credentials. Define SLOs tied to outcomes, not to model scorecards.</p>
</div>
</div>
<p>SLOs don’t just measure success — they <em>reveal</em> what’s blocking it. If you can’t meet your TTD target, the bottleneck is somewhere in your latency path. If you can’t meet your alert budget, the bottleneck is in precision or triage capacity. Now let’s find exactly where.</p>
</section>
<section id="find-the-bottleneck" class="level3">
<h3 class="anchored" data-anchor-id="find-the-bottleneck">Find the Bottleneck</h3>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Mental Model
</div>
</div>
<div class="callout-body-container callout-body">
<p>Your ML pipeline is a series of stages, each with a capacity ceiling. The stage with the lowest effective capacity is your constraint — it sets the ceiling for the entire system. Everything upstream queues up; everything downstream sits idle. Barnard’s memorable shortcut: <em>“Check what you’re waiting for. Where’s the backlog?”</em></p>
</div>
</div>
<p>With SLOs defined, you can systematically measure where the system breaks down. Build a <strong>constraint ledger</strong> — a table measuring capacity, utilization, latency, queue depth, and top failure mode at each pipeline stage:</p>
<table class="table">
<colgroup>
<col style="width: 8%">
<col style="width: 22%">
<col style="width: 15%">
<col style="width: 15%">
<col style="width: 15%">
<col style="width: 21%">
</colgroup>
<thead>
<tr class="header">
<th>Stage</th>
<th style="text-align: right;">Capacity (rec/s)</th>
<th style="text-align: right;">Utilization</th>
<th style="text-align: right;">p95 Latency</th>
<th style="text-align: right;">Queue Depth</th>
<th>Top Failure Mode</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Ingest</td>
<td style="text-align: right;">100K</td>
<td style="text-align: right;">60%</td>
<td style="text-align: right;">2ms</td>
<td style="text-align: right;">0</td>
<td>burst loss</td>
</tr>
<tr class="even">
<td>Feature-Tier1</td>
<td style="text-align: right;">100K</td>
<td style="text-align: right;">65%</td>
<td style="text-align: right;">5ms</td>
<td style="text-align: right;">0</td>
<td>cache miss</td>
</tr>
<tr class="odd">
<td><strong>Feature-Tier2</strong></td>
<td style="text-align: right;"><strong>30K</strong></td>
<td style="text-align: right;"><strong>95%</strong></td>
<td style="text-align: right;"><strong>50ms</strong></td>
<td style="text-align: right;"><strong>1.2K</strong></td>
<td><strong>window skew</strong></td>
</tr>
<tr class="even">
<td>Feature-Tier3</td>
<td style="text-align: right;">10K</td>
<td style="text-align: right;">20%</td>
<td style="text-align: right;">200ms</td>
<td style="text-align: right;">0</td>
<td>cold start</td>
</tr>
<tr class="odd">
<td>Inference</td>
<td style="text-align: right;">50K</td>
<td style="text-align: right;">40%</td>
<td style="text-align: right;">10ms</td>
<td style="text-align: right;">0</td>
<td>batch sizing</td>
</tr>
<tr class="even">
<td>Alerting</td>
<td style="text-align: right;">1K</td>
<td style="text-align: right;">10%</td>
<td style="text-align: right;">100ms</td>
<td style="text-align: right;">0</td>
<td>dedup thrash</td>
</tr>
</tbody>
</table>
<p>The diagnostic pattern is simple: <strong>high utilization + growing queue = bottleneck</strong>. Feature-Tier2 jumps out — 95% utilization with a queue of 1.2K while other stages sit at 10–65%. During peak periods, the system is forced to either sample traffic (missing attacks), queue records (violating TTD), or drop features (hurting accuracy). The model never sees complete feature representations because feature extraction can’t keep pace.</p>
<p>Build this table for your own system. The constraint is almost always obvious once you measure.</p>
<blockquote class="blockquote">
<p><strong>Capacity conversion</strong>: 10 Gbps network traffic ≈ 100K flows/sec.&nbsp;1M daily e-commerce orders ≈ 12/sec average, 50/sec peak. 10K IoT sensors at 1Hz ≈ 10K records/sec.</p>
</blockquote>
</section>
<section id="validate-before-you-invest" class="level3">
<h3 class="anchored" data-anchor-id="validate-before-you-invest">Validate Before You Invest</h3>
<p>Before building anything, run a 24-hour experiment: temporarily throw 3x resources at your suspected bottleneck. If system-level metrics improve dramatically, you’ve found the right constraint. If not, look elsewhere. This experiment costs a day; building the wrong solution costs months.</p>
<p>We provisioned 3x compute for Feature-Tier2, enabling 90K records/sec.&nbsp;The results were dramatic: detection time dropped, false positives decreased (the model makes better decisions with complete feature sets), and we nearly met our SLOs. No other improvement — not model accuracy, not infrastructure, not threshold tuning — would have achieved this.</p>
<p><a href="images/before-after-dashboard.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-5"><img src="https://imaddabbura.github.io/posts/mlsys/images/before-after-dashboard.svg" class="img-fluid"></a></p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Pitfall: Premature Model Optimization
</div>
</div>
<div class="callout-body-container callout-body">
<p>Teams spend months improving model accuracy while system-level metrics stagnate. The pipeline logic explains why: if the constraint isn’t in the model, then making the model infinitely better has zero impact on system throughput. We spent three months experimenting with transformer architectures for 2% accuracy improvement — while 70% of traffic was never analyzed due to feature extraction bottlenecks. The transformer detected sophisticated attacks brilliantly, on the 30% of traffic it actually saw. Always validate the constraint before optimizing.</p>
</div>
</div>
<p><strong>Key takeaway:</strong> The constraint is the only thing worth optimizing right now. Everything else is rearranging deck chairs.</p>
</section>
</section>
<section id="step-2-understand-why-its-stuck" class="level2">
<h2 class="anchored" data-anchor-id="step-2-understand-why-its-stuck">Step 2: Understand Why It’s Stuck</h2>
<p>You’ve found the bottleneck. Now resist the urge to fix the surface symptom. “Feature extraction is slow” is a temperature reading, not a diagnosis. You need the underlying cause — because the cause determines the cure.</p>
<section id="five-whys-with-evidence" class="level3">
<h3 class="anchored" data-anchor-id="five-whys-with-evidence">Five Whys — With Evidence</h3>
<p>The Five Whys technique is simple: ask “why” repeatedly until you reach a root cause, but <em>validate each answer with evidence</em> before proceeding to the next. Unvalidated whys lead to plausible-sounding but wrong root causes.</p>
<p><a href="images/five-whys.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-6"><img src="https://imaddabbura.github.io/posts/mlsys/images/five-whys.svg" class="img-fluid"></a></p>
<p>Here’s how this played out for our Feature-Tier2 bottleneck:</p>
<ol type="1">
<li><strong>Why</strong> is Feature-Tier2 at 95% utilization? → It computes 47 features per record. <em>(Validated: profiling shows 89% of computation in 12% of features)</em></li>
<li><strong>Why</strong> so many features? → Designed for offline research with unlimited compute. <em>(Validated: 31 features contribute &lt;0.1% to decisions)</em></li>
<li><strong>Why</strong> no production constraints in the design? → Development was disconnected from deployment. <em>(Validated: git history shows features added without removal)</em></li>
<li><strong>Why</strong> disconnected? → ML team and platform team operate in silos. <em>(Validated: team interviews confirm no shared requirements)</em></li>
<li><strong>Why</strong> silos? → No ownership of end-to-end system performance.</li>
</ol>
<p>Notice where we ended up: the root cause isn’t technical — it’s organizational.</p>
<p><a href="images/iceberg-visible-constraint.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-7"><img src="https://imaddabbura.github.io/posts/mlsys/images/iceberg-visible-constraint.svg" class="img-fluid"></a></p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Common Root Causes in ML Systems — Check Which Applies
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Feature Explosion</strong>: Teams extract every conceivable signal because “it might help.” Features grow monotonically — each has an advocate, none has a removal date. Most provide redundant information.</li>
<li><strong>Multi-granularity Overhead</strong>: Computing signals at every timescale (seconds, minutes, hours, days) when most decisions only need one. Common in anomaly detection, fraud, and demand forecasting.</li>
<li><strong>Stale Reference Data</strong>: Maintaining expensive rolling statistics (baselines, embeddings, aggregates) for thousands of entities, even though most change negligibly between updates. The recomputation cost dwarfs the information gained.</li>
</ul>
</div>
</div>
<p>If your Five Whys keep ending at technical causes, go one more level. The technical problem often has an organizational parent — siloed teams, misaligned incentives, no end-to-end ownership. These patterns aren’t unique to our system. Fraud detection, recommendations, and forecasting all exhibit the same failure modes.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Pitfall: Feature Creep Without Cost Analysis
</div>
</div>
<div class="callout-body-container callout-body">
<p>Feature counts grow monotonically because each has an advocate who remembers when it caught something. Our system grew from 50 to 247 features over two years. Analysis showed 180 contributed &lt;0.1% to decisions but consumed 60% of computation. Track a <strong>feature value score</strong> — importance divided by computational cost — and require cost-benefit analysis for new features.</p>
</div>
</div>
<p><strong>Key takeaway:</strong> Root causes are usually organizational, not algorithmic. If you fix the technical symptom without fixing the organizational cause, the symptom will return.</p>
</section>
</section>
<section id="step-3-see-the-hidden-tradeoff" class="level2">
<h2 class="anchored" data-anchor-id="step-3-see-the-hidden-tradeoff">Step 3: See the Hidden Tradeoff</h2>
<p>You know the root cause. So why hasn’t anyone fixed it? Almost always, it’s because the problem is an unresolved conflict — and people are stuck choosing between two approaches that both seem necessary.</p>
<p>Barnard puts it precisely: <em>any problem can be defined as an unresolved conflict.</em> In our case, the Feature-Tier2 bottleneck persists because of a fundamental tension:</p>
<p><a href="images/feature-extraction-conflict.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-8"><img src="https://imaddabbura.github.io/posts/mlsys/images/feature-extraction-conflict.svg" class="img-fluid"></a></p>
<p>We need <strong>rich feature analysis</strong> for accurate detection of sophisticated attacks. We <em>also</em> need <strong>efficient processing</strong> for real-time response and cost control. These seem to contradict each other, so the team oscillates — add features after a missed attack, remove features after a performance degradation. Two years later, they’re exactly where they started.</p>
<section id="why-teams-get-stuck" class="level3">
<h3 class="anchored" data-anchor-id="why-teams-get-stuck">Why Teams Get Stuck</h3>
<p>Barnard identifies two failure modes that keep teams trapped in these oscillations:</p>
<ul>
<li><strong>Getting stuck / procrastinating</strong>: Exaggerated fears — fear of losing what the current approach does well, or fear of the effort and risk required to change. (“If we remove features, we’ll miss attacks.”)</li>
<li><strong>Overreacting / jumping to conclusions</strong>: Exaggerated frustration with the current approach’s downsides, or exaggerated expectations of a new solution. (“Let’s just throw out all the expensive features and rely on the model.”)</li>
</ul>
<p>Most ML teams alternate between these two modes without recognizing the pattern.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Pitfall: Alert Budget Myopia
</div>
</div>
<div class="callout-body-container callout-body">
<p>A textbook case of oscillation: facing missed incidents, teams lower thresholds (overreacting). This floods analysts with alerts, who start ignoring them, leading to <em>more</em> missed incidents — which triggers another round of threshold lowering. Little’s Law makes the math concrete: L = λW — if analysts can investigate 50 alerts/day and each takes 45 minutes, that’s the hard capacity ceiling. No threshold change can overcome it. This is the precision/coverage conflict manifesting as a vicious cycle.</p>
</div>
</div>
<p><a href="images/alert-fatigue.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-9"><img src="https://imaddabbura.github.io/posts/mlsys/images/alert-fatigue.svg" class="img-fluid"></a></p>
</section>
<section id="map-your-conflict" class="level3">
<h3 class="anchored" data-anchor-id="map-your-conflict">Map Your Conflict</h3>
<p>The breakthrough comes from asking: <em>what assumptions make this conflict seem unresolvable?</em> To find them, map the conflict explicitly:</p>
<div class="sourceCode" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode md code-with-copy"><code class="sourceCode markdown"><span id="cb1-1">We need <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">[</span><span class="ot" style="color: #003B4F;
background-color: null;
font-style: inherit;">rich feature analysis</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">]</span> to achieve <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">[</span><span class="ot" style="color: #003B4F;
background-color: null;
font-style: inherit;">accurate detection</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">]</span>.</span>
<span id="cb1-2">We need <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">[</span><span class="ot" style="color: #003B4F;
background-color: null;
font-style: inherit;">efficient processing</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">]</span> to achieve <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">[</span><span class="ot" style="color: #003B4F;
background-color: null;
font-style: inherit;">real-time response</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">]</span>.</span>
<span id="cb1-3">These conflict because we assume:</span>
<span id="cb1-4"><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">  1. </span>All records need the same analysis depth</span>
<span id="cb1-5"><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">  2. </span>Features must be computed synchronously</span>
<span id="cb1-6"><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">  3. </span>One model handles all decisions</span>
<span id="cb1-7"></span>
<span id="cb1-8">Challenge each:</span>
<span id="cb1-9"><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">  - </span>Is assumption 1 always true? No — routine DNS queries</span>
<span id="cb1-10">    don't need the same scrutiny as connections to unknown IPs.</span>
<span id="cb1-11"><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">  - </span>Is assumption 2 always true? No — historical comparisons</span>
<span id="cb1-12">    could be asynchronous.</span>
<span id="cb1-13"><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">  - </span>Is assumption 3 always true? No — different attack types</span>
<span id="cb1-14">    could use specialized models.</span></code></pre></div>
<p>This template works for any ML conflict. A fraud detection team might write: “We need comprehensive transaction analysis AND sub-200ms decisions. Hidden assumption: every transaction needs the same analysis depth.” A recommendation team: “We need deep personalization AND instant page load. Hidden assumption: personalization must happen at request time.”</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Common ML Conflicts
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Accuracy vs Latency</strong>: Complex models are more accurate but slower</li>
<li><strong>Precision vs Coverage</strong>: Tight thresholds reduce false positives but miss edge cases</li>
<li><strong>Real-time vs Historical Context</strong>: Immediate response vs rich contextual analysis</li>
<li><strong>Generic vs Specific Models</strong>: Broad coverage vs environment-specific accuracy</li>
</ul>
</div>
</div>
<p><strong>Key takeaway:</strong> The tradeoff that’s blocking you is almost never fundamental. It persists because of hidden assumptions. Find the assumption. Challenge it. The conflict evaporates.</p>
</section>
</section>
<section id="step-4-break-the-tradeoff" class="level2">
<h2 class="anchored" data-anchor-id="step-4-break-the-tradeoff">Step 4: Break the Tradeoff</h2>
<p>You’ve identified the assumptions propping up the conflict. Now comes the payoff: designing a solution that captures the Pros of both the current approach and the alternative.</p>
<p>The goal is to get as many Pros from both sides as possible. Sometimes you genuinely get all of them. More often, some tradeoffs remain — added complexity, operational overhead, calibration effort. The difference from compromise is that these residual cons are <em>deliberate and manageable</em>, not the paralyzing either/or that kept the team stuck. You’re not splitting the difference. You’re changing the game so the remaining tradeoffs feel trivial compared to where you started.</p>
<section id="the-thinking-process" class="level3">
<h3 class="anchored" data-anchor-id="the-thinking-process">The Thinking Process</h3>
<p>After mapping your conflict and challenging assumptions (Step 3), work through each challenged assumption systematically:</p>
<ol type="1">
<li><p><strong>Sketch the system without the assumption.</strong> If you challenged “all records need the same analysis depth,” draw the pipeline where they don’t. What would variable-depth processing look like? What decides the depth?</p></li>
<li><p><strong>Look for the four reusable patterns.</strong> Most ML system innovations are combinations of these:</p>
<ul>
<li><strong>Cascade filtering</strong> — cheap check first, expensive check only when needed. Applicable whenever most inputs are routine. (Fraud: score transactions with simple rules before running the full model. Recs: serve cached recommendations before running personalization.)</li>
<li><strong>Async enrichment</strong> — decide now, enrich later. Useful whenever decision speed and decision quality have different time horizons. (Generate an alert with basic info immediately; add forensic context over the next 30 seconds.)</li>
<li><strong>Confidence-based routing</strong> — let the model decide how much compute each input deserves. Turns a fixed-cost pipeline into an adaptive one. (High-confidence benign traffic exits at Tier-1; uncertain traffic escalates.)</li>
<li><strong>Feature caching</strong> — never compute the same thing twice across pipeline stages. Obvious but rarely implemented. (Features from early triage stages are reused in deep analysis — we achieved 84% cache hit rates.)</li>
</ul></li>
<li><p><strong>Check for async opportunities</strong> — what’s being computed <em>before</em> the decision that could move to <em>after</em>?</p></li>
<li><p><strong>Check for caching opportunities</strong> — what’s being computed repeatedly across stages, records, or time windows?</p></li>
</ol>
</section>
<section id="a-worked-example-progressive-analysis" class="level3">
<h3 class="anchored" data-anchor-id="a-worked-example-progressive-analysis">A Worked Example: Progressive Analysis</h3>
<p>In our case, the most load-bearing assumption was: <em>“All records need the same analysis depth.”</em> Once you challenge it, the architecture follows from the patterns above — cascade filtering with confidence-based routing between tiers:</p>
<table class="table">
<colgroup>
<col style="width: 13%">
<col style="width: 22%">
<col style="width: 15%">
<col style="width: 28%">
<col style="width: 20%">
</colgroup>
<thead>
<tr class="header">
<th>Tier</th>
<th>Features</th>
<th>Model</th>
<th>Traffic Seen</th>
<th>Latency</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Tier-1</strong>: Wire-speed Triage</td>
<td>5 cheap features</td>
<td>Logistic regression</td>
<td>100% (68% exits)</td>
<td>~3ms</td>
</tr>
<tr class="even">
<td><strong>Tier-2</strong>: Fast Analysis</td>
<td>25 features</td>
<td>Moderate</td>
<td>~32%</td>
<td>~15ms</td>
</tr>
<tr class="odd">
<td><strong>Tier-3</strong>: Deep Analysis</td>
<td>100 features</td>
<td>Complex</td>
<td>~4%</td>
<td>~100ms</td>
</tr>
<tr class="even">
<td><strong>Forensic</strong>: Full Analysis</td>
<td>All features</td>
<td>Exhaustive</td>
<td>&lt;1%</td>
<td>~500ms</td>
</tr>
</tbody>
</table>
<p><a href="images/progressive-analysis-architecture.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-10"><img src="https://imaddabbura.github.io/posts/mlsys/images/progressive-analysis-architecture.svg" class="img-fluid"></a></p>
<p>Each stage outputs a prediction <em>and</em> a confidence score. High-confidence benign traffic exits immediately. Low confidence escalates. When stages experience backlog, confidence thresholds adjust dynamically — low-risk records defer to async processing during congestion, ensuring high-risk traffic always gets full analysis. And alerts are generated immediately with basic info, then progressively enriched over 30 seconds with connection context, historical patterns, and full forensics.</p>
<p><a href="images/feature-dependency.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-11"><img src="https://imaddabbura.github.io/posts/mlsys/images/feature-dependency.svg" class="img-fluid"></a></p>
</section>
<section id="fix-the-organization-too" class="level3">
<h3 class="anchored" data-anchor-id="fix-the-organization-too">Fix the Organization Too</h3>
<p>Remember: Step 2 told us the root cause was organizational — siloed teams, no end-to-end ownership, features added without production constraints. The progressive architecture only sticks if the organizational structure changes with it. We restructured so ML and platform teams share SLOs, and adding features now requires cross-team cost-benefit approval. Without this, the feature explosion that caused the original bottleneck would have returned within a year.</p>
<p><strong>Key takeaway:</strong> The innovation doesn’t have to be novel to the field. It has to be novel to <em>your</em> system. Progressive analysis is a known pattern — applying it to our specific bottleneck was the breakthrough. But the technical fix and the organizational fix are a package deal.</p>
</section>
</section>
<section id="step-5-prove-it-works" class="level2">
<h2 class="anchored" data-anchor-id="step-5-prove-it-works">Step 5: Prove It Works</h2>
<p>You’ve designed a solution on paper. Before you spend three months building it, spend two weeks proving the riskiest assumption.</p>
<p>An important distinction: a Minimally Viable Experiment (MVE) comes <em>before</em> a Minimally Viable Product (MVP). An MVP builds the smallest usable product. An MVE is smaller — it tests whether the core assumption behind the innovation is even valid. Don’t build anything until you’ve validated the assumption.</p>
<section id="identify-the-riskiest-assumption" class="level3">
<h3 class="anchored" data-anchor-id="identify-the-riskiest-assumption">Identify the Riskiest Assumption</h3>
<p>Ask: <em>what’s the single assumption that, if wrong, kills the entire approach?</em> For our progressive architecture, it was: “Can Tier-1 triage accurately identify benign traffic without missing attacks?” If lightweight features can’t reliably separate benign from suspicious, the whole cascade fails.</p>
<p>Design the smallest test that answers this question. We trained a logistic regression on 5 cheap features and tested on realistic data with known attacks:</p>
<div class="sourceCode" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1">tier1_features <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [</span>
<span id="cb2-2">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'src_reputation_score'</span>,  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Pre-computed reputation</span></span>
<span id="cb2-3">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'dst_port'</span>,              <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Destination port number</span></span>
<span id="cb2-4">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'protocol'</span>,              <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># TCP/UDP/ICMP</span></span>
<span id="cb2-5">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'packet_rate'</span>,           <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Packets per second</span></span>
<span id="cb2-6">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'byte_rate'</span>              <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Bytes per second</span></span>
<span id="cb2-7">]</span>
<span id="cb2-8"></span>
<span id="cb2-9">tier1_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> LogisticRegression(C<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>)</span>
<span id="cb2-10">tier1_model.fit(X_train[tier1_features], y_train_benign)</span></code></pre></div>
<p><a href="images/mve-experiment-timeline.svg" class="lightbox" data-gallery="quarto-lightbox-gallery-12"><img src="https://imaddabbura.github.io/posts/mlsys/images/mve-experiment-timeline.svg" class="img-fluid"></a></p>
</section>
<section id="results" class="level3">
<h3 class="anchored" data-anchor-id="results">Results</h3>
<table class="table">
<thead>
<tr class="header">
<th>Metric</th>
<th>Result</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Triage rate</td>
<td>68% identified as benign at Tier-1</td>
</tr>
<tr class="even">
<td>False negative rate</td>
<td>0% — no attacks missed</td>
</tr>
<tr class="odd">
<td>Throughput</td>
<td>95K records/sec</td>
</tr>
<tr class="even">
<td>p95 latency</td>
<td>3ms per record</td>
</tr>
<tr class="odd">
<td>Cache hit rate</td>
<td>84% across stages</td>
</tr>
</tbody>
</table>
</section>
<section id="what-the-iterations-taught-us" class="level3">
<h3 class="anchored" data-anchor-id="what-the-iterations-taught-us">What the Iterations Taught Us</h3>
<p>The MVE revealed things we couldn’t have predicted from design alone:</p>
<ol type="1">
<li><p><strong>Confidence calibration</strong>: Initial triage was too conservative — 32% of traffic passed to Tier-2 unnecessarily. The model lacked confidence on legitimate-but-unusual ports. Retraining with expanded examples achieved a 71% triage rate without missing attacks.</p></li>
<li><p><strong>Dynamic resource allocation</strong>: Fixed compute allocation caused bottlenecks when traffic patterns shifted. We implemented stages borrowing compute from idle stages, smoothing throughput across load profiles.</p></li>
<li><p><strong>Feature pruning</strong>: 15 Tier-3 features never influenced decisions in production. Removing them increased throughput 30% without affecting detection. Track a feature value score — importance / computational_cost — and prune ruthlessly.</p></li>
</ol>
</section>
<section id="production-rollout-checklist" class="level3">
<h3 class="anchored" data-anchor-id="production-rollout-checklist">Production Rollout Checklist</h3>
<ul>
<li><strong>Shadow mode</strong>: Run progressive pipeline parallel to existing system. Compare decisions, measure divergence. Success: no P1 incidents missed for one week.</li>
<li><strong>Canary (10%)</strong>: Route 10% of traffic through progressive pipeline. A/B test alert quality with analysts. Success: SLOs maintained, analyst preference ≥ baseline.</li>
<li><strong>Gradual expansion</strong>: 10% → 25% → 50% → 75%, holding each level for 48 hours. Automated rollback on any SLO violation.</li>
<li><strong>Full production</strong>: 100% with old system as instant fallback. Document runbooks, train operations team. Success: one week at 100% with all SLOs met.</li>
</ul>
</section>
<section id="gono-go-criteria" class="level3">
<h3 class="anchored" data-anchor-id="gono-go-criteria">Go/No-Go Criteria</h3>
<p>After the MVE, the decision is straightforward: does the riskiest assumption hold? If yes — the cascade correctly separates benign from suspicious — proceed to shadow mode. If the assumption fails, you haven’t wasted months; you’ve spent two weeks learning that you need a different innovation. Go back to Step 4 and challenge a different assumption.</p>
<p><strong>Key takeaway:</strong> A two-week MVE teaches more than two years of production experience. Test your riskiest assumption first.</p>
</section>
</section>
<section id="the-cycle-continues" class="level2">
<h2 class="anchored" data-anchor-id="the-cycle-continues">The Cycle Continues</h2>
<p>Here’s the part that surprises people: solving one constraint doesn’t fix the system forever. It reveals the <em>next</em> constraint. And that’s a feature, not a bug — because you always know exactly what to work on.</p>
<p>With Feature-Tier2 no longer the bottleneck, a new one emerged in our system: alert investigation. SOC analysts averaged 45 minutes per Tier-3 alert. This limited how many sophisticated attacks could be properly investigated. Applying the framework again:</p>
<ul>
<li><strong>Goal → Constraint</strong>: Reduce investigation time to 15 minutes while maintaining decision quality</li>
<li><strong>Constraint → Problem</strong>: Analysts manually correlate across multiple tools and data sources</li>
<li><strong>Problem → Conflict</strong>: Automated enrichment vs human judgment</li>
<li><strong>Conflict → Innovation</strong>: AI-assisted investigation that augments rather than replaces analysts</li>
<li><strong>Innovation → Experiment</strong>: Test on historical alerts with analyst feedback</li>
</ul>
<p>Each cycle makes the system more capable. Here’s where our NTA system ended up after one full pass through the framework:</p>
<table class="table">
<thead>
<tr class="header">
<th>Metric</th>
<th>Before</th>
<th>After</th>
<th>Change</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Detection time (p95)</td>
<td>47 min</td>
<td>3.2 min</td>
<td>15x faster</td>
</tr>
<tr class="even">
<td>Daily analyst alerts</td>
<td>847</td>
<td>11</td>
<td>98.7% reduction</td>
</tr>
<tr class="odd">
<td>Incidents missed/month</td>
<td>23</td>
<td>0</td>
<td>Eliminated</td>
</tr>
<tr class="even">
<td>Traffic coverage</td>
<td>30%</td>
<td>98%</td>
<td>Full visibility</td>
</tr>
<tr class="odd">
<td>Feature-Tier2 utilization</td>
<td>95%</td>
<td>42%</td>
<td>Headroom restored</td>
</tr>
</tbody>
</table>
<p>The constraint moved — from feature extraction to investigation to response automation — and each move represents the next opportunity for breakthrough improvement.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    A["&lt;b&gt;Goal &amp; Constraint&lt;/b&gt;&lt;br/&gt;SLOs + ledger"] --&gt; B["&lt;b&gt;Understand Why&lt;/b&gt;&lt;br/&gt;Root cause"]
    B --&gt; C["&lt;b&gt;Map Conflict&lt;/b&gt;&lt;br/&gt;Challenge assumptions"]
    C --&gt; D["&lt;b&gt;Innovate&lt;/b&gt;&lt;br/&gt;Best of both sides"]
    D --&gt; E["&lt;b&gt;Experiment&lt;/b&gt;&lt;br/&gt;MVE → rollout"]
    E --&gt; |"Constraint moves"| A

    style A fill:#e8f5e9,stroke:#333
    style B fill:#e3f2fd,stroke:#333
    style C fill:#fff3e0,stroke:#333
    style D fill:#fce4ec,stroke:#333
    style E fill:#f3e5f5,stroke:#333
</pre>
</div>
<p></p><figcaption> The Theory of Constraints is a cycle, not a line</figcaption> </figure><p></p>
</div>
</div>
</div>
<p>The Theory of Constraints isn’t another optimization technique. It’s an operating system for continuous improvement. The constraint keeps moving, but so do you.</p>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion">Conclusion</h2>
<p>Most ML teams are stuck in optimization theater — tuning components that don’t govern system performance. The Theory of Constraints gives you a way out: find the one bottleneck that sets the ceiling, understand why it’s stuck, and design an innovation that breaks the tradeoff instead of compromising on it.</p>
<p>The method is five steps, but the discipline is one idea: <em>at any moment, only one thing limits your system.</em> Find it. Fix it. Then find the next one.</p>
<p>If you take one thing from this post, make it this: before your next “optimization” sprint, build the constraint ledger. Measure every stage. Find the row with high utilization and a growing queue. That’s where your effort belongs — and nowhere else.</p>
</section>
<section id="references" class="level2">
<h2 class="anchored" data-anchor-id="references">References</h2>
<ul>
<li>Goldratt, E. M. (1984). “The Goal: A Process of Ongoing Improvement.” North River Press.</li>
<li>Sculley, D., et al.&nbsp;(2015). “Hidden Technical Debt in Machine Learning Systems.” NIPS 2015.</li>
<li>Paleyes, A., et al.&nbsp;(2022). “Challenges in Deploying Machine Learning: A Survey of Case Studies.” ACM Computing Surveys.</li>
<li>Kleppmann, M. (2017). “Designing Data-Intensive Applications.” O’Reilly Media.</li>
<li>Polyzotis, N., et al.&nbsp;(2018). “Data Lifecycle Challenges in Production Machine Learning.” SIGMOD Record.</li>
<li>Chandola, V., et al.&nbsp;(2009). “Anomaly Detection: A Survey.” ACM Computing Surveys.</li>
<li>Ahmed, M., et al.&nbsp;(2016). “A Survey of Network Anomaly Detection Techniques.” Journal of Network and Computer Applications.</li>
<li>Sommer, R., &amp; Paxson, V. (2010). “Outside the Closed World: On Using Machine Learning for Network Intrusion Detection.” IEEE S&amp;P.</li>
</ul>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>ML Systems</category>
  <guid>https://imaddabbura.github.io/posts/mlsys/improving-mlsys-theory-of-constraint.html</guid>
  <pubDate>Sun, 21 Sep 2025 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/mlsys/images/ml-toc.png" medium="image" type="image/png" height="96" width="144"/>
</item>
<item>
  <title>Hard-Learned Lessons in Shipping Software (AI/ML) Projects</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/product-management/shipping-software-projects.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge growing">growing</span></div>
<section id="why-ml-projects-fail-to-ship" class="level2">
<h2 class="anchored" data-anchor-id="why-ml-projects-fail-to-ship">Why ML Projects Fail to Ship</h2>
<p>Some ML projects I’ve worked on shipped six months late. Others shipped and quietly died in production. A few never shipped at all — and those are the ones I keep coming back to. I’ve been through this as an individual contributor and as the person leading the team. For a long time I blamed the usual suspects: unclear requirements, technical debt, underestimating complexity. The real cause, I eventually realized, was more structural — and it looked the same from both seats.</p>
<p>A web feature has a clear definition of done: the button appears, the form submits, the data is saved. An ML feature doesn’t. “Improve recommendation accuracy” could mean another week of training runs, another architecture experiment, another round of feature engineering — indefinitely. Unlike traditional software, where the solution space is bounded by the spec, ML projects have an effectively unbounded search space. Every model can be made larger, every feature set more complete, every training run longer.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    subgraph trad ["Traditional Software"]
        direction LR
        t1["Define Feature"] --&gt; t2["Build"] --&gt; t3["Ship ✓"]
    end
    subgraph ml ["ML Without Constraints"]
        direction LR
        m1["Vague Goal"] --&gt; m2["Experiment #1"]
        m2 --&gt;|"+0.5% accuracy"| m3["Experiment #2"]
        m3 --&gt;|"+0.2% accuracy"| m4["Experiment #3"]
        m4 -.-&gt;|"one more try..."| m2
    end
</pre>
</div>
<p></p><figcaption> Traditional software has a bounded end state. ML projects without defined constraints loop indefinitely — each experiment looks like progress.</figcaption> </figure><p></p>
</div>
</div>
</div>
<p>This produces a predictable failure mode: a project that <em>looks</em> like it’s making progress — models training, experiments running, metrics moving — but never ships.</p>
<p>The root cause, I’ve come to believe, is a category error. ML projects sit uncomfortably between research and engineering. Research is unbounded by design — you keep going until you understand something. Engineering is bounded by design — you keep going until it ships. The teams that deliver consistently have made a deliberate choice about which one they’re doing. The ones that don’t, haven’t — and so they run a research process inside an engineering context, indefinitely.</p>
<p>What follows is what I’ve learned — sometimes from my own failures, sometimes from watching teams I was leading repeat patterns I’d already lived through — about how to actually make that choice.</p>
</section>
<section id="define-the-target-before-writing-a-line-of-code" class="level2">
<h2 class="anchored" data-anchor-id="define-the-target-before-writing-a-line-of-code">Define the Target Before Writing a Line of Code</h2>
<p>The most consistent mistake I’ve made — and watched others make — is starting before the goal is properly defined. It doesn’t feel like a mistake at the time. There’s energy, there’s a general direction, there’s a team ready to move. But you can’t constrain an undefined goal. The first structural requirement for shipping is a precise definition of version one — not the roadmap, not the vision, but version one.</p>
<p>Write it as a falsifiable criterion: <em>“A model that identifies churn risk 14 days in advance with precision ≥ 70% on the holdout set, deployable via the existing prediction service.”</em> That’s a definition. “Improve churn prediction” is not — it’s a direction, and directions don’t ship.</p>
<p>Before writing code, three questions force the definition:</p>
<p><strong>Who is this for, specifically?</strong> A customer-facing recommendation system for mobile users has different input distributions, latency constraints, and acceptable error modes than an internal analyst tool. “Users in general” means nobody in particular, and a system designed for nobody in particular gets spec’d indefinitely.</p>
<p><strong>What does version one accomplish — and what does it explicitly not do?</strong> The second half matters as much as the first. Scope creep in ML is insidious because experiments feel like progress: an additional feature, a new architecture variant, a cleaned edge case — each looks like forward motion. The out-of-scope list is what makes the in-scope list real.</p>
<p><strong>What are the success criteria, written down and falsifiable?</strong> Precision ≥ 0.70. Latency ≤ 100ms at p99. Deployable on the existing serving infrastructure. Criteria that can be verified make it possible to call the project done. Criteria that can’t — “good enough,” “production-ready,” “performs well” — guarantee the project never ends.</p>
<p>Working backwards from these answers also produces the project structure. Once you know what version one must accomplish, you can enumerate prerequisite questions: <em>What training data do we need? What does the evaluation harness look like? How does it plug into production?</em> Each answer either gets scheduled or gets cut. Vague goals don’t allow this decomposition — they keep the surface area perpetually open.</p>
</section>
<section id="make-time-the-constraint-not-scope" class="level2">
<h2 class="anchored" data-anchor-id="make-time-the-constraint-not-scope">Make Time the Constraint, Not Scope</h2>
<p>The natural instinct is to treat scope as fixed and deadline as flexible. This is exactly backwards.</p>
<p>Scope in an ML project is not fixed — it’s infinitely expandable. There’s always a better architecture to try, a cleaner way to handle edge cases, a feature that might help. Teams treat scope as the constraint because it feels <em>owned</em>: the team wrote the requirements, agreed on them, and changing them feels like abandoning a commitment. Deadlines, by contrast, feel externally imposed and therefore more negotiable — something to push when the work “isn’t ready yet.”</p>
<p>I’ve been in this meeting many times — sometimes as the engineer watching the deadline move, more often as the person responsible for it. The deadline slips, then slips again, then becomes a standing item on the weekly status call.</p>
<p>Flip the constraint. Treat the deadline as fixed and scope as the variable. This changes the question from <em>“when will we be done with everything we planned?”</em> to <em>“what’s the most important thing we can ship by this date?”</em> The second question forces real prioritization. The model that trains in four hours ships; the one that takes 24 hours doesn’t. The feature built on existing infrastructure stays; the one that requires a new data pipeline gets cut.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Deadlines as a Design Tool
</div>
</div>
<div class="callout-body-container callout-body">
<p>A deadline doesn’t dictate quality — it dictates scope. The discipline is specifically about protecting the deadline from <em>scope expansion</em>, not accelerating the work. When a new requirement surfaces mid-project, the question isn’t <em>“can we fit it in?”</em> — it’s <em>“what does it displace?”</em></p>
</div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Version One Is Supposed to Be Small
</div>
</div>
<div class="callout-body-container callout-body">
<p>Version one of most production models is smaller, faster, and more constrained than anything the team initially imagined. Good. The goal of version one isn’t to build the best possible system — it’s to establish the deployment path, validate production integration, and generate real usage data. The best possible system comes later, built on what version one teaches you.</p>
</div>
</div>
</section>
<section id="decompose-until-done-is-unambiguous" class="level2">
<h2 class="anchored" data-anchor-id="decompose-until-done-is-unambiguous">Decompose Until Done Is Unambiguous</h2>
<p>Once you have a target and a deadline, break the project into tasks — not work items, not epics, tasks. The distinction matters: a task has an unambiguous definition of done. A project doesn’t.</p>
<p>“Train a production NLP model” is a project. Tasks are: - <em>“Label 500 training examples from the January logs”</em> — done or not done. - <em>“Achieve F1 ≥ 0.82 on the validation split”</em> — done or not done. - <em>“Write the endpoint that accepts raw text and returns a classification”</em> — done or not done.</p>
<p>If you can’t tell whether a piece of work is finished without discussion, break it down further.</p>
<p>With tasks in hand, ruthlessly prioritize against the version one criteria. Not everything is equally important, and pretending otherwise is how projects stall:</p>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Category</th>
<th>Description</th>
<th>Rule</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Must-have</strong></td>
<td>System cannot ship without this</td>
<td>Do first, never cut</td>
</tr>
<tr class="even">
<td><strong>Should-have</strong></td>
<td>Meaningfully improves the product</td>
<td>Include if time allows</td>
</tr>
<tr class="odd">
<td><strong>Nice-to-have</strong></td>
<td>Incremental gain, no blocker</td>
<td>Version two</td>
</tr>
<tr class="even">
<td><strong>Gold-plating</strong></td>
<td>No clear user benefit</td>
<td>Cut immediately</td>
</tr>
</tbody>
</table>
<p>The failure mode is treating “should-haves” as “must-haves.” It happens because, deep down, the team doesn’t believe version two is coming. If this feels like the only shot, every improvement feels essential. But that’s exactly backwards: version two only exists because version one shipped. Holding version one hostage to version two’s requirements is how you guarantee neither does.</p>
</section>
<section id="validate-the-core-assumption-before-building-the-system" class="level2">
<h2 class="anchored" data-anchor-id="validate-the-core-assumption-before-building-the-system">Validate the Core Assumption Before Building the System</h2>
<p>Every ML project rests on a single load-bearing assumption: <em>“Does a model trained on this data actually produce useful predictions for this problem?”</em> Everything else — the serving infrastructure, the feature pipeline, the retraining loop, the monitoring dashboard — only matters if the answer is yes.</p>
<p>I’ve fallen into this trap early in my career, and led teams into it later. The pattern is always the same: stand up a feature store, design a training pipeline, architect a serving layer — then train the model and discover the data doesn’t support the prediction task, or the signal is too weak, or the problem is better solved without ML at all. Months of infrastructure work, none of it applicable to the revised approach. The infrastructure trap is just as easy to fall into when you’re the one setting the direction as when you’re the one doing the building.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    subgraph wrong ["❌ Common Mistake"]
        direction LR
        w1["Feature Store"] --&gt; w2["Training Pipeline"] --&gt; w3["Model Registry"] --&gt; w4["Model"] --&gt; w5["Works?"]
    end
    subgraph right ["✓ Correct Order"]
        direction LR
        r1["Validate\nApproach"] --&gt; r2["Establish\nDeploy Path"] --&gt; r3["Build\nInfrastructure"] --&gt; r4["Ship ✓"]
    end
</pre>
</div>
<p></p><figcaption> The infrastructure trap: building the full system before validating the approach. The correct order validates cheaply first, then invests in infrastructure.</figcaption> </figure><p></p>
</div>
</div>
</div>
<p>The minimum viable experiment: train a simple baseline on a slice of the data, evaluate it against a manually-labeled holdout, and show the results to at least one person who’d actually use the output. Logistic regression, a small neural net, a fine-tuned pretrained model — whatever takes days, not months. If the results are promising, the infrastructure investment is justified. If not, you’ve learned the most important thing about the project for the cost of two weeks.</p>
<p>This also determines tool choice during validation. <code>scikit-learn</code>, <code>PyTorch</code>, pre-trained transformers from HuggingFace — these represent thousands of engineering hours and are battle-tested at scale. Custom architectures and bespoke training loops are justified when profiling data shows standard tools can’t meet your requirements. That data doesn’t exist before validation. Building custom infrastructure before validating the approach is the fastest way to spend six months on something nobody uses.</p>
</section>
<section id="ship-then-compound" class="level2">
<h2 class="anchored" data-anchor-id="ship-then-compound">Ship, Then Compound</h2>
<p>Once version one meets the criteria, ship it — even if it’s not perfect.</p>
<p>Every model I’ve shipped has surprised me in production. Not because the evaluation was wrong, but because it was measuring the wrong things. Your holdout set measures what you measured. Real users do things you didn’t anticipate — edge cases you didn’t label, inputs from distributions you didn’t sample, and above all, they surface which errors actually matter. A model that’s 92% accurate on the evaluation set might be systematically wrong on the 8% of inputs that are disproportionately important to users. You won’t know that until the model is deployed.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    A["Ship\nImperfect v1"] --&gt; B["Real\nUsage Data"]
    B --&gt; C["Discover\nActual Failures"]
    C --&gt; D["Targeted\nFixes"]
    D --&gt; E["Ship\nBetter v2"]
    E --&gt; B
</pre>
</div>
<p></p><figcaption> The iteration flywheel: each shipped version surfaces real failures that targeted improvements address, compounding over time.</figcaption> </figure><p></p>
</div>
</div>
</div>
<p>Ship versions that meet the bar, not versions that approach some imagined ceiling. Version one will be wrong in ways you didn’t anticipate — I’ve never shipped one that wasn’t, and I’ve never led a team that did. One of the harder things about leading engineers through this is convincing them that shipping something imperfect isn’t a compromise — it’s the whole point. The failures you discover in production are the ones that matter. Find them early, when fixing them is fast, not late, when the system is load-bearing and everything is entangled.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Each Version Enables the Next
</div>
</div>
<div class="callout-body-container callout-body">
<p>Real usage reveals the failures that matter — not the ones you hypothesized in the design doc, but the ones users actually encounter. Their feedback tells you which improvements are worth making. Infrastructure built for version one scales to version two. The teams that ship consistently aren’t the ones with better planning processes — they’re the ones who’ve completed more cycles of this loop.</p>
</div>
</div>
</section>
<section id="key-takeaways" class="level2">
<h2 class="anchored" data-anchor-id="key-takeaways">Key Takeaways</h2>
<ol type="1">
<li><p><strong>Decide whether you’re doing research or engineering before you start.</strong> ML projects that don’t make this distinction run research processes in engineering contexts — indefinitely.</p></li>
<li><p><strong>Define version one as a falsifiable criterion.</strong> Precision ≥ X. Latency ≤ Y. Deployable on Z. Criteria that can’t be verified guarantee the project never ends.</p></li>
<li><p><strong>Treat deadline as fixed, scope as variable.</strong> The question is always: <em>“What’s the most important thing we can ship by this date?”</em></p></li>
<li><p><strong>Decompose until done is unambiguous.</strong> If you can’t tell whether a task is finished without discussion, it’s not a task — it’s a project.</p></li>
<li><p><strong>Validate the core assumption before building infrastructure.</strong> Does the model work on this data? Answer that first, with the simplest possible tools. Everything else comes after.</p></li>
<li><p><strong>Ship the imperfect version.</strong> Offline evaluation measures what you measured. Real usage reveals what you missed. Each shipped version enables the next.</p></li>
</ol>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Machine Learning</category>
  <category>Deep Learning</category>
  <category>Software Engineering</category>
  <guid>https://imaddabbura.github.io/posts/product-management/shipping-software-projects.html</guid>
  <pubDate>Sun, 05 Jan 2025 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/product-management/shipping-software-projects.png" medium="image" type="image/png" height="84" width="144"/>
</item>
<item>
  <title>From Forgetting to Fluency: How to Learn Smarter, Not Harder</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/personal-growth/notes-on-learning.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge growing">growing</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p>In today’s fast-paced world, the ability to learn effectively and retain information has become more crucial than ever. Whether you’re a student preparing for exams, a professional mastering new skills, or simply someone seeking personal growth, finding ways to optimize learning and improve memory recall can make a significant difference in achieving your goals.</p>
<p>Fortunately, research in neuroscience and cognitive psychology has shed light on strategies that enhance how we absorb and retain knowledge. There are numerous methods to make learning not only more efficient but also more enjoyable. This article explores a variety of evidence-based approaches to optimize learning and strengthen recall and hopefully you’ll find actionable insights to elevate your learning journey.</p>
</section>
<section id="learning-techniques" class="level2">
<h2 class="anchored" data-anchor-id="learning-techniques">Learning Techniques</h2>
<ul>
<li>Effective instructions should match the content not the learning styles. For example, cooking instruction should use hands-on practices even if the student is a visual learner.</li>
<li>Learning means that a change made to long-term memory.</li>
<li>Human memory is not as precise/reliable as computer memory. It has <code>read-and-update</code>. Reading memory will lead to strengthen and modify the fetched information especially if the information is recently learned.</li>
<li>Stored information is stored in interconnected neural pathways. If we try to access targeted information, we activate a pathway of neurons to access the information which leads to spread the activation to other connected pathways that may not be related to the target information. This spreading activation leave related pathways primed for activation for hours . As a result:
<ul>
<li>Spreading activation leads to related but imprecise information to be conflated with target information, which leads to unreliable recall of information.</li>
<li>Because pathways stay primed for hours, it helps us with problem solving when we step away to work on something else, go for a walk, or take a shower, and the two unrelated areas connect in the middle.</li>
</ul></li>
<li>There are two types of memory:
<ul>
<li><strong>Long-term memory</strong> where information is permanently stored and is unlimited. It is analogous to disk storage.</li>
<li><strong>Working (short-term) memory</strong> is used to solve problems. It is analogous to CPU’s registers. The bigger the working memory, the faster we can learn. It is roughly fixed at birth.</li>
</ul></li>
<li><strong>Chunking</strong> is when we relate information together as one piece. The more we combine information as one piece (chunk), the easier it is to reason about and solve problems related to it. This is due to the fact that we can store pointers to such chunks in the working memory and access such chunks in long-term memory if needed. Therefore, it is critical to decompose difficult tasks into smaller pieces (chunks) when learning, which later will be chunked together as we practice.</li>
<li>The difference between experts and beginners is that experts remember and recognize patterns to help them solve problems. However, beginners read code line by line to understand what it is doing or how to approach solving problems. Therefore, to achieve proficiency in programming, you need to read/write and work with a lot of code to be exposed to more patterns as well as programming using different programming paradigms/languages.</li>
<li>To understand a concept, we may need to go from abstract to diverse set of concrete examples and back to abstract. This helps us with chunking and treating all the concrete examples as different views of the abstract concept. Once we understand the concrete examples, we can connect it back to the abstract concept.</li>
<li><strong>Spacing</strong> and <strong>Repetition</strong> are keys for learning. We learn problem-solving concepts best by spacing out their practice across multiple sessions, multiple days, and ideally, multiple weeks. Practice helps us connect the text in the problem to the concept and applying the concept to solve the problem.</li>
<li>Concentration after <strong>90 minutes</strong> is hard due to neuro-chemical balance in the brain. It is recommended to rest/sleep/walk after the 90 minutes so the information gets consolidated. Don’t work on other tasks, talk to others, or browse the internet.</li>
<li>Even if we can access information on the internet, it is advisable to understand concepts we deal with frequently so the brain can form connections and help us understand deeper concepts. Also, it is much better to try to recall information from long-term memory than search for it on the internet especially if we are not experts.</li>
<li>Problem-solving is not a generic skill. It is domain-specific skill. This means that a good chess player may not be a good problem-solver in programming or other domains. As a result, to get better at programming problem-solving, learn to solve programming problems.</li>
<li>There is no clear predictor in programming ability other than experience.</li>
<li>Growth mindset, learning to overcome setbacks and failures, and practice is all you need to be successful in your career. You will have to always evaluate your learning strategies to get the best outcome.</li>
</ul>
</section>
<section id="recommendations" class="level2">
<h2 class="anchored" data-anchor-id="recommendations">Recommendations</h2>
<ul>
<li>For recruiting:
<ul>
<li>There are no good proxies for programming ability, look at their previous work or test them on authentic programming tasks.</li>
<li>At least among young developers, years of experience may not be a very reliable measure of ability.</li>
</ul></li>
<li>For learning and training:
<ul>
<li>Reading a lot of code will help become a more efficient programmer.</li>
<li>Experts are not always the best at training beginners.</li>
<li>Learning takes time, including time between learning sessions. Intense cramming is not effective, but spaced repetition is.</li>
<li>Similarly, spending time away from a problem can help to solve it.</li>
<li>Just because you can find it through an Internet search or generative AI tool does not mean learning has become obsolete.</li>
<li>Use examples to go between abstract concepts and concrete learnable facts.</li>
<li>Seeking to succeed (rather than avoid failure) and believing that ability is changeable are important factors in resilience and learning.</li>
</ul></li>
</ul>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion">Conclusion</h2>
<p>As we’ve explored various approaches to optimize learning and enhance recall, it’s clear that the journey to effective learning is both dynamic and multifaceted. By incorporating evidence-based strategies such as spaced repetition, and active engagement, individuals can significantly improve their ability to absorb and retain information.</p>
<p>Embracing a growth mindset and being open to adapting your strategies will further empower you to navigate the complexities of acquiring knowledge. In conclusion, by implementing these innovative approaches, you can transform your learning experience into a more productive and rewarding endeavor. Start applying these insights today and watch as your ability to learn and recall information flourishes!</p>
</section>
<section id="further-reading" class="level2">
<h2 class="anchored" data-anchor-id="further-reading">Further Reading</h2>
<ul>
<li><em>Why Don’t Students Like School?</em> by Daniel T. Willingham provides a short and readable explanation of many of the principles of memory and how the brain works.</li>
<li><em>The Programmer’s Brain</em> by Felienne Hermans et al.c relates these concepts to programming and describes how techniques for learning and revision that are used at school can still apply to professional development.</li>
<li><em>How Learning Happens: Seminal Works in Educational Psychology and What They Mean in Practice</em> by Paul A. Kirschner and Carl Hendrick provides a tour through influential papers, explaining them in plain language and the implications and linkages between them.</li>
<li><a href="https://cacm.acm.org/magazines/2024/1/278891-10-things-software-developers-should-learn-about-learning/fulltext"><em>10 Things Software Developers Should Learn about Learning</em></a> by Neil C. C. Brown, Felienne F. J. Hermans, and Lauren E. Margulieux.</li>
</ul>
<p>#personal-growth #career-advice</p>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Problem Solving</category>
  <guid>https://imaddabbura.github.io/posts/personal-growth/notes-on-learning.html</guid>
  <pubDate>Fri, 13 Sep 2024 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/personal-growth/learning.svg" medium="image" type="image/svg+xml"/>
</item>
<item>
  <title>Why Your Final Layer Shouldn’t Have Softmax</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/dl/why-not-softmax.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="a-common-mistake-thats-hard-to-see" class="level2">
<h2 class="anchored" data-anchor-id="a-common-mistake-thats-hard-to-see">A Common Mistake That’s Hard to See</h2>
<p>If you’ve built a classifier in PyTorch, you’ve probably seen <code>nn.Softmax</code> and <code>nn.CrossEntropyLoss</code> in the same codebase. You may have even used them together — softmax at the end of the model, cross-entropy as the loss. The code runs, the loss decreases, the model converges. Everything looks fine.</p>
<p>But something is wrong. <code>nn.CrossEntropyLoss</code> already applies softmax internally. Applying it again in the model’s final layer means softmax is computed twice — and the gradients computed during backpropagation are the gradients of the wrong function. The model still learns, just more slowly, less stably, and to a worse optimum.</p>
<p>This post unpacks <em>why</em> — starting with what softmax actually does, then working through the numerical stability mechanism that motivates keeping raw logits, and finishing with a clear picture of when softmax belongs and when it doesn’t.</p>
</section>
<section id="what-softmax-does" class="level2">
<h2 class="anchored" data-anchor-id="what-softmax-does">What Softmax Does</h2>
<p>The softmax function takes a vector of raw scores — <strong>logits</strong> — and squashes them into a probability distribution:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Ctext%7Bsoftmax%7D(z_i)%20=%20%5Cfrac%7Be%5E%7Bz_i%7D%7D%7B%5Csum_j%20e%5E%7Bz_j%7D%7D"></p>
<p>The outputs are in <img src="https://latex.codecogs.com/png.latex?%5B0,%201%5D"> and sum to 1. For a ten-class classifier, softmax turns a vector like <img src="https://latex.codecogs.com/png.latex?%5B2.1,%5C%20-0.3,%5C%200.8,%5C%20%5Cldots%5D"> into a proper probability distribution over the ten classes. This seems like exactly the right thing to do before computing a loss that expects probabilities.</p>
<p>The problem isn’t what softmax does. It’s <em>where</em> you do it — and whether the operation downstream already does it better.</p>
</section>
<section id="the-log-sum-exp-trick" class="level2">
<h2 class="anchored" data-anchor-id="the-log-sum-exp-trick">The Log-Sum-Exp Trick</h2>
<p>To understand why <code>CrossEntropyLoss</code> wants raw logits, we need to look at what it computes. Cross-entropy loss for the true class <img src="https://latex.codecogs.com/png.latex?y"> is:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BL%7D%20=%20-%5Clog%5Cleft(%5Cfrac%7Be%5E%7Bz_y%7D%7D%7B%5Csum_j%20e%5E%7Bz_j%7D%7D%5Cright)%20=%20-z_y%20+%20%5Clog%5Csum_j%20e%5E%7Bz_j%7D"></p>
<p>The second term — <img src="https://latex.codecogs.com/png.latex?%5Clog%5Csum_j%20e%5E%7Bz_j%7D"> — is the <strong>log-sum-exp (LSE)</strong>, and it’s numerically dangerous. If any logit is large, the exponent overflows to <code>inf</code> before the log can bring it back down:</p>
<div class="sourceCode" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb1-2">z <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tensor([<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000.0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1001.0</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1002.0</span>])</span>
<span id="cb1-3">torch.softmax(z, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># → tensor([nan, nan, nan])</span></span></code></pre></div>
<p>The standard fix is the <strong>log-sum-exp trick</strong>: subtract the maximum logit before exponentiating.</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Clog%5Csum_j%20e%5E%7Bz_j%7D%20=%20c%20+%20%5Clog%5Csum_j%20e%5E%7Bz_j%20-%20c%7D,%20%5Cquad%20c%20=%20%5Cmax_j%20z_j"></p>
<p>Subtracting <img src="https://latex.codecogs.com/png.latex?c%20=%20%5Cmax_j%20z_j"> keeps all terms in <img src="https://latex.codecogs.com/png.latex?%5Be%5E%7B-%5Cinfty%7D,%5C%201%5D"> — never overflowing, never underflowing. The mathematical result is identical; the numerical result is stable.</p>
<p>This is exactly what <code>nn.CrossEntropyLoss</code> does internally. It doesn’t apply softmax and then compute cross-entropy — it <strong>fuses both operations</strong> into one numerically stable pass using the LSE trick. Passing raw logits is what makes this possible.</p>
<p>If you apply softmax first, the loss function receives <img src="https://latex.codecogs.com/png.latex?p_i%20=%20e%5E%7Bz_i%7D/%5Csum%20e%5E%7Bz_j%7D"> instead of raw logits and then applies its own log-softmax to those values — effectively computing <img src="https://latex.codecogs.com/png.latex?%5Clog(%5Ctext%7Bsoftmax%7D(%5Ctext%7Bsoftmax%7D(z)))">. The numbers are wrong and the gradients are wrong.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    subgraph bad ["❌ Double Application"]
        direction LR
        z1["Logits z"] --&gt; s1["nn.Softmax"] --&gt; p1["Probs p"] --&gt; ce1["CrossEntropyLoss\nlog-softmax(p)"]
    end
    subgraph good ["✅ Correct"]
        direction LR
        z2["Logits z"] --&gt; ce2["CrossEntropyLoss\nlog-softmax(z) — fused, stable"]
    end
</pre>
</div>
<p></p><figcaption> Left: pre-applying softmax breaks the fused computation, producing gradients of the wrong function. Right: raw logits let CrossEntropyLoss apply the log-sum-exp trick internally.</figcaption> </figure><p></p>
</div>
</div>
</div>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Practical Rule
</div>
</div>
<div class="callout-body-container callout-body">
<p><code>nn.CrossEntropyLoss</code> (PyTorch) and <code>tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)</code> (TensorFlow) both expect <strong>raw logits</strong>. The loss handles the stable, fused computation internally. Don’t apply softmax to the final layer of a classifier.</p>
</div>
</div>
</section>
<section id="multi-label-classification-the-wrong-prior" class="level2">
<h2 class="anchored" data-anchor-id="multi-label-classification-the-wrong-prior">Multi-Label Classification: The Wrong Prior</h2>
<p>Softmax enforces <strong>competition</strong> between classes: increasing one class’s probability necessarily decreases the others. This is the correct structure for single-label tasks — exactly one class is true — and entirely the wrong structure for multi-label tasks, where multiple classes can be true simultaneously.</p>
<p>Consider a document classifier that assigns topics like “machine learning,” “software engineering,” and “career advice.” A document can belong to all three. Softmax forces these to compete: pushing “machine learning” up automatically pushes the others down. The model is fighting its own output structure.</p>
<p>There’s a deeper problem. Because softmax outputs always sum to 1, <strong>the model is structurally forced to predict high confidence for exactly one class</strong> — regardless of the input. If an image contains no objects from the training categories, softmax still redistributes its probability mass across the classes and picks a winner. If an image contains three objects, softmax still collapses to one. It has no way to say “multiple things are present” or “nothing relevant is here” — the sum-to-one constraint makes both answers impossible.</p>
<p>For multi-label classification, the correct output is <strong>sigmoid</strong>, applied independently per class:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Csigma(z_i)%20=%20%5Cfrac%7B1%7D%7B1%20+%20e%5E%7B-z_i%7D%7D"></p>
<p>Each output is an independent probability in <img src="https://latex.codecogs.com/png.latex?%5B0,%201%5D"> with no constraint that they sum to 1. Use <code>nn.BCEWithLogitsLoss</code> — which applies sigmoid internally with the same kind of numerical stability fusion — rather than sigmoid in the model followed by <code>nn.BCELoss</code>.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Choosing the Right Output Layer
</div>
</div>
<div class="callout-body-container callout-body">
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Task</th>
<th>Loss function</th>
<th>Notes</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Single-label classification</td>
<td><code>nn.CrossEntropyLoss</code></td>
<td>Expects raw logits; applies log-softmax internally</td>
</tr>
<tr class="even">
<td>Multi-label classification</td>
<td><code>nn.BCEWithLogitsLoss</code></td>
<td>Expects raw logits; applies sigmoid internally</td>
</tr>
<tr class="odd">
<td>Binary classification</td>
<td><code>nn.BCEWithLogitsLoss</code></td>
<td>Same as above</td>
</tr>
<tr class="even">
<td>Probabilities at inference</td>
<td>Apply <code>softmax</code> <em>after</em> training</td>
<td>Not during training</td>
</tr>
</tbody>
</table>
</div>
</div>
</section>
<section id="softmax-and-overconfidence" class="level2">
<h2 class="anchored" data-anchor-id="softmax-and-overconfidence">Softmax and Overconfidence</h2>
<p>Softmax is sensitive to the <strong>scale</strong> of the logits, not just their relative ordering. Logits <img src="https://latex.codecogs.com/png.latex?%5B3,%5C%201,%5C%200%5D"> and <img src="https://latex.codecogs.com/png.latex?%5B300,%5C%20100,%5C%200%5D"> produce the same ranking but very different softmax outputs — the scaled version concentrates nearly all probability mass on the top class. As training progresses, logit magnitudes tend to grow, and softmax increasingly exaggerates these differences.</p>
<p>The result is systematic overconfidence: a model that outputs near-100% probability on examples it gets wrong. The <a href="https://arxiv.org/abs/1706.04599">Guo et al.&nbsp;2017 calibration paper</a> showed this is a consistent property of modern neural networks, not a training artifact.</p>
<p>The standard fix is <strong>temperature scaling</strong>: divide logits by a learned scalar <img src="https://latex.codecogs.com/png.latex?T%20%3E%201"> before applying softmax at inference time.</p>
<p><img src="https://latex.codecogs.com/png.latex?p_i%20=%20%5Ctext%7Bsoftmax%7D(z_i%20/%20T)"></p>
<p><img src="https://latex.codecogs.com/png.latex?T%20%3E%201"> flattens the distribution (less confident); <img src="https://latex.codecogs.com/png.latex?T%20%3C%201"> sharpens it. <img src="https://latex.codecogs.com/png.latex?T"> is fit on a held-out validation set after training finishes. Crucially, this only works if the model was trained on raw logits — the scale information that temperature scaling adjusts is preserved through training and only consumed at inference.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Calibration Connection
</div>
</div>
<div class="callout-body-container callout-body">
<p>Post-hoc calibration methods (temperature scaling, Platt scaling, isotonic regression) all operate on the raw logit magnitudes that accumulate through training. If your output layer applies softmax during training, the scale information is destroyed before calibration is attempted — the calibration methods have nothing useful to fit.</p>
</div>
</div>
</section>
<section id="when-softmax-belongs" class="level2">
<h2 class="anchored" data-anchor-id="when-softmax-belongs">When Softmax Belongs</h2>
<p>Removing softmax from the final classification layer doesn’t mean it’s always wrong — it means the structure it imposes (mutual exclusivity, sum-to-one) has to match what the computation actually needs.</p>
<p><strong>Attention mechanisms.</strong> The scaled dot-product attention in Transformers applies softmax to produce a distribution over positions. This is exactly right: each query should distribute its weight across keys, and the competition structure is intentional. There’s no fused loss downstream computing log-softmax again.</p>
<p><strong>Contrastive learning.</strong> Methods like CLIP apply softmax across the batch as part of the contrastive loss. The within-batch competition is the learning signal.</p>
<p><strong>Inference-time probabilities.</strong> If downstream code requires calibrated probabilities — confidence thresholds, ensemble averaging, displaying to users — apply softmax to the final logits after the forward pass, outside the model:</p>
<div class="sourceCode" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">with</span> torch.no_grad():</span>
<span id="cb2-2">    logits <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(x)</span>
<span id="cb2-3">    probs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.softmax(logits, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span></code></pre></div>
<p>The pattern: softmax belongs when the distribution semantics genuinely fit the computation, and when nothing downstream is already computing a fused version of it.</p>
</section>
<section id="key-takeaways" class="level2">
<h2 class="anchored" data-anchor-id="key-takeaways">Key Takeaways</h2>
<ol type="1">
<li><p><strong>Don’t apply softmax in your model’s final layer for classification.</strong> <code>nn.CrossEntropyLoss</code> expects raw logits and applies a fused, numerically stable log-softmax internally using the log-sum-exp trick. Pre-applying softmax computes gradients of the wrong function.</p></li>
<li><p><strong>The numerical instability is real and silent.</strong> Large logits overflow naive softmax — you get <code>nan</code> losses and corrupted gradients, often without a clear error. The fused implementation avoids this entirely.</p></li>
<li><p><strong>Multi-label tasks need sigmoid, not softmax.</strong> Softmax enforces mutual exclusivity. For tasks where multiple labels are simultaneously valid, use <code>nn.BCEWithLogitsLoss</code> with raw logits.</p></li>
<li><p><strong>Overconfidence is a logit scale problem.</strong> Softmax exaggerates differences as magnitudes grow through training. Temperature scaling is the standard fix — but only if raw logit scale is preserved through training.</p></li>
<li><p><strong>Softmax has legitimate uses.</strong> Attention weights, contrastive losses, and inference-time probability outputs are correct applications. The question is always whether competition semantics fit the problem, and whether a fused stable implementation already handles the math downstream.</p></li>
</ol>
</section>
<section id="resources" class="level2">
<h2 class="anchored" data-anchor-id="resources">Resources</h2>
<ol type="1">
<li><a href="https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html"><strong>PyTorch Documentation — CrossEntropyLoss</strong></a> — Documents why raw logits are expected and how log-softmax is fused internally.</li>
<li><a href="https://arxiv.org/abs/1706.04599"><strong>On Calibration of Modern Neural Networks</strong></a> — Guo et al.&nbsp;on systematic softmax overconfidence and temperature scaling as the practical fix.</li>
<li><a href="https://www.deeplearningbook.org/contents/mlp.html"><strong>Deep Learning Book — Chapter 6</strong></a> — Goodfellow et al.&nbsp;on output units and loss function design for classification.</li>
</ol>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Deep Learning</category>
  <guid>https://imaddabbura.github.io/posts/dl/why-not-softmax.html</guid>
  <pubDate>Sun, 09 Jun 2024 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/dl/images/softmax-img.png" medium="image" type="image/png" height="79" width="144"/>
</item>
<item>
  <title>Cutting the Fat: A Practical Guide to Neural Network Pruning</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/efficient-ml/pruning-dl-models.html</link>
  <description><![CDATA[ 





<p>Neural network pruning is a critical optimization technique used to enhance the efficiency of deep learning models by systematically removing unnecessary parameters, such as weights or neurons, while maintaining model performance. This technique is particularly important because memory access and movement are extremely expensive operations in terms of both latency and energy consumption.</p>
<section id="why-do-we-need-pruning" class="level2">
<h2 class="anchored" data-anchor-id="why-do-we-need-pruning">Why Do We Need Pruning?</h2>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="pruning-optimization.png" class="lightbox" data-gallery="quarto-lightbox-gallery-1"><img src="https://imaddabbura.github.io/posts/efficient-ml/pruning-optimization.png" class="quarto-figure quarto-figure-center figure-img" width="800" height="500"></a></p>
</figure>
</div>
<p>The primary objective of pruning can be formalized as minimizing a loss function <img src="https://latex.codecogs.com/png.latex?L(W_P)">, where <img src="https://latex.codecogs.com/png.latex?W"> represents the original weights, <img src="https://latex.codecogs.com/png.latex?W_P"> is the pruned weights subject to the constraint that the number of non-zero weights (<img src="https://latex.codecogs.com/png.latex?%7C%7CW_P%7C%7C_0">) is less than a threshold <img src="https://latex.codecogs.com/png.latex?N">.</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Carg%5Cmin_%7BW_P%7D%20L(x,%20W_P)"></p>
<p>subject to <img src="https://latex.codecogs.com/png.latex?%7C%7CW_p%7C%7C_0%20%3C%20N"></p>
<p>This optimization leads to sparse weight matrices, which can significantly reduce:</p>
<ul>
<li>Model size</li>
<li>Memory footprint</li>
<li>Computational complexity</li>
<li>Energy consumption</li>
</ul>
</section>
<section id="types-of-pruning" class="level2">
<h2 class="anchored" data-anchor-id="types-of-pruning">Types of Pruning</h2>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="pruning-weight-vs-activation.png" class="lightbox" data-gallery="quarto-lightbox-gallery-2"><img src="https://imaddabbura.github.io/posts/efficient-ml/pruning-weight-vs-activation.png" class="quarto-figure quarto-figure-center figure-img" width="800" height="500"></a></p>
</figure>
</div>
<ol type="1">
<li><strong>Weight Pruning</strong>:</li>
</ol>
<ul>
<li>Focuses on removing connections between neurons</li>
<li>Reduces model size and computational complexity</li>
<li>Based on weight importance metrics</li>
<li>May require fine-tuning after pruning</li>
<li>Results in sparse weight matrices <a href="https://www.researchgate.net/publication/318471114_Activation_Pruning_of_Deep_Convolutional_Neural_Networks#:~:text=Activation%20Pruning%20of">[1]</a></li>
</ul>
<ol start="2" type="1">
<li><strong>Activation Pruning</strong>:</li>
</ol>
<ul>
<li>Removes entire neurons or channels</li>
<li>Reduces computational cost more directly</li>
<li>Based on activation importance</li>
<li>Can lead to better performance in terms of misclassification error compared to unpruned networks <a href="https://www.researchgate.net/publication/318471114_Activation_Pruning_of_Deep_Convolutional_Neural_Networks#:~:text=Activation%20Pruning%20of">[1]</a></li>
</ul>
</section>
<section id="pruning-approaches" class="level2">
<h2 class="anchored" data-anchor-id="pruning-approaches">Pruning Approaches</h2>
<section id="only-pruning" class="level3">
<h3 class="anchored" data-anchor-id="only-pruning">1. Only Pruning</h3>
<ul>
<li>Simplest approach: directly remove weights without additional steps</li>
<li>Often results in significant accuracy drop</li>
<li>Not recommended for production systems</li>
</ul>
</section>
<section id="pruning-fine-tuning" class="level3">
<h3 class="anchored" data-anchor-id="pruning-fine-tuning">2. Pruning + Fine-tuning</h3>
<ul>
<li>Prune weights followed by model fine-tuning</li>
<li>Helps recover accuracy lost during pruning</li>
<li>More effective than pruning alone <a href="https://openreview.net/pdf?id=Cb54AMqHQFP#:~:text=In%20this%20paper%2C%20we,is%20fine%2Dtuning%2C%20which%20aims">[2]</a></li>
</ul>
</section>
<section id="iterative-pruning-fine-tuning" class="level3">
<h3 class="anchored" data-anchor-id="iterative-pruning-fine-tuning">3. Iterative Pruning + Fine-tuning</h3>
<ul>
<li>Gradually removes weights over multiple steps</li>
<li>Each step is less aggressive than the previous</li>
<li>Includes fine-tuning between pruning steps</li>
<li>Achieves best accuracy with high pruning ratios (90%+ range)</li>
<li>More computationally expensive but yields better results <a href="https://arxiv.org/html/2409.19727v1#:~:text=allows%20the%20network%20to,a%20large%20portion%20of">[3]</a></li>
</ul>
</section>
</section>
<section id="pruning-granularities" class="level2">
<h2 class="anchored" data-anchor-id="pruning-granularities">Pruning Granularities</h2>
<section id="fine-grained-unstructured-pruning" class="level3">
<h3 class="anchored" data-anchor-id="fine-grained-unstructured-pruning">1. Fine-grained (Unstructured) Pruning</h3>
<p><strong>Pros:</strong></p>
<ul>
<li>Maximum flexibility in weight selection</li>
<li>Highest possible compression ratio</li>
<li>Minimal accuracy loss <a href="https://www.researchgate.net/publication/381997604_Research_on_pruning_optimization_techniques_for_neural_networks#:~:text=ratio.%20Fine%2Dgrained%20pruning%20has,because%20it%20is%20unstructured">[4]</a></li>
</ul>
<p><strong>Cons:</strong></p>
<ul>
<li>Irregular weight indices</li>
<li>Difficult to accelerate on hardware</li>
<li>Requires specialized implementations <a href="https://www.researchgate.net/publication/381997604_Research_on_pruning_optimization_techniques_for_neural_networks#:~:text=the%20pruned%20model%20cannot,usually%20requires%20additional%20hardware">[5]</a></li>
</ul>
</section>
<section id="coarse-grained-structured-pruning" class="level3">
<h3 class="anchored" data-anchor-id="coarse-grained-structured-pruning">2. Coarse-grained (Structured) Pruning</h3>
<p><strong>Pros:</strong></p>
<ul>
<li>Hardware-friendly</li>
<li>Easier to implement</li>
<li>Maintains dense matrix operations</li>
</ul>
<p><strong>Cons:</strong></p>
<ul>
<li>Less flexible than fine-grained pruning</li>
<li>Limited to row/column pruning</li>
<li>May result in lower compression ratios</li>
</ul>
</section>
<section id="pattern-based-pruning-nm-sparsity" class="level3">
<h3 class="anchored" data-anchor-id="pattern-based-pruning-nm-sparsity">3. Pattern-based Pruning (N:M Sparsity)</h3>
<ul>
<li>For every M contiguous elements, N elements must be pruned</li>
<li>Common pattern is 2:4 sparsity (50%)</li>
<li>Uses compressed matrix format:
<ul>
<li>One matrix for non-zero values</li>
<li>One matrix for indices (bit)</li>
</ul></li>
<li>Some hardware architectures support this scheme natively</li>
</ul>
</section>
<section id="channel-based-pruning" class="level3">
<h3 class="anchored" data-anchor-id="channel-based-pruning">4. Channel-based Pruning</h3>
<p><strong>Pros:</strong></p>
<ul>
<li>Most regular structure</li>
<li>Highest potential speedup</li>
<li>Straightforward implementation <a href="https://www.sciencedirect.com/science/article/abs/pii/S0952197625009200">[6]</a></li>
</ul>
<p><strong>Cons:</strong></p>
<ul>
<li>Least flexible</li>
<li>Lower compression ratio</li>
<li>Can have uniform or varying sparsity across layers</li>
</ul>
</section>
</section>
<section id="pruning-criteria" class="level2">
<h2 class="anchored" data-anchor-id="pruning-criteria">Pruning Criteria</h2>
<section id="magnitude-based-pruning" class="level3">
<h3 class="anchored" data-anchor-id="magnitude-based-pruning">1. Magnitude-based Pruning</h3>
<ul>
<li>Removes weights with smallest magnitude</li>
<li>Uses L1 or L2 norm for measurement</li>
<li>Can be applied row-wise for improved regularity</li>
</ul>
</section>
<section id="scaling-based-pruning" class="level3">
<h3 class="anchored" data-anchor-id="scaling-based-pruning">2. Scaling-based Pruning</h3>
<ul>
<li>Associates learnable scaling factors with output channels</li>
<li>Prunes channels with small scaling factors</li>
<li>More adaptive than simple magnitude-based methods <a href="https://arxiv.org/pdf/1912.04845#:~:text=The%20strategy%20that%20has,results%20by%20pruning%20weights">[7]</a></li>
</ul>
</section>
<section id="percentage-of-zero-based-pruning" class="level3">
<h3 class="anchored" data-anchor-id="percentage-of-zero-based-pruning">3. Percentage-of-Zero-Based Pruning</h3>
<ul>
<li>Focuses on activation patterns</li>
<li>Removes channels with highest percentage of zeros</li>
<li>Requires analysis of activation patterns during inference</li>
<li>Dynamic approach compared to static weight pruning</li>
</ul>
</section>
<section id="regression-based-pruning" class="level3">
<h3 class="anchored" data-anchor-id="regression-based-pruning">4. Regression-based Pruning</h3>
<ul>
<li>Minimizes reconstruction error of layer outputs</li>
<li>Avoids full backpropagation</li>
<li>Particularly effective for Large Language Models</li>
<li>More sophisticated approach with better accuracy retention</li>
</ul>
</section>
</section>
<section id="important-considerations" class="level2">
<h2 class="anchored" data-anchor-id="important-considerations">Important Considerations</h2>
<ol type="1">
<li><p><strong>Large vs.&nbsp;Small Models</strong>: It’s generally better to prune a large model than train a smaller model from scratch. Over-parameterization helps avoid local minima by providing more dimensions to escape saddle points.</p></li>
<li><p><strong>Hardware Considerations</strong>: The choice of pruning granularity should consider the target hardware architecture. Structured pruning may be preferred for standard hardware, while specialized hardware might better handle unstructured pruning.</p></li>
<li><p><strong>Layer-wise Pruning</strong>: Different layers may have different levels of redundancy, making uniform pruning across all layers suboptimal. Adaptive approaches that consider layer-specific characteristics often yield better results.</p></li>
</ol>
<p>This comprehensive understanding of pruning techniques enables Machine Learning Engineers/Data Scientists to make informed decisions when optimizing their deep learning models for specific applications and hardware constraints.</p>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Deep Learning</category>
  <guid>https://imaddabbura.github.io/posts/efficient-ml/pruning-dl-models.html</guid>
  <pubDate>Fri, 03 May 2024 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/efficient-ml/pruning.jpeg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>Building GPT(2/3) from Scratch: Turning Theory into a Working Transformer</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/nlp/GPT2-From-Scratch.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p>There’s an old saying in engineering: “You don’t really understand something until you can build it.” This has never been more true than in the era of LLMs. While we’ve previously explored the foundational concepts in my post on the Transformer architecture <a href="https://imaddabbura.github.io/posts/nlp/Transformer-Architecture-Explained.html">explained here</a>, true understanding comes from implementation. That’s why today, we’re building a GPT-style model (the 124M variant) from scratch in PyTorch.</p>
<p>This project has a different focus than my last “from scratch” endeavor, where I built an <a href="https://github.com/ImadDabbura/tiny-pytorch">entire deep learning framework</a> to grasp the low-level mechanics of autograd and tensor ops. Here, we’ll leverage PyTorch’s battle-tested primitives to focus on what makes GPT special: multi-head attention, positional encodings, and the specific architectural decisions that enable language understanding.</p>
<p>This hands-on process reveals challenges you can’t appreciate from diagrams alone. You’ll watch your GPU memory overflow, see training grind to a halt from inefficient data loading, and learn firsthand why techniques like mixed-precision training, gradient accumulation, and activation checkpointing are necessities, not just optimizations. It’s in facing these hurdles that you truly appreciate the engineering craft required to build and scale transformers efficiently.</p>
</section>
<section id="gpts" class="level2">
<h2 class="anchored" data-anchor-id="gpts">GPTs</h2>
<p>GPT (Generative Pre-trained Transformer) models, developed by OpenAI, represent a breakthrough in natural language processing. GPT-2, released in 2019, demonstrated that a transformer-based model trained on vast amounts of text could generate remarkably coherent and contextually relevant content. GPT-3, its successor, scaled this approach to 175 billion parameters, showcasing emergent capabilities like few-shot learning and complex reasoning. Both models share the same fundamental architecture: stacked transformer decoder blocks that predict the next token in a sequence, trained on the simple objective of minimizing prediction error across massive text corpora. The 124M parameter version we’ll be building captures the essential architecture while remaining computationally tractable for individual developers—though even at this “small” scale, you’ll quickly discover why the ML community spends so much time optimizing both training efficiency and model performance.</p>
<p>By the end of this journey, you won’t just know how transformers work—you’ll have built the critical components with your own hands, optimized the training loop, and watched your model evolve from random noise to coherent text generation. Let’s begin.</p>
</section>
<section id="implementation" class="level2">
<h2 class="anchored" data-anchor-id="implementation">Implementation</h2>
<p>Throughout this implementation, every piece of code will be thoroughly annotated with explanations of not just what we’re doing, but why we’re doing it. More importantly, we’ll use few optimizations that make a big difference in terms of computational efficiency:</p>
<ul>
<li><p><strong>TensorFloat32 (TF32)</strong>: NVIDIA’s precision format that uses 19 bits of precision instead of 23, providing up to 8x speedup on A100 GPUs while maintaining model quality. We’ll see how a single line of code can dramatically accelerate matrix multiplications.</p></li>
<li><p><strong>BFloat16 with Autocast</strong>: Mixed precision training using brain floating-point format, which maintains the same exponent range as FP32 but reduces mantissa precision. Combined with automatic mixed precision (AMP), this cuts memory usage in half and speeds up training significantly.</p></li>
<li><p><strong>torch.compile</strong>: PyTorch 2.0’s just-in-time compilation that fuses operations and generates optimized kernels. We’ll explore how graph compilation can provide 10-30% speedups with minimal code changes.</p></li>
<li><p><strong>Flash Attention and Online Softmax</strong>: An algorithmic improvement that computes attention without materializing the full attention matrix, reducing memory complexity from O(n²) to O(n).</p></li>
<li><p><strong>Fused AdamW</strong>: A single-kernel implementation of the AdamW optimizer that reduces memory reads/writes by computing all parameter updates in one pass, providing up to 2x optimizer step speedup.</p></li>
<li><p><strong>Annealed Learning Rate</strong>: Starting with a warmup phase followed by cosine decay, we’ll implement the learning rate schedule that has become standard for training transformers, understanding why stable training requires careful lr management.</p></li>
<li><p><strong>Weight Decay Only on Matrices</strong>: A subtle but crucial detail—applying weight decay only to weight matrices in Linear and Embedding layers while excluding biases and layer normalization parameters, which improves model performance.</p></li>
<li><p><strong>Distributed Data Parallelism (DDP)</strong>: Scaling training across multiple GPUs using PyTorch’s DDP, including gradient synchronization, proper data loading, and the intricacies of maintaining consistent model states across devices.</p></li>
</ul>
<p>Finally, since the GPT-2 paper omits certain architectural details and hyperparameter specifications, we’ll refer to the GPT-3 paper to fill these gaps—fortunately, the core architecture remains consistent between the two models, making the GPT-3 paper a reliable source for these missing implementation details.</p>
<div id="660e14ce-2270-4cf4-9078-36759f24db29" class="cell">
<div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## | code-fold: true</span></span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> inspect</span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> math</span>
<span id="cb1-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> os</span>
<span id="cb1-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> time</span>
<span id="cb1-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> dataclasses <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> dataclass</span>
<span id="cb1-7"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> functools <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> partial, wraps</span>
<span id="cb1-8"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> typing <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Callable, Iterable</span>
<span id="cb1-9"></span>
<span id="cb1-10"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> tiktoken</span>
<span id="cb1-11"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb1-12"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.distributed <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> dist</span>
<span id="cb1-13"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.nn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> nn</span>
<span id="cb1-14"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.nn.functional <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> F</span>
<span id="cb1-15"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.optim <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> opt</span>
<span id="cb1-16"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torch.distributed <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> destroy_process_group, init_process_group</span>
<span id="cb1-17"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> torch.nn.parallel <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> DistributedDataParallel <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> DDP</span></code></pre></div>
</div>
<div id="f585c521" class="cell">
<div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## | code-fold: true</span></span>
<span id="cb2-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> listify(obj):</span>
<span id="cb2-3">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> obj <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb2-4">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> []</span>
<span id="cb2-5">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(obj, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>):</span>
<span id="cb2-6">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> [obj]</span>
<span id="cb2-7">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(obj, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>):</span>
<span id="cb2-8">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> obj</span>
<span id="cb2-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(obj, Iterable):</span>
<span id="cb2-10">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(obj)</span>
<span id="cb2-11">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span>:</span>
<span id="cb2-12">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> [obj]</span></code></pre></div>
</div>
<div id="4fd6e093-b414-4fb9-a165-04ae9e312c03" class="cell">
<div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## | code-fold: true</span></span>
<span id="cb3-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> annealer(func: Callable):</span>
<span id="cb3-3">    wraps(func)</span>
<span id="cb3-4"></span>
<span id="cb3-5">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> annealer_wrapper(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>args, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>kwargs):</span>
<span id="cb3-6">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> partial(func, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>args, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>kwargs)</span>
<span id="cb3-7"></span>
<span id="cb3-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> annealer_wrapper</span>
<span id="cb3-9"></span>
<span id="cb3-10"></span>
<span id="cb3-11"><span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@annealer</span></span>
<span id="cb3-12"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> lin_sched(start, end, pos):</span>
<span id="cb3-13">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Linear scheduler."""</span></span>
<span id="cb3-14">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> start <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> (end <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> start) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> pos</span>
<span id="cb3-15"></span>
<span id="cb3-16"></span>
<span id="cb3-17"><span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@annealer</span></span>
<span id="cb3-18"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> cos_sched(start, end, pos):</span>
<span id="cb3-19">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Cosine scheduler."""</span></span>
<span id="cb3-20">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> start <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> math.cos(math.pi <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> pos))) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> (end <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> start) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span></span>
<span id="cb3-21"></span>
<span id="cb3-22"></span>
<span id="cb3-23"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> combine_scheds(pcts, scheds):</span>
<span id="cb3-24">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb3-25"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Combine multiple schedulers, each run for a given percentage of the</span></span>
<span id="cb3-26"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    training process.</span></span>
<span id="cb3-27"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    """</span></span>
<span id="cb3-28">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">assert</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(pcts) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(scheds), <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Each scheduler should have its `pct`."</span></span>
<span id="cb3-29">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">assert</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>(pcts) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Sum of the `pcts` should be equal to 1."</span></span>
<span id="cb3-30">    pcts <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tensor([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> listify(pcts))</span>
<span id="cb3-31">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">assert</span> (pcts <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">all</span>(), <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"All percentages should be non-negative."</span></span>
<span id="cb3-32">    pcts <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.cumsum(pcts, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb3-33"></span>
<span id="cb3-34">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> _inner(pos):</span>
<span id="cb3-35">        idx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> pcts).nonzero().<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>()</span>
<span id="cb3-36">        actual_pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> pcts[idx]) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (pcts[idx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> pcts[idx])</span>
<span id="cb3-37">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> scheds[idx](actual_pos)</span>
<span id="cb3-38"></span>
<span id="cb3-39">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> _inner</span></code></pre></div>
</div>
<div id="96f98324-58a9-49f1-9ad1-789c1a6b20e4" class="cell">
<div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@dataclass</span></span>
<span id="cb4-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> GPTConfig:</span>
<span id="cb4-3">    block_sz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1024</span> <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Sequence length</span></span>
<span id="cb4-4">    vocab_sz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (</span>
<span id="cb4-5">        <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50257</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Originally 50000 BPE merges + 256 byte tokens + 1 for &lt;|endoftext|&gt; token</span></span>
<span id="cb4-6">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## which will delimits different documents. This token's index is 50256</span></span>
<span id="cb4-7">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## However, we found that using 50257 as the vocab size is not a multiple of 64 and we</span></span>
<span id="cb4-8">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## could improve efficiency and performance (through better occupancy) if we round up</span></span>
<span id="cb4-9">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## to the closest multiple of 64, which is 5304.</span></span>
<span id="cb4-10">    )</span>
<span id="cb4-11">    n_layer: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span> <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Number of layers</span></span>
<span id="cb4-12">    n_embd: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">768</span> <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Embedding dimension</span></span>
<span id="cb4-13">    n_head: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span> <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Number of attention heads</span></span>
<span id="cb4-14">    lr: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">3e-4</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Good for big models</span></span>
<span id="cb4-15">    batch_sz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span></span>
<span id="cb4-16">    dropout: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span></span>
<span id="cb4-17">    bias: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">bool</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span></span></code></pre></div>
</div>
<div id="f41875c6" class="cell">
<div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> MLP(nn.Module):</span>
<span id="cb5-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config: GPTConfig):</span>
<span id="cb5-3">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Point-wise feed-forward network that applies non-linearity</span></span>
<span id="cb5-4">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## on every token separately. THERE IS NO INTERACTION BETWEEN TOKENS</span></span>
<span id="cb5-5">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## This is where almost all the capacity and non-linearities of the </span></span>
<span id="cb5-6">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## model come from especially when we project it to 4 x n_embd</span></span>
<span id="cb5-7">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb5-8">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_fc <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.n_embd, config.n_embd <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>)</span>
<span id="cb5-9">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Found to be better than ReLU in terms of gradient saturation</span></span>
<span id="cb5-10">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.gelu <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.GELU(approximate<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"tanh"</span>)</span>
<span id="cb5-11">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_proj <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.n_embd <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, config.n_embd)</span>
<span id="cb5-12">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Dropout(config.dropout)</span>
<span id="cb5-13">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_proj.NANOGPT_SCALE_INIT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb5-14"></span>
<span id="cb5-15">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb5-16">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_proj(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.gelu(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_fc(x))))</span></code></pre></div>
</div>
<div id="a14dee4c" class="cell">
<div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> CausalSelfAttention(nn.Module):</span>
<span id="cb6-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config: GPTConfig):</span>
<span id="cb6-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb6-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_head <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> config.n_head</span>
<span id="cb6-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_embd <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> config.n_embd</span>
<span id="cb6-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_attn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.n_embd, config.n_embd <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>config.bias)</span>
<span id="cb6-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_proj <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.n_embd, config.n_embd, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>config.bias)</span>
<span id="cb6-8">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.attn_dropout <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Dropout(config.dropout)</span>
<span id="cb6-9">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.resid_dropout <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Dropout(config.dropout)</span>
<span id="cb6-10">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> config.dropout</span>
<span id="cb6-11">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_proj.NANOGPT_SCALE_INIT <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb6-12">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## </span><span class="al" style="color: #AD0000;
background-color: null;
font-style: inherit;">NOTE</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">: Mask is not needed when we use Pytorch's Flash attention</span></span>
<span id="cb6-13">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## self.register_buffer(</span></span>
<span id="cb6-14">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##     "mask",</span></span>
<span id="cb6-15">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##     torch.tril(torch.ones(config.block_sz, config.block_sz)).view(</span></span>
<span id="cb6-16">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##         config.block_sz, config.block_sz</span></span>
<span id="cb6-17">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##     ),</span></span>
<span id="cb6-18">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## )</span></span>
<span id="cb6-19"></span>
<span id="cb6-20">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb6-21">        B, T, C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.shape</span>
<span id="cb6-22">        qkv <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_attn(x)</span>
<span id="cb6-23">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## q/k/v is B x T x n_embd each</span></span>
<span id="cb6-24">        q, k, v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.split(qkv, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_embd, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb6-25">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Reshape q/k/v to B x n_head x T x (n_embd / n_head)</span></span>
<span id="cb6-26">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## So each head would be learning different kind of</span></span>
<span id="cb6-27">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## relationships</span></span>
<span id="cb6-28">        q <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> q.view(B, T, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_head, C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_head).transpose(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb6-29">        k <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> k.view(B, T, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_head, C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_head).transpose(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb6-30">        v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> v.view(B, T, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_head, C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_head).transpose(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb6-31">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## attn is B x T x T</span></span>
<span id="cb6-32">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.shape[-1]))</span></span>
<span id="cb6-33">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## ## Mask out future tokens</span></span>
<span id="cb6-34">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## attn = attn.masked_fill(self.mask[:T, :T] == 0, float("-inf"))</span></span>
<span id="cb6-35">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## attn = self.attn_dropout(F.softmax(attn, dim=-1))</span></span>
<span id="cb6-36">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## ## y is B x T x n_embd</span></span>
<span id="cb6-37">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## y = attn @ v</span></span>
<span id="cb6-38">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Uses Flash attention that never materialize attention matrices for</span></span>
<span id="cb6-39">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## each head and is aware of the memory hierarchy and tries to reduce</span></span>
<span id="cb6-40">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## read/writes with more FLOPs -&gt; Speed up since we're memory bound</span></span>
<span id="cb6-41">        y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> F.scaled_dot_product_attention(q, k, v, is_causal<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb6-42">        y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> y.transpose(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>).contiguous().view(B, T, C)</span>
<span id="cb6-43">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.resid_dropout(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.c_proj(y))</span></code></pre></div>
</div>
<div id="85779fd0" class="cell">
<div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> Block(nn.Module):</span>
<span id="cb7-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config: GPTConfig):</span>
<span id="cb7-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb7-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ln_1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.n_embd)</span>
<span id="cb7-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ln_2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.n_embd)</span>
<span id="cb7-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.mlp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> MLP(config)</span>
<span id="cb7-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.attn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> CausalSelfAttention(config)</span>
<span id="cb7-8"></span>
<span id="cb7-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb7-10">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Use Pre-layer normalization which deviates from the</span></span>
<span id="cb7-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## transformer original paper that uses post-layer normalization.</span></span>
<span id="cb7-12">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## This should help stabilize training</span></span>
<span id="cb7-13">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.attn(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ln_1(x))</span>
<span id="cb7-14">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.mlp(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ln_2(x))</span>
<span id="cb7-15">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> x</span></code></pre></div>
</div>
<div id="20a4b2f4" class="cell">
<div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> GPT2(nn.Module):</span>
<span id="cb8-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config: GPTConfig):</span>
<span id="cb8-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb8-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.config <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> config</span>
<span id="cb8-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.transformer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.ModuleDict(</span>
<span id="cb8-6">            <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">dict</span>(</span>
<span id="cb8-7">                wte<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>nn.Embedding(config.vocab_sz, config.n_embd),</span>
<span id="cb8-8">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Attention operation is a permutation equivariant, this means that</span></span>
<span id="cb8-9">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## if we permute the input then the corresponding output will be</span></span>
<span id="cb8-10">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## permuted in exactly the same way. In other words, attention mechanism</span></span>
<span id="cb8-11">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## is not aware of the relative ordering of the tokens. Therefore, we</span></span>
<span id="cb8-12">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## need some way to encode the positions of the tokens in each sequence.</span></span>
<span id="cb8-13">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## This is where positional encoding comes into play.</span></span>
<span id="cb8-14">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Here we use a simple positional encoding that is a simple</span></span>
<span id="cb8-15">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## embedding of the position of the token in the sequence.</span></span>
<span id="cb8-16">                wpe<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>nn.Embedding(config.block_sz, config.n_embd),</span>
<span id="cb8-17">                h<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>nn.ModuleList(</span>
<span id="cb8-18">                    [Block(config) <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(config.n_layer)]</span>
<span id="cb8-19">                ),</span>
<span id="cb8-20">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Final layer norm after all transformer layers</span></span>
<span id="cb8-21">                ln_f<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>nn.LayerNorm(config.n_embd),</span>
<span id="cb8-22">            )</span>
<span id="cb8-23">        )</span>
<span id="cb8-24">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.lm_head <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.n_embd, config.vocab_sz, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb8-25"></span>
<span id="cb8-26">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Weigth sharing between the token embedding layer and</span></span>
<span id="cb8-27">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## last linear layer (LM head classifier). The rationale is</span></span>
<span id="cb8-28">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## that tokens that are semantically similar to each other in</span></span>
<span id="cb8-29">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## the embedding space should have similar probabilities in the</span></span>
<span id="cb8-30">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## softmax of the LM head layer</span></span>
<span id="cb8-31">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Also, these matrices are one of the biggest matrices in the the model</span></span>
<span id="cb8-32">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## This means, for model like GPT2, we save almost 30 % of the parameters</span></span>
<span id="cb8-33">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## by sharing the weight matrices (50257 * 768) / 124M = ~31%</span></span>
<span id="cb8-34">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.transformer.wte.weight <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.lm_head.weight</span>
<span id="cb8-35">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">apply</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._init_weights)</span>
<span id="cb8-36"></span>
<span id="cb8-37">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> _init_weights(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, module):</span>
<span id="cb8-38">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## The following initialization comes from gpt2 src code</span></span>
<span id="cb8-39">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## </span><span class="al" style="color: #AD0000;
background-color: null;
font-style: inherit;">NOTE</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">: Because token embedding and classifier weights are shared,</span></span>
<span id="cb8-40">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## our initialization logic will initialize the weight matrix twice</span></span>
<span id="cb8-41">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## but shouldn't be an issue since they're being initialized with the</span></span>
<span id="cb8-42">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## same std and mean</span></span>
<span id="cb8-43">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(module, nn.Linear):</span>
<span id="cb8-44">            std <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.02</span></span>
<span id="cb8-45">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## We're changing std because residual path affect std</span></span>
<span id="cb8-46">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## by increasing it on every layer so we need to adjust</span></span>
<span id="cb8-47">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## it so we still have the same std = 0.02</span></span>
<span id="cb8-48">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">hasattr</span>(module, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"NANOGPT_SCALE_INIT"</span>):</span>
<span id="cb8-49">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## `2` here because every layer has two blocks:</span></span>
<span id="cb8-50">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##   - Attention block</span></span>
<span id="cb8-51">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##   - MLP block</span></span>
<span id="cb8-52">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## `N` is the number of layers in the model (n_layer)</span></span>
<span id="cb8-53">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Since they are independent, variance of the sum of the two</span></span>
<span id="cb8-54">                <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## blocks is the sum of the variances</span></span>
<span id="cb8-55">                std <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.config.n_layer) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span></span>
<span id="cb8-56">            nn.init.normal_(module.weight, std<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>std)</span>
<span id="cb8-57">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> module.bias <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">is</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">not</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb8-58">                nn.init.zeros_(module.bias)</span>
<span id="cb8-59">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">elif</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">isinstance</span>(module, nn.Embedding):</span>
<span id="cb8-60">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## We're initializing the token and positional embeddings</span></span>
<span id="cb8-61">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## with the same std but the paper initialized the positional</span></span>
<span id="cb8-62">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## embedding with std = 0.01</span></span>
<span id="cb8-63">            nn.init.normal_(module.weight, std<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.02</span>)</span>
<span id="cb8-64"></span>
<span id="cb8-65">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x, targets<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>):</span>
<span id="cb8-66">        T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.shape[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]</span>
<span id="cb8-67">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">assert</span> (</span>
<span id="cb8-68">            T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.config.block_sz</span>
<span id="cb8-69">        ), <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Sequence length must be &lt;= </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>config<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>block_sz<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, got </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>T<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span></span>
<span id="cb8-70">        pos_emb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.transformer.wpe(</span>
<span id="cb8-71">            torch.arange(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, T, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">long</span>, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>x.device)</span>
<span id="cb8-72">        )</span>
<span id="cb8-73">        tok_emb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.transformer.wte(x)</span>
<span id="cb8-74">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pos_emb <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> tok_emb</span>
<span id="cb8-75">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> block <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.transformer.h:</span>
<span id="cb8-76">            x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> block(x)</span>
<span id="cb8-77">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.transformer.ln_f(x)</span>
<span id="cb8-78">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## logits is B x T x vocab_sz</span></span>
<span id="cb8-79">        logits <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.lm_head(x)</span>
<span id="cb8-80">        loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span>
<span id="cb8-81">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> targets <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">is</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">not</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb8-82">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## F.cross_entropy expects the 2nd dimension to be probabilities</span></span>
<span id="cb8-83">            loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> F.cross_entropy(</span>
<span id="cb8-84">                logits.view(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.config.vocab_sz), targets.view(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb8-85">            )</span>
<span id="cb8-86">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> logits, loss</span>
<span id="cb8-87"></span>
<span id="cb8-88">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> configure_optimizer(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, weight_decay, lr, device):</span>
<span id="cb8-89">        params_dict <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {</span>
<span id="cb8-90">            pn: p <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> pn, p <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.named_parameters() <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> p.requires_grad</span>
<span id="cb8-91">        }</span>
<span id="cb8-92">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## We're not applying weight decay to bias and layer norm parameters</span></span>
<span id="cb8-93">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## And any 1D parameters. Therefore, we are ONLY applying weight decay</span></span>
<span id="cb8-94">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## to the weight matrices in Embedding and Linear layers</span></span>
<span id="cb8-95">        decay_params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [p <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> p <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> params_dict.values() <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> p.ndim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>]</span>
<span id="cb8-96">        nondecay_params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [p <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> p <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> params_dict.values() <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> p.ndim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>]</span>
<span id="cb8-97">        params_groups <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> [</span>
<span id="cb8-98">            {<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"params"</span>: decay_params, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"weight_decay"</span>: weight_decay},</span>
<span id="cb8-99">            {<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"params"</span>: nondecay_params, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"weight_decay"</span>: <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span>},</span>
<span id="cb8-100">        ]</span>
<span id="cb8-101">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Fused AdamW is available for PyTorch 2.0+</span></span>
<span id="cb8-102">        fused_available <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"fused"</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> inspect.signature(opt.AdamW).parameters</span>
<span id="cb8-103">        use_fused <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> fused_available <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">and</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cuda"</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> device</span>
<span id="cb8-104">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> opt.AdamW(</span>
<span id="cb8-105">            params_groups, lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>lr, betas<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.9</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.95</span>), eps<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1e-8</span>, fused<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>use_fused</span>
<span id="cb8-106">        )</span>
<span id="cb8-107"></span>
<span id="cb8-108">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@torch.no_grad</span></span>
<span id="cb8-109">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> generate(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, idxs: torch.tensor, max_tokens: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>):</span>
<span id="cb8-110">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(max_tokens):</span>
<span id="cb8-111">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x would be B x T x vocab_sz (At most we we would have</span></span>
<span id="cb8-112">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## block_sz tokens since we're using fixed block_sz for the</span></span>
<span id="cb8-113">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## positional embedding</span></span>
<span id="cb8-114">            idxs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> idxs[:, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.config.block_sz :]</span>
<span id="cb8-115">            logits, _ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>(idxs)</span>
<span id="cb8-116">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Get probs for last token to predict next token</span></span>
<span id="cb8-117">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## This would be B x vocab_sz</span></span>
<span id="cb8-118">            logits <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> logits[:, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, :]</span>
<span id="cb8-119">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Apply softmax to get probabilities</span></span>
<span id="cb8-120">            probs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> F.softmax(logits, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb8-121">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Pick top 50 prob -&gt; we would never pick tokens with</span></span>
<span id="cb8-122">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## very smally probs (right tails) -&gt; B x 50</span></span>
<span id="cb8-123">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## probs/idxs are sorted in descending order</span></span>
<span id="cb8-124">            topk_probs, topk_idxs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.topk(probs, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50</span>, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb8-125">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Sample 1 token from the top 50 tokens -&gt; idx is B x 1</span></span>
<span id="cb8-126">            idx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.multinomial(topk_probs, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb8-127">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Get the vocab idx as `multinomial` returns only indices that</span></span>
<span id="cb8-128">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## corresponds to the given array</span></span>
<span id="cb8-129">            idx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.gather(topk_idxs, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, idx)</span>
<span id="cb8-130">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## </span><span class="al" style="color: #AD0000;
background-color: null;
font-style: inherit;">TODO</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">: We should check for end_of_text token and break out of</span></span>
<span id="cb8-131">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## the loop (stop generation) even if we have not reached max_tokens</span></span>
<span id="cb8-132">            idxs <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.cat([idxs, idx], dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb8-133">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> idxs</span></code></pre></div>
</div>
<div id="b10a62ed" class="cell">
<div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> DataLoaderLight:</span>
<span id="cb9-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(</span>
<span id="cb9-3">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>,</span>
<span id="cb9-4">        file_path: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>,</span>
<span id="cb9-5">        batch_sz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>,</span>
<span id="cb9-6">        block_sz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>,</span>
<span id="cb9-7">        process_rank: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>,</span>
<span id="cb9-8">        number_processes: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>,</span>
<span id="cb9-9">    ) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb9-10">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_sz</span>
<span id="cb9-11">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.block_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> block_sz</span>
<span id="cb9-12">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.process_rank <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> process_rank</span>
<span id="cb9-13">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.number_processes <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> number_processes</span>
<span id="cb9-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">with</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">open</span>(file_path, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"r"</span>) <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> f:</span>
<span id="cb9-15">            text <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> f.read()</span>
<span id="cb9-16">        encoder <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> tiktoken.get_encoding(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"gpt2"</span>)</span>
<span id="cb9-17">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.tokens <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tensor(encoder.encode(text), dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>torch.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">long</span>)</span>
<span id="cb9-18">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## We can truncate the tokens to a be multiple of batch_sz x block_sz</span></span>
<span id="cb9-19">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x number_processes. This is useful for multi-node training and mimics</span></span>
<span id="cb9-20">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## the behavior of DataLoader's `drop_last` parameter.</span></span>
<span id="cb9-21">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.tokens <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.tokens[: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.tokens) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> (</span>
<span id="cb9-22">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.block_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.number_processes</span>
<span id="cb9-23">        )</span>
<span id="cb9-24">            <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.batch_sz</span>
<span id="cb9-25">            <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.block_sz</span>
<span id="cb9-26">            <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.number_processes]</span>
<span id="cb9-27">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Loaded </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.tokens)<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> tokens"</span>)</span>
<span id="cb9-28">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"1 epoch = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>)<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> batches"</span>)</span>
<span id="cb9-29">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.current_pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> block_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> process_rank</span>
<span id="cb9-30"></span>
<span id="cb9-31">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__len__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>):</span>
<span id="cb9-32">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.tokens) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> (<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.block_sz)</span>
<span id="cb9-33"></span>
<span id="cb9-34">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> next_batch(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>):</span>
<span id="cb9-35">        buf <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.tokens[</span>
<span id="cb9-36">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.current_pos : <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.current_pos</span>
<span id="cb9-37">            <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.block_sz</span>
<span id="cb9-38">            <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb9-39">        ]</span>
<span id="cb9-40">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> buf[:<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].view(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.batch_sz, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.block_sz)</span>
<span id="cb9-41">        y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> buf[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>:].view(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.batch_sz, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.block_sz)</span>
<span id="cb9-42">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Each process will process batch_sz x block_sz tokens in each</span></span>
<span id="cb9-43">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## iteration -&gt; with number_processes processes, total tokens processed</span></span>
<span id="cb9-44">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## in each iteration is batch_sz x block_sz x number_processes. In the</span></span>
<span id="cb9-45">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## case of one process, total tokens would be batch_sz x block_sz</span></span>
<span id="cb9-46">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.current_pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> (</span>
<span id="cb9-47">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.block_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.number_processes</span>
<span id="cb9-48">        )</span>
<span id="cb9-49">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Similar to DataLoader's `drop_last` parameter, we drop the last</span></span>
<span id="cb9-50">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## batch if it's not a multiple of batch_sz x block_sz x number_processes</span></span>
<span id="cb9-51">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## if self.current_pos + (</span></span>
<span id="cb9-52">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##     self.batch_sz * self.block_sz * self.number_processes</span></span>
<span id="cb9-53">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## ) + self.number_processes &gt; len(self):</span></span>
<span id="cb9-54">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.current_pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.tokens):</span>
<span id="cb9-55">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.current_pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb9-56">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> x, y</span></code></pre></div>
</div>
<div id="7913deb5" class="cell">
<div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">###########</span></span>
<span id="cb10-2"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Distributed Data Parallel</span></span>
<span id="cb10-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">###########</span></span>
<span id="cb10-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Distributed Data Parallel let us run the same model (replica) on different GPUs,</span></span>
<span id="cb10-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## where each GPU would work on a different slice of data. After we do the backward</span></span>
<span id="cb10-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## pass, we average the gradients across all processes (GPUs) and synchronize all</span></span>
<span id="cb10-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## parameters across all devices. We use allReduce op to do this and communicate the</span></span>
<span id="cb10-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## updates with all processes.</span></span>
<span id="cb10-9"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Each process would go through the same code from top to bottom not aware there</span></span>
<span id="cb10-10"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## are other processes running the same thing on other devices</span></span>
<span id="cb10-11"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">#</span></span>
<span id="cb10-12"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## torchrun command sets the following environment variables:</span></span>
<span id="cb10-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## RANK: Id of the process in the process group. It is an int 0-WORLD_SIZE</span></span>
<span id="cb10-14"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## LOCAL_RANK: In the case of multi-nodes, LOCAL_RANK is the id of</span></span>
<span id="cb10-15"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##             the process in the same node. Example: If we have a node</span></span>
<span id="cb10-16"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##             with 4 GPUs, the first process will have LOCAL_RANK=0</span></span>
<span id="cb10-17"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##             but RANK of this process mayn't be 0 if we are running</span></span>
<span id="cb10-18"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##             on multiple nodes.</span></span>
<span id="cb10-19"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##             This is useful when we have multiple nodes and we want to</span></span>
<span id="cb10-20"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##             run the processes on different GPUs in the same node.</span></span>
<span id="cb10-21"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##             In this case, we can set the LOCAL_RANK to the GPU id in the</span></span>
<span id="cb10-22"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##             node.</span></span>
<span id="cb10-23"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## WORLD_SIZE: Total number of processes</span></span>
<span id="cb10-24">ddp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(os.getenv(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"RANK"</span>, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!=</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Check if it is a ddp run</span></span>
<span id="cb10-25"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> ddp:</span>
<span id="cb10-26">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## DDP requires CUDA so we need to set the device for each process</span></span>
<span id="cb10-27">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## so only one process can run per device</span></span>
<span id="cb10-28">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">assert</span> torch.cuda.is_available(), <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"DDP requires CUDA"</span></span>
<span id="cb10-29">    init_process_group(backend<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"nccl"</span>)</span>
<span id="cb10-30">    ddp_rank <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(os.getenv(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"RANK"</span>))</span>
<span id="cb10-31">    ddp_local_rank <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(os.getenv(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"LOCAL_RANK"</span>))</span>
<span id="cb10-32">    ddp_world_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>(os.getenv(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"WORLD_SIZE"</span>))</span>
<span id="cb10-33">    device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"cuda:</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>ddp_local_rank<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span></span>
<span id="cb10-34">    torch.cuda.set_device(device)</span>
<span id="cb10-35">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Master process will do more things such as checkpointing and logging</span></span>
<span id="cb10-36">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## while other processes would assist in the computations.</span></span>
<span id="cb10-37">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## It always has RANK=0</span></span>
<span id="cb10-38">    master_process <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ddp_rank <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb10-39"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span>:</span>
<span id="cb10-40">    ddp_rank <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb10-41">    ddp_local_rank <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb10-42">    ddp_world_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb10-43">    master_process <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span></span>
<span id="cb10-44">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> torch.cuda.is_available():</span>
<span id="cb10-45">        device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cuda"</span></span>
<span id="cb10-46">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">elif</span> torch.backends.mps.is_built():  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Apple Silicon</span></span>
<span id="cb10-47">        device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"mps"</span></span>
<span id="cb10-48">        torch.mps.manual_seed(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1337</span>)</span>
<span id="cb10-49">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span>:</span>
<span id="cb10-50">        device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cpu"</span></span>
<span id="cb10-51"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(device)</span>
<span id="cb10-52">torch.manual_seed(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1337</span>)</span>
<span id="cb10-53"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> torch.cuda.is_available():</span>
<span id="cb10-54">    torch.cuda.manual_seed(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1337</span>)</span>
<span id="cb10-55"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##########</span></span>
<span id="cb10-56"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Initialize model and optimizer</span></span>
<span id="cb10-57"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##########</span></span>
<span id="cb10-58"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Everything in GPUs is a power of 2 such as tiling ops</span></span>
<span id="cb10-59"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## So try to always have matrices be power of 2 to improve use of:</span></span>
<span id="cb10-60"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## •    Tensor Cores</span></span>
<span id="cb10-61"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## • Memory coalescing</span></span>
<span id="cb10-62"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## •    Shared memory bank alignment</span></span>
<span id="cb10-63"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## •    Warp scheduling</span></span>
<span id="cb10-64"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Here we change the vocab_sz by rounding it up to the closest</span></span>
<span id="cb10-65"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## number that is power of. This will increase space overhead</span></span>
<span id="cb10-66"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## but would speed up computations</span></span>
<span id="cb10-67">model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> GPT2(GPTConfig(vocab_sz<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50304</span>)).to(device)</span>
<span id="cb10-68"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Speed up model by building static graph that analyzes all ops</span></span>
<span id="cb10-69"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## and optimizes them such as fusing some of them to avoid unnecessary</span></span>
<span id="cb10-70"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## trips to memory</span></span>
<span id="cb10-71"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## model = torch.compile(model)</span></span>
<span id="cb10-72"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> ddp:</span>
<span id="cb10-73">    model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DDP(model, device_ids<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[ddp_local_rank])</span>
<span id="cb10-74">raw_model <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model.module <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> ddp <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span> model</span>
<span id="cb10-75">max_lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">3e-4</span></span>
<span id="cb10-76">min_lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> max_lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span></span>
<span id="cb10-77">warmup_steps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span></span>
<span id="cb10-78">max_steps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50</span></span>
<span id="cb10-79">sched <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> combine_scheds(</span>
<span id="cb10-80">    [warmup_steps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> max_steps, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> (warmup_steps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> max_steps)],</span>
<span id="cb10-81">    [lin_sched(min_lr, max_lr), cos_sched(max_lr, min_lr)],</span>
<span id="cb10-82">)</span>
<span id="cb10-83">optimizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> raw_model.configure_optimizer(</span>
<span id="cb10-84">    weight_decay<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>, lr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>max_lr, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device</span>
<span id="cb10-85">)</span>
<span id="cb10-86"></span>
<span id="cb10-87"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##########</span></span>
<span id="cb10-88"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Run training loop</span></span>
<span id="cb10-89"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">#########</span></span>
<span id="cb10-90"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## </span><span class="al" style="color: #AD0000;
background-color: null;
font-style: inherit;">NOTE</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">: In order to run 0.5M (from GPT3 paper) tokens per fwd/bwd iteration,</span></span>
<span id="cb10-91"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## we need to ## use gradient accumulation because we can't fit it in almost</span></span>
<span id="cb10-92"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## any commodity ## GPU -&gt; We only do backward after we loop through ~0.5M tokens.</span></span>
<span id="cb10-93">total_batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">19</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## closest number to 0.5M</span></span>
<span id="cb10-94"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">assert</span> total_batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> (GPTConfig.batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> GPTConfig.block_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ddp_world_size) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"total batch size must be divisible by micro batch_sz x block_sz x ddp_world_size"</span></span>
<span id="cb10-95">grad_accum_steps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> total_batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> (</span>
<span id="cb10-96">    GPTConfig.batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> GPTConfig.block_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ddp_world_size</span>
<span id="cb10-97">)</span>
<span id="cb10-98"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> master_process:</span>
<span id="cb10-99">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Total desired batch size: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>total_batch_sz<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb10-100">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Calculated gradient accumulation steps: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>grad_accum_steps<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb10-101"></span>
<span id="cb10-102">train_dl <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DataLoaderLight(</span>
<span id="cb10-103">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"tinyshakespeare.txt"</span>,</span>
<span id="cb10-104">    batch_sz<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>GPTConfig.batch_sz,</span>
<span id="cb10-105">    block_sz<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>GPTConfig.block_sz,</span>
<span id="cb10-106">    process_rank<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>ddp_rank,</span>
<span id="cb10-107">    number_processes<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>ddp_world_size</span>
<span id="cb10-108">)</span>
<span id="cb10-109"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Pytorch will use TensorFloat32 if available, else use FP32</span></span>
<span id="cb10-110"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## But the weights will still be stored using FP32 with less precision</span></span>
<span id="cb10-111"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## (10 bits for mantissa instead of 23). It is just the</span></span>
<span id="cb10-112"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## operations would be executed as TF32 if available</span></span>
<span id="cb10-113">torch.set_float32_matmul_precision(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"high"</span>)</span>
<span id="cb10-114"></span>
<span id="cb10-115"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> step <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(max_steps):</span>
<span id="cb10-116">    start <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> time.time()</span>
<span id="cb10-117">    x, y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> train_dl.next_batch()</span>
<span id="cb10-118">    x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.to(device)</span>
<span id="cb10-119">    y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> y.to(device)</span>
<span id="cb10-120">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## code.interact(local=locals())</span></span>
<span id="cb10-121">    optimizer.zero_grad()</span>
<span id="cb10-122">    loss_accum <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.0</span></span>
<span id="cb10-123">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> macro_step <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(grad_accum_steps):</span>
<span id="cb10-124">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> device <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"cuda"</span>:</span>
<span id="cb10-125">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Tensors that will be greatly affected by less precission such</span></span>
<span id="cb10-126">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## as loss, layernorm would still be in FP32 while others such</span></span>
<span id="cb10-127">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## as attention weights would be in BF16</span></span>
<span id="cb10-128">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">with</span> torch.autocast(device_type<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>device, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>torch.bfloat16):</span>
<span id="cb10-129">                logits, loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(x, y)</span>
<span id="cb10-130">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span>:</span>
<span id="cb10-131">            logits, loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> model(x, y)</span>
<span id="cb10-132">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Just accumulating gradients yield to summation of objective but</span></span>
<span id="cb10-133">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## we want mean so we weight each loss by 1/grad_accum_steps</span></span>
<span id="cb10-134">        loss <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/=</span> grad_accum_steps</span>
<span id="cb10-135">        loss_accum <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> loss.detach()</span>
<span id="cb10-136">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## To avoid syncing the gradients between the processes after every</span></span>
<span id="cb10-137">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## macro step, we disable it and only allows the sync up of</span></span>
<span id="cb10-138">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## gradients after we finish all gradient accumulation in each</span></span>
<span id="cb10-139">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## process</span></span>
<span id="cb10-140">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> ddp:</span>
<span id="cb10-141">            model.require_backward_grad_sync <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (</span>
<span id="cb10-142">                macro_step <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> grad_accum_steps <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb10-143">            )</span>
<span id="cb10-144">        loss.backward()</span>
<span id="cb10-145">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Each process would have its own loss_accum tensor, so to get the</span></span>
<span id="cb10-146">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## average loss_accum across all processes, we want to compute the</span></span>
<span id="cb10-147">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## average of all loss_accum in all processes</span></span>
<span id="cb10-148">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> ddp:</span>
<span id="cb10-149">        dist.all_reduce(loss_accum, op<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>dist.ReduceOp.AVG)</span>
<span id="cb10-150">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Clips gradient to global norm. It is very useful to avoid having a</span></span>
<span id="cb10-151">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## very high loss for some (bad) batch(es) that would have very high</span></span>
<span id="cb10-152">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## loss ## which would lead to high gradients and huge updates</span></span>
<span id="cb10-153">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## </span><span class="al" style="color: #AD0000;
background-color: null;
font-style: inherit;">NOTE</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">: In the beginning of training it is normal to have high norms</span></span>
<span id="cb10-154">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## as the ## model initialized randomly</span></span>
<span id="cb10-155">    norm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.utils.clip_grad_norm_(model.parameters(), <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.0</span>)</span>
<span id="cb10-156"></span>
<span id="cb10-157">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## </span><span class="al" style="color: #AD0000;
background-color: null;
font-style: inherit;">TODO</span><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">: Use ParamScheduler from `cmn_ai`</span></span>
<span id="cb10-158">    lr <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> sched(step <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> max_steps)</span>
<span id="cb10-159">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> pg <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> optimizer.param_groups:</span>
<span id="cb10-160">        pg[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"lr"</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lr</span>
<span id="cb10-161">    optimizer.step()</span>
<span id="cb10-162">    end <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> time.time()</span>
<span id="cb10-163">    elapsed_time <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> end <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> start</span>
<span id="cb10-164">    token_per_sec <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (</span>
<span id="cb10-165">        GPTConfig.batch_sz</span>
<span id="cb10-166">        <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> GPTConfig.block_sz</span>
<span id="cb10-167">        <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> grad_accum_steps</span>
<span id="cb10-168">        <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ddp_world_size</span>
<span id="cb10-169">    ) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (elapsed_time)</span>
<span id="cb10-170">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(</span>
<span id="cb10-171">        <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"step </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>step<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, loss: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>loss<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>item()<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, lr </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>lr<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4e}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, norm: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>norm<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.2f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">, time: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>elapsed_time<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.2f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">s, tok/sec: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>token_per_sec<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.2f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span></span>
<span id="cb10-172">    )</span>
<span id="cb10-173"></span>
<span id="cb10-174"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> ddp:</span>
<span id="cb10-175">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Kills all processes</span></span>
<span id="cb10-176">    destroy_process_group()</span></code></pre></div>
</div>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion">Conclusion</h2>
<p>We’ve come a long way in this journey—from implementing the core transformer architecture with multi-head attention and positional encodings, to building an efficient training pipeline complete with modern optimizations like flash attention, mixed precision training, and distributed parallelism. We’ve debugged exploding gradients, optimized memory usage, and watched our model evolve from producing random gibberish to generating coherent text. Along the way, we’ve gained deep insights into why each component exists and how they work together to create these remarkable language models.</p>
<p>I hope this deep dive has been as illuminating for you as it has been for me. Writing this implementation forced me to confront gaps in my own understanding and solidified concepts that previously felt abstract. There’s something uniquely satisfying about seeing your hand-built transformer successfully predict its first coherent sentence—a moment where theory truly becomes understanding.</p>
<p>If you’ve made it this far, thank you for joining me on this journey. I’d love to hear about your experiences implementing transformers, any bugs you’ve encountered, optimizations you’ve discovered, or questions this post might have raised. Feel free to reach out with feedback, corrections, or insights—the best part of sharing these implementations is learning from the community’s collective wisdom. Happy building!</p>
</section>
<section id="resources" class="level2">
<h2 class="anchored" data-anchor-id="resources">Resources</h2>
<ul>
<li><a href="https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf">GPT2: Language Models are Unsupervised Multitask Learners</a></li>
<li><a href="http://arxiv.org/abs/2005.14165">GPT3: Language Models are Few-Shot Learners</a></li>
<li><a href="https://arxiv.org/abs/1706.03762">Attention is All You Need</a></li>
<li><a href="https://arxiv.org/abs/1608.05859v3">Using the Output Embedding to Improve Language Models</a></li>
<li><a href="https://arxiv.org/abs/2205.14135">FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness</a></li>
</ul>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>NLP</category>
  <guid>https://imaddabbura.github.io/posts/nlp/GPT2-From-Scratch.html</guid>
  <pubDate>Wed, 10 Apr 2024 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/nlp/images/gpt2.jpeg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>Byte Pair Encoding from Scratch</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/nlp/BPE-Tokenizer.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="why-tokenization-matters" class="level2">
<h2 class="anchored" data-anchor-id="why-tokenization-matters">Why Tokenization Matters</h2>
<p>When you type “unhappiness” into ChatGPT, the model doesn’t see the word “unhappiness.” It sees something like <code>["un", "happ", "iness"]</code> — three <strong>tokens</strong> that were chosen by an algorithm months before the model was even trained. That algorithm decided, based on statistics from a massive training corpus, that these three pieces are the right granularity. Not individual characters (too many tokens, too little meaning per token). Not whole words (too many unique words, no way to handle words never seen in training). Subwords — the sweet spot.</p>
<p>This isn’t a minor preprocessing detail. Tokenization defines <strong>what the model can see</strong>. Consider three strategies on the same sentence:</p>
<table class="table">
<colgroup>
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th>Strategy</th>
<th>“The cat sat unhappily” becomes</th>
<th>Tokens</th>
<th>Vocab Size</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Character-level</strong></td>
<td><code>["T","h","e"," ","c","a","t"," ","s","a","t"," ","u","n","h","a","p","p","i","l","y"]</code></td>
<td>21</td>
<td>~256</td>
</tr>
<tr class="even">
<td><strong>Word-level</strong></td>
<td><code>["The", "cat", "sat", "unhappily"]</code></td>
<td>4</td>
<td>100,000+</td>
</tr>
<tr class="odd">
<td><strong>Subword (BPE)</strong></td>
<td><code>["The", " cat", " sat", " un", "happ", "ily"]</code></td>
<td>6</td>
<td>~50,000</td>
</tr>
</tbody>
</table>
<p>With characters, a fixed context window of 2048 tokens covers ~400 words. With subwords, the same window covers ~1500 words — nearly 4× more context for the model to reason over. Word-level is compact but brittle: “unhappily” might never appear in training data, making it an <code>&lt;UNK&gt;</code> token the model is completely blind to. But “un”, “happ”, and “ily” almost certainly do appear — the model can compose meaning from pieces it knows.</p>
<p>The algorithm that learns these splits is <strong>Byte Pair Encoding (BPE)</strong> — originally a data compression technique (<a href="https://www.derczynski.com/papers/archive/BPE_Gage.pdf">Gage, 1994</a>), adapted for NLP by <a href="https://arxiv.org/abs/1508.07909">Sennrich et al.&nbsp;(2016)</a>, and now used in GPT-2, GPT-3/4, LLaMA, and most modern language models. In this post, we’ll understand how it works, implement it from scratch, and see how GPT-2 refined the basic algorithm for production.</p>
</section>
<section id="how-bpe-works" class="level2">
<h2 class="anchored" data-anchor-id="how-bpe-works">How BPE Works</h2>
<p>The core insight is simple: <strong>if two symbols frequently appear next to each other, they probably belong together.</strong> Merge them into a single token, then look for the next most frequent pair, and repeat. It’s exactly how you’d compress a text file — find repeated patterns and replace them with shorter symbols. Frequent patterns get absorbed into single tokens; rare patterns stay as smaller pieces.</p>
<p>Think of it like learning abbreviations. If you keep writing “machine learning” in your notes, you’d eventually start writing “ML.” Then if “ML model” keeps appearing, maybe you’d abbreviate that too. BPE does the same thing, but systematically and bottom-up — starting from the smallest units (bytes) and building up to subwords.</p>
<section id="seeing-it-in-action" class="level3">
<h3 class="anchored" data-anchor-id="seeing-it-in-action">Seeing It in Action</h3>
<p>Before formalizing the algorithm, let’s watch it work on a real example. Consider a tiny training corpus containing the words <code>"low lower lowest"</code>:</p>
<table class="table">
<thead>
<tr class="header">
<th>Step</th>
<th>Token Sequence</th>
<th>Most Frequent Pair</th>
<th>New Token</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Start</td>
<td><code>l o w _ l o w e r _ l o w e s t</code></td>
<td>—</td>
<td>—</td>
</tr>
<tr class="even">
<td>Merge 1</td>
<td><code>lo w _ lo w e r _ lo w e s t</code></td>
<td><code>(l, o)</code> → <code>lo</code></td>
<td>3×</td>
</tr>
<tr class="odd">
<td>Merge 2</td>
<td><code>low _ low e r _ low e s t</code></td>
<td><code>(lo, w)</code> → <code>low</code></td>
<td>3×</td>
</tr>
<tr class="even">
<td>Merge 3</td>
<td><code>low _ lowe r _ lowe s t</code></td>
<td><code>(low, e)</code> → <code>lowe</code></td>
<td>2×</td>
</tr>
</tbody>
</table>
<p>BPE discovered that <code>l</code> and <code>o</code> always appear together, then that <code>lo</code> and <code>w</code> always appear together, building up <code>low</code> as a token — effectively learning the word stem. Then it found <code>lowe</code> as a shared prefix of “lower” and “lowest.” Without any linguistic rules, purely from frequency, BPE learned morphological structure.</p>
<p>Notice what happened in merge 2: the algorithm merged <code>lo</code> with <code>w</code>, where <code>lo</code> itself was created in merge 1. BPE builds tokens <strong>hierarchically</strong> — later merges compose earlier ones. We can visualize this as a tree:</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph BT
    l["l (byte)"] --&gt; lo["lo (merge 1)"]
    o["o (byte)"] --&gt; lo
    lo --&gt; low["low (merge 2)"]
    w["w (byte)"] --&gt; low
    low --&gt; lowe["lowe (merge 3)"]
    e["e (byte)"] --&gt; lowe

    style l fill:#f0f0f0,stroke:#999
    style o fill:#f0f0f0,stroke:#999
    style w fill:#f0f0f0,stroke:#999
    style e fill:#f0f0f0,stroke:#999
    style lo fill:#d4e6f1,stroke:#2980b9
    style low fill:#aed6f1,stroke:#2471a3
    style lowe fill:#85c1e9,stroke:#1a5276
</pre>
</div>
<p></p><figcaption> BPE merges build tokens bottom-up. Each merge composes two existing tokens into a new one, forming a hierarchy from bytes to subwords.</figcaption> </figure><p></p>
</div>
</div>
</div>
<p>Each level in the tree depends on the levels below it. This is why merge order matters during encoding — you can’t build <code>low</code> until <code>lo</code> exists.</p>
</section>
<section id="the-algorithm" class="level3">
<h3 class="anchored" data-anchor-id="the-algorithm">The Algorithm</h3>
<p>With the intuition in place, here’s the formal procedure:</p>
<ol type="1">
<li><p><strong>Initialize</strong>: Start with a base vocabulary of all 256 byte values (0–255). Every string can be represented as bytes, so this guarantees full coverage — no <code>&lt;UNK&gt;</code> tokens, ever.</p></li>
<li><p><strong>Count pairs</strong>: Scan the corpus and count every adjacent pair of tokens.</p></li>
<li><p><strong>Merge the most frequent pair</strong>: Create a new token for it and replace all occurrences in the corpus.</p></li>
<li><p><strong>Repeat</strong> steps 2–3 until you’ve done <code>vocab_size - 256</code> merges.</p></li>
</ol>
<p>The output is a <strong>merge table</strong>: an ordered list of pair → token mappings. This table <em>is</em> the tokenizer.</p>
</section>
<section id="training-vs.-encoding-a-subtle-difference" class="level3">
<h3 class="anchored" data-anchor-id="training-vs.-encoding-a-subtle-difference">Training vs.&nbsp;Encoding: A Subtle Difference</h3>
<p>There’s an important asymmetry between how BPE <em>learns</em> merges (training) and how it <em>applies</em> them to new text (encoding).</p>
<p>During <strong>training</strong>, we always merge the globally most <em>frequent</em> pair — that’s how we decide which merges to learn. But during <strong>encoding</strong>, we apply merges in the <em>order they were learned</em>, not by their frequency in the new text.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Encoding Replays Merges — It Doesn’t Re-learn Them
</div>
</div>
<div class="callout-body-container callout-body">
<p>A common misconception is that encoding finds the most frequent pair in the new text and merges it. It doesn’t. Encoding applies the <em>training-time</em> merges in their original order. The token <code>low</code> only exists after <code>lo</code> has been created (merge 1). If we tried to merge <code>(lo, w)</code> before creating <code>lo</code>, we’d never find the pair. In the implementation, this shows up as <code>min(stats, key=lambda p: self.merges.get(p, float("inf")))</code> — picking the pair with the <em>lowest merge index</em>, not the highest frequency.</p>
</div>
</div>
</section>
<section id="why-bytes-not-characters" class="level3">
<h3 class="anchored" data-anchor-id="why-bytes-not-characters">Why Bytes, Not Characters?</h3>
<p>Starting from bytes (0–255) rather than Unicode code points is a practical decision. Unicode has over 150,000 code points — that’s an impractically large base vocabulary. By working at the byte level, we start with just 256 symbols and can represent <em>any</em> string in <em>any</em> language or script. BPE merges then learn to compose bytes into characters, characters into subwords, and subwords into common words — all driven by frequency in the training data.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Multilingual Tax
</div>
</div>
<div class="callout-body-container callout-body">
<p>Languages underrepresented in training data pay a tokenization tax. English “hello” might be one token, but the same greeting in a low-resource language could take 3–4 tokens because the byte sequences were never frequent enough to merge. This means the model burns more of its context window on the same content — a well-documented source of multilingual inefficiency (<a href="https://arxiv.org/abs/2311.09071">Petrov et al., 2023</a>). It also means the model takes more compute per word for these languages, making inference more expensive.</p>
</div>
</div>
</section>
<section id="vocabulary-size-a-key-hyperparameter" class="level3">
<h3 class="anchored" data-anchor-id="vocabulary-size-a-key-hyperparameter">Vocabulary Size: A Key Hyperparameter</h3>
<p>How many merges should we do? This is the vocabulary size, and it’s a meaningful trade-off:</p>
<table class="table">
<colgroup>
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th>Vocab Size</th>
<th>Tokens per Text</th>
<th>Embedding Table</th>
<th>Character</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Small</strong> (~1k)</td>
<td>Many — close to character-level</td>
<td>Tiny</td>
<td>Better generalization on rare words, but sequences are long and training is slow</td>
</tr>
<tr class="even">
<td><strong>Medium</strong> (~32k–50k)</td>
<td>Moderate — good compression</td>
<td>Manageable</td>
<td>The sweet spot for most models (GPT-2: 50k, LLaMA: 32k)</td>
</tr>
<tr class="odd">
<td><strong>Large</strong> (~100k+)</td>
<td>Few — common phrases become single tokens</td>
<td>Very large</td>
<td>Risk of overfitting to training distribution; rare tokens get poorly trained embeddings</td>
</tr>
</tbody>
</table>
<p>Larger vocabularies mean each token carries more information, so sequences are shorter and the model sees more context per forward pass. But each token also needs an embedding vector, so the embedding table grows linearly. And tokens that appear rarely in training will have poorly learned embeddings — they’ve simply not been seen enough times.</p>
<p>Most modern models settle in the 32k–100k range. GPT-2 uses ~50k tokens. LLaMA uses 32k. GPT-4 reportedly uses ~100k. The right size depends on the training data, the target languages, and the compute budget.</p>
<div id="660e14ce-2270-4cf4-9078-36759f24db29" class="cell" data-execution_count="2">
<div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## | echo: false</span></span>
<span id="cb1-2"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## %load_ext lab_black</span></span></code></pre></div>
</div>
</section>
</section>
<section id="implementation" class="level2">
<h2 class="anchored" data-anchor-id="implementation">Implementation</h2>
<p>Let’s turn the algorithm into code. The <code>BPETokenizer</code> class below has four core methods, each mapping directly to a step we’ve discussed:</p>
<ul>
<li><strong><code>train</code></strong>: The learning loop — encode the corpus to bytes, then greedily merge the most frequent pair <code>vocab_size - 256</code> times. Each merge is recorded in <code>self.merges</code> as a <code>(pair) → index</code> mapping. This ordered dictionary <em>is</em> the tokenizer.</li>
<li><strong><code>encode</code></strong>: The encoding step — convert new text to bytes, then apply merges in <em>learned order</em> (earliest first, using the <code>min</code> trick we discussed). This is where training-order matters: we pick the pair with the smallest merge index, not the most frequent.</li>
<li><strong><code>decode</code></strong>: The inverse — look up each token ID in the vocabulary to get its byte sequence, concatenate, and decode back to a string.</li>
<li><strong><code>_get_stats</code> / <code>_merge</code></strong>: Helpers that count adjacent pairs and replace a specific pair with its merged token throughout a sequence.</li>
</ul>
<p>One implementation detail: <code>_build_vocab</code> relies on Python 3.7+ dictionary insertion order. Since merges are inserted chronologically, iterating <code>self.merges</code> replays them in order — each merged token is the byte-concatenation of its two parents, which must already exist in the vocabulary.</p>
<div id="1517b8b6-f809-44a3-a368-2ecf7073c0d5" class="cell" data-execution_count="23">
<div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## | code-fold: true</span></span>
<span id="cb2-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> typing <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> Iterable</span>
<span id="cb2-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> requests</span></code></pre></div>
</div>
<div id="a672ce35-a11c-4452-92fa-09b54198aa31" class="cell" data-execution_count="24">
<div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> BPETokenizer:</span>
<span id="cb3-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Byte-pair encoder."""</span></span>
<span id="cb3-3"></span>
<span id="cb3-4">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, vocab_sz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>):</span>
<span id="cb3-5">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""</span></span>
<span id="cb3-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        Args:</span></span>
<span id="cb3-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">            vocab_sz (int): Vocabulary size.</span></span>
<span id="cb3-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        """</span></span>
<span id="cb3-9">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.vocab_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> vocab_sz</span>
<span id="cb3-10">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.vocab <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb3-11">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.merges <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb3-12"></span>
<span id="cb3-13">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> train(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, text: Iterable[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>]):</span>
<span id="cb3-14">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Train Byte-pair encoder."""</span></span>
<span id="cb3-15">        ids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(text.encode(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"utf-8"</span>))</span>
<span id="cb3-16">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> idx <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.vocab_sz):</span>
<span id="cb3-17">            stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._get_stats(ids)</span>
<span id="cb3-18">            pair <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(stats, key<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>stats.get)</span>
<span id="cb3-19">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.merges[pair] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> idx</span>
<span id="cb3-20">            ids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._merge(ids, pair, idx)</span>
<span id="cb3-21">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.vocab <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._build_vocab(ids)</span>
<span id="cb3-22"></span>
<span id="cb3-23">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> encode(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, text):</span>
<span id="cb3-24">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Encode string to bytes using vocabulary built during training."""</span></span>
<span id="cb3-25">        ids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(text.encode(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"utf-8"</span>))</span>
<span id="cb3-26"></span>
<span id="cb3-27">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## If text is empty or has one character -&gt; it is already encoded from previous step</span></span>
<span id="cb3-28">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">while</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(ids) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>:</span>
<span id="cb3-29">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## stats is used only for getting pairs next to each other</span></span>
<span id="cb3-30">            stats <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._get_stats(ids)</span>
<span id="cb3-31">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Because we built vocab (and merges) bottom-up, we need to encode</span></span>
<span id="cb3-32">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## idx from smallest index because some later pairs depend on pairs</span></span>
<span id="cb3-33">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## occured before</span></span>
<span id="cb3-34">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## If a pair doesn't exist, it wouldn't participate in the list</span></span>
<span id="cb3-35">            pair <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>(stats, key<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">lambda</span> p: <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.merges.get(p, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"inf"</span>)))</span>
<span id="cb3-36">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> pair <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">not</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.merges:</span>
<span id="cb3-37">                <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">break</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## No more pairs to merge</span></span>
<span id="cb3-38">            idx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.merges[pair]</span>
<span id="cb3-39">            ids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>._merge(ids, pair, idx)</span>
<span id="cb3-40">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> ids</span>
<span id="cb3-41"></span>
<span id="cb3-42">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> decode(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, tokens: Iterable[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>]):</span>
<span id="cb3-43">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Decode tokens into string using the vocabulary built during training."""</span></span>
<span id="cb3-44">        tokens <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">b""</span>.join(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.vocab[idx] <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> idx <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> tokens)</span>
<span id="cb3-45">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## It is important to replace tokens that were not seen during training</span></span>
<span id="cb3-46">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## with `?`; otherwise, it would fail</span></span>
<span id="cb3-47">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> tokens.decode(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"utf-8"</span>, errors<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"replace"</span>)</span>
<span id="cb3-48"></span>
<span id="cb3-49">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> _get_stats(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, ids: Iterable[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>]):</span>
<span id="cb3-50">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Get pair counts."""</span></span>
<span id="cb3-51">        counts <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {}</span>
<span id="cb3-52">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> pair <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">zip</span>(ids, ids[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>:]):</span>
<span id="cb3-53">            counts[pair] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> counts.get(pair, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb3-54">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> counts</span>
<span id="cb3-55"></span>
<span id="cb3-56">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> _merge(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, ids: Iterable[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>], pair: Iterable[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>], idx: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>):</span>
<span id="cb3-57">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Merge pairs that match `pair` with new index `idx`."""</span></span>
<span id="cb3-58">        newids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb3-59">        i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb3-60">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">while</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(ids):</span>
<span id="cb3-61">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(ids) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">and</span> pair[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> ids[i] <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">and</span> pair[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> ids[i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]:</span>
<span id="cb3-62">                newids.append(idx)</span>
<span id="cb3-63">                i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span></span>
<span id="cb3-64">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span>:</span>
<span id="cb3-65">                newids.append(ids[i])</span>
<span id="cb3-66">                i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span></span>
<span id="cb3-67">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> newids</span>
<span id="cb3-68"></span>
<span id="cb3-69">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> _build_vocab(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, ids: Iterable[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span>]):</span>
<span id="cb3-70">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Build vocabulary from 0-255 bytes and merges."""</span></span>
<span id="cb3-71">        vocab <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {idx: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">bytes</span>([idx]) <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> idx <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">256</span>)}</span>
<span id="cb3-72">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Here we assume the items returned would be in the same order they were inserted.</span></span>
<span id="cb3-73">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## This is Okay Python 3.7+</span></span>
<span id="cb3-74">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> (p0, p1), idx <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.merges.items():</span>
<span id="cb3-75">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## This would be a concatenation of the bytes</span></span>
<span id="cb3-76">            vocab[idx] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> vocab[p0] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> vocab[p1]</span>
<span id="cb3-77">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> vocab</span></code></pre></div>
</div>
<div id="a26b9c11-cf65-45ca-a03a-7f2acef56f13" class="cell" data-execution_count="25">
<div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1">text <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> requests.get(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"https://docs.python.org/3/library/stdtypes.html#bytes.decode"</span>).text</span></code></pre></div>
</div>
<div id="4368c372-793e-4ca7-be14-c8e62b9c9ca9" class="cell" data-execution_count="26">
<div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">tokenizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> BPETokenizer(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">300</span>)</span></code></pre></div>
</div>
<div id="6a4e7562-ca4c-4e81-907c-70479b2448ce" class="cell" data-execution_count="27">
<div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1">tokenizer.train(text)</span></code></pre></div>
</div>
<div id="3ba82c99-8c22-4c9e-be22-717b6be33ab3" class="cell" data-execution_count="28">
<div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1">tokenizer.decode(tokenizer.encode(text)) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> text</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="28">
<pre><code>True</code></pre>
</div>
</div>
</section>
<section id="from-vanilla-bpe-to-gpt-2s-tokenizer" class="level2">
<h2 class="anchored" data-anchor-id="from-vanilla-bpe-to-gpt-2s-tokenizer">From Vanilla BPE to GPT-2’s Tokenizer</h2>
<p>The implementation above is vanilla byte-level BPE — it works, but it has a practical problem. Because merges are purely frequency-driven, the algorithm doesn’t respect word boundaries. The word “play” might appear in the corpus as “play.”, “play!”, “play,”, and “play” — and BPE will learn separate tokens for each variant, wasting vocabulary slots on what is essentially the same word with different punctuation.</p>
<p>GPT-2 (<a href="https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf">Radford et al., 2019</a>) introduced a key refinement: <strong>pre-tokenization with a regex pattern</strong> that splits text into chunks <em>before</em> BPE runs. The regex prevents merges from crossing certain boundaries — letters can’t merge with digits, punctuation stays separate from words, and spaces attach to the <em>beginning</em> of words rather than the end.</p>
<p>The GPT-2 regex pattern:</p>
<pre><code>'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+</code></pre>
<p>This ensures that:</p>
<ul>
<li><strong>Contractions</strong> are split cleanly: “don’t” → <code>["don", "'t"]</code></li>
<li><strong>Spaces attach to the next word</strong>: ” hello” stays together, preserving word boundaries</li>
<li><strong>Punctuation stays isolated</strong>: “play!” → <code>["play", "!"]</code> instead of learning “play!” as one token</li>
<li><strong>Digits don’t merge with letters</strong>: “h3llo” → <code>["h", "3", "llo"]</code></li>
</ul>
<p>BPE then runs <em>within</em> each chunk independently. The result: a much cleaner vocabulary where tokens correspond to linguistically meaningful units rather than artifacts of adjacent punctuation.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Try It Yourself
</div>
</div>
<div class="callout-body-container callout-body">
<p>Use <a href="https://tiktokenizer.vercel.app">Tiktokenizer</a> to see how GPT-2 and GPT-4 tokenize arbitrary text. Try pasting the same sentence in English and another language — you’ll immediately see the multilingual tokenization tax in action: the non-English version will use significantly more tokens for the same meaning.</p>
</div>
</div>
<p>This pre-tokenization pattern has been refined in later models. GPT-4 uses a <a href="https://github.com/openai/tiktoken">more sophisticated pattern</a> that handles apostrophes, numbers, and whitespace more carefully, and also limits the length of digit sequences to avoid learning overly specific number tokens. The core idea remains the same: constrain where merges can happen to produce a more useful vocabulary.</p>
</section>
<section id="references-resources" class="level2">
<h2 class="anchored" data-anchor-id="references-resources">References &amp; Resources</h2>
<ul>
<li><strong>Gage, P.</strong> (1994). <a href="https://www.derczynski.com/papers/archive/BPE_Gage.pdf">A New Algorithm for Data Compression</a>. <em>The C Users Journal</em>. The original BPE paper — a compression algorithm that found new life in NLP.</li>
<li><strong>Sennrich, R. et al.</strong> (2016). <a href="https://arxiv.org/abs/1508.07909">Neural Machine Translation of Rare Words with Subword Units</a>. <em>ACL 2016</em>. The paper that adapted BPE for NLP tokenization.</li>
<li><strong>Radford, A. et al.</strong> (2019). <a href="https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf">Language Models are Unsupervised Multitask Learners</a>. The GPT-2 paper that introduced byte-level BPE with regex pre-tokenization.</li>
<li><strong>Karpathy, A.</strong> (2024). <a href="https://www.youtube.com/watch?v=zduSFxRajkE">Let’s build the GPT Tokenizer</a>. Excellent video walkthrough of building a BPE tokenizer from scratch.</li>
<li><a href="https://www.reedbeta.com/blog/programmers-intro-to-unicode/">A Programmer’s Introduction to Unicode</a> — why bytes vs.&nbsp;code points matters.</li>
<li><a href="https://utf8everywhere.org/">UTF-8 Everywhere</a> — the case for UTF-8 as the universal encoding.</li>
<li><a href="https://tiktokenizer.vercel.app">Tiktokenizer</a> — interactive web app to visualize how different tokenizers split text.</li>
</ul>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>NLP</category>
  <guid>https://imaddabbura.github.io/posts/nlp/BPE-Tokenizer.html</guid>
  <pubDate>Wed, 10 Apr 2024 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/nlp/images/bpe-tokenizer.jpg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>The RAG Optimization Playbook</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/nlp/improving-rag.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p>RAG-based applications have so many components and moving parts that seems impossible to optimize or know where to start. Add to that the fact that the field changes so fast, which makes it super hard to keep up. So I’ve gathered few ideas over time to improve RAG-based applications from reading research papers and implementations I’ve deployed in the past.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>The list will keep changing as I learn/implement new things</p>
</div>
</div>
</section>
<section id="ideas" class="level2">
<h2 class="anchored" data-anchor-id="ideas">Ideas</h2>
<ul>
<li>Metadata filtering is key for good RAG apps</li>
<li>For an MVP, it is a good idea to use both bi-encoders and full-text search such as Tf-Idf and BM25 and combine them</li>
<li>ColBERT reranker is great and less sensitive to chunking</li>
<li>Have at least 50 chars overlapping in chunks when splitting to not cut-off context</li>
<li>If you have data and compute, always fine-tune both encoders if you can</li>
<li>Use <code>sentence-transformer</code> to fine-tune embedding models
<ul>
<li>We typically use triplet loss where for each query we would have positive and negative examples. We want the negative examples to be hard negatives -&gt; very close to positive examples so the model can learn to differentiate between them</li>
</ul></li>
<li>Large/New LLM not necessarily are good embedding models and mayn’t be worth the latency. LLMs with ~1B is enough for most cases</li>
<li>Challenges with embedding models:
<ul>
<li>Mayn’t transfer to your domain</li>
<li>Fixed vocabulary used when model was trained</li>
<li>Because chunk/doc is represented in one vector which is combination of all tokens in the chunk/doc, the output vector may dilute the meaning especially for long texts -&gt; Be careful about chunking strategy</li>
</ul></li>
<li>Always start with a baseline such as BM25 (Best Match 25)</li>
<li>Build your own gold dataset and check its correlation with synthetic dataset generated from LLMs</li>
<li>Chunking beyond 256 tokens will affect high precision search because it will dilute the vector representation because embedding models were not trained on long contexts such as BERT-based encoders</li>
<li>Feedback of how users are liking the app is key to guide us where we should focus our efforts to improve the app
<ul>
<li>Satisfaction ratings such as “How did we do today” or “Did we answer your question”</li>
</ul></li>
<li>Monitor <code>cosine</code> similarity between embeddings of query and retrieved docs and reranking scores that come from reranker (cohere)</li>
<li>Use clustering of questions using tools such as LDA or BERT-Topic to cluster questions into topics and focus on largest topics (by count) that have lowest means of cosine and feedback</li>
<li>We have two kinds of topics:
<ul>
<li>Content topics: Topics that we don’t have inventory of documents about such topics</li>
<li>Capability topics: Topics that reader will never be able to generate if we don’t capture them in our docs and docs metadata and include them in the prompt. For example, “Who last updated the pricing document” is asking about last modified date/person</li>
</ul></li>
<li>Build classifier to classify questions real-time for better observability and better react to sudden changes in usage</li>
<li>Generate synthetic data (questions) for topics we’re not doing great job at and evaluate new improvements on the generated questions
<ul>
<li>This can be done by providing random chunk from docs that belong to topics we’re trying to improve to decent LLM and ask to generate questions</li>
</ul></li>
<li>We can use LLM to get metadata about docs/objects</li>
<li>Lancedb is a good vector database to use for small/scale workloads</li>
<li>BM25 (full text search) outperforms similarity search when questions are just searching for file names. They may have similar performance with similarity search baseline.
<ul>
<li>It is always helpful to include BM25</li>
</ul></li>
<li>We can do citation through prompting and attaching IDs to chunks</li>
<li>Fine-tune embedding model is key for domain-specific RAGs
<ul>
<li>With recent increase in context window sizes, chunk size of 800 and 30% overlap is recommended</li>
</ul></li>
</ul>
</section>
<section id="resources" class="level2">
<h2 class="anchored" data-anchor-id="resources">Resources</h2>
<ul>
<li><a href="https://arxiv.org/abs/2309.10621">Large language models can accurately predict searcher preferences</a></li>
<li><a href="https://arxiv.org/abs/2406.06519">UMBRELA: UMbrela is the (Open-Source Reproduction of the) Bing RELevance Assessor</a></li>
</ul>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>NLP</category>
  <guid>https://imaddabbura.github.io/posts/nlp/improving-rag.html</guid>
  <pubDate>Tue, 05 Mar 2024 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/nlp/images/rag.jpeg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>Inside Python’s Modules and Packages: The Machinery Behind import</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/python/Modules-And-Packages.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge growing">growing</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p>Python’s simplicity and versatility are largely attributed to its extensive ecosystem of modules and packages. These essential components enable developers to write clean, reusable, and efficient code, whether for simple scripts or complex applications.</p>
<p>This article aims to deepen our understanding of Python’s modules and packages and the machinery involved, helping me become a more effective Python programmer. We will explore their structure and functionality, covering everything from the basics of importing to creating custom packages and managing dependencies. By unpacking the underlying machinery of how modules/packages get imported and what they really are, we’ll gain insights that will enhance our coding practices and project organization.</p>
<ul>
<li>Python has only one type of module object regardless of the language the module was implemented it (C/Python/…)</li>
<li>Package provides a naming hierarchy to organize modules (same analogy to directory in Unix file system):
<ul>
<li>All packages are modules but not all modules are packages</li>
<li>A module is a package that has <code>__path__</code></li>
<li>A package can include subpackages (sub-directories)</li>
</ul></li>
<li>There are two types of packages:
<ul>
<li><strong>Regular packages</strong> are directories that have <code>__init__.py</code>. When importing a package/subpackage -&gt; implicitly executes all <code>__init__.py</code> files on the path and bound objects to names in the package’s namespace
<ul>
<li>When import machinery is looking for the package, it stops once it finds it</li>
</ul></li>
<li><strong>Namespace packages</strong> are directories that don’t have <code>__init__.py</code>
<ul>
<li>When import machinery is looking for the package, it does not stop when it finds it and assuming there may be a regular package in some other paths in <code>sys.path</code> but keep a record of all namespace packages it found during the search. If it finds a regular package with that name -&gt; discard all namespace packages it found and import the regular package. If it doesn’t find any regular package with that name -&gt; Use all the namespace packages it found during the search and combine their paths in <code>namespace_path</code> so when we try to import subpackage or modules, it checks all the paths in the namespace_path (which is a list)</li>
<li>There can be multiple packages of the same name (under different directories) -&gt; They are all combined together and the <code>namespace_path</code> list would have the path for all of them. Therefore, the same package can be used to refer to completely different modules in different directories</li>
<li>Python first scans the whole sys.path before deciding that the package is a namespace -&gt; If any name is found with <code>__init__.py</code> in it, it will give this priority and don’t continue.</li>
</ul></li>
</ul></li>
<li>When importing subpackages such as <code>foo.bar.baz</code>, Python first imports foo, then foo.bar, then foo.bar.baz
<ul>
<li>Each of these will be cached in <code>sys.modules</code></li>
</ul></li>
<li><code>__init__.py</code> makes a directory a python package
<ul>
<li>We can use it to import useful stuff from different modules/subpackages so it can be available to user</li>
<li>When importing an object, it has <code>__module__</code> attribute which determines the global environment for the object</li>
<li>We can define <code>__all__</code> in <code>__init__</code> as concatenation of all <code>__all__</code> in modules
<ul>
<li>Example: <code>__all__ = foo.__all__ + bar.__all__</code> BUT we need to either first import foo and bar or import anything from them so they can be defined such as <code>from .foo import *</code> OR <code>from .foo import y</code></li>
</ul></li>
<li><code>__init__</code> can be used to also initialize things and maybe monkeypatch some other modules</li>
</ul></li>
<li>Each package has <code>__path__</code> attribute that helps when searching for subpackages. This will be given to path finder when loading subpackages
<ul>
<li>It is a list; similar to <code>sys.path</code> so we can change it. But it is not recommended</li>
<li>Example: <code>import math; math.__path__.append(~/Documents")</code></li>
</ul></li>
<li>Relative import is preferred to absolute imports inside packages to avoid having issues if the package name is changed</li>
<li>Packages get loaded once even if we import it multiple times</li>
<li>We can in theory upgrade a package in the cache like this:
<ul>
<li><code>sys.modules[new_name] = package_name</code></li>
</ul></li>
<li>If we use <code>python -m package.module</code>, it executes module as the main program and <strong>relative imports</strong> works. Otherwise, relative imports won’t work.
<ul>
<li><code>m</code> stands for module</li>
</ul></li>
<li>The <code>__main__</code> module is a special case relative to Python’s import system. The <code>__main__</code> module is directly initialized at interpreter startup, much like <code>sys</code> and <code>builtins</code>. The manner in which <code>__main__</code> is initialized depends on the flags and other options with which the interpreter is invoked</li>
<li><code>__main__.py</code> designates main for a package/subpackage and also allows package directory to be executable -&gt; explicitly marks the entry point. Examples:
<ul>
<li><code>python package</code> would look for <code>__main__.py</code> to execute it</li>
<li><code>python -m package.subpackage</code> would look for <code>__main__.py</code> inside package/subpackage to execute</li>
<li><code>__package__</code> is set so the relative imports still work</li>
<li>A lot of programming tools utilize this to their own benefit: <code>python -m profile script.y</code> OR <code>python -m pdb script.py</code></li>
<li>NOTE THAT <code>__init__.py</code> files on the path will still be executed</li>
</ul></li>
<li>Depending on how <code>__main__</code> is initialized, <code>__main__.__spec__</code> gets set appropriately or to None.
<ul>
<li>When Python is started with the -m option, <code>__spec__</code> is set to the module spec of the corresponding module or package. <code>__spec__</code> is also populated when the <code>__main__</code> module is loaded as part of executing a directory, zipfile or other sys.path entry.</li>
<li>Otherwise, it will be set to None</li>
</ul></li>
</ul>
<p><img src="https://imaddabbura.github.io/posts/python/images/executable-submodules.png" width="400px"></p>
</section>
<section id="sys.path" class="level2">
<h2 class="anchored" data-anchor-id="sys.path">Sys.path</h2>
<ul>
<li><code>importlib</code> has a rich API to interact with import system. It is preferred over <code>__import__()</code></li>
<li><code>__import__</code> Only does module search and creation without the name binding</li>
<li><code>import</code> Does everything. Module search, creation, and name binding. It calls <code>__import__</code> under the hood</li>
<li><code>.egg</code> files are just directories or .zip files with extra metadata for package managers</li>
<li><code>sys.path</code> is where python looks to search for a module/package (last place) that we try to import. It traverses it from start-to-end
<ul>
<li>It has the name of directorires, .zipfiles, .egg files</li>
<li>first match wins</li>
<li>If it can’t find it -&gt; can not be imported</li>
</ul></li>
<li><code>sys.prefix</code> is where python is stored (<strong>os.py is the landmark</strong>) and <code>sys.exec_prefix</code> is where compiled binaries are stored (<strong>lib-dynload is the landmark</strong>)
<ul>
<li>With virtual environments -&gt; each one has its own sys.prefix</li>
<li>It is constructed from <code>sys.prefix</code>, <code>PYTHONHOME</code>, and <code>site.py</code>. Setting <code>PYTHONHOME</code> would override <code>sys.prefix</code> and <code>sys.exec_prefic</code></li>
<li>Python looks for its libraries starting from where it is and keep going up until the root of the file syetsm. It looks for <code>os.py</code> and use that location as a landmark</li>
<li><code>python -S</code> skips site.py</li>
<li><code>python -vv</code> to see what python tries to do with every statement</li>
<li>Setting PYTHONPATH to some directories will insert them into the beginning of sys.path. Example:
<ul>
<li><code>env PYTHONPATH="/Users/imad/Documents python</code> to run python with documents inserted at the beginning of the sys.apth</li>
</ul></li>
<li><code>site.py</code> appends the path to third-party libraries. This is where installed packages get stored. Example: <code>/usr/local/lib/python3.4/site-packages</code></li>
</ul></li>
<li>Python now have builtin virtual environments that can create one using the <code>venv</code> module
<ul>
<li><code>python -m venv env_name</code> will create new environment called <em>env_name</em></li>
<li>This environment will include few directories such as include, lib, site-packages, bin and pyvenv.cfg</li>
<li>This new environment has no third party libraries or any system wide libraries such as those in /usr/local</li>
<li>All third libraries will be installed in site-packages directory</li>
<li>Python binary will refer to the original Python installation when the environment was created</li>
<li>We can use <code>source path_to_env_name/bin/activate</code> to activate the environment. <code>deactivate</code> to deactivate it. Finally, <code>rm -r path_to_env_name</code> or <code>pyenv --rm</code> if we create it using <strong>poetry</strong></li>
</ul></li>
<li>Files with <code>.pth</code> extension in site-packages directory get added to the sys.path. We can list directories in those files that will be added to sys.path for any new instance of Python
<ul>
<li>Package managers and other third-party packages use this kind of hack to add paths to the sys.path</li>
</ul></li>
<li>sitecustomize and usercustomize also can be used to add stuff to the sys.path</li>
<li>Finally the current working directory will be added to the path (at the beginning)</li>
</ul>
</section>
<section id="modules" class="level2">
<h2 class="anchored" data-anchor-id="modules">Modules</h2>
<ul>
<li>Modules are just objects of type ModuleType. They act like a dictionary that holds references to objects it holds; <code>module.__dict__</code>
<ul>
<li>When importing a module, it executes the module from top to bottom before returning to the caller</li>
<li>Module can be namespace, py file, execution environment for statements or container of global variables</li>
<li>We can set/delete attributes. <code>module.x = 10</code> is the same as <code>module.__dict__['x'] = 10</code></li>
<li>The dictionary has preset attributes such as <code>__path__</code>, <code>__loader__</code> …</li>
<li>Main attributes:
<ul>
<li><code>__name__</code> : ## Module name</li>
<li><code>__file__</code> : ## Associated source file (if any)</li>
<li><code>__doc__</code> : ## Doc string</li>
<li><code>__path__</code> : ## Package path. It is used to look for package subcomponents</li>
<li><code>__package__</code> : ## The module’s <code>__package__</code> attribute must be set. Its value must be a string, but it can be the same value as its <code>__name__</code>. When the module is a package, its <code>__package__</code> value should be set to its <code>__name__</code>. When the module is not a package, <code>__package__</code> should be set to the empty string for top-level modules, or for submodules, to the parent package’s name.</li>
<li><code>__spec__</code> : ## Module spec</li>
</ul></li>
</ul></li>
<li>The main difference between modules and packages is that packages have <code>__path__</code> and <code>__package__</code> defined (not None)</li>
<li><code>sys.modules</code> serves as a cache for all imported modules/packages
<ul>
<li>It is a dictionary so we can delete/set keys</li>
<li>If we delete a module, it will force Python to import it when we reimport it</li>
<li>If we set module key to None -&gt; result in <code>ModuleNotFoundError</code></li>
</ul></li>
<li>Even if we import one object from a module/package, the module/package will be cached in the <code>sys.modules</code> but not available in the global name space</li>
<li>The module created during loading and passed to exec_module() may not be the one returned at the end of the import
<ul>
<li>This can happen if the imported module set the <code>sys.modules[__name__]</code> to some other module</li>
</ul></li>
<li>The module’s attributes are set after creation and before execution</li>
<li>Execution of the module is what populates the module’s <code>__dict__</code> (namespace of the module). This is done by the loader</li>
<li>When a submodule is loaded using any mechanism, a binding is placed in the parent module’s namespace to the submodule object. For example, if we have a package called spam that has a submodule foo and it imports any of its objects like <code>from .foo import x</code>, after importing spam, spam will have an attribute foo which is bound to the submodule -&gt; We can now use <code>spam.foo</code></li>
<li>Relative imports use leading dots. A single leading dot indicates a relative import, starting with the current package. Two or more leading dots indicate a relative import to the parent(s) of the current package, one level per dot after the first.
<ul>
<li>Relative imports can only use this form of import: <code>from &lt;&gt; import &lt;&gt;</code></li>
<li>It can’t use <code>import .&lt;&gt;</code> because this is not a valid expression</li>
</ul></li>
<li>Absolute imports have to start from the top level package and go downward to refer to the module:
<ul>
<li><code>from package.subpackage import</code> module</li>
<li>Not recommended because if we change the name of the package then we need to change all the import statements -&gt; relative imports are more robust and don’t care about namings</li>
</ul></li>
<li>Process when importing a module/package (after locating it):
<ol type="1">
<li>First checks if it is cached. If not, continue</li>
<li>It creates a ModuleType object with that name</li>
<li>Cache the module in sys.modules</li>
<li>Executes the source code inside the module (first prefixing it with .py and then assign <code>__file__</code>)
<ul>
<li>In the case of the package/subpackage, it assign it the <code>__init__.py</code> file</li>
<li>It also executes all the <code>__init__.py</code> on the path</li>
</ul></li>
<li>Assign a variable to the module object</li>
</ol></li>
</ul>
<div class="sourceCode" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> sys, types</span>
<span id="cb1-2"></span>
<span id="cb1-3"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> import_module(modname):</span>
<span id="cb1-4">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Check if it is in the cache first</span></span>
<span id="cb1-5">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> modname <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> sys.modules:</span>
<span id="cb1-6">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> sys.modules[modname]</span>
<span id="cb1-7">    </span>
<span id="cb1-8">    sourcepath <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> modname <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'.py'</span></span>
<span id="cb1-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">with</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">open</span>(sourcepath, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r'</span>) <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> f:</span>
<span id="cb1-10">        sourcecode <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> f.read()</span>
<span id="cb1-11">    mod <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> types.ModuleType(modname)</span>
<span id="cb1-12">    mod.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__file__</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> sourcepath</span>
<span id="cb1-13">    </span>
<span id="cb1-14">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Cache the module</span></span>
<span id="cb1-15">    sys.modules[modname] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> mod</span>
<span id="cb1-16">    </span>
<span id="cb1-17">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Convert it to Python ByteCode</span></span>
<span id="cb1-18">    code <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">compile</span>(sourcecode, sourcepath, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'exec'</span>)</span>
<span id="cb1-19">    </span>
<span id="cb1-20">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Execute the code in the module from top to bottom</span></span>
<span id="cb1-21">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## And update the state (globals) in the module's dictionary</span></span>
<span id="cb1-22">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">exec</span>(code, mod.__dict__)</span>
<span id="cb1-23">    </span>
<span id="cb1-24">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## We return the cached one in case there is some patching inside the module</span></span>
<span id="cb1-25">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> sys.modules[modname]</span></code></pre></div>
</section>
<section id="module-compilation" class="level2">
<h2 class="anchored" data-anchor-id="module-compilation">Module Compilation</h2>
<p><img src="https://imaddabbura.github.io/posts/python/images/module-compilation.png" width="400px"></p>
<ul>
<li>Python put a lock when importing a module until it is done so that we don’t have multiple threads trying to import the same module at the same time</li>
<li><code>__import__</code> is the machinery behind <code>import</code> statement</li>
<li>We can use <code>importlib.import_module(module)</code> which is the same thing as <code>__import__</code>
<ul>
<li><code>importlib.import_module(spam)</code> is the same as <code>import spam</code></li>
<li><code>importlib.import_module('.spam', __package__)</code> is the same as <code>from . import spam</code></li>
<li>We can track all imports as follows:</li>
</ul></li>
</ul>
<div class="sourceCode" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> builtins</span>
<span id="cb2-2"></span>
<span id="cb2-3"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> imp_mod(modname, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>args, imp<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">__import__</span>):</span>
<span id="cb2-4">    <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"Importing </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>modname<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>)</span>
<span id="cb2-5">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> imp(modname, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>args)</span>
<span id="cb2-6"></span>
<span id="cb2-7">builtins.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">__import__</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> imp_mod</span></code></pre></div>
<ul>
<li><strong>Module Reloading</strong>:
<ul>
<li>It is not a good idea to reload a module because it creates zombies. Basically Python doesn’t try to clean up the dictionary from the old module, but instead exec() the new state of the module using the old <code>module.__dict__</code>. This means stuff from previous load may still exist and we end up having weird cases. This is how Python reloads a module:</li>
</ul>
<div class="sourceCode" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1">code <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">open</span>(module.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__file__</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'rb'</span>).<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">open</span>()</span>
<span id="cb3-2"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">exec</span>(code, module.__dict__)</span></code></pre></div>
<ul>
<li>Also, submodules that are loaded in the module/package don’t get reloaded. They still have their old version. Exampe: If module has <code>import pandas as pd</code>, when reloading the module it doesn’t reload pandas.</li>
<li>Also, if we have instances that use the old version of the module and then we reload -&gt; New instances of the same object (class) will refer to different code implementation than the instances created before the reload -&gt; Even though they refer to the same class, instances will have different types</li>
</ul></li>
<li><code>sys.path</code> is only the small part of the import machinery</li>
<li>Imports is actually controlled by <code>sys.meta_path</code>
<ul>
<li>It is a list of importers</li>
</ul>
<div class="sourceCode" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1">[_frozen_importlib.BuiltinImporter,</span>
<span id="cb4-2">_frozen_importlib.FrozenImporter,</span>
<span id="cb4-3">_frozen_importlib_external.PathFinder,</span>
<span id="cb4-4"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>six._SixMetaPathImporter at <span class="bn" style="color: #AD0000;
background-color: null;
font-style: inherit;">0x10c8769b0</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span>,</span>
<span id="cb4-5"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>pkg_resources.extern.VendorImporter at <span class="bn" style="color: #AD0000;
background-color: null;
font-style: inherit;">0x10dbf9300</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span>]</span></code></pre></div>
<ul>
<li>Python’s default sys.meta_path has three meta path finders, one that knows how to import built-in modules, one that knows how to import frozen modules, and one that knows how to import modules from an import path</li>
<li>For every import statement, it goes from start-to-end to know if sys.meta_path knows how to install it</li>
</ul></li>
</ul>
<div class="sourceCode" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">importlib.util <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> imp</span>
<span id="cb5-2"></span>
<span id="cb5-3"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> find_spec(modname):</span>
<span id="cb5-4">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> imp <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> sys.meta_path:</span>
<span id="cb5-5">        spec <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> imp.find_spec(modname)</span>
<span id="cb5-6">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> spec:</span>
<span id="cb5-7">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> spec</span>
<span id="cb5-8">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span></code></pre></div>
<ul>
<li><p>ModuleSpec of a module is its metadata that the loader uses to load it. We can also use <code>importlib.util.find_spec()</code> to get the module spec of any loaded package. If the package/module is not found -&gt; returns None. Example of pandas module spec:</p>
<div class="sourceCode" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1">  ModuleSpec(name<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'pandas'</span>, loader<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=&lt;</span>_frozen_importlib_external.SourceFileLoader <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">object</span> at <span class="bn" style="color: #AD0000;
background-color: null;
font-style: inherit;">0x10e609f90</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span>, origin<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'/Users/imad/anaconda3/envs/python-exp/lib/python3.10/site-packages/pandas/__init__.py'</span>, submodule_search_locations<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'/Users/imad/anaconda3/envs/python-exp/lib/python3.10/site-packages/pandas'</span>])</span></code></pre></div>
<ul>
<li>Module Spec main info:
<ul>
<li>spec.name : ## Full module name</li>
<li>spec.parent : ## Enclosing package</li>
<li>spec.submodule_search_locations : ## Package <strong>path</strong></li>
<li>spec.has_location : ## Has external location</li>
<li>spec.origin : ## Source file location</li>
<li>spec.cached : ## Cached location</li>
<li>spec.loader : ## Loader object</li>
</ul></li>
<li>We can use the <code>loader</code> from module spec to get the source code w/o importing it. They actually create the imported module:</li>
</ul>
<div class="sourceCode" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1">module <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec.loader.create_module(spec)</span>
<span id="cb7-2"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">not</span> module:</span>
<span id="cb7-3">    module <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> types.ModuleType(spec.name)</span>
<span id="cb7-4">    module.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__file__</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec.origin</span>
<span id="cb7-5">    module.__loader__ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec.loader</span>
<span id="cb7-6">    module.__package__ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec.parent</span>
<span id="cb7-7">    module.__path__ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec.submodule_search_locations</span>
<span id="cb7-8">    module.__spec__ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec</span></code></pre></div>
<ul>
<li>We can create module from spec with <code>importlib.util.module_from_spec</code>. This DOES NOT LOAD THE MODEL., it only creates it. To load the module, the module must be executed with <code>spec.loader.exec_module(spec)</code> and then cache it <code>sys.modules[spec.name] module</code>. <code>exec_module</code> will populate the <code>__dict__</code> of the module.</li>
</ul></li>
<li><p>We can execute modules lazily on first access. Implementation example:</p></li>
</ul>
<div class="sourceCode" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> types</span>
<span id="cb8-2"></span>
<span id="cb8-3"></span>
<span id="cb8-4"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> _Module(types.ModuleType):</span>
<span id="cb8-5">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">pass</span></span>
<span id="cb8-6"></span>
<span id="cb8-7"></span>
<span id="cb8-8"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> _LazyModule(_Module):</span>
<span id="cb8-9"></span>
<span id="cb8-10">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, spec):</span>
<span id="cb8-11">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(spec.name) </span>
<span id="cb8-12">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__file__</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec.origin</span>
<span id="cb8-13">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.__package__ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec.parent </span>
<span id="cb8-14">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.__loader__ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec.loader</span>
<span id="cb8-15">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.__path__ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec.submodule_search_locations </span>
<span id="cb8-16">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.__spec__ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> spec</span>
<span id="cb8-17"></span>
<span id="cb8-18">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__getattr__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, name):</span>
<span id="cb8-19">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.__class__ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> _Module</span>
<span id="cb8-20">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.__spec__.loader.exec_module(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>)</span>
<span id="cb8-21">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">assert</span> sys.modules[<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__name__</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span></span>
<span id="cb8-22">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">getattr</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, name)</span></code></pre></div>
<div class="sourceCode" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> importlib.util, sys</span>
<span id="cb9-2"></span>
<span id="cb9-3"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> lazy_import(name):</span>
<span id="cb9-4">   <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## If already loaded, return the module</span></span>
<span id="cb9-5">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> name <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> sys.modules:</span>
<span id="cb9-6">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> sys.modules[name]</span>
<span id="cb9-7">    </span>
<span id="cb9-8">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Not loaded. Find the spec</span></span>
<span id="cb9-9">    spec <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> importlib.util.find_spec(name)</span>
<span id="cb9-10">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">not</span> spec:</span>
<span id="cb9-11">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">raise</span> <span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">ImportError</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'No module </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>name<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:</span>r<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb9-12">    </span>
<span id="cb9-13">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Check for compatibility</span></span>
<span id="cb9-14">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">not</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">hasattr</span>(spec.loader, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'exec_module'</span>):</span>
<span id="cb9-15">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">raise</span> <span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">ImportError</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Not supported'</span>)</span>
<span id="cb9-16"></span>
<span id="cb9-17">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Perform the lazy import</span></span>
<span id="cb9-18">    module <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> sys.modules[name] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> _LazyModule(spec)</span>
<span id="cb9-19">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> module</span></code></pre></div>
<ul>
<li>Therefore, the module create/loading has been decoupled in recent versions of Python</li>
<li>We can insert an importer to <code>sys.meta_path</code> that can change the behavior of imports
<ul>
<li>If it is in the beginning, it supercedes all other loaders and we can do crazy things</li>
</ul>
<div class="sourceCode" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> sys</span>
<span id="cb10-2"></span>
<span id="cb10-3"></span>
<span id="cb10-4"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> Watcher(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">object</span>):</span>
<span id="cb10-5"></span>
<span id="cb10-6">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@classmethod</span></span>
<span id="cb10-7">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> find_spec(cls, name, path, target<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>):</span>
<span id="cb10-8">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Importing'</span>, name, path, target)</span>
<span id="cb10-9">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span>
<span id="cb10-10"></span>
<span id="cb10-11"></span>
<span id="cb10-12">sys.meta_path.insert(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, Watcher)</span></code></pre></div></li>
<li>We can also use this idea to add some logic such as autoinstall packages that are not found using pip. We insert the installer at the end of <code>sys.meta_path</code></li>
</ul>
<div class="sourceCode" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> sys</span>
<span id="cb11-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> subprocess</span>
<span id="cb11-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> importlib.util</span>
<span id="cb11-4"></span>
<span id="cb11-5"></span>
<span id="cb11-6"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> AutoInstall(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">object</span>):</span>
<span id="cb11-7">    _loaded <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">set</span>()</span>
<span id="cb11-8"></span>
<span id="cb11-9">    <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@classmethod</span></span>
<span id="cb11-10">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> find_spec(cls, name, path, target<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>):</span>
<span id="cb11-11">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> path <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">is</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">and</span> name <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">not</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> cls._loaded: </span>
<span id="cb11-12">            cls._loaded.add(name)</span>
<span id="cb11-13">            <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Installing"</span>, name)</span>
<span id="cb11-14">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">try</span>:</span>
<span id="cb11-15">                out <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> subprocess.check_output(</span>
<span id="cb11-16">                          [sys.executable, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'-m'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'pip'</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'install'</span>, name])</span>
<span id="cb11-17">                <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> importlib.util.find_spec(name) </span>
<span id="cb11-18">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">except</span> <span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">Exception</span> <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span> e:</span>
<span id="cb11-19">                <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Failed"</span>)</span>
<span id="cb11-20">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span></span>
<span id="cb11-21">sys.meta_path.append(AutoInstall)</span></code></pre></div>
<ul>
<li><p>We can also import packages not found on the system from some other systems such as Redis</p></li>
<li><p><code>sys.path_hooks</code> is responsible for the actual loading of the module/package depending on the path</p>
<ul>
<li>Each entry in the <code>sys.path</code> is tested against a list of <strong>path hooks</strong> to assosiate a module finder with each path entry</li>
<li>Path finders are used to locate module and return module spec along with loader</li>
<li>Path finders get cached in <code>sys.path_importer_cache</code></li>
</ul></li>
<li><p>Both <code>loaders</code> and <code>finders</code> have <code>find_spec()</code> that returns <strong>spec</strong> of module if they know how to find/load it. Otherwise, they return <code>None</code></p></li>
<li><p>What happens during import:</p></li>
</ul>
<div class="sourceCode" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1">modname <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'somemodulename'</span></span>
<span id="cb12-2"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> entry <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> sys.path:</span>
<span id="cb12-3">    finder <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> sys.path_importer_cache[entry]</span>
<span id="cb12-4">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> finder:</span>
<span id="cb12-5">        spec <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> finder.find_spec(modname)</span>
<span id="cb12-6">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> spec:</span>
<span id="cb12-7">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">break</span></span>
<span id="cb12-8"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span>:</span>
<span id="cb12-9">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">raise</span> <span class="pp" style="color: #AD0000;
background-color: null;
font-style: inherit;">ImportError</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'No such module'</span>)</span>
<span id="cb12-10">...</span>
<span id="cb12-11"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Load module from the spec</span></span>
<span id="cb12-12">...</span></code></pre></div>
</section>
<section id="experiments" class="level2">
<h2 class="anchored" data-anchor-id="experiments">Experiments</h2>
<div id="4efe21f4-70e0-4876-a98d-51fbed950551" class="cell" data-execution_count="2">
<div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1">sys.path.append(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"/Users/imad/Desktop/"</span>)</span></code></pre></div>
</div>
<div id="3a689c5c-b33a-44c8-a65d-4f109f396781" class="cell" data-execution_count="3">
<div class="sourceCode cell-code" id="cb14" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> pck.mod <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> X</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>pck.mod</code></pre>
</div>
</div>
<div id="8c327a3a-78f7-444f-aee2-df36dee0bb7d" class="cell" data-execution_count="4">
<div class="sourceCode cell-code" id="cb16" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1">X</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="4">
<pre><code>100</code></pre>
</div>
</div>
<div id="420680d8-7c17-4db5-8b4c-dd234f31d620" class="cell" data-execution_count="5">
<div class="sourceCode cell-code" id="cb18" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> pck.test <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> X</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>pck.test</code></pre>
</div>
</div>
<div id="099cd685-593f-474e-b20d-936c57883a19" class="cell" data-execution_count="6">
<div class="sourceCode cell-code" id="cb20" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb20-1">sys.modules[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"pck"</span>].__path__</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="6">
<pre><code>_NamespacePath(['/Users/imad/Documents/python-materials/modules-and-packages/pck', '/Users/imad/Documents/python-materials/modules-and-packages/pck', '/Users/imad/Desktop/pck'])</code></pre>
</div>
</div>
<div id="26ae3e7c-313a-4388-bc49-9fc9e2d7a162" class="cell" data-execution_count="3">
<div class="sourceCode cell-code" id="cb22" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1">foo.__package__, foo.__path__</span></code></pre></div>
<div class="cell-output cell-output-error">
<pre><code>AttributeError: module 'package.foo' has no attribute '__path__'</code></pre>
</div>
</div>
<div id="f204330c-1eff-4da4-a9fd-167e7a74c177" class="cell" data-execution_count="18">
<div class="sourceCode cell-code" id="cb24" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb24-1"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">globals</span>()[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"foo"</span>]</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="18">
<pre><code>&lt;module 'package.foo' from '/Users/imad/Documents/python-materials/modules-and-packages/package/foo.py'&gt;</code></pre>
</div>
</div>
<div id="596a075a-f58c-42d7-9d54-f1c0bbcef01b" class="cell" data-execution_count="21">
<div class="sourceCode cell-code" id="cb26" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> f():</span>
<span id="cb26-2">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">pass</span></span></code></pre></div>
</div>
<div id="ce368771-aeb3-420b-847d-ef6418735c45" class="cell" data-execution_count="23">
<div class="sourceCode cell-code" id="cb27" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb27-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> pandas <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> read_csv</span></code></pre></div>
</div>
<div id="38e03cfd-d29b-4564-bc74-e5b0d9bcaa2c" class="cell" data-execution_count="2">
<div class="sourceCode cell-code" id="cb28" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1">sys.path_hooks</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="2">
<pre><code>[zipimport.zipimporter,
 &lt;function _frozen_importlib_external.FileFinder.path_hook.&lt;locals&gt;.path_hook_for_FileFinder(path)&gt;]</code></pre>
</div>
</div>
<div id="86da7a73-13f4-4c85-83c6-7c209370a0f0" class="cell" data-execution_count="17">
<div class="sourceCode cell-code" id="cb30" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb30-1"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(sys.path_importer_cache.keys())[:<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>]</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="17">
<pre><code>['/Users/imad/anaconda3/envs/python-exp/lib/python310.zip',
 '/Users/imad/anaconda3/envs/python-exp/lib/python3.10',
 '/Users/imad/anaconda3/envs/python-exp/lib/python3.10/encodings',
 '/Users/imad/anaconda3/envs/python-exp/lib/python3.10/importlib',
 '/Users/imad/anaconda3/envs/python-exp/lib/python3.10/site-packages',
 '/Users/imad/anaconda3/envs/python-exp/lib/python3.10/lib-dynload',
 '/Users/imad/anaconda3/envs/python-exp/lib/python3.10/site-packages/PyYAML-6.0-py3.10-macosx-10.9-x86_64.egg',
 '/Users/imad/Documents/python-materials/modules-and-packages',
 '/Users/imad/anaconda3/envs/python-exp/lib/python3.10/site-packages/ipykernel',
 '/Users/imad/anaconda3/envs/python-exp/lib/python3.10/json']</code></pre>
</div>
</div>
<div id="e4ece638-93da-4b93-a594-52d3984b93c1" class="cell" data-execution_count="10">
<div class="sourceCode cell-code" id="cb32" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb32-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> importlib.util <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> find_spec</span></code></pre></div>
</div>
<div id="9e70eb1d-aee1-4c99-8222-0f1a114b78ba" class="cell" data-execution_count="14">
<div class="sourceCode cell-code" id="cb33" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb33-1">m <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> find_spec(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"mod"</span>)</span>
<span id="cb33-2">m.loader.get_source(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"mod"</span>)</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="14">
<pre><code>'y = 200\nprint(y)\n\nclass A:\n    print("A")\n'</code></pre>
</div>
</div>
<div id="83c9c028-470a-49b3-a6cc-d9f967425469" class="cell" data-execution_count="8">
<div class="sourceCode cell-code" id="cb35" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb35-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> sys</span>
<span id="cb35-2">sys.meta_path</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="8">
<pre><code>[_frozen_importlib.BuiltinImporter,
 _frozen_importlib.FrozenImporter,
 _frozen_importlib_external.PathFinder,
 &lt;six._SixMetaPathImporter at 0x10c8769b0&gt;,
 &lt;pkg_resources.extern.VendorImporter at 0x10dbf9300&gt;]</code></pre>
</div>
</div>
<div id="89c67452-6262-4f27-8ec2-df6628331245" class="cell" data-execution_count="1">
<div class="sourceCode cell-code" id="cb37" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb37-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> mod</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>200
A</code></pre>
</div>
</div>
<div id="f02003c4-f556-4db3-92d3-949103c9f7a5" class="cell" data-execution_count="2">
<div class="sourceCode cell-code" id="cb39" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb39-1">a <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> mod.A()</span></code></pre></div>
</div>
<div id="bde71e6e-5ca3-49e0-8c90-6d268ca6ec55" class="cell" data-execution_count="4">
<div class="sourceCode cell-code" id="cb40" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb40-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> importlib <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">reload</span></span></code></pre></div>
</div>
<div id="35796286-737b-4a22-a13c-d02fece9e169" class="cell" data-execution_count="5">
<div class="sourceCode cell-code" id="cb41" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb41-1"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">reload</span>(mod)</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>200
A</code></pre>
</div>
<div class="cell-output cell-output-display" data-execution_count="5">
<pre><code>&lt;module 'mod' from '/Users/imad/Documents/python-materials/modules-and-packages/mod.py'&gt;</code></pre>
</div>
</div>
<div id="41431941-be17-499a-b2af-699511a131c8" class="cell" data-execution_count="3">
<div class="sourceCode cell-code" id="cb44" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb44-1">b <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> mod.A()</span></code></pre></div>
</div>
<div id="de332317-b541-45cf-9da2-991928a725eb" class="cell" data-execution_count="4">
<div class="sourceCode cell-code" id="cb45" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb45-1">a.__class__, b.__class__, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span>(a) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span>(b)</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="4">
<pre><code>(mod.A, mod.A, True)</code></pre>
</div>
</div>
<div id="345de28d-d954-4a18-b2c4-580bf0a3b6a5" class="cell" data-execution_count="5">
<div class="sourceCode cell-code" id="cb47" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb47-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> importlib.util <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> find_spec</span></code></pre></div>
</div>
<div id="291d72dd-a4eb-45ca-83d5-dddc6856ddc3" class="cell" data-execution_count="7">
<div class="sourceCode cell-code" id="cb48" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb48-1">find_spec(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"sys"</span>)</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="7">
<pre><code>ModuleSpec(name='sys', loader=&lt;class '_frozen_importlib.BuiltinImporter'&gt;, origin='built-in')</code></pre>
</div>
</div>
<div id="08e6b24b-7546-4545-a621-73a9a63fa2ec" class="cell" data-execution_count="6">
<div class="sourceCode cell-code" id="cb50" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb50-1">find_spec(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"pandas"</span>)</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="6">
<pre><code>ModuleSpec(name='pandas', loader=&lt;_frozen_importlib_external.SourceFileLoader object at 0x10e609f90&gt;, origin='/Users/imad/anaconda3/envs/python-exp/lib/python3.10/site-packages/pandas/__init__.py', submodule_search_locations=['/Users/imad/anaconda3/envs/python-exp/lib/python3.10/site-packages/pandas'])</code></pre>
</div>
</div>
<div id="4c91a469-0147-495b-8e41-92232fbede57" class="cell" data-execution_count="21">
<div class="sourceCode cell-code" id="cb52" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb52-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> importlib</span></code></pre></div>
</div>
<div id="22a22d5e-b7bd-431d-99b4-cdcea27de573" class="cell">
<div class="sourceCode cell-code" id="cb53" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb53-1">importlib.import_module()</span></code></pre></div>
</div>
<div id="76d1435d-6edd-4834-967c-98800b78b12c" class="cell" data-execution_count="11">
<div class="sourceCode cell-code" id="cb54" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb54-1">pd.__path__, pd.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__name__</span>, pd.__package__, pd.<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">__file__</span>, pd.__doc__</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="11">
<pre><code>(['/Users/imad/anaconda3/envs/python-exp/lib/python3.10/site-packages/pandas'],
 'pandas',
 'pandas',
 '/Users/imad/anaconda3/envs/python-exp/lib/python3.10/site-packages/pandas/__init__.py',
 '\npandas - a powerful data analysis and manipulation library for Python\n=====================================================================\n\n**pandas** is a Python package providing fast, flexible, and expressive data\nstructures designed to make working with "relational" or "labeled" data both\neasy and intuitive. It aims to be the fundamental high-level building block for\ndoing practical, **real world** data analysis in Python. Additionally, it has\nthe broader goal of becoming **the most powerful and flexible open source data\nanalysis / manipulation tool available in any language**. It is already well on\nits way toward this goal.\n\nMain Features\n-------------\nHere are just a few of the things that pandas does well:\n\n  - Easy handling of missing data in floating point as well as non-floating\n    point data.\n  - Size mutability: columns can be inserted and deleted from DataFrame and\n    higher dimensional objects\n  - Automatic and explicit data alignment: objects can be explicitly aligned\n    to a set of labels, or the user can simply ignore the labels and let\n    `Series`, `DataFrame`, etc. automatically align the data for you in\n    computations.\n  - Powerful, flexible group by functionality to perform split-apply-combine\n    operations on data sets, for both aggregating and transforming data.\n  - Make it easy to convert ragged, differently-indexed data in other Python\n    and NumPy data structures into DataFrame objects.\n  - Intelligent label-based slicing, fancy indexing, and subsetting of large\n    data sets.\n  - Intuitive merging and joining data sets.\n  - Flexible reshaping and pivoting of data sets.\n  - Hierarchical labeling of axes (possible to have multiple labels per tick).\n  - Robust IO tools for loading data from flat files (CSV and delimited),\n    Excel files, databases, and saving/loading data from the ultrafast HDF5\n    format.\n  - Time series-specific functionality: date range generation and frequency\n    conversion, moving window statistics, date shifting and lagging.\n')</code></pre>
</div>
</div>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Software Engineering</category>
  <guid>https://imaddabbura.github.io/posts/python/Modules-And-Packages.html</guid>
  <pubDate>Fri, 09 Feb 2024 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/python/images/modules-packages-image.jpeg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>Automatic Differentiation Demystified</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/mlsys/automatic-differentiation.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="the-derivative-engine-behind-every-loss.backward" class="level2">
<h2 class="anchored" data-anchor-id="the-derivative-engine-behind-every-loss.backward">The Derivative Engine Behind Every <code>loss.backward()</code></h2>
<p>Every time you train a neural network, something computes exact derivatives through millions of operations automatically. You call <code>loss.backward()</code> and gradients appear — but <em>how</em>? And why does training a 7B-parameter LLM consume 5x more GPU memory than running inference on it?</p>
<p>The answer to both questions is <strong>Automatic Differentiation (AD)</strong>: a family of techniques for computing exact derivatives through arbitrary code, efficiently. Understanding it changes how you reason about memory budgets, gradient flow failures, and why certain training tricks (gradient checkpointing, mixed precision) exist at all.</p>
<p>There are two fundamentally different approaches — <strong>forward mode</strong> and <strong>reverse mode</strong> — and the choice between them explains why deep learning frameworks are built the way they are.</p>
</section>
<section id="why-not-just-use-calculus-or-finite-differences" class="level2">
<h2 class="anchored" data-anchor-id="why-not-just-use-calculus-or-finite-differences">Why Not Just Use Calculus or Finite Differences?</h2>
<p>Before getting to AD, it helps to understand what it replaced.</p>
<p><strong>Numerical differentiation</strong> approximates the derivative using finite differences: <img src="https://latex.codecogs.com/png.latex?f'(x)%20%5Capprox%20%5Cfrac%7Bf(x+h)%20-%20f(x)%7D%7Bh%7D"> for some small <img src="https://latex.codecogs.com/png.latex?h">. It’s dead simple but has two fatal flaws: it requires one extra forward pass <em>per parameter</em> (catastrophic for millions of parameters), and floating-point subtraction of nearly-equal numbers amplifies numerical error badly.</p>
<p><strong>Symbolic differentiation</strong> (what a computer algebra system does) applies calculus rules to produce a closed-form derivative expression. It’s exact, but the resulting expressions grow exponentially with computation depth — a 100-layer network would produce a gradient expression no machine could reasonably evaluate.</p>
<p>AD is neither. It applies the chain rule mechanically at each elementary operation, accumulating intermediate values rather than symbolic expressions. The result is exact (to floating-point precision) and efficient — no expression explosion, no extra passes per parameter.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Three Ways to Differentiate Code
</div>
</div>
<div class="callout-body-container callout-body">
<table class="table">
<colgroup>
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th>Method</th>
<th>Accuracy</th>
<th>Cost</th>
<th>Practical for ML?</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Numerical (finite diff)</td>
<td>Approximate</td>
<td>1 extra pass per input</td>
<td>❌ Too slow</td>
</tr>
<tr class="even">
<td>Symbolic</td>
<td>Exact</td>
<td>Expression explosion</td>
<td>❌ Intractable</td>
</tr>
<tr class="odd">
<td>AD — forward mode</td>
<td>Exact</td>
<td>1 pass per input</td>
<td>⚠️ Only if few inputs</td>
</tr>
<tr class="even">
<td>AD — reverse mode</td>
<td>Exact</td>
<td>1 pass per output</td>
<td>✅ Standard choice</td>
</tr>
</tbody>
</table>
</div>
</div>
</section>
<section id="forward-mode-ad-sensitivity-flowing-downstream" class="level2">
<h2 class="anchored" data-anchor-id="forward-mode-ad-sensitivity-flowing-downstream">Forward Mode AD: Sensitivity Flowing Downstream</h2>
<p>Forward mode AD propagates <strong>derivatives alongside values</strong> as computation flows from inputs to outputs. At each operation, it tracks not just the result but how sensitive that result is to a chosen input.</p>
<p>The elegant implementation uses <strong>dual numbers</strong>: instead of a scalar <img src="https://latex.codecogs.com/png.latex?x">, carry a pair <img src="https://latex.codecogs.com/png.latex?(x,%5C%20%5Cdot%7Bx%7D)"> where <img src="https://latex.codecogs.com/png.latex?%5Cdot%7Bx%7D"> represents the derivative of <img src="https://latex.codecogs.com/png.latex?x"> with respect to some chosen input <img src="https://latex.codecogs.com/png.latex?x_i">. Operations on dual numbers automatically propagate the derivative via the chain rule — you never write it explicitly:</p>
<p><img src="https://latex.codecogs.com/png.latex?f(a%20+%20b%5Cvarepsilon)%20%5Capprox%20f(a)%20+%20f'(a)%5Ccdot%20b%5Cvarepsilon%20%5Cqquad%20(%5Cvarepsilon%5E2%20=%200)"></p>
<p>The <img src="https://latex.codecogs.com/png.latex?%5Cvarepsilon"> coefficient carries the derivative forward through every arithmetic operation.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    x1["x₁\n(x₁, ẋ₁=1)"] --&gt; mul["×"]
    x2["x₂\n(x₂, ẋ₂=0)"] --&gt; mul
    mul --&gt;|"(x₁x₂, x₂·1)"| add["+"]
    x3["x₃\n(x₃, ẋ₃=0)"] --&gt; add
    add --&gt;|"(x₁x₂+x₃, x₂)"| L["L\n∂L/∂x₁ = x₂"]
</pre>
</div>
<p></p><figcaption> Forward mode propagates (value, derivative) pairs from inputs to output. The derivative component tracks sensitivity w.r.t. one chosen input. Here, the seed is set for x₁, so x₂’s dot is 0.</figcaption> </figure><p></p>
</div>
</div>
</div>
<p>The critical limitation: the initial <strong>seed vector</strong> — the <img src="https://latex.codecogs.com/png.latex?(0,%5Cldots,1,%5Cldots,0)"> that selects which input you’re differentiating with respect to — means one forward pass gives you the sensitivity with respect to <em>one</em> input. Getting gradients for all <img src="https://latex.codecogs.com/png.latex?n"> inputs requires <img src="https://latex.codecogs.com/png.latex?n"> passes.</p>
<p>For a 7B-parameter LLM, that’s 7 billion passes to compute a single gradient update. Forward mode is not the answer for ML.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
When Forward Mode Wins
</div>
</div>
<div class="callout-body-container callout-body">
<p>Forward mode is efficient when <strong>outputs greatly outnumber inputs</strong> — the opposite of ML. It shines in scientific computing: a simulation with 3 input parameters and 10,000 output metrics needs only 3 forward passes, not 10,000. In ML the ratio is reversed: millions of inputs (parameters), one output (scalar loss). Reverse mode exists to handle exactly this case.</p>
</div>
</div>
</section>
<section id="reverse-mode-ad-tracing-blame-upstream" class="level2">
<h2 class="anchored" data-anchor-id="reverse-mode-ad-tracing-blame-upstream">Reverse Mode AD: Tracing Blame Upstream</h2>
<p>Reverse mode flips the direction. Instead of asking “how does changing this input affect the output?”, it asks “how much did each node contribute to this output?”</p>
<p>The key insight: for a scalar output (a loss function), <strong>one backward pass distributes gradient credit back to every node in the graph simultaneously</strong>. One pass. All gradients.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    subgraph fwd ["① Forward Pass — compute and store"]
        direction LR
        x["x"] --&gt; mul["mul"] --&gt; add["add"] --&gt; L["L (scalar)"]
        w["w"] --&gt; mul
        b["b"] --&gt; add
    end
    subgraph bwd ["② Backward Pass — propagate gradients"]
        direction RL
        dL["∂L/∂L = 1"] --&gt; dadd["∂L/∂add"] --&gt; dmul["∂L/∂mul"]
        dmul --&gt; dx["∂L/∂x"]
        dmul --&gt; dw["∂L/∂w"]
        dadd --&gt; db["∂L/∂b"]
    end
    fwd --&gt; bwd
</pre>
</div>
<p></p><figcaption> Reverse mode runs two phases: a forward pass that computes and stores all intermediate values, then a backward pass that propagates ∂L/∂· back to every node.</figcaption> </figure><p></p>
</div>
</div>
</div>
<section id="the-unavoidable-memory-cost" class="level3">
<h3 class="anchored" data-anchor-id="the-unavoidable-memory-cost">The Unavoidable Memory Cost</h3>
<p>Here’s the catch. To compute gradients during the backward pass, each operation needs its <strong>inputs from the forward pass</strong>. For a <code>mul</code> node computing <img src="https://latex.codecogs.com/png.latex?z%20=%20w%20%5Ccdot%20x">, the backward step needs both <img src="https://latex.codecogs.com/png.latex?w"> and <img src="https://latex.codecogs.com/png.latex?x"> to distribute credit:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20w%7D%20=%20x%20%5Ccdot%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20z%7D,%20%5Cqquad%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20x%7D%20=%20w%20%5Ccdot%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20z%7D"></p>
<p>So the framework must <strong>keep every intermediate tensor alive</strong> until the backward pass consumes it. The consequence:</p>
<ul>
<li><strong>Inference</strong>: each layer’s activations can be discarded once the next layer is computed → memory is roughly <img src="https://latex.codecogs.com/png.latex?O(1)"> in depth</li>
<li><strong>Training</strong>: all activations must survive until their gradient is computed → memory is <img src="https://latex.codecogs.com/png.latex?O(N)"> in depth</li>
</ul>
<p>This is why training a transformer consumes so much more memory than running inference on it. At large batch sizes, forward activations alone can dwarf the parameter memory.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Why Your GPU OOMs During Training But Not Inference
</div>
</div>
<div class="callout-body-container callout-body">
<p>During inference, each layer’s output overwrites the previous buffer — memory stays roughly constant regardless of model depth. During training, every layer’s output must survive until the backward pass reaches it. A 24-layer transformer holds 24 layers of activations simultaneously. Scale batch size by 4x and activation memory scales 4x too — parameters don’t budge, activations do. This is the first thing to check when you hit an OOM that doesn’t happen at inference time.</p>
</div>
</div>
</section>
<section id="gradient-checkpointing-buying-memory-back-with-compute" class="level3">
<h3 class="anchored" data-anchor-id="gradient-checkpointing-buying-memory-back-with-compute">Gradient Checkpointing: Buying Memory Back with Compute</h3>
<p>The standard solution to activation memory pressure is <strong>gradient checkpointing</strong> (also called activation recomputation): don’t store all activations during the forward pass. Store only at segment boundaries — <strong>checkpoints</strong> — and recompute intermediate activations on-the-fly during the backward pass when they’re needed.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    subgraph s1 ["Segment 1"]
        L1["Layer 1"] --&gt; L2["Layer 2"] --&gt; L3["Layer 3"]
    end
    subgraph s2 ["Segment 2"]
        L4["Layer 4"] --&gt; L5["Layer 5"] --&gt; L6["Layer 6"]
    end
    s1 --&gt;|"✓ checkpoint"| s2
    style L1 fill:#e8f5e9
    style L3 fill:#e8f5e9
    style L4 fill:#e8f5e9
    style L6 fill:#e8f5e9
</pre>
</div>
<p></p><figcaption> Checkpointing stores activations only at segment boundaries (green). During backward, each segment re-runs its forward pass to recover the discarded intermediates.</figcaption> </figure><p></p>
</div>
</div>
</div>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Strategy</th>
<th>Activation memory</th>
<th>Compute overhead</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>No checkpointing</td>
<td><img src="https://latex.codecogs.com/png.latex?O(N)"> layers</td>
<td>None</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?%5Csqrt%7BN%7D"> checkpoints</td>
<td><img src="https://latex.codecogs.com/png.latex?O(%5Csqrt%7BN%7D)"> layers</td>
<td>~1 extra forward pass</td>
</tr>
<tr class="odd">
<td>Recompute everything</td>
<td><img src="https://latex.codecogs.com/png.latex?O(1)"></td>
<td>Up to <img src="https://latex.codecogs.com/png.latex?N"> extra forward passes</td>
</tr>
</tbody>
</table>
<p>The sweet spot for most LLM training is <img src="https://latex.codecogs.com/png.latex?%5Csqrt%7BN%7D"> checkpoints — roughly one extra forward pass in exchange for a meaningful memory reduction. This is what <code>torch.utils.checkpoint.checkpoint_sequential</code> implements.</p>
</section>
</section>
<section id="the-trade-off-stated-clearly" class="level2">
<h2 class="anchored" data-anchor-id="the-trade-off-stated-clearly">The Trade-off, Stated Clearly</h2>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th></th>
<th>Forward Mode</th>
<th>Reverse Mode</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Passes needed</strong></td>
<td>1 per input variable</td>
<td>1 per output variable</td>
</tr>
<tr class="even">
<td><strong>Best for</strong></td>
<td>Few inputs, many outputs</td>
<td>Many inputs, few outputs (ML)</td>
</tr>
<tr class="odd">
<td><strong>Memory overhead</strong></td>
<td>Low — no stored intermediates</td>
<td>High — all intermediates stored</td>
</tr>
<tr class="even">
<td><strong>What frameworks use</strong></td>
<td>Occasionally for Jacobians</td>
<td>Always for gradient-based training</td>
</tr>
</tbody>
</table>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Jacobian Perspective
</div>
</div>
<div class="callout-body-container callout-body">
<p>Forward mode naturally computes a <strong>Jacobian-vector product (JVP)</strong> — the full Jacobian multiplied by a chosen input direction. Reverse mode naturally computes a <strong>vector-Jacobian product (VJP)</strong> — a chosen output direction multiplied by the full Jacobian. For a scalar loss, the VJP with direction <img src="https://latex.codecogs.com/png.latex?%5B1%5D"> gives you the complete gradient vector in one pass. This is the mathematical reason reverse mode dominates ML training.</p>
</div>
</div>
</section>
<section id="what-breaks-in-practice" class="level2">
<h2 class="anchored" data-anchor-id="what-breaks-in-practice">What Breaks in Practice</h2>
<p><strong>Gradient flow failures.</strong> In reverse mode, gradients are products of local Jacobians chained across all layers. If any factor is consistently small (saturating activations, poor initialization) or large (unbounded weights), the gradient signal degrades before reaching early layers. This is the vanishing/exploding gradient problem — it’s not specific to RNNs, it’s a structural property of deep reverse-mode computation.</p>
<p><strong>Silent NaN propagation.</strong> A NaN anywhere in the forward pass propagates silently through the computation graph. During backward, every gradient flowing through the affected node becomes NaN, and the weight update corrupts the entire model. Use <code>torch.autograd.set_detect_anomaly(True)</code> to get a traceback pointing to the originating operation — invaluable for tracking these down.</p>
<p><strong>In-place operations on tensors with gradients.</strong> In-place ops (e.g., <code>x += 1</code>) can modify a tensor that the backward pass expects to find unchanged. PyTorch raises a runtime error when it detects this, but the error message can be confusing. The fix is simple: avoid in-place ops on any tensor that requires gradients, or clone before modifying.</p>
</section>
<section id="key-takeaways" class="level2">
<h2 class="anchored" data-anchor-id="key-takeaways">Key Takeaways</h2>
<ol type="1">
<li><p><strong>AD is not numerical or symbolic differentiation.</strong> It applies the chain rule exactly at each elementary operation — no approximation, no expression explosion.</p></li>
<li><p><strong>Forward mode needs one pass per input; reverse mode needs one pass per output.</strong> For ML — scalar loss, millions of parameters — reverse mode wins unconditionally.</p></li>
<li><p><strong>The cost of reverse mode is memory.</strong> Every intermediate tensor from the forward pass must stay alive for the backward pass. This is the root cause of training using far more memory than inference.</p></li>
<li><p><strong>Gradient checkpointing trades compute for memory.</strong> Store only at segment boundaries, recompute the rest during backward. Expect roughly one extra forward pass overhead for a meaningful memory reduction.</p></li>
<li><p><strong>Most gradient problems are reverse-mode problems.</strong> Vanishing/exploding gradients, NaN propagation, and in-place op errors all stem from how reverse-mode AD chains local Jacobians through the computation graph. Understanding the mechanism is the fastest path to diagnosing them.</p></li>
</ol>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>ML Systems</category>
  <guid>https://imaddabbura.github.io/posts/mlsys/automatic-differentiation.html</guid>
  <pubDate>Sat, 03 Feb 2024 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/mlsys/images/automatic-differentiation-image.jpeg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>Git from the Inside Out</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/swe/Advanced-Git.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p><strong>Git</strong> is a distributed version control system that thinks/stores its data as a series of snapshots (not delta). Each commit is a snapshot for the state of the system at the time of the commit. For files that haven’t changed, <em>Git</em> doesn’t store the file again but uses a pointer to the previous identical file that it stored before. It also lets us do almost all operations locally.</p>
<p>Everything in <em>Git</em> is checksummed before it is stored in its object store using <code>SHA-1</code> hash. <code>SHA-1</code> hash returns 40 hexadecimal characters. All objects are referred to by their checksummed because <em>Git</em> is content addressable filesystem. This means that Git notice any changes to the files it tracks by comparing the checksummed of the stored version vs the current version.</p>
<p>All actions in <em>Git</em> only add data to the object store (Git database). Therefore, it is almost impossible to not undo any operation especially if we regularly push our Git database to other repository such as <em>Github</em>.</p>
<p>Git has three states:</p>
<ul>
<li><strong>Modified</strong>: file changed but not yet committed.</li>
<li><strong>Staged</strong>: marked changed file to go to next commit snapshot. Staging area is a single file that is typically called “index”, which stores information about what will go into our next commit snapshot. When we run <code>git add file</code>, Git does the following:
<ul>
<li>Computes checksum of the file and store the SHA-1 value in index file</li>
<li>Compress the contents of the file and store it in <code>.git</code> directory under <code>objects</code> where the first two characters of the checkum would be the name of the directory and the next 38 characters would be the name of the file</li>
<li>Add the checksum to the index file (staging area)</li>
</ul></li>
<li><strong>Committed</strong>: store data (snapshot) in the database. The snapshot is represented as tree for root directory of the Git project. When we run <code>git commit</code>, Git does the following:
<ul>
<li>It computes checksum of each subdirectory until we end up with the root directory.</li>
<li>Stores them as tree objects in Git repository</li>
<li>Finally, Git create a commit object and store it in the Git repository with the following metadata:
<ul>
<li>Date</li>
<li>Author name</li>
<li>Committer name</li>
<li>Commit message</li>
<li>Parent(s) commit. First commit would have no parents. Following commits may have 1 parent or more parents in the case of merges</li>
<li>Pointer to the root project tree</li>
</ul></li>
</ul></li>
</ul>
<p><code>.git</code> directory which is at the root directory of the project has all the metadata for Git project such as the database (object store).</p>
<p>Files can be in two states:</p>
<ul>
<li><strong>UnTracked</strong>: files that Git doesn’t know about. They are files that are neither in any snapshot nor in staging area. Therefore, they don’t have modified/unmodified states.</li>
<li><strong>Tracked</strong>: files that were in last snapshot or in staging area. They have all states mentioned above.</li>
</ul>
</section>
<section id="git-object-model" class="level2">
<h2 class="anchored" data-anchor-id="git-object-model">Git Object Model</h2>
<ul>
<li>Git stores everything in <strong>.git</strong> directory. So deleting this directory will basically delete the whole history and can’t be recovered.</li>
<li>Git stores all of its representations using <strong>objects</strong> directory. Object can be: blob or tree or commit.</li>
<li>Git use <strong>sha1sum</strong> to get the hash value of each object. It is 40 hexadecimal characters (160 bits).
<ul>
<li>Git uses the first two characters for the name of directory for the object and the other 38 characters for the object itself.</li>
<li>Git stores objects based on their hash values (content addressable storage).</li>
<li>Git compresses the contents using zlib</li>
</ul>
<div class="sourceCode" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode c code-with-copy"><code class="sourceCode c"><span id="cb1-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">// a file is a bunch of bytes</span></span>
<span id="cb1-2">type object <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> blob <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> commit</span>
<span id="cb1-3">objects <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> map<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>sha1sum<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(</span>object<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">),</span> object<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span></span></code></pre></div>
<div class="sourceCode" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> store(obj):</span>
<span id="cb2-2">  <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">id</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> sha1sum(obj)</span>
<span id="cb2-3">  objects[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">id</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> obj</span>
<span id="cb2-4">  <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span></span>
<span id="cb2-5"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> a directory contains named files <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">and</span> directories</span>
<span id="cb2-6"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">type</span> tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">map</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>string, tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">file</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span></span>
<span id="cb2-7"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> load(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">id</span>):</span>
<span id="cb2-8">  <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> objects[<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">id</span>]</span></code></pre></div></li>
</ul>
<div id="cell-5" class="cell" data-execution_count="7">
<div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>ls <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>al ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>total 48
drwxr-xr-x  15 imad  staff   480 Nov  5 09:08 .
drwxr-xr-x@ 12 imad  staff   384 Nov  5 06:56 ..
-rw-r--r--   1 imad  staff    15 Mar 17  2020 COMMIT_EDITMSG
-rw-r--r--   1 imad  staff    23 Feb 12  2020 HEAD
drwxr-xr-x   2 imad  staff    64 Feb 12  2020 branches
-rw-r--r--   1 imad  staff   455 Mar 17  2020 config
-rw-r--r--   1 imad  staff    73 Feb 12  2020 description
drwxr-xr-x  13 imad  staff   416 Feb 12  2020 hooks
-rw-r--r--   1 imad  staff  3913 Nov  5 09:08 index
drwxr-xr-x   3 imad  staff    96 Feb 12  2020 info
drwxr-xr-x   4 imad  staff   128 Feb 12  2020 logs
drwxr-xr-x   3 imad  staff    96 Mar 17  2020 modules
drwxr-xr-x  94 imad  staff  3008 Mar 17  2020 objects
-rw-r--r--   1 imad  staff   114 Feb 12  2020 packed-refs
drwxr-xr-x   5 imad  staff   160 Feb 12  2020 refs</code></pre>
</div>
</div>
<div id="cell-6" class="cell" data-execution_count="27">
<div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>ls <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>a ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>objects</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>.    0c   1a   26   37   43   50   62   72   8b   9e   b1   c5   d1   e0   fc
..   0f   1d   29   38   45   52   63   73   8e   a2   b2   c7   d2   e1   ff
00   12   1f   2a   39   49   58   64   75   91   a5   b3   ca   d3   eb   info
03   13   20   2b   3f   4a   5a   67   77   96   a9   b4   cc   d8   ee   pack
07   17   21   2c   40   4b   5f   68   7f   98   aa   b6   ce   dc   f0
08   19   22   33   42   4d   61   6f   8a   9d   b0   b8   d0   dd   f8</code></pre>
</div>
</div>
<div id="cell-7" class="cell" data-execution_count="21">
<div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>ls <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>Ral ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>objects<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>ee</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>total 8
drwxr-xr-x   3 imad  staff    96 Mar  4  2020 .
drwxr-xr-x  94 imad  staff  3008 Mar 17  2020 ..
-r--r--r--   1 imad  staff   166 Mar  4  2020 5941ab3c125a3a669370d96cd5cb8496f8acde</code></pre>
</div>
</div>
<div id="cell-8" class="cell" data-execution_count="26">
<div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>git cat<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">file</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>p ee5941ab3c125a3a669370d96cd5cb8496f8acde</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>100644 blob b6e47617de110dea7ca47e087ff1347cc2646eda    .gitignore
100644 blob 261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64    LICENSE
100644 blob 58da9de606d62625c379fea5ca020d19d958fb18    README.md
040000 tree 6fea6c3802fd1cf83bf19bfc2302da6b79638ab5    missing-cs-semester</code></pre>
</div>
</div>
<section id="blobs" class="level3">
<h3 class="anchored" data-anchor-id="blobs">Blobs</h3>
<ul>
<li><code>blobs</code> are binary large objects which stores only the context of the file; not its name (array of bytes).</li>
</ul>
<div class="sourceCode" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode c code-with-copy"><code class="sourceCode c"><span id="cb11-1">type blob <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> array<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>byte<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span></span></code></pre></div>
<ul>
<li>The type of file which is “blob”, the number of characters in it, the separator character, and the actual content are passed to the sha1sum to get the hash value.</li>
<li>Since Git does not store the name of the file or any of its metadata, if you have two files with the same content then Git only stores it once.</li>
</ul>
<div id="cell-11" class="cell" data-execution_count="30">
<div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>git cat<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">file</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>p <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">58</span><span class="er" style="color: #AD0000;
background-color: null;
font-style: inherit;">da9de606d62625c379fea5ca020d19d958fb18</span></span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code># software-engineering
Materials for software engineering.</code></pre>
</div>
</div>
<div id="cell-12" class="cell" data-execution_count="31">
<div class="sourceCode cell-code" id="cb14" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>wc ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>README.md</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>       2       6      59 ../../README.md</code></pre>
</div>
</div>
<div id="cell-13" class="cell" data-execution_count="68">
<div class="sourceCode cell-code" id="cb16" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%%</span>bash</span>
<span id="cb16-2">cat <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>(echo <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>e <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"blob 60</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\0</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>) ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>README.md</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>blob 60
# software-engineering
Materials for software engineering.</code></pre>
</div>
</div>
<div id="cell-14" class="cell" data-execution_count="75">
<div class="sourceCode cell-code" id="cb18" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb18-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># # Since todo.md and todo2.md are identical, Git saves ONLY one copy</span></span>
<span id="cb18-2"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># 100644 blob b3dfa8b0b7c73f2c7156dfc69c737d05f2f900c3    file.txt</span></span>
<span id="cb18-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># 100644 blob c1ee9d5404109b66f21fa193da635aa8c4f04c47    todo.md</span></span>
<span id="cb18-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># 100644 blob c1ee9d5404109b66f21fa193da635aa8c4f04c47    todo2.md</span></span></code></pre></div>
</div>
<div id="cell-15" class="cell" data-execution_count="66">
<div class="sourceCode cell-code" id="cb19" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%%</span>bash</span>
<span id="cb19-2">cat <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>(echo <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>e <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"blob 58</span><span class="ch" style="color: #20794D;
background-color: null;
font-style: inherit;">\0</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>) ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>README.md <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> shasum</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>5ee782528aa7bc3d388c33339962f6fce514b39e  -</code></pre>
</div>
</div>
</section>
<section id="trees" class="level3">
<h3 class="anchored" data-anchor-id="trees">Trees</h3>
<ul>
<li>Tree is a recursive data structure that contains other trees/blobs; i.e.&nbsp;it contains a list of pointers to other trees/blobs. In this context, tree is a directory. Therefore, the root directory is the main directory that has .git as its subdirectory. Each line in the tree object’s file contains a pointer (the object’s hash) to one such object (tree or blob), while also providing the mode, object type, and a name for the file or directory.</li>
</ul>
<div class="sourceCode" id="cb21" style="background: #f1f3f5;"><pre class="sourceCode c code-with-copy"><code class="sourceCode c"><span id="cb21-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">// a directory contains named files and directories</span></span>
<span id="cb21-2">type tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> map<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>sha1sum<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(</span>tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> file<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">),</span> tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">|</span> file<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;;</span></span></code></pre></div>
<ul>
<li>It maps strings (hash values) to objects. So if a directory is empty, Git does not add it as untracked change until we add a file or a directory to it because empty directory has nothing to map stuff to. Therefore, to track empty directories, we can add <code>.gitkeep</code> to the directory if it is empty to be able to track it.</li>
<li>We pass all objects (not their contents) to get the hash value.</li>
<li>Tree objects themselves do not have names, much like blobs. Parent trees associate names for subtrees, and the root tree, referred to as the “working tree” of a repository, in fact has no name. This has two fun characteristics:
<ul>
<li>The repo doesn’t care what you call it. You can rename your local directory that contains your repository to anything you’d like. Git is blissfully unaware of the name of the directory that contains the .git repo directory.</li>
<li>We can rename subtrees as much as we want, and only parent objects need to update. The subtree object itself and everything below remain untouched.</li>
</ul></li>
<li>Trees summary:
<ul>
<li>Trees list out the contents of a directory (blobs and subtrees)</li>
<li>For each object, the mode, permissions, type, hash, and name is listed</li>
<li>Tree objects must contain at least one blob or tree; otherwise, it won’t be tracked</li>
<li>Trees can be nested to any depth</li>
<li>Trees, like blobs, don’t store names. The names are stored in parent trees. Therefore, changing names of subtrees only change the names in the parent tree. Therefore, since root directory has no parent, changing its name doesn’t have any effect on git</li>
<li>Trees are named and stored in the objects directory by hashing their contents (the list of objects described above)</li>
</ul></li>
</ul>
<div id="cell-18" class="cell" data-execution_count="78">
<div class="sourceCode cell-code" id="cb22" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb22-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># master is a branch that points to a commit which also points to a tree</span></span>
<span id="cb22-2"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>git ls<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>tree master</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>100644 blob ced9612a7d927cdb23d0ba2de47679504b0c9fc3    Command-Line-Environment.ipynb
100644 blob 68fe60e479d316b7f40f194a8d3400e7f5c8af60    Data-Wrangling.ipynb
100644 blob dcee12ed0ae6096339dc40a70c5ff67a2afdec31    Debugging-And-Profiling.ipynb
100644 blob 40582a6d5d23028acc251c54f6e124ce9f2ec5ba    Petpouri.ipynb
100644 blob 6405b8d2b26e17020b12f98639148615d6c9baea    Plan.ipynb
100644 blob f0aca55804924d43eb3d687ebc9a780f3b8baff3    Security-And-Cryptography.ipynb
100644 blob 035f52d11f6126cac575b899f6ff0011060aeddd    Shell-Scripting.ipynb
100644 blob a2235c5c9a32920e7f15c3bf63afde41e60c4e52    Version-Control(Git).ipynb
100644 blob 395d086c29d15560f0b3eee28c0489afe1b6de8e    Vim-Tutor-Summaries.ipynb
100644 blob 9d756b15f398735d1bd414fd97afa5f709db06f5    basic.png
100644 blob 50daf1bb695821f251fbc44880413ca17ca8a8b6    commit_history.png
100644 blob 43d8b8f1031c6c81f0153e0616be7576de6401f9    pycallgraph.png
100644 blob 5a0fa1a6cb3918e9c2d316433edfd5ccd275bd59    vim-tutorial.md</code></pre>
</div>
</div>
</section>
<section id="commits" class="level3">
<h3 class="anchored" data-anchor-id="commits">Commits</h3>
</section>
<section id="section" class="level3">
<h3 class="anchored" data-anchor-id="section"></h3>
<ul>
<li><code>commits</code> contain parent, message, author, commiter, and current tree. Therefore, it is a file like any other object.</li>
</ul>
<div class="sourceCode" id="cb24" style="background: #f1f3f5;"><pre class="sourceCode c code-with-copy"><code class="sourceCode c"><span id="cb24-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">// a commit has parents, metadata, and the top-level tree</span></span>
<span id="cb24-2">type commit <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">struct</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span></span>
<span id="cb24-3">    parent<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:</span> array<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>commit<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;;</span></span>
<span id="cb24-4">    author<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:</span> string</span>
<span id="cb24-5">    message<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:</span> string</span>
<span id="cb24-6">    snapshot<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:</span> tree</span>
<span id="cb24-7"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span></span></code></pre></div>
<ul>
<li>It’s worth noting that the commit object only contains a single reference to a working directory; Git doesn’t store diffs. When diffing between two commits, it compares the working trees of the commits, computing the diff on demand.</li>
</ul>
<p>Git only stores the delta changes between commits and not everything. It also point to blobs/trees that have not been changed using old commits and don’t store them again for new commits. Therefore, if a file has not been changed from previous commit, the hash value for that commit is the same so its address is still the same -&gt; keep the same pointer.</p>
<p><img src="https://imaddabbura.github.io/posts/swe/images/git-objects-simple.png" height="400px" width="300px" align="center"></p>
<p><strong>References</strong> are nothing but pointers to commits. They are stored under <code>.git/refs</code> directory as files where each file contains the hash value of some commit. Since it is a hassle to always refer to objects by their 40 hexadecimal string, we can use references to refer to objects. Contrary to objects, references are mutable. For example, <code>master</code> always refers to the latest commit in the main branch. <code>HEAD</code> refers to where we currently are in the history which will be used when creating new snapshot by making the parent for this commit the <code>HEAD</code> and then update <code>HEAD</code>.</p>
<div class="sourceCode" id="cb25" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb25-1">references <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">map</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span>string, commit<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;;</span></span>
<span id="cb25-2"></span>
<span id="cb25-3"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> update_reference(name, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">id</span>):</span>
<span id="cb25-4">    references[name] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">id</span></span>
<span id="cb25-5"></span>
<span id="cb25-6"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> read_reference(name):</span>
<span id="cb25-7">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> references[name]</span>
<span id="cb25-8"></span>
<span id="cb25-9"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> load_reference(name_or_id):</span>
<span id="cb25-10">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> name_or_id <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> references:</span>
<span id="cb25-11">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> load(references[name_or_id])</span>
<span id="cb25-12">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span>:</span>
<span id="cb25-13">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> load(name_or_id)</span></code></pre></div>
<div id="cell-24" class="cell" data-execution_count="88">
<div class="sourceCode cell-code" id="cb26" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb26-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>L <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>../../.git
├── COMMIT_EDITMSG
├── HEAD
├── branches
├── config
├── description
├── hooks
├── index
├── info
├── logs
├── modules
├── objects
├── packed-refs
└── refs

7 directories, 6 files</code></pre>
</div>
</div>
<div id="cell-25" class="cell" data-execution_count="89">
<div class="sourceCode cell-code" id="cb28" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb28-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>L <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>refs<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span></span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>../../.git/refs/
├── heads
├── remotes
└── tags

3 directories, 0 files</code></pre>
</div>
</div>
<div id="cell-26" class="cell" data-execution_count="90">
<div class="sourceCode cell-code" id="cb30" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb30-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>L <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>refs<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>heads<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span></span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>../../.git/refs/heads/
└── master

0 directories, 1 file</code></pre>
</div>
</div>
<div id="cell-27" class="cell" data-execution_count="91">
<div class="sourceCode cell-code" id="cb32" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb32-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>cat ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>refs<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>heads<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>master</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>e1d95ecb4c02e0c30a8635a37631c523d2041299</code></pre>
</div>
</div>
<div id="cell-28" class="cell" data-execution_count="93">
<div class="sourceCode cell-code" id="cb34" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb34-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>git cat<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">file</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>t e1d95ecb4c02e0c30a8635a37631c523d2041299</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>commit</code></pre>
</div>
</div>
<div id="cell-29" class="cell" data-execution_count="94">
<div class="sourceCode cell-code" id="cb36" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb36-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>git cat<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">file</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>p e1d95ecb4c02e0c30a8635a37631c523d2041299</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>tree 75407c234b245d258c809de234e030f57dd98148
parent 429137bbf1334dcea2719458bcc3a323cd829ecd
author Imad &lt;imad.dabbura@hotmail.com&gt; 1584462760 -0500
committer Imad &lt;imad.dabbura@hotmail.com&gt; 1584462760 -0500

Review all nbs</code></pre>
</div>
</div>
<p><code>HEAD</code>, unlike the other objects we’ve discussed, is a singleton, meaning that there is only ever one HEAD. It identifies the currently checked out object. Typically, this is a branch (with that branch pointing to a commit), but it is possible to check out a commit directly, in which case HEAD would be pointing at that commit.</p>
<p>HEAD is a file just like our branch objects. It lives at the root of the .git directory and its contents are similarly simple.</p>
<div id="cell-31" class="cell" data-execution_count="114">
<div class="sourceCode cell-code" id="cb38" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb38-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>cat <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>Users<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>imad<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>Desktop<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>repo<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>HEAD</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>ref: refs/heads/master</code></pre>
</div>
</div>
<div id="cell-32" class="cell" data-execution_count="115">
<div class="sourceCode cell-code" id="cb40" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb40-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%%</span>bash</span>
<span id="cb40-2">cd <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">~/</span>Desktop<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>repo<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span></span>
<span id="cb40-3">git graph2</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>* 6e9c688    (HEAD -&gt; master, tag: v.0.1, test) Renaming (Imad)
* ff9e667    Add test1 dir (Imad)
* e503b6e    Add test dir (Imad)
* 8610c78    Add copied file (Imad)
* ed738c9    (feature) Rebased all commits (Imad)
* 2050b90    Add host to file (Imad)
* 91eacf5    Add host (Imad)
* ed27259    patch commit (Imad)
* ff2d260    third commit (Imad)
* 6dd0c14    Change second commit (Imad)
* c2b7166    first commit (Imad)</code></pre>
</div>
</div>
<div id="cell-33" class="cell" data-execution_count="116">
<div class="sourceCode cell-code" id="cb42" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb42-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%%</span>bash</span>
<span id="cb42-2">cd <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">~/</span>Desktop<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>repo<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span></span>
<span id="cb42-3">git checkout <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8610</span><span class="er" style="color: #AD0000;
background-color: null;
font-style: inherit;">c78</span></span></code></pre></div>
<div class="cell-output cell-output-stderr">
<pre><code>Note: checking out '8610c78'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by performing another checkout.

If you want to create a new branch to retain commits you create, you may
do so (now or later) by using -b with the checkout command again. Example:

  git checkout -b &lt;new-branch-name&gt;

HEAD is now at 8610c78 Add copied file</code></pre>
</div>
</div>
<div id="cell-34" class="cell" data-execution_count="118">
<div class="sourceCode cell-code" id="cb44" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb44-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>cat <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>Users<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>imad<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>Desktop<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>repo<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>HEAD</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>8610c78113fe423b20a9f84d485b49af5ad089b0</code></pre>
</div>
</div>
<div id="cell-35" class="cell" data-execution_count="117">
<div class="sourceCode cell-code" id="cb46" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb46-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%%</span>bash</span>
<span id="cb46-2">cd <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">~/</span>Desktop<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>repo<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span></span>
<span id="cb46-3">git graph2</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>* 6e9c688    (tag: v.0.1, test, master) Renaming (Imad)
* ff9e667    Add test1 dir (Imad)
* e503b6e    Add test dir (Imad)
* 8610c78    (HEAD) Add copied file (Imad)
* ed738c9    (feature) Rebased all commits (Imad)
* 2050b90    Add host to file (Imad)
* 91eacf5    Add host (Imad)
* ed27259    patch commit (Imad)
* ff2d260    third commit (Imad)
* 6dd0c14    Change second commit (Imad)
* c2b7166    first commit (Imad)</code></pre>
</div>
</div>
</section>
<section id="summary" class="level3">
<h3 class="anchored" data-anchor-id="summary">Summary</h3>
<ul>
<li><strong>Objects</strong>: blobs, trees, and commits</li>
<li><strong>Refs</strong>: branches, tags, and remote branches</li>
<li><strong>HEAD</strong>: The single pointer to rule them all</li>
</ul>
<p><img src="https://imaddabbura.github.io/posts/swe/images/git-objects.png" width="500"></p>
</section>
</section>
<section id="branches" class="level2">
<h2 class="anchored" data-anchor-id="branches">Branches</h2>
<ul>
<li><code>heads</code> aka branches (because it is a collection of <code>HEAD</code>s for each branch in Git repo) are nothing but pointers to commits. They are very simple objects, they only contain hash value of the commit they are pointing to. Therefore, creating a branch is just creating a file in <code>refs/heads</code> with the name of the branch that has the commit of the HEAD of that branch. At the beginning, this file will have the commit the HEAD points to from the branch you were on when created the branch.</li>
<li>When you switch branches, Git resets your working directory to look like it did the last time you committed on that branch. It adds, removes, and modifies files automatically to make sure your working copy is what the branch looked like on your last commit.</li>
<li>Merging:
<ul>
<li>If we are merging a feature branch into master branch and the feature branch is directly ahead of master where master’s last commit can be reached following feature branch commit’s history, Git will do fast-forward merge, which means it just updates the pointer to point forward.</li>
<li>Otherwise, if head of master branch isn’t direct ancestor of feature branch, Git does three-way merge by using 3 commits:
<ul>
<li>Common ancestor commit</li>
<li>Last commit from master branch and feature branch</li>
<li>Creates a new snapshot with new commit object (<strong>merge commit</strong>) that points to two parents: last commit from master and last commit from feature branches</li>
</ul></li>
<li>If we have merge conflict, we can either abort the merge or resolve the merge conflist ourselves. Once we resolve the conflicts in all files, we should stage those files and then commit the changes. This would be the merge commit. We can use mergetool to resolve merge conflicts such as <code>vimdiff</code>.</li>
</ul></li>
<li><code>git branch -v</code> will show last commit of all branches</li>
<li><code>git branch --merged</code> show all branches that were merged with the current branch we are on. <code>git branch --merged master</code> show all branches that were merged with master branch.</li>
<li><code>git branch --no-merged</code> does the opposite.</li>
<li>We can’t delete a branch if it has work that we haven’t merged with master branch. We can force delete using <code>-D</code> flag.</li>
<li>We can rename a branch, but we should do it both locally and on the remote server. It is recommended to avoid renaming master branch because it would break integrations/scripts/etc. and requires a lot more work.
<ul>
<li>Locally:
<ul>
<li><code>git branch --move oldname newname</code></li>
</ul></li>
<li>Remote:
<ul>
<li><code>git push --set-upstream origin newname</code></li>
<li><code>git push origin -d oldname</code></li>
</ul></li>
</ul></li>
<li><code>heads</code> are for local branches.</li>
</ul>
</section>
<section id="remote-branches" class="level2">
<h2 class="anchored" data-anchor-id="remote-branches">Remote Branches</h2>
<p><strong>Remote Branches</strong> are the same as local branches. They are again files that point to commits.</p>
<ul>
<li>We can have multiple remotes where each one has its own branches. <code>origin</code> is the (default) main one typically used for the upstream (we can change it to other names when cloning a repo such as <code>git clone URL -o anothername</code>. <code>git remote -v</code> would list all the remotes for the repository. We can add remotes <code>git remote add remote_name remote_url</code></li>
<li>All the remote branches under <code>remotes/origin/</code> will be updated <strong>ONLY</strong> when communicating with the remote server. Such branches act more as bookmarks and can’t be changed by any Git commands to point to different commits directly.</li>
<li>Local branch is called <strong>Tracking Branch</strong> if it tracks a remote branch (called <strong>Upstream Branch</strong>)
<ul>
<li><code>git checkout branchname</code> would create tracking branch that tracks default remotename/branchname if branchname doesn’t exist and exactly matches one upstream branch names.</li>
<li>We can have local branches track branches from different remotes: <code>git checkout -b remotename/remotebranch</code> which would create local branch named remotebranch that tracks remotebranch on remotename server. We can have different name for our local branch as <code>git checkout -b localbranchname remotename/remotebranch</code>.</li>
<li>If we already have a local branch, we can use <code>git branch --set-upstream-to=remotename/remotebranch</code> to make current branch track remotebranch on remotename server</li>
<li>If I am on a tracking branch and run <code>git pull</code>, it knows which server to fetch from and which branch to merge in</li>
</ul></li>
<li><code>git fetch</code> download the changes from all branches from remote to local repository without merging them. We should do the merge ourselves such as <code>git merge origin/branchname</code></li>
<li><code>git pull</code> download and merge the changes from remote to local branches.</li>
<li><code>git remote show remote_name</code> will show everything in details about the <code>remote_name</code> such as URL, local/remote branches, etc.</li>
<li>We can rename/delete remotes as <code>git remote rename/remove remote_name</code>. If we delete remote, it deletes all config/settings related to the deleted remote. Renaming would rename branches.</li>
<li>Remote references are read-only, which means we will never update them using <code>git commit</code> but <em>Git</em> manages them as bookmarks.</li>
<li>By default, <em>Git</em> fetches all references from remote to heads -&gt; All branches. We can change this behavior on the command line when running <code>git fetch remote_name remote_branch:refs/remotes/remote_name/branch_name</code></li>
<li>Pushing local branch to remote can be done in different forms:
<ul>
<li><code>git push origin branchname</code></li>
<li><code>git push origin localbranchname:remotebranchname</code> which lets us have a different name on the remote server for our local branch</li>
</ul></li>
<li>We can delete remote branch <code>git push origin --delete branchname</code></li>
</ul>
<div id="cell-42" class="cell" data-execution_count="108">
<div class="sourceCode cell-code" id="cb48" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb48-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>C ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>refs<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>remotes<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span></span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>../../.git/refs/remotes/
└── origin
    ├── HEAD
    └── master

1 directory, 2 files</code></pre>
</div>
</div>
<div id="cell-43" class="cell" data-execution_count="109">
<div class="sourceCode cell-code" id="cb50" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb50-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>cat ..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>..<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>refs<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>remotes<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>origin<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/*</span></span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>ref: refs/remotes/origin/master
e1d95ecb4c02e0c30a8635a37631c523d2041299</code></pre>
</div>
</div>
</section>
<section id="tags" class="level2">
<h2 class="anchored" data-anchor-id="tags">Tags</h2>
<ul>
<li><strong>Tags</strong> are like branches, they too point to a commit and stored under <code>.git/refs</code> dir in <code>tags</code> dir. They are basically files that have the commits they are pointing to.</li>
<li>Tags can be created simply by <code>git tag version_no</code>. We can also create more complex tags by adding annotations, PGP signature, and other metadata. In this case, they will be stored in <code>refs/objects</code> dir and the tag will simply be the hash value of the tag object (which will also contain the hash of the commit that was tagged).
<ul>
<li>Annotated tags, however, are stored as full objects in the Git database. They’re checksummed; contain the tagger name, email, and date; have a tagging message; and can be signed and verified with GNU Privacy Guard (GPG). It’s generally recommended that you create annotated tags so you can have all this information.</li>
</ul></li>
<li>We can also tag previous commits by specifying their hash abbreviation: <code>git tag -a v1.0 ca21323</code></li>
<li><code>git push</code> doesn’t transfer tags to remote server, we have to explicityly push tags: <code>git push origin v1.0</code></li>
<li><code>git tag</code> to list all tags</li>
<li><code>git tag -l pattern</code> to look for tags that match specific patters</li>
<li>We can checkout tags to inspect files from that version: <code>git checkout tagname</code>. Any changes that are made and committed wouldn’t belong to any branch and be unreachable unless we use exact commit hash. Therefore, to fix issues, create new branch from tag and do the changes.</li>
<li>The <strong>difference</strong> between tags and branches is that branches evolve over time; however, tags point to fixed commit in repo’s history.</li>
<li>We can delete tags:
<ul>
<li>locally: <code>git tag -d tagname</code></li>
<li>remote: <code>git push remote_name -d tagname</code> OR <code>git push remote_name :refs/tags/tagname</code></li>
</ul></li>
</ul>
<div id="cell-46" class="cell" data-execution_count="96">
<div class="sourceCode cell-code" id="cb52" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb52-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>tree <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>C <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>Users<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>imad<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>Desktop<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>repo<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>refs<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span></span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>/Users/imad/Desktop/git-repo/.git/refs/
├── heads
│&nbsp;&nbsp; ├── feature
│&nbsp;&nbsp; ├── master
│&nbsp;&nbsp; └── test
└── tags
    └── v.0.1

2 directories, 4 files</code></pre>
</div>
</div>
<div id="cell-47" class="cell" data-execution_count="97">
<div class="sourceCode cell-code" id="cb54" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb54-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">!</span>cat <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>Users<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>imad<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>Desktop<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>repo<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>.git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>refs<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>tags<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/*</span></span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>6e9c6886a2180fdfde291a130aa9e10a52bac679</code></pre>
</div>
</div>
<div id="cell-48" class="cell" data-execution_count="100">
<div class="sourceCode cell-code" id="cb56" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb56-1"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%%</span>bash</span>
<span id="cb56-2">cd <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">~/</span>Desktop<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span>git<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>repo<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span></span>
<span id="cb56-3">git graph2</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>* 6e9c688    (HEAD -&gt; master, tag: v.0.1, test) Renaming (Imad)
* ff9e667    Add test1 dir (Imad)
* e503b6e    Add test dir (Imad)
* 8610c78    Add copied file (Imad)
* ed738c9    (feature) Rebased all commits (Imad)
* 2050b90    Add host to file (Imad)
* 91eacf5    Add host (Imad)
* ed27259    patch commit (Imad)
* ff2d260    third commit (Imad)
* 6dd0c14    Change second commit (Imad)
* c2b7166    first commit (Imad)</code></pre>
</div>
</div>
</section>
<section id="cloning-repository" class="level2">
<h2 class="anchored" data-anchor-id="cloning-repository">Cloning Repository</h2>
<p><code>git clone https://github.com/UserName/RepoName</code> would do the following:</p>
<ul>
<li>Create a directory called <code>RepoName</code></li>
<li>Create a directory called <code>.git</code> inside <code>RepoName</code></li>
<li>Pull down all versions for every file for the history of the project</li>
<li>Check out the latest version</li>
</ul>
<p>As a result, if initially a huge file was committed but then deleted years ago, cloning will pull down the huge file even if such file is never needed again. Therefore, if a project has a long history, we may not need to clone all history and restrict to last <code>N</code> days.</p>
</section>
<section id="ignoring-files" class="level2">
<h2 class="anchored" data-anchor-id="ignoring-files">Ignoring Files</h2>
<p><code>.gitignore</code> hosts all patterns that Git should ignore and not track. It is typically located at the root directory of the project and applies recursively to all subsdirectories; however, we can have <code>.gitignore</code> in subdirectories that only gets applied specifically to those subdirectories.</p>
<p>The rules for the patterns you can put in the .gitignore file are as follows:</p>
<ul>
<li>Blank lines or lines starting with <code>#</code> are ignored.</li>
<li>Standard glob patterns work, and will be applied recursively throughout the entire working tree. Example:
<ul>
<li><code>*.log</code> ignores all files that end with <code>log</code> recursively</li>
<li><code>doc/*.txt</code> ignores all <code>.txt</code> files under <code>doc</code></li>
<li><code>doc/**/*.pdf</code> igores all pdf files in the <code>doc</code> directory and all its subdirectoris</li>
</ul></li>
<li>We can start patterns with a forward slash (/) to avoid recursivity. Example: <code>/TODO</code> ignores TODO in the current directory.</li>
<li>You can end patterns with a forward slash (/) to specify a directory. Example: <code>build/</code> ignores all files under <code>build</code> in all directories.</li>
<li>You can negate a pattern by starting it with an exclamation point (!). Example: <code>!test.log</code> tracks <code>test.log</code></li>
</ul>
</section>
<section id="general" class="level2">
<h2 class="anchored" data-anchor-id="general">General</h2>
<ul>
<li><p><code>git rm</code> would remove a file from working tree and stage it. If we ever staged files by mistake, we could run <code>git rm --cached filename</code> to remove it from staging area and keep it on hard drive especially if we don’t want Git to track it.</p></li>
<li><p><code>git mv</code> both change the file name and stage it</p></li>
<li><p><code>git reflog</code> remember all actions taken in a repository (even intermediate steps such as creating branches, clone, pull, etc.) and not just commits. It is local to your copy of the respository and others who have the same copy of the repoository would have their own version of reflog. It also starts empty after we clone the repository. Therefore, it is more like <code>shell</code> history. We can always get back to some state.</p>
<ul>
<li><code>git reflog show hash_value</code> will show all the actions happened for the hash.</li>
</ul></li>
<li><p><code>git diff</code> to see the changes in the working tree compared to the index</p></li>
<li><p><code>git diff --cached</code> to see the changes compared to the last commit</p></li>
<li><p><code>git difftool</code> shows the changes in external tools such as <em>vimdiff</em></p></li>
<li><p><code>git add -i</code> for interactive staging to control which files to stage and which parts of the files (patches) to stage all interactively. This is very helpful if we have done a lot of work on many files without staging anything. We can <code>add/checkout/restore/stash</code> patches (parts of files) by adding <code>--patch</code> or <code>-p</code> flag to their corresponding git command.</p></li>
<li><p><code>git commit -m "test"</code> When we commit, It involves at least a change to one blob. This will lead to a creation of new tree with current state of the code that reflect the changes. Git then creates commit object that will point to the new tree. Finally, it will update the current branch to point to the newly created branch.</p></li>
<li><p><code>git merge --ff-only branch-name</code> This kind of merge creates no objects, It just updates the current branch to a different commit.</p></li>
<li><p><code>git merge branch-name</code> In contrast to the fast forward merge, Git creates a new tree by trying its best to combine two divergent branches. It then creates a new commit that point to the newly created tree and the parent would be the two commits; latest commit from each branch. This is what <em>merging using pull requests</em> does on Github. It may not be preferable because the actual code changes during merge when Git tries to combine all of them and we may end up with conflict.</p></li>
</ul>
</section>
<section id="stashing" class="level2">
<h2 class="anchored" data-anchor-id="stashing">Stashing</h2>
<p>It is very helpful when we staged some work and/or have modified tracked files and want to jump to different branch to work on something else. By default, Git stashes only modified and staged tracked files but not untracked files. We can add <code>-u</code> to add also untracked files.</p>
<ul>
<li>We can run <code>git stash</code> to stash the work</li>
<li><code>git stash list</code> to list all the stashes</li>
<li><code>git stash apply</code> to apply last stash OR <code>git stash apply stashname</code>. This keeps the stash on the stack</li>
<li><code>git stash drop</code> to remove a stash</li>
<li><code>git stash pop</code> to apply and remove last stash in one command</li>
</ul>
<p>We can apply stashes from one branch on another branches.</p>
<p>To avoid issues/merge conflicts when trying to apply stashes, it may be helpful to create new branch and apply stash in the new branch. This can be done by <code>git stash branch newbranchname</code>. This will create new branch, checkout last commit you were on, apply stash, and then drop the stash.</p>
</section>
<section id="managing-history" class="level2">
<h2 class="anchored" data-anchor-id="managing-history">Managing History</h2>
<ul>
<li><code>git log</code> has all the information we need to the repo’s history.
<ul>
<li><code>git log --all --decorate --graph --oneline</code> is great to get an overview and see the divergence of branches</li>
<li><code>git log -n</code> will limit the log to the top n</li>
<li><code>git log --oneline file</code> is useful to get an overview of the log of one file</li>
<li><code>git log --pretty=format:'%C(yellow)%h%C(reset) - %an [%C(green)%ar%C(reset)] %s</code> to change the format of the log output</li>
<li><code>git log -E -i --grep regexp</code> will do extended search the logs for the regexp phrase; case insensitive</li>
<li><code>git log -S term</code> will search for changes related to that term in the code base (addition/deletion). Check this <a href="https://thoughtbot.com/blog/code-sleuthing-with-git">post</a></li>
<li><code>git log -G regexp</code> will search for changes related to the regexp in the code base; but looks for patterns not literal string.</li>
</ul></li>
<li><code>git show commit</code> will show everything that happened with that commit including <strong>diff</strong></li>
<li><code>git blame file</code> is useful to know who did what to the file and when especially if we want to trace who introduced some bug/logic to the codebase. Use <code>-L</code> to restrict to specific lines. Use <code>-C</code> to detect if block of lines were copied from other files that were in the same commit.</li>
</ul>
</section>
<section id="bisect" class="level2">
<h2 class="anchored" data-anchor-id="bisect">Bisect</h2>
<p>Git Bisect is useful to trace when a bug is introduced to get the commit that introduced the bug especially if the commit was pretty far in the history. It does binary search between the commit that you believe was good (no bug) and the current commit or any commit that we know has the bug. Below are a typical workflow:</p>
<ul>
<li><code>git bisect start</code> to start the binary search</li>
<li><code>git bisect bad</code> which means current <code>HEAD</code> is the bad commit which would be last commit in the range of commits of the binary search</li>
<li><code>git bisect good commit</code> which tells Git that the provided commit didn’t have the bug and would be the first commit in the range of commits of the binary search</li>
<li>We can use the three commands in one command <code>git bisect start badcommit goodcommit</code></li>
<li>From here, we interactively run either <code>git bisect good</code> to tell Git the given commit is good so it does binary search from next commit to the last commit OR <code>git bisect bad</code> to tell Git that the given commit is bad and the next binary search stops at the commit before it. We keep doing this until we arrive at the commit that introduced the bug.</li>
</ul>
<p>We can also use a script that runs tests for us to check whether a commit is good or bad and automate the whole process:</p>
<ul>
<li><code>git bisect start badcommit goodcommit</code></li>
<li><code>git bisect run test-script.sh</code> OR <code>git bisect run make rule</code> OR <code>git bisect run pytest</code>. For each commit, <code>git bisect</code> runs the script or command on the checked out commit. If it returns 0 -&gt; good; otherwise, bad.</li>
</ul>
</section>
<section id="submodules" class="level2">
<h2 class="anchored" data-anchor-id="submodules">Submodules</h2>
<p>Git submodules are Git repositories inside Git repository that allows us to track them and keep commit histories separate. Each submodule would be in different directory inside the project git repository.</p>
<ul>
<li>We can add submodule by <code>git submodule add URL</code>. This will create a directory with the name of the Git repository (we can have different names using <code>git submodule add URL name</code>). If we run <code>git status</code>, we see that Git added the directory as special type of file as well as add a file named <code>.gitmodules</code> that has the <em>path</em> and the <em>URL</em> for each submodule. We need to commit those two files to include them in our main project history.</li>
<li>If we clone a project that has submodules, we can either pass <code>--recurse-submodules</code> to initialize and pull all contents of all submodules OR go in each submodule directory and run <code>git submodule update --init</code> (add <code>--recursive</code> if there are any nested submodules.</li>
<li>To pull out changes made to submodules, run <code>git submodule update --remote submodule_name</code></li>
<li><code>git diff --submodule</code> to get a nice diff for submodules</li>
</ul>
</section>
<section id="hooks" class="level2">
<h2 class="anchored" data-anchor-id="hooks">Hooks</h2>
<p>Git <a href="https://thoughtbot.com/blog/use-git-hooks-to-automate-annoying-tasks">hooks</a> are scripts that can either be client-side hooks that run for operations such as committing/merging or system-side hooks that run on network operations such as receiving pushed commits. All hooks are stored in <code>.git/hooks</code> directory. Git prepopulates any new Git repository with example hooks that end with <code>.sample</code>. To use such hooks, remove the extrension. We can write hooks in many languages such as Python but they have to be executable and can’t have any extension. Also, client-side hooks aren’t copied when the repository is cloned.</p>
<p>Below are the most common client-side hooks:</p>
<ul>
<li><code>pre-commit</code>: Runs before we type the commit message and abort if the return code is not zero. This can be used to run tests, check code style, check for documentation or whitespaces, etc.</li>
<li><code>prepare-commit-msg</code>: Runs before the commit message editor but after the default message is created.</li>
<li><code>commit-msg</code>: Typically used to check if a commit message conforms to some predefined patterns.</li>
<li><code>post-commit</code>: Runs after the commit proccess is completed.</li>
</ul>
<p>There are other client-side hooks such as <code>pre-rebase</code>, <code>pre-merge</code>, <code>post-merge</code>, etc.</p>
<p>Below are the most common system-side hooks:</p>
<ul>
<li><code>pre-receive</code>: Runs when handling a push from client. It can be used to check for things such as rejecting non-fast-forwards or access control.</li>
<li><code>update</code>: Similar to <code>pre-receive</code> but runs once for each branch the pusher is trying to update.</li>
<li><code>post-receive</code>: Runs after the entire push process is completed. It can be used to notify users or update services.</li>
</ul>
</section>
<section id="resetting" class="level2">
<h2 class="anchored" data-anchor-id="resetting">Resetting</h2>
<p><code>git reset HEAD|commit</code> command allows us to:</p>
<ul>
<li>Move what the branch <code>HEAD</code> points to (stops if <code>--soft</code> and everything will be in the staging).</li>
<li>Make the index look like HEAD (stops here if not <code>--hard</code>)</li>
<li>Make the working directory look like the index</li>
</ul>
<p>If we provide a path such as <code>git reset filepath</code>, it is a shorthand for <code>git reset --mixed HEAD filepath</code> and does the following:</p>
<ul>
<li>Move what the branch <code>HEAD</code> points to (skipped)</li>
<li>Make the index look like <code>HEAD</code>; i.e.&nbsp;has the effect of unstaging the file</li>
</ul>
<p>If we run <code>git reset commit -- filepath</code>, it will act as if we reverted the content of the file to what was in the commit and then ran <code>git add</code> on the file without changing working directory. The <code>HEAD</code> and the working directory would have the same version of the file. Therefore, running <code>git commit</code> will commit the changes back to what was in the commit leaving both index file and <code>HEAD</code> point to the same changes.</p>
<p><code>git checkout</code> without paths is similar to <code>git reset</code> with two differences:</p>
<ul>
<li><code>reset</code> moves the branch <code>HEAD</code> points to while <code>checkout</code> moves <code>HEAD</code> itself. For example, <code>git checkout branch</code> would change what <code>HEAD</code> is pointing to while <code>git reset commit</code> would change what branch points to.</li>
<li><code>checkout</code> is working-directory safe where it tries to do a trivial merge but <code>reset --hard</code> will overwrite working-directory.</li>
</ul>
<p><code>git checkout filepath</code> is similar to `git reset –hard filepath -&gt; overwrite working directory.</p>
</section>
<section id="inspecting-commit-ranges" class="level2">
<h2 class="anchored" data-anchor-id="inspecting-commit-ranges">Inspecting Commit Ranges</h2>
<ul>
<li><code>^</code> refers to the parent. <code>HEAD^</code> means the parent of last commit in the current branch.</li>
<li><code>~</code> refers to the first parent. <code>HEAD~</code> means the first parent of last commit in the current branch. It will be different than <code>^</code> in the case a commit has multiple parents as is the case of merge commits that have multiple parents.</li>
<li><code>HEAD~5</code> is equivalent in some sense to <code>HEAD^^^^^</code>.</li>
<li>Double dots (<code>..</code>): If we want to see the commits that are reachable from target branch (commit) but not the source branch (commit), we use <code>git log sourcecommit..targetcommit</code>.</li>
<li>Triple dots (<code>...</code>): If we want to see the commits that are reachable by either of the branches (commits) but not from both of them, we use <code>git log sourcecommit...targetcommit</code>. This will return commits unique to sourcecommit and targetcommit but not common commit.</li>
<li>Multiple points: If we want to see the commits for multiple points such as <code>git log refA refB ^refC</code> which means commits reachable from refA and refB but not C. Therefore:
<ul>
<li><code>git log refA..refB</code> is equivalent to <code>git log refB ^refA</code></li>
</ul></li>
</ul>
</section>
<section id="grep" class="level2">
<h2 class="anchored" data-anchor-id="grep">Grep</h2>
<p>Git <code>grep</code> allows us to search for a pattern in working directory, index, and committed tree. We can also search in older versions of the code such as using old tags/commits, which <code>grep/ack</code> tools can’t.</p>
<p>The most useful flags to use with grep is <code>git grep -n -p --break --heading pattern optional_path optionalcommit</code>.</p>
</section>
<section id="undoing" class="level2">
<h2 class="anchored" data-anchor-id="undoing">Undoing</h2>
<p><strong>Commits are immutable. This means that even though we can fix some stuff related to commits, we can’t change the commits themselves. They will still be in the history.</strong> Therefore, anything that is committed in Git can almost always be recovered. Even commits that were on branches that were deleted or commits that were overwritten with an –amend commit can be recovered. However, anything you lose that was never committed is likely never to be seen again.</p>
<ul>
<li><code>git commit --amend</code> will open an editor to write a new commit message to the already committed changes.
<ul>
<li><code>git commit --amend -m "message"</code> is a shorthand</li>
<li><code>git commit --amend --no-edit</code> will add new files to the last commit; in case we forgot to add some files to that belong to the same commit</li>
</ul></li>
<li><code>git reset HEAD file</code> OR <code>git restore --staged file</code> will undo the staging of the file. This is helpful if we staged a file and then we need to change some things before committing.</li>
<li><code>get checkout -- file</code> OR <code>git restore file</code> will delete all the changes made to a file. We will never be able to get back the deleted changes.</li>
<li><code>get reset --soft HEAD~2</code> This will remove the commits from the history and point HEAD to its grand parent. <code>--soft</code> here means to keep the changes in the current working directory and index file. Therefore, running <code>git commit</code> would commit the latest changes and make grand parent as the parent of changes (<strong>Squashing Commits</strong>).</li>
<li>To cancel the commit while writing the message, we can exit vim with <code>:cquit</code> which exits vim with error and git will get that error -&gt; won’t proceed in creating the commit.</li>
</ul>
</section>
<section id="rebasing-history" class="level2">
<h2 class="anchored" data-anchor-id="rebasing-history">Rebasing History</h2>
<ul>
<li><code>git add file</code> or <code>git add --all</code> or <code>git add directory</code>. This will add all changes made to a specific file/directory.</li>
<li><code>git add --patch</code> Allows us to cherry pick the changes that we want to stage. This is useful if we want to split the changes we made to a specific file into different commits. When we run the command, we will interactively choose what we want to stage using shortcuts.</li>
<li><code>git diff/log HEAD..HEAD~2</code> will give us the diff/log for the range between two commits in history. We can either choose hash_values of commits or their references such as HEAD/master.</li>
<li><code>git reset --hard HEAD~1</code> will make HEAD point to its parent and remove last commit from log history. Note that the last commit is not completely removed, we see that with <code>git reflog</code>.</li>
<li><code>git cherry-pick origin/master..master</code> will replay the commits with this range in another branch. This is useful when we commit to the wrong branch and we want to make those commits in another branch. We can use this command after we checkout the correct branch and run the above command. To remove the commits from the branch we first commit, we can use <code>git reset --hard</code> (even though the removed commits are still in our history).</li>
<li><code>git rebase master</code> We want to take the work we’ve done on our feature branch, and reapply it as if it was done on top of the additional commits in our master branch. When performing the rebase, Git finds the commits unique to our branch and computes the diff off the changes they introduced, then moves to the target branch, master in this case, and one by one applies the diffs, creating new commits reusing the commit messages from our branch. Once done, it updates our branch to point at the newest of these commits created by reapplying the diffs.</li>
<li>While we would never revise published history, specifically the master branch, we almost always revise our commits on feature branches before merging them in. We value a clean history, and the majority of the time, the commits in a feature branch contain many rounds of refactoring and PR reviews which we don’t want in the permanent history. Instead, we want the most direct and concise form of the history that fully captures the change we settled on in our feature branch after completing any refactoring or updates. Use <code>git rebase -i master</code> will allow us to do just that.
<ul>
<li>We can remove, reorder, squash, edit, and split commits using interactive rebase.</li>
<li>Git applies and rewrite the changed commits and all the commits that follow the changed ones.</li>
<li>It is highly recommended to not change history if you already pushed it to the remote server unless we’re working on feature branch and are doing it to clean up history before merging and close the pull request.</li>
<li>Reording is simply reordering the commits shown in the editor.</li>
<li>Be careful that the order of commits is reverse order. This means last commit will be last.</li>
</ul></li>
</ul>
</section>
<section id="packfiles" class="level2">
<h2 class="anchored" data-anchor-id="packfiles">Packfiles</h2>
<p>These are files that <em>Git</em> uses to combine files into single file to save space instead of having different versions of the same file taking all the space and only saves the original version with deltas where pack index file will have offsets that point to the object in the pack file. <em>Git</em> automatically runs this when we have too many loose files or when run <code>git gc</code> command or when we push to remote server.</p>
</section>
<section id="github-and-remotes" class="level2">
<h2 class="anchored" data-anchor-id="github-and-remotes">Github and Remotes</h2>
<ul>
<li><strong>Hub</strong> and Github CLI tool <strong>gh</strong> make it easy to interact with Github from the command line and integrate well with Git. Useful commands are <code>compare</code>, <code>browse</code>, and <code>pull-request</code>.</li>
<li>To share the code on a given branch using a URL that always point to the same code, we can press <code>y</code> to change the name of the branch with its hash that will always point to the same version of code even if we make changes to the branch. We can also select lines from the code file that will be highlighted when we open the URL.</li>
<li>If we are creating a new branch locally and want to have an upstream version for that branch:
<ul>
<li><code>git branch --remote origin/new-branch-name</code> will create upstream version of the branch so we can easily push.</li>
<li><code>git push -u origin new-branch-name</code> will create the new branch while pushing to Github</li>
<li>If we want the upstream name to have different name than the local branch name, <code>git push -u origin local-branch-name:upstream-branch-name</code></li>
</ul></li>
<li>If we want to delete a branch:
<ul>
<li>Locally: <code>git branch -d branch-name</code></li>
<li>Upstream: <code>git remote --delete branch-name</code></li>
</ul></li>
<li>We can force to push local branch to another existing upstream branch. This is risky and we may not need to use it <code>git push --force origin local-branch:upstream-branch</code></li>
</ul>
</section>
<section id="typical-workflow" class="level2">
<h2 class="anchored" data-anchor-id="typical-workflow">Typical Workflow</h2>
<ul>
<li>Always start by creating new branch for new features. Almost always strive to not commit directly to master branch even for small changes. The workflow is:</li>
</ul>
<blockquote class="blockquote">
<p>create new branch <strong>-&gt;</strong> make small changes <strong>-&gt;</strong> create pull request <strong>-&gt;</strong> pass code reviews and other stuff like CI/CD <strong>-&gt;</strong> Rebase master into feature branch <strong>-&gt;</strong> Interactive rebase to squash all commits from feature branch into one commit message <strong>-&gt;</strong> Fast forward merge with master <strong>-&gt;</strong> push master <strong>-&gt;</strong> delete feature branch locally and on upstream.</p>
</blockquote>
<ul>
<li>Always commit small changes and don’t wait for large changes to commit. It will be harder to figure out what changes have been made and make it difficult for code reviewers to understand. We can always refine commits with interactive rebase.</li>
<li><strong>Pull Requests:</strong>
<ul>
<li>We first need to push the feature branch into Github using <code>git push -u origin feature-branch</code></li>
<li>We then have two choices to open PRs: Either through Github UI or though command line tools like <code>hub</code> and <code>gh</code>. The advantage of Github UI is that it lets you review the code one more time through compare view before submitting it.</li>
<li>Provide as much context as possible when drafting your PR description. Try to provide as much useful detail as you can. Answering the following questions is a great start:
<ul>
<li>Why is this change needed?</li>
<li>Were other solutions considered?</li>
<li>Were any assumptions made?</li>
</ul></li>
<li>For work that can’t be broken down into small changes, we can use Github Task lists that shows all the items that need to be worked on and the methodology so that people would know not to do in depth code reviews. So every time we push changes we mark items that were already done.</li>
<li>Code reviews resources:
<ul>
<li><a href="http://confreaks.tv/videos/railsconf2015-implementing-a-strong-code-review-culture">Derek Prior’s talk on Code Review Culture</a></li>
<li><a href="https://github.com/thoughtbot/guides/tree/master/code-review">thoughtbot guide to code review</a></li>
</ul></li>
<li>After getting the feedback from the team on the code reviews as well as the CI comments, we can incorporate the changes that team recommended. Then push the new commits to the feature branch and those will automatically be included in the PR.</li>
<li>We prefer a clean history built using fast-forward merges. In order to ensure this, before merging our PR we always pull master and rebase our feature branch onto master to ensure that our commits are ahead of master. One nice helper for this is the mup alias which checks out master, pulls, then checks back out our feature branch: <code>mup = !git checkout master &amp;&amp; git pull &amp;&amp; git checkout -</code>. Finally, <code>git rebase master</code>. If we’ve done any rebase, we need to force push changes to remote <code>git push -f</code></li>
<li>Once we’re ahead of master, we can perform an interactive rebase to revise our commits and craft our history. In particular, we can use this time to squash down cleanup and WIP commits, ensuring that each commit we keep is useful and has a solid commit message.</li>
<li>This is the time to ensure that we’ve captured as much context as possible in our commit message to describe the “why” of the change. Two great resources on this topic are:
<ul>
<li><a href="https://robots.thoughtbot.com/5-useful-tips-for-a-better-commit-message">Five Rules for A Good Git Commit Message</a></li>
<li><a href="http://rakeroutes.com/blog/deliberate-git">Stephen Ball’s Deliberate Git talk</a></li>
</ul></li>
<li>If we’ve performed any form of rebase, then we’ll have created new commits and will want to push those up to GitHub in order to get everything in sync. To do this we can force push (<code>git push -f</code>) our branch.</li>
<li>Final steps:
<ul>
<li>If we’ve force pushed after rebasing as described above, we should be all set, but never hurts to give one last git push just to confirm that our local and remote feature branches are in sync.</li>
<li>Merge fast-forward: <code>git co master</code> &amp; <code>git merge - --ff-only</code></li>
<li>Push master: Now that we’ve merged master, we can push it up to GitHub with git push. As a reminder, with a fast-forward merge we are simply moving our master branch pointer to point at our feature branches tip commit, not actually creating any new commits. This is one of the main benefits of using fast-forward merges, namely that all commits are created and can be reviewed on our feature branch before merging into master. With “Big Green Button on GitHub” merges and other non-fast-forward merges, the merge commit is created directly on master based on Git’s merging algorithm.</li>
<li>Delete local branch: <code>git branch -d decks-ordering</code></li>
<li>Delete remote branch: <code>git push origin --delete &lt;branchName&gt;</code>. We can also delete the branch via the GitHub PR page, and then git pull on master, letting the fetch prune setting automatically clean up our local reference to the remote branch.</li>
<li>Pull request auto closing. Assuming we’ve performed the steps outlined above, GitHub will have automatically closed the PR based on the fact that master now contains our branch’s commits.</li>
</ul></li>
</ul></li>
</ul>
</section>
<section id="configuration" class="level2">
<h2 class="anchored" data-anchor-id="configuration">Configuration</h2>
<p>Git looks for configurations in the following places:</p>
<ul>
<li>First look for/inside <code>/etc/gitconfig</code>. Any time we use <code>git config --system</code>, it reads/writes this file</li>
<li>Second look for/inside <code>~/.gitconfig</code> for each user. Any time we use <code>git config --global</code>, it reads/writes this file</li>
<li>Finaly look for/inside <code>.gitconfig</code> inside the Git directory. Any time we use <code>git config --local</code>, it reads/writes this file</li>
</ul>
<p><code>gitconfig</code> file is read automatically before any Git command is run. That turns out to be very handy as it means you never have to reload or experience out-of-sync commands. Additionally, git automatically writes to it when we run commands like git config –global alias.ga.</p>
<ul>
<li>The config file is split into sections such as <em>color, alias, core, push, etc.</em> For example:</li>
</ul>
<div class="sourceCode" id="cb58" style="background: #f1f3f5;"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb58-1"><span class="ex" style="color: null;
background-color: null;
font-style: inherit;">[push]</span></span>
<span id="cb58-2">    <span class="ex" style="color: null;
background-color: null;
font-style: inherit;">default</span> = upstream</span></code></pre></div>
<p>is the same as <code>git config --global push.default upstream</code>.</p>
<p>Few useful configurations:</p>
<ul>
<li><code>push.default upstream</code> this instructs Git how to respond when you run git push with no arguments. With the upstream configuration, it will push the configured upstream tracking branch (set up with <code>git push -u</code>).</li>
<li><code>merge.ff only</code> this configuration tells Git to reject merges that are non-fastforward. With fast-forward merges, no new commits are created, but instead the merging branch (typically master) is only moved to point at the commits on the target branch (typically our feature branch).</li>
<li><code>fetch.prune true</code> this instructs Git to clear local references to remote branches which have been deleted when you pull.</li>
</ul>
<p>By default, we can only execute one git command when aliasing. To execute more than one command, we can start the command with <code>!</code> and then we can execute multiple shell commands using pipes, &amp;&amp;, and ||. For example, <code>!git checkout master &amp;&amp; git pull &amp;&amp; git checkout -</code>.</p>
<p><strong>Git subcommands</strong> allow us to write scripts in any language we want; not necessarily bash, and make Git executes it. The script subcommand has to be:</p>
<ul>
<li>On our <code>$PATH</code></li>
<li>Marked as executable</li>
<li>The file name has to be prefixed with git and then dash and then the name of the command. For example, <code>git-subcommand-name</code>. Actually, all git commands are files that share all those criteria such as <code>git-add</code>. Below is an example of subcommand:</li>
</ul>
<div class="sourceCode" id="cb59" style="background: #f1f3f5;"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb59-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">#!/bin/bash</span></span>
<span id="cb59-2"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">#</span></span>
<span id="cb59-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Small wrapper around git commit. Bare 'cm' will enter normal git commit</span></span>
<span id="cb59-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># editor, but with args it will do a direct `commit -m`</span></span>
<span id="cb59-5"></span>
<span id="cb59-6"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">[[</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">$#</span> <span class="ot" style="color: #003B4F;
background-color: null;
font-style: inherit;">&gt;</span> 0 <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">]];</span> <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">then</span></span>
<span id="cb59-7">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">git</span> commit <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">-m</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">$@</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span></span>
<span id="cb59-8"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span></span>
<span id="cb59-9">    <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">git</span> commit <span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">-v</span></span>
<span id="cb59-10"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">fi</span></span></code></pre></div>
</section>
<section id="resources" class="level2">
<h2 class="anchored" data-anchor-id="resources">Resources</h2>
<ul>
<li><a href="http://gitready.com/">Git Ready</a>: Practical how-to pages on topics like “get a file from a specific revision.”</li>
<li><a href="https://progit.org/">Pro Git</a>: A great in-depth resource I find myself continually coming back to.</li>
<li><a href="https://github.com/pluralsight/git-internals-pdf">Git Internals</a>: A deep dive into the Git object model, with more detail and nuance than we could cover in the this course’s video on the topic</li>
<li><a href="https://github.com/thoughtbot/guides/tree/main/">Thoughtbot Guides</a></li>
<li><a href="https://cli.github.com/">Github CLI</a></li>
<li>Add this to Vim <code>autocmd Filetype gitcommit setlocal spell textwidth=72</code></li>
<li><a href="https://github.com/tpope/vim-fugitive">Fugitive Plugin</a>
<ul>
<li><a href="http://vimcasts.org/blog/2011/05/the-fugitive-series/">five part Fugitive series on Vimcasts</a></li>
</ul></li>
<li><a href="https://github.com/christoomey/vim-conflicted">Conflicted</a> Optimizing Fugitive for merge and rebase conflicts</li>
</ul>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Software Engineering</category>
  <guid>https://imaddabbura.github.io/posts/swe/Advanced-Git.html</guid>
  <pubDate>Fri, 22 Dec 2023 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/swe/images/git.jpeg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>I Built My Own PyTorch (Tiny Version) — Here’s Everything I Learned</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/mlsys/dl-systems.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="why-build-a-deep-learning-framework-from-scratch" class="level2">
<h2 class="anchored" data-anchor-id="why-build-a-deep-learning-framework-from-scratch">Why Build a Deep Learning Framework from Scratch?</h2>
<p>Every deep learning practitioner eventually runs <code>loss.backward()</code> and watches gradients flow. But what <em>actually</em> happens inside that call? Where do the intermediate tensors live? Why does your GPU run out of memory on a model that “should” fit? And why does reshaping a tensor sometimes silently copy gigabytes of data?</p>
<p>I built <a href="https://github.com/ImadDabbura/tiny-pytorch"><code>tiny_pytorch</code></a> to answer these questions for myself. Along the way, I encountered nearly every foundational design decision that real frameworks like PyTorch, TensorFlow, and Caffe had to make — and learned <em>why</em> they made them.</p>
<p>This post distills everything I learned into a coherent narrative. We’ll start from the framework-level design philosophy, work our way down to how bytes are laid out in memory, and then zoom back out to distributed training across multiple GPUs. The goal is <strong>intuition</strong>: mental models you can carry with you when debugging real systems.</p>
<section id="roadmap" class="level3">
<h3 class="anchored" data-anchor-id="roadmap">Roadmap</h3>
<p>Here’s what we’ll cover and why it matters:</p>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Section</th>
<th>What You’ll Learn</th>
<th>Why It Matters</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Framework Design</strong></td>
<td>Static vs.&nbsp;dynamic graphs, and the Caffe → TF → PyTorch arc</td>
<td>Understand trade-offs you inherit from your framework</td>
</tr>
<tr class="even">
<td><strong>Automatic Differentiation</strong></td>
<td>Forward vs.&nbsp;reverse mode AD, what gets saved</td>
<td>Know <em>why</em> backward passes consume so much memory</td>
</tr>
<tr class="odd">
<td><strong>Memory Layout</strong></td>
<td>Shapes, strides, views, and when copies happen</td>
<td>Stop guessing about tensor memory behavior</td>
</tr>
<tr class="even">
<td><strong>Hardware Acceleration</strong></td>
<td>Alignment, parallelism, BLAS, im2col</td>
<td>Understand the layer between your code and silicon</td>
</tr>
<tr class="odd">
<td><strong>Initialization &amp; Normalization</strong></td>
<td>Why init persists, and how norms fix training</td>
<td>Debug training instabilities at their root</td>
</tr>
<tr class="even">
<td><strong>Regularization</strong></td>
<td>Implicit vs.&nbsp;explicit, dropout mechanics</td>
<td>Apply regularization correctly (L2 ≠ weight decay!)</td>
</tr>
<tr class="odd">
<td><strong>Scaling Up</strong></td>
<td>Checkpointing, data/model/pipeline parallelism</td>
<td>Train models that don’t fit in memory</td>
</tr>
<tr class="even">
<td><strong>Neural Network Architectures</strong></td>
<td>CNN, RNN, LSTM, Transformer, GAN design choices</td>
<td>See architectures through a <em>systems</em> lens</td>
</tr>
</tbody>
</table>
<hr>
</section>
</section>
<section id="the-evolution-of-dl-frameworks" class="level2">
<h2 class="anchored" data-anchor-id="the-evolution-of-dl-frameworks">The Evolution of DL Frameworks</h2>
<p>Before writing a single line of code, it helps to understand the three philosophies that shaped modern deep learning frameworks. Each solved a real problem — and introduced new ones.</p>
<section id="caffe-layers-all-the-way-down" class="level3">
<h3 class="anchored" data-anchor-id="caffe-layers-all-the-way-down">Caffe: Layers All the Way Down</h3>
<p>Caffe (C++ only) was beautifully simple. You defined your computation as a stack of <strong>layers</strong>, each implementing a <code>forward()</code> and <code>backward()</code> method. The backward pass was a direct implementation of the backpropagation algorithm from Hinton’s seminal work — each layer knew how to compute its own gradients, and updates happened in-place.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Mental Model
</div>
</div>
<div class="callout-body-container callout-body">
<p>Think of Caffe layers like a stack of Lego bricks. Each brick knows its own shape (forward) and how to “unstick” itself (backward). Simple, intuitive, but rigid — you can’t easily build non-linear architectures.</p>
</div>
</div>
</section>
<section id="tensorflow-1.x-the-static-graph" class="level3">
<h3 class="anchored" data-anchor-id="tensorflow-1.x-the-static-graph">TensorFlow 1.x: The Static Graph</h3>
<p>TensorFlow introduced a powerful idea: <strong>construct a static computation graph first</strong>, then execute it. This separation of <em>definition</em> and <em>execution</em> unlocked serious optimizations — the compiler could fuse operations, reuse memory, and skip unnecessary computations at run-time.</p>
<p>The cost? Debugging was painful. You couldn’t just print a tensor mid-computation. The graph had its own “programming language” that felt alien to Python developers. Experimentation slowed down because every change required rebuilding the graph.</p>
</section>
<section id="pytorch-define-by-run" class="level3">
<h3 class="anchored" data-anchor-id="pytorch-define-by-run">PyTorch: Define by Run</h3>
<p>PyTorch flipped the script with <strong>dynamic computation graphs</strong> — the graph is built on-the-fly as you execute operations. This is called <em>define by run</em>. You can mix Python control flow (if/else, loops) directly with tensor operations, set breakpoints anywhere, and inspect intermediate values trivially.</p>
<p>The trade-off? Dynamic graphs are typically harder to optimize ahead of time. You lose the global view that static compilation provides. Modern PyTorch addresses this with <code>torch.compile()</code> and JIT compilation, getting closer to static-graph performance while keeping the dynamic-graph developer experience.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Trade-off Triangle
</div>
</div>
<div class="callout-body-container callout-body">
<p>Every DL framework navigates three competing goals: <strong>ease of debugging</strong>, <strong>optimization potential</strong>, and <strong>flexibility</strong>. Caffe optimized for simplicity, TensorFlow for optimization, and PyTorch for flexibility. No framework gets all three for free.</p>
</div>
</div>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    A["&lt;b&gt;Caffe&lt;/b&gt;&lt;br/&gt;Layers with forward/backward&lt;br/&gt;In-place updates&lt;br/&gt;C++ only"] --&gt; B["&lt;b&gt;TensorFlow 1.x&lt;/b&gt;&lt;br/&gt;Static graph&lt;br/&gt;Compile-then-run&lt;br/&gt;Hard to debug"]
    B --&gt; C["&lt;b&gt;PyTorch&lt;/b&gt;&lt;br/&gt;Dynamic graph&lt;br/&gt;Define-by-run&lt;br/&gt;Python-native"]
    C --&gt; D["&lt;b&gt;Modern PyTorch&lt;/b&gt;&lt;br/&gt;torch.compile / JIT&lt;br/&gt;Best of both worlds"]

    style A fill:#f9f,stroke:#333
    style B fill:#bbf,stroke:#333
    style C fill:#fbb,stroke:#333
    style D fill:#bfb,stroke:#333
</pre>
</div>
<p></p><figcaption> The evolution of DL framework design philosophies</figcaption> </figure><p></p>
</div>
</div>
</div>
<p><strong>Key takeaway:</strong> Framework design is fundamentally about <em>when</em> the computation graph is known. Know it early (static) and you can optimize aggressively. Know it late (dynamic) and you can iterate fast. Modern systems try to give you both.</p>
<hr>
</section>
</section>
<section id="automatic-differentiation-the-engine-room" class="level2">
<h2 class="anchored" data-anchor-id="automatic-differentiation-the-engine-room">Automatic Differentiation: The Engine Room</h2>
<p>Automatic differentiation (AD) is the core engine of every deep learning framework. It’s what makes <code>loss.backward()</code> work. But there are two fundamentally different approaches, and understanding <em>why</em> we use one over the other is essential.</p>
<section id="forward-mode-ad" class="level3">
<h3 class="anchored" data-anchor-id="forward-mode-ad">Forward Mode AD</h3>
<p>In forward mode, we walk from <strong>inputs to outputs</strong>. At each node, we compute the partial derivative of that node with respect to a <em>single</em> input variable. This means:</p>
<ul>
<li>For <strong>each input variable</strong>, we need a <em>full forward pass</em> through the graph.</li>
<li>If we have <img src="https://latex.codecogs.com/png.latex?n"> inputs, we need <img src="https://latex.codecogs.com/png.latex?n"> forward AD passes.</li>
</ul>
<p>For a typical deep learning loss function — a scalar output with millions of input parameters — this is catastrophically inefficient. We’d need millions of passes just to get one gradient update.</p>
</section>
<section id="reverse-mode-ad-backpropagation" class="level3">
<h3 class="anchored" data-anchor-id="reverse-mode-ad-backpropagation">Reverse Mode AD (Backpropagation)</h3>
<p>Reverse mode flips the direction. We walk from the <strong>output back to inputs</strong>, computing the gradient of the scalar output with respect to <em>all</em> input nodes in a <strong>single backward pass</strong>. This is why it’s the standard for deep learning: one output, millions of inputs, one pass.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    subgraph forward["Forward Mode (one pass per input)"]
        direction LR
        x1f["x₁"] --&gt; |"∂a/∂x₁"| af["a"] --&gt; |"∂b/∂x₁"| bf["b"] --&gt; |"∂L/∂x₁"| Lf["L"]
    end

    subgraph reverse["Reverse Mode (one pass for ALL inputs)"]
        direction RL
        Lr["L"] --&gt; |"∂L/∂b"| br["b"] --&gt; |"∂L/∂a"| ar["a"] --&gt; |"∂L/∂x₁&lt;br/&gt;∂L/∂x₂&lt;br/&gt;∂L/∂x₃"| xr["x₁, x₂, x₃"]
    end
</pre>
</div>
<p></p><figcaption> Forward vs.&nbsp;reverse mode AD — reverse mode computes all gradients in a single backward pass</figcaption> </figure><p></p>
</div>
</div>
</div>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Memory Cost of Reverse Mode
</div>
</div>
<div class="callout-body-container callout-body">
<p>Reverse mode has a catch: to compute gradients during the backward pass, we need the <strong>intermediate values from the forward pass</strong>. For each operation, we must store the input tensors and remember which operation created them. This is why training uses far more memory than inference — all those “saved tensors” accumulate on the graph.</p>
</div>
</div>
<p>Here’s what the autograd system actually tracks:</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    x["Input x&lt;br/&gt;&lt;i&gt;leaf tensor&lt;/i&gt;"] --&gt; mul["Mul"]
    w["Weight W&lt;br/&gt;&lt;i&gt;leaf tensor&lt;/i&gt;"] --&gt; mul
    mul --&gt; |"z = W·x&lt;br/&gt;&lt;b&gt;saved: W, x&lt;/b&gt;"| act["ReLU"]
    act --&gt; |"a = relu(z)&lt;br/&gt;&lt;b&gt;saved: z&lt;/b&gt;"| loss_fn["MSELoss"]
    y["Target y"] --&gt; loss_fn
    loss_fn --&gt; |"L = loss(a, y)&lt;br/&gt;&lt;b&gt;saved: a, y&lt;/b&gt;"| L["Scalar Loss L"]

    L -.-&gt; |"backward()"| loss_fn
    loss_fn -.-&gt; act
    act -.-&gt; mul
    mul -.-&gt; x
    mul -.-&gt; w

    style x fill:#e8f5e9
    style w fill:#e8f5e9
    style L fill:#ffcdd2
</pre>
</div>
<p></p><figcaption> What the autograd engine saves during a forward pass — every intermediate result and its creator must be retained for backward</figcaption> </figure><p></p>
</div>
</div>
</div>
<p>The dashed arrows show the backward pass, which retraces the forward graph in reverse. At each node, the saved tensors are consumed to compute local gradients.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Gradients as Directional Information
</div>
</div>
<div class="callout-body-container callout-body">
<p>The gradient at each node tells you: <em>“In which direction would changing this value increase the loss most steeply?”</em> It points toward steepest <strong>ascent</strong> — the direction of maximum loss increase. To decrease the loss, we move in the <strong>negative</strong> gradient direction. This is why gradient descent subtracts the gradient from the parameters: <img src="https://latex.codecogs.com/png.latex?%5Ctheta%20%5Cleftarrow%20%5Ctheta%20-%20%5Calpha%20%5Cnabla_%5Ctheta%20L">.</p>
</div>
</div>
<p>One powerful consequence: the backward pass itself builds a computation graph for the gradients. This means you can compute <strong>gradients of gradients</strong> simply by adding more operations — which is exactly what second-order methods and some meta-learning approaches do.</p>
<p><strong>Key takeaway:</strong> Reverse mode AD gives us all gradients in one pass, but the price is memory — every intermediate tensor from the forward pass must be kept alive until it’s consumed by the backward pass.</p>
<hr>
</section>
</section>
<section id="memory-layout-shapes-strides-and-the-viewcopy-divide" class="level2">
<h2 class="anchored" data-anchor-id="memory-layout-shapes-strides-and-the-viewcopy-divide">Memory Layout: Shapes, Strides, and the View/Copy Divide</h2>
<p>This is where the rubber meets the road. Understanding how tensors are stored in memory explains a surprising number of performance issues and subtle bugs.</p>
<section id="the-flat-array-reality" class="level3">
<h3 class="anchored" data-anchor-id="the-flat-array-reality">The Flat Array Reality</h3>
<p>Whether you’re on CPU or GPU, the hardware gives you a <strong>flat, contiguous block of memory</strong>. There are no “dimensions” at the hardware level — just consecutive slots. To create the <em>illusion</em> of an N-dimensional array, we need three pieces of metadata:</p>
<ul>
<li><strong>Shape</strong>: The logical dimensions (e.g., <code>[3, 4]</code> for a 3×4 matrix)</li>
<li><strong>Stride</strong>: How many elements to skip in the flat array to move one step along each dimension</li>
<li><strong>Offset</strong>: Where the data starts within the flat array</li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Row-Major vs.&nbsp;Column-Major via Strides
</div>
</div>
<div class="callout-body-container callout-body">
<p>For a 2D array <code>A</code> with shape <code>[R, C]</code>:</p>
<ul>
<li><strong>Row-major</strong> (C/NumPy/PyTorch default): <code>stride = [C, 1]</code> — rows are contiguous</li>
<li><strong>Column-major</strong> (Fortran/BLAS): <code>stride = [1, R]</code> — columns are contiguous</li>
</ul>
<p>Most BLAS libraries (the workhorses of linear algebra) are implemented in Fortran and expect column-major layout. This is why you sometimes see frameworks internally transposing data before calling into BLAS routines.</p>
</div>
</div>
</section>
<section id="views-same-memory-different-perspective" class="level3">
<h3 class="anchored" data-anchor-id="views-same-memory-different-perspective">Views: Same Memory, Different Perspective</h3>
<p>The stride mechanism enables something powerful: multiple tensor objects can <strong>share the same underlying memory</strong> with different shapes, strides, and offsets. These are called <em>views</em>. Three critical operations create views, not copies:</p>
<table class="table">
<thead>
<tr class="header">
<th>Operation</th>
<th>What Changes</th>
<th>Memory Cost</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Slice</strong></td>
<td>Offset + shape + stride</td>
<td>Zero (view)</td>
</tr>
<tr class="even">
<td><strong>Transpose</strong></td>
<td>Strides are swapped, shape changes</td>
<td>Zero (view)</td>
</tr>
<tr class="odd">
<td><strong>Broadcast</strong></td>
<td>Stride set to 0 along new dims</td>
<td>Zero (view)</td>
</tr>
<tr class="even">
<td><strong>Reshape/View</strong></td>
<td>Shape + stride (if compatible)</td>
<td>Zero <em>or</em> copy</td>
</tr>
</tbody>
</table>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
When Reshape Becomes a Copy
</div>
</div>
<div class="callout-body-container callout-body">
<p><code>reshape</code> / <code>view</code> can create a view <em>only</em> when the new shape is compatible with existing strides (i.e., the data is already contiguous in the right order). If the tensor has been transposed or sliced in a way that makes the data non-contiguous, <code>reshape</code> must <strong>copy</strong> the data into a new contiguous block. This can silently allocate gigabytes of memory.</p>
<p><strong>How to detect it:</strong> In PyTorch, call <code>tensor.is_contiguous()</code> before reshaping. If it returns <code>False</code>, the reshape will trigger a copy. Use <code>tensor.contiguous()</code> explicitly to make the copy intentional and visible.</p>
</div>
</div>
</section>
<section id="the-contiguity-problem" class="level3">
<h3 class="anchored" data-anchor-id="the-contiguity-problem">The Contiguity Problem</h3>
<p>After operations like slicing or transposing, the logical tensor and the physical memory layout can diverge. The tensor is no longer <em>compact</em> — meaning the offset isn’t 0 or the strides don’t correspond to row-major order.</p>
<p>This matters because many operations (especially matrix multiplication) require contiguous data for efficient memory access. The framework typically handles this by checking compactness before an operation and creating a contiguous copy if needed. But this implicit copy is a hidden performance cost.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    flat["Flat memory: [a b c d e f g h i j k l]"] --&gt; orig["Tensor A&lt;br/&gt;shape=[3,4], stride=[4,1], offset=0"]
    flat --&gt; slice["Slice A[0:2, 1:3]&lt;br/&gt;shape=[2,2], stride=[4,1], offset=1&lt;br/&gt;&lt;b&gt;VIEW (shared memory)&lt;/b&gt;"]
    flat --&gt; trans["A.T&lt;br/&gt;shape=[4,3], stride=[1,4], offset=0&lt;br/&gt;&lt;b&gt;VIEW (shared memory)&lt;/b&gt;"]

    trans --&gt; |"reshape(-1) on&lt;br/&gt;non-contiguous tensor"| copy["New flat memory&lt;br/&gt;&lt;b&gt;COPY (new allocation)&lt;/b&gt;"]

    style flat fill:#fff3e0
    style slice fill:#e8f5e9
    style trans fill:#e8f5e9
    style copy fill:#ffcdd2
</pre>
</div>
<p></p><figcaption> View operations share memory; some operations force a copy when data is non-contiguous</figcaption> </figure><p></p>
</div>
</div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Rule of Thumb
</div>
</div>
<div class="callout-body-container callout-body">
<p>If you chain <code>transpose</code> + <code>reshape</code>, you’re almost certainly triggering a copy. If you’re in a hot loop or a custom kernel, this matters. Profile with <code>torch.cuda.memory_allocated()</code> to catch surprise allocations.</p>
</div>
</div>
<p><strong>Key takeaway:</strong> Tensors are flat arrays dressed up with metadata. Operations that only change metadata (slice, transpose, broadcast) are free. Operations that need physically contiguous data may silently copy. Know which is which.</p>
<hr>
</section>
</section>
<section id="broadcasting-and-its-gradient-implications" class="level2">
<h2 class="anchored" data-anchor-id="broadcasting-and-its-gradient-implications">Broadcasting and Its Gradient Implications</h2>
<p>Broadcasting is one of the most convenient features in numerical computing — and one of the most misunderstood when it comes to gradients.</p>
<section id="the-forward-pass-implicit-repetition" class="level3">
<h3 class="anchored" data-anchor-id="the-forward-pass-implicit-repetition">The Forward Pass: Implicit Repetition</h3>
<p>When you add a bias vector <code>b</code> of shape <code>[1, C]</code> to an activation matrix <code>A</code> of shape <code>[N, C]</code>, broadcasting logically <em>repeats</em> <code>b</code> along the batch dimension <code>N</code> times. But crucially, <strong>no data is copied</strong>. The framework simply sets the stride to 0 along the broadcast dimension, so the same values are read repeatedly.</p>
</section>
<section id="the-backward-pass-sum-reduce" class="level3">
<h3 class="anchored" data-anchor-id="the-backward-pass-sum-reduce">The Backward Pass: Sum-Reduce</h3>
<p>Here’s the subtle part. During the backward pass, if a value was broadcast (repeated) across a dimension, the gradients must be <strong>summed along that dimension</strong>. Why? Because the same parameter contributed to multiple outputs — its total influence is the sum of all its partial effects.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    subgraph fwd["Forward: broadcast adds"]
        direction TB
        A_fwd["A: shape [N, C]"] --&gt; plus["+ (broadcast)"]
        b_fwd["b: shape [1, C]&lt;br/&gt;(stride 0 on dim 0)"] --&gt; plus
        plus --&gt; out_fwd["Output: shape [N, C]"]
    end

    subgraph bwd["Backward: sum-reduce"]
        direction TB
        grad_out["∂L/∂Output: shape [N, C]"] --&gt; sum_op["sum(dim=0)"]
        sum_op --&gt; grad_b["∂L/∂b: shape [1, C]"]
        grad_out --&gt; grad_A["∂L/∂A: shape [N, C]&lt;br/&gt;(passed through directly)"]
    end

    fwd --&gt; |"backward()"| bwd
</pre>
</div>
<p></p><figcaption> Broadcasting repeats values in the forward pass; gradients must sum-reduce along broadcast dimensions in the backward pass</figcaption> </figure><p></p>
</div>
</div>
</div>
<p><strong>Worked example:</strong></p>
<p>Suppose <code>A</code> has shape <code>[3, 2]</code> and <code>b</code> has shape <code>[1, 2]</code> with values <code>[0.5, -0.3]</code>. After broadcasting, every row of <code>A</code> gets the same bias added. If the upstream gradient <code>∂L/∂Output</code> is:</p>
<pre><code>[[1.0, 2.0],
 [0.5, 1.5],
 [0.3, 0.7]]</code></pre>
<p>Then <code>∂L/∂b = sum along dim 0 = [1.8, 4.2]</code>, because <code>b</code> influenced all three rows.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
General Rule
</div>
</div>
<div class="callout-body-container callout-body">
<p>For any operation in autograd: <strong>the gradient of a broadcast is a reduction, and the gradient of a reduction is a broadcast.</strong> This duality shows up everywhere — in loss functions, in normalization layers, and in attention mechanisms.</p>
</div>
</div>
<p><strong>Key takeaway:</strong> Broadcasting doesn’t copy data (strides handle it), but gradients must sum-reduce along every dimension that was broadcast. Forgetting this is a common source of shape mismatch bugs in custom autograd functions.</p>
<hr>
</section>
</section>
<section id="hardware-acceleration-from-strides-to-silicon" class="level2">
<h2 class="anchored" data-anchor-id="hardware-acceleration-from-strides-to-silicon">Hardware Acceleration: From Strides to Silicon</h2>
<p>Understanding the hardware layer helps you write code that runs fast <em>by default</em> instead of fighting the machine.</p>
<section id="memory-alignment" class="level3">
<h3 class="anchored" data-anchor-id="memory-alignment">Memory Alignment</h3>
<p>Hardware loads data into caches in fixed-size chunks called <strong>cache lines</strong> (typically 64 bytes). If your data is aligned to cache line boundaries, a single load brings in exactly what you need. If it’s misaligned, you need <em>two</em> loads for data that spans a boundary — doubling the memory traffic for that access.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Practical Impact
</div>
</div>
<div class="callout-body-container callout-body">
<p>Memory alignment mostly matters for custom kernels and low-level code. High-level frameworks handle this for you. But if you’re writing CUDA kernels or using <code>ctypes</code> to interface with C libraries, ensure your allocations are aligned.</p>
</div>
</div>
</section>
<section id="parallelization-with-openmp" class="level3">
<h3 class="anchored" data-anchor-id="parallelization-with-openmp">Parallelization with OpenMP</h3>
<p>On CPU, the simplest form of parallelism is loop parallelization. Tools like <strong>OpenMP</strong> let you annotate a loop with <code>#pragma omp parallel for</code>, and the runtime splits iterations across CPU cores automatically.</p>
<p>This is the basis for CPU-accelerated tensor operations. Each core processes a different slice of the tensor, and the results are combined. The bottleneck shifts from compute to <strong>memory bandwidth</strong> — reading and writing large tensors becomes the limiting factor, not arithmetic.</p>
</section>
<section id="the-im2col-trick-convolution-as-matrix-multiplication" class="level3">
<h3 class="anchored" data-anchor-id="the-im2col-trick-convolution-as-matrix-multiplication">The im2col Trick: Convolution as Matrix Multiplication</h3>
<p>Convolution is the most compute-intensive operation in CNNs. The <strong>im2col</strong> (image-to-column) trick converts convolution into matrix multiplication, which lets us use heavily optimized BLAS routines.</p>
<p>The process for a batch of images (<code>N × H × W × Cᵢₙ</code>) with filters (<code>K × K × Cᵢₙ × Cₒᵤₜ</code>):</p>
<ol type="1">
<li>Create a 6D strided view: <code>N × H_out × W_out × K × K × Cᵢₙ</code></li>
<li>Reshape to a 2D im2col matrix: <code>(N·H_out·W_out) × (K·K·Cᵢₙ)</code></li>
<li>Reshape weights to 2D: <code>(K·K·Cᵢₙ) × Cₒᵤₜ</code></li>
<li>Matrix multiply: <code>im2col @ weights</code></li>
<li>Reshape result: <code>N × H_out × W_out × Cₒᵤₜ</code></li>
</ol>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
im2col Memory Overhead
</div>
</div>
<div class="callout-body-container callout-body">
<p>The im2col matrix is typically <strong>much larger</strong> than the original image tensor because filter patches overlap. Each input pixel appears in multiple rows of the im2col matrix. The reshape from the 6D strided view to 2D <em>cannot</em> be done as a view (the data isn’t contiguous in the right order), so it triggers a <strong>full copy</strong>. This is a significant memory cost — for large images with many channels, the im2col matrix can be several times the size of the input.</p>
<p><strong>When it helps:</strong> When your BLAS library is highly optimized (which it usually is). The speedup from using GEMM far outweighs the memory copy cost.</p>
<p><strong>When it hurts:</strong> When you’re memory-constrained. Alternative approaches like FFT-based convolution or Winograd transforms can reduce memory usage at the cost of implementation complexity.</p>
</div>
</div>
<p><strong>Key takeaway:</strong> The gap between “logical operations on tensors” and “what the hardware actually does” is large. Frameworks bridge it with tricks like im2col, cache-aware memory layout, and loop parallelization. When performance matters, understanding this layer is essential.</p>
<hr>
</section>
</section>
<section id="weight-initialization-the-effects-that-persist" class="level2">
<h2 class="anchored" data-anchor-id="weight-initialization-the-effects-that-persist">Weight Initialization: The Effects That Persist</h2>
<p>Weight initialization might seem like a minor detail — just pick some random numbers and start training. But the evidence tells a more nuanced story.</p>
<section id="why-initialization-matters-more-than-you-think" class="level3">
<h3 class="anchored" data-anchor-id="why-initialization-matters-more-than-you-think">Why Initialization Matters More Than You Think</h3>
<p>Two observations that changed how I think about initialization:</p>
<ol type="1">
<li><p><strong>The effect of initialization persists throughout training.</strong> Bad initialization affects the relative norms of activations and gradients <em>at every step</em>. If you don’t initialize appropriately (e.g., using a standard deviation of <img src="https://latex.codecogs.com/png.latex?%5Csqrt%7B%5Cfrac%7B2%7D%7Bn%7D%7D"> for ReLU networks, known as He initialization), the L2-norm of activations or gradients will drift — leading to vanishing signals or exploding values.</p></li>
<li><p><strong>Weights don’t move far from their initial values.</strong> This is surprising. If you plot the variance of weights before and after training for each layer, you’ll see remarkably similar values. The weights shift in certain directions, but relative to their initial magnitude, the change is small — especially for deep networks.</p></li>
</ol>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Implication
</div>
</div>
<div class="callout-body-container callout-body">
<p>Together, these observations mean initialization isn’t just “where you start” — it effectively defines the <em>neighborhood</em> of weight space you’ll explore during training. Proper initialization puts you in a good neighborhood. Bad initialization puts you somewhere the optimizer can’t easily escape.</p>
</div>
</div>
</section>
<section id="how-to-diagnose-initialization-problems" class="level3">
<h3 class="anchored" data-anchor-id="how-to-diagnose-initialization-problems">How to Diagnose Initialization Problems</h3>
<p><strong>Monitor two metrics across layers over all training iterations:</strong></p>
<ul>
<li><strong>Norm of weights</strong> per layer</li>
<li><strong>Norm of gradients</strong> per layer</li>
</ul>
<p>If the weight norms explode or collapse across layers, or if gradient norms vary by orders of magnitude between early and late layers, your initialization is likely wrong. Proper initialization keeps these norms roughly stable across layers.</p>
<p><strong>Key takeaway:</strong> Proper weight initialization speeds up training and leads to lower final error rates. It defines the effective search region for your optimizer, and its influence doesn’t fade — it persists throughout training.</p>
<hr>
</section>
</section>
<section id="normalization-fixing-what-initialization-cant" class="level2">
<h2 class="anchored" data-anchor-id="normalization-fixing-what-initialization-cant">Normalization: Fixing What Initialization Can’t</h2>
<p>If we know that activation norms can drift during training (due to imperfect initialization or the dynamics of optimization itself), why not just <em>force</em> them to be well-behaved? That’s the idea behind normalization layers.</p>
<section id="batch-normalization" class="level3">
<h3 class="anchored" data-anchor-id="batch-normalization">Batch Normalization</h3>
<p>Batch Normalization normalizes activations <strong>across the batch dimension</strong> for each feature independently. For a given feature, it computes the mean and variance across all examples in the batch, then normalizes to zero mean and unit variance.</p>
<p><strong>When it helps:</strong></p>
<ul>
<li>Dramatically speeds up training by maintaining stable activation norms</li>
<li>Preserves the discriminative information <em>between features</em> within each layer (because normalization is per-feature, not per-example)</li>
</ul>
<p><strong>When it hurts:</strong></p>
<ul>
<li>Creates <strong>dependency between samples</strong> in a batch — each example’s normalized activation depends on the other examples in the batch</li>
<li><strong>Unstable with small batches</strong> — statistics become noisy, and with a batch of 1, the variance is undefined</li>
<li><strong>Doesn’t work well with RNNs</strong> — the hidden state has temporal dependencies across time steps, and computing batch statistics independently at each time step ignores this structure</li>
</ul>
</section>
<section id="layer-normalization" class="level3">
<h3 class="anchored" data-anchor-id="layer-normalization">Layer Normalization</h3>
<p>Layer Normalization normalizes <strong>across all features</strong> for each sample independently. No dependency on other samples in the batch.</p>
<p><strong>When it helps:</strong></p>
<ul>
<li>Works with <strong>any batch size</strong>, including batch size 1</li>
<li><strong>Perfect for RNNs and Transformers</strong> — it normalizes across the embedding dimension for each token in each example, respecting temporal structure</li>
<li>This is why it’s the standard in Transformer architectures</li>
</ul>
<p><strong>When it hurts:</strong></p>
<ul>
<li>For fully connected networks, forcing zero mean and unit variance <em>across features</em> can destroy the relative magnitude differences between activations for different examples. These magnitude differences can be an important discriminative signal.</li>
<li>This makes it harder to drive loss low on tasks where inter-example feature magnitude differences matter</li>
</ul>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Choosing Between Them
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Use BatchNorm</strong> for CNNs with reasonably large batches (≥32). <strong>Use LayerNorm</strong> for Transformers, RNNs, and any setting where batch size is small or variable.</p>
</div>
</div>
<p><strong>Key takeaway:</strong> Normalization layers fix the activation drift that initialization can only partially prevent. BatchNorm and LayerNorm make different trade-offs about <em>what to normalize over</em>, and the right choice depends on your architecture and batch size.</p>
<hr>
</section>
</section>
<section id="regularization-controlling-complexity" class="level2">
<h2 class="anchored" data-anchor-id="regularization-controlling-complexity">Regularization: Controlling Complexity</h2>
<p>Regularization prevents models from memorizing the training data, forcing them to learn patterns that generalize to unseen examples.</p>
<section id="implicit-regularization" class="level3">
<h3 class="anchored" data-anchor-id="implicit-regularization">Implicit Regularization</h3>
<p>Before you add <em>any</em> explicit regularization, your training procedure already constrains the model. <strong>SGD with a particular initialization</strong> only explores a subset of all possible neural networks. The initialization defines the starting point, and the optimizer’s dynamics (step size, momentum, batch sampling) determine the trajectory through weight space.</p>
<p>This is called <em>implicit regularization</em>, and it’s powerful. The fact that SGD-trained networks generalize well — even when they have enough capacity to memorize the training set — is partly due to these implicit biases of the optimization procedure.</p>
</section>
<section id="explicit-regularization" class="level3">
<h3 class="anchored" data-anchor-id="explicit-regularization">Explicit Regularization</h3>
<p>Explicit regularization directly limits the functions the model can learn:</p>
<p><strong>L2 Regularization</strong> adds a penalty proportional to the squared magnitude of the weights. The premise: smoother functions (which don’t change dramatically for small input changes) tend to have smaller weights. By penalizing large weights, we encourage smoother, simpler functions.</p>
<p><strong>Dropout</strong> randomly zeroes out activations with probability <img src="https://latex.codecogs.com/png.latex?p"> during training. A useful mental model: dropout is a <em>stochastic approximation</em> of each layer’s activations, similar to how SGD approximates the full gradient with a mini-batch sample. During inference, we multiply activations by <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B1%7D%7B1-p%7D"> (or equivalently, scale during training) to keep the expected value consistent.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
L2 Regularization ≠ Weight Decay (for Adam!)
</div>
</div>
<div class="callout-body-container callout-body">
<p>For vanilla SGD, L2 regularization and weight decay are mathematically equivalent. But for adaptive optimizers like <strong>Adam</strong>, they are <em>not</em> the same.</p>
<p>Why? Adam computes first and second moments of the gradients. If you add the L2 penalty to the gradient (L2 regularization), the penalty gets scaled by Adam’s adaptive learning rate, making it <strong>less effective</strong> than intended. Weight decay, which adds the penalty directly to the parameter update step <em>without</em> modifying the gradient, avoids this issue.</p>
<p>This distinction — first identified in the “Decoupled Weight Decay” paper (AdamW) — is why AdamW is preferred over Adam + L2 regularization in practice.</p>
</div>
</div>
<p><strong>Key takeaway:</strong> Regularization operates at two levels: the implicit biases of SGD and initialization, and explicit penalties like L2/weight decay and dropout. For Adam-family optimizers, always use weight decay (AdamW), not L2 regularization.</p>
<hr>
</section>
</section>
<section id="scaling-up-when-one-gpu-isnt-enough" class="level2">
<h2 class="anchored" data-anchor-id="scaling-up-when-one-gpu-isnt-enough">Scaling Up: When One GPU Isn’t Enough</h2>
<p>Large datasets demand large models, and large models push hardware to its limits. Here’s how the systems community addresses this.</p>
<section id="the-memory-bottleneck" class="level3">
<h3 class="anchored" data-anchor-id="the-memory-bottleneck">The Memory Bottleneck</h3>
<p>The memory hierarchy tells the story:</p>
<ul>
<li><strong>Shared memory per core (GPU):</strong> ~64 KB — fast, tiny</li>
<li><strong>Global GPU memory:</strong> 10–80 GB depending on the device — this is the typical bottleneck</li>
<li><strong>CPU RAM:</strong> 64–512 GB — large but slow to access from GPU</li>
</ul>
<p>Most large models can’t fit entirely in GPU global memory during training, because we need to store: model parameters, optimizer state (2x or 3x model size for Adam), activations (saved for backward), and gradients.</p>
</section>
<section id="memory-saving-techniques" class="level3">
<h3 class="anchored" data-anchor-id="memory-saving-techniques">Memory-Saving Techniques</h3>
<section id="inference-buffer-reuse" class="level4">
<h4 class="anchored" data-anchor-id="inference-buffer-reuse">Inference: Buffer Reuse</h4>
<p>During inference, we don’t need to keep activations for backward. We can reuse a small set of buffers (2 or 3) across layers, writing each layer’s output into a buffer that a previous layer no longer needs. This reduces memory from <code>O(N)</code> to <code>O(1)</code> in the number of layers.</p>
</section>
<section id="training-activation-checkpointing" class="level4">
<h4 class="anchored" data-anchor-id="training-activation-checkpointing">Training: Activation Checkpointing</h4>
<p>During training, we normally keep <em>all</em> activations for the backward pass. Checkpointing trades memory for compute:</p>
<ol type="1">
<li>Divide the network into <strong>segments</strong> of roughly <img src="https://latex.codecogs.com/png.latex?%5Csqrt%7BN%7D"> layers</li>
<li>Only store activations at <strong>segment boundaries</strong> (checkpoints)</li>
<li>During the backward pass, <strong>recompute</strong> the forward pass within each segment to recover the needed activations</li>
</ol>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart LR
    subgraph seg1["Segment 1"]
        L1["Layer 1"] --&gt; L2["Layer 2"] --&gt; L3["Layer 3"]
    end
    subgraph seg2["Segment 2"]
        L4["Layer 4"] --&gt; L5["Layer 5"] --&gt; L6["Layer 6"]
    end
    subgraph seg3["Segment 3"]
        L7["Layer 7"] --&gt; L8["Layer 8"] --&gt; L9["Layer 9"]
    end

    seg1 --&gt; |"✓ checkpoint"| seg2
    seg2 --&gt; |"✓ checkpoint"| seg3

    style L1 fill:#e8f5e9,stroke:#333
    style L3 fill:#e8f5e9,stroke:#333
    style L4 fill:#e8f5e9,stroke:#333
    style L6 fill:#e8f5e9,stroke:#333
    style L7 fill:#e8f5e9,stroke:#333
    style L9 fill:#e8f5e9,stroke:#333
</pre>
</div>
<p></p><figcaption> Activation checkpointing: store only segment boundaries, recompute the rest during backward</figcaption> </figure><p></p>
</div>
</div>
</div>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Approach</th>
<th>Memory</th>
<th>Compute Overhead</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>No checkpointing</td>
<td><code>O(N)</code> activations</td>
<td>None</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?%5Csqrt%7BN%7D"> checkpoints</td>
<td><code>O(√N)</code> activations</td>
<td>~1 extra forward pass</td>
</tr>
<tr class="odd">
<td>Aggressive checkpointing</td>
<td><code>O(1)</code> activations</td>
<td>Up to <code>N</code> extra forward passes</td>
</tr>
</tbody>
</table>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Smart Checkpoint Placement
</div>
</div>
<div class="callout-body-container callout-body">
<p>Choose checkpoints at layers with <strong>cheap recomputation</strong>. ReLU activations are trivial to recompute (just check sign). Convolution or attention layers are expensive. Checkpoint <em>after</em> cheap layers to minimize the recomputation cost.</p>
</div>
</div>
</section>
</section>
<section id="distributed-training-data-and-model-parallelism" class="level3">
<h3 class="anchored" data-anchor-id="distributed-training-data-and-model-parallelism">Distributed Training: Data and Model Parallelism</h3>
<p>When one GPU isn’t enough, we spread the work across multiple devices. There are two fundamental strategies:</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">flowchart TD
    DT["Distributed Training"] --&gt; DP["&lt;b&gt;Data Parallelism&lt;/b&gt;&lt;br/&gt;Same model, different data"]
    DT --&gt; MP["&lt;b&gt;Model Parallelism&lt;/b&gt;&lt;br/&gt;Different parts of model"]

    DP --&gt; PS["Parameter Server&lt;br/&gt;Central coordinator"]
    DP --&gt; AR["AllReduce&lt;br/&gt;Peer-to-peer"]

    MP --&gt; TP["Tensor Parallelism&lt;br/&gt;Split layers across devices"]
    MP --&gt; PP["Pipeline Parallelism&lt;br/&gt;Different layers on different devices"]

    style DT fill:#fff3e0
    style DP fill:#e3f2fd
    style MP fill:#fce4ec
</pre>
</div>
<p></p><figcaption> Taxonomy of distributed training approaches</figcaption> </figure><p></p>
</div>
</div>
</div>
<section id="data-parallelism" class="level4">
<h4 class="anchored" data-anchor-id="data-parallelism">Data Parallelism</h4>
<p>Every worker runs a <strong>full replica of the model</strong> on a different micro-batch. Since gradients are additive (they’re independent across examples), we just need to sum them across workers before performing the weight update.</p>
<p>Two coordination strategies:</p>
<p><strong>Parameter Server:</strong> A central server collects gradients from all workers, sums them, performs the update, and broadcasts the new weights. Workers can start sending gradients as soon as they’re computed (layer by layer), overlapping communication with computation.</p>
<ul>
<li><strong>Bottleneck:</strong> The parameter server becomes a communication bottleneck as the number of workers grows. All traffic flows through one node.</li>
</ul>
<p><strong>AllReduce:</strong> A peer-to-peer approach where all workers collectively sum their gradients and each receives the result. No central bottleneck — communication scales more gracefully. Algorithms like Ring-AllReduce distribute the bandwidth load evenly.</p>
<ul>
<li><strong>Bottleneck:</strong> Total communication volume still grows with model size. Network bandwidth between nodes becomes the limiting factor.</li>
</ul>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
When Communication Dominates
</div>
</div>
<div class="callout-body-container callout-body">
<p>Communication overhead dominates training time when:</p>
<ul>
<li><strong>Model is large</strong> relative to batch computation time (small compute-to-communication ratio)</li>
<li><strong>Network bandwidth is low</strong> (especially across nodes vs.&nbsp;within a node with NVLink)</li>
<li><strong>Gradient compression</strong> isn’t used</li>
</ul>
<p>Rule of thumb: if your per-step compute time is less than 3x the gradient synchronization time, communication is your bottleneck. Scale batch size or use gradient compression/accumulation to amortize the cost.</p>
</div>
</div>
</section>
<section id="model-parallelism-pipeline-parallelism" class="level4">
<h4 class="anchored" data-anchor-id="model-parallelism-pipeline-parallelism">Model Parallelism (Pipeline Parallelism)</h4>
<p>When the model itself doesn’t fit on one device, we split the computation graph across devices. Each device handles a different set of layers, and they <strong>pipeline</strong> the computation: while device 2 processes micro-batch 1, device 1 can start on micro-batch 2.</p>
<p>Communication happens at layer boundaries via <code>send</code>/<code>recv</code> operations. The challenge is minimizing <strong>pipeline bubbles</strong> — idle time when a device is waiting for input from the previous stage.</p>
<p><strong>Key takeaway:</strong> Scaling from one GPU to many introduces a new bottleneck: communication. Data parallelism is simpler and scales well when the model fits on one device. Model/pipeline parallelism is necessary when it doesn’t, but introduces pipeline bubbles and more complex communication patterns.</p>
<hr>
</section>
</section>
</section>
<section id="neural-network-architectures-through-a-systems-lens" class="level2">
<h2 class="anchored" data-anchor-id="neural-network-architectures-through-a-systems-lens">Neural Network Architectures Through a Systems Lens</h2>
<p>The remaining sections cover architectures not as algorithmic curiosities, but as <em>systems design decisions</em> — what problem does each one solve, and what trade-off does it introduce?</p>
<section id="convolutional-neural-networks-cnns" class="level3">
<h3 class="anchored" data-anchor-id="convolutional-neural-networks-cnns">Convolutional Neural Networks (CNNs)</h3>
<p>CNNs exploit three structural priors about spatial data:</p>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Property</th>
<th>What It Means</th>
<th>Systems Benefit</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Parameter sharing</strong></td>
<td>Same filter everywhere in the image</td>
<td>Massive reduction in parameters</td>
</tr>
<tr class="even">
<td><strong>Sparse connectivity</strong></td>
<td>Each output depends only on a local receptive field</td>
<td>Few computations per output pixel</td>
</tr>
<tr class="odd">
<td><strong>Translation equivariance</strong></td>
<td>Shifting input shifts output the same way</td>
<td>No need to learn position-specific detectors</td>
</tr>
</tbody>
</table>
<p><strong>Dilation</strong> increases the receptive field without increasing parameters — each filter element is spread out by a dilation factor, giving access to a larger spatial area. This is particularly useful for temporal problems where context matters.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Convolution as Matrix Multiplication
</div>
</div>
<div class="callout-body-container callout-body">
<p>We can express convolution as a matrix multiplication where the weight matrix has a specific sparsity pattern (filled with actual weights and zeros reflecting the filter structure). We don’t actually construct this matrix — it would be enormous — but this view explains why the backward pass of a convolution is a convolution with a flipped filter: multiplying by the transpose of the convolution matrix is equivalent to convolving with the spatially flipped kernel.</p>
</div>
</div>
</section>
<section id="recurrent-neural-networks-rnns" class="level3">
<h3 class="anchored" data-anchor-id="recurrent-neural-networks-rnns">Recurrent Neural Networks (RNNs)</h3>
<p>RNNs address temporal dependencies by maintaining a <strong>hidden state</strong> that gets updated at each time step as a function of the current input and the previous hidden state. In theory, the last hidden state captures the entire input history.</p>
<p>In practice, the hidden state is a bottleneck. The entire past is <em>compacted</em> into a single vector, and information from early time steps (<img src="https://latex.codecogs.com/png.latex?x_1">) gets diluted compared to recent ones (<img src="https://latex.codecogs.com/png.latex?x_t">).</p>
<p><strong>Backpropagation Through Time (BPTT):</strong> Because weights are shared across time steps, gradients must flow through the entire unrolled sequence. If the dominant eigenvalue of the weight matrix is less than 1, gradients <strong>vanish</strong> exponentially with sequence length. Greater than 1, they <strong>explode</strong>.</p>
</section>
<section id="lstm-gating-the-information-flow" class="level3">
<h3 class="anchored" data-anchor-id="lstm-gating-the-information-flow">LSTM: Gating the Information Flow</h3>
<p>LSTMs address vanishing gradients by separating the hidden state into two components:</p>
<ul>
<li><strong>Cell state</strong>: A “highway” for long-range information flow</li>
<li><strong>Hidden state</strong>: The working memory exposed to the next layer</li>
</ul>
<p>Four gates (learned transformations) control information flow at each step:</p>
<ol type="1">
<li><strong>Forget gate</strong>: What information from the cell state to discard</li>
<li><strong>Input gate</strong>: What new information to add to the cell state</li>
<li><strong>Cell update</strong>: The candidate new information</li>
<li><strong>Output gate</strong>: What to expose as the hidden state</li>
</ol>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
LSTMs Don’t Fully Solve Long-Range Dependencies
</div>
</div>
<div class="callout-body-container callout-body">
<p>Despite the gating mechanism, both RNNs and LSTMs struggle with information far in the past. Recent tokens have a much more direct connection to the current hidden state. The cell state highway helps, but it’s not a complete solution for very long sequences. This is the fundamental motivation for attention mechanisms.</p>
</div>
</div>
</section>
<section id="transformers-global-receptive-field-via-attention" class="level3">
<h3 class="anchored" data-anchor-id="transformers-global-receptive-field-via-attention">Transformers: Global Receptive Field via Attention</h3>
<p>Transformers replace recurrence with <strong>attention</strong>, which gives every position direct access to every other position — a global receptive field.</p>
<p>However, the attention mechanism is inherently <strong>order-invariant</strong>: permuting the input tokens permutes the output in the same way. There’s no notion of “first” or “last.” This is why <strong>positional encodings</strong> are essential — they inject order information that attention alone cannot capture.</p>
<p>For <strong>autoregressive tasks</strong> (language modeling, text generation), a causal mask restricts each position to attend only to current and previous positions, preserving the left-to-right generation constraint.</p>
</section>
<section id="gans-adversarial-generation" class="level3">
<h3 class="anchored" data-anchor-id="gans-adversarial-generation">GANs: Adversarial Generation</h3>
<p>GANs learn to generate data by pitting two networks against each other:</p>
<ul>
<li><strong>Generator</strong>: Takes a random noise vector and tries to produce realistic images. Its objective is to <em>maximize</em> the discriminator’s error — make the discriminator believe the fake images are real.</li>
<li><strong>Discriminator</strong>: Receives both real and generated images and tries to classify them correctly. It <em>minimizes</em> its classification loss.</li>
</ul>
<p>The discriminator acts as a learned loss function that guides the generator toward producing increasingly realistic outputs. The “adversarial” aspect refers to the generator learning to exploit subtle distributional differences that are imperceptible to humans.</p>
<p><strong>Conv2dTranspose (Deconvolution):</strong> The generator typically needs to upsample from a small latent vector to a full-resolution image. Transposed convolution reverses the spatial dimension change of convolution — taking a small spatial input and producing a larger spatial output.</p>
<p><strong>Key takeaway:</strong> Each architecture encodes different assumptions about data structure. CNNs assume spatial locality. RNNs assume temporal ordering. Transformers assume that global relationships matter and let attention learn what to focus on. GANs assume that the best loss function is a learned one.</p>
<hr>
</section>
</section>
<section id="model-deployment-considerations" class="level2">
<h2 class="anchored" data-anchor-id="model-deployment-considerations">Model Deployment Considerations</h2>
<p>Training a model is only half the battle. Deploying it introduces a different set of constraints:</p>
<ul>
<li><strong>Application environment restrictions</strong>: Model size limits, no Python runtime available (embedded/mobile)</li>
<li><strong>Hardware acceleration</strong>: Leveraging mobile GPUs, NPUs, or specialized CPU instructions (AVX, NEON)</li>
<li><strong>Integration</strong>: Fitting into existing application architectures and serving infrastructure</li>
</ul>
<p>These constraints often drive post-training optimizations like quantization, pruning, distillation, and conversion to inference-specific formats (ONNX, TensorRT, Core ML).</p>
<hr>
</section>
<section id="tying-it-all-together" class="level2">
<h2 class="anchored" data-anchor-id="tying-it-all-together">Tying It All Together</h2>
<p>If you’ve made it this far, you’ve traced the full stack of a deep learning system:</p>
<ol type="1">
<li><strong>Framework design</strong> determines your development experience and optimization ceiling</li>
<li><strong>Autograd</strong> gives you gradients but demands memory for saved tensors</li>
<li><strong>Memory layout</strong> (strides, views, contiguity) determines whether operations are free or expensive</li>
<li><strong>Hardware acceleration</strong> turns logical operations into physical memory accesses and arithmetic</li>
<li><strong>Initialization and normalization</strong> keep training stable from start to finish</li>
<li><strong>Regularization</strong> prevents overfitting at both implicit and explicit levels</li>
<li><strong>Scaling</strong> trades communication overhead for the ability to train larger models</li>
<li><strong>Architecture choices</strong> encode structural assumptions about your data</li>
</ol>
<p>These layers interact. Autograd’s saved tensors create memory pressure, which motivates checkpointing, which trades memory for recomputation. Initialization determines activation norms, which normalization layers can stabilize, which affects gradient flow, which determines whether training converges. Strides determine memory access patterns, which determine kernel performance, which determines whether you’re compute-bound or memory-bound.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Systems Thinking Payoff
</div>
</div>
<div class="callout-body-container callout-body">
<p>The next time training is slow, memory is exploding, or loss isn’t decreasing — you’ll have a mental model of the full stack to reason about where the problem might be. That’s the real value of building a framework from scratch.</p>
</div>
</div>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>ML Systems</category>
  <guid>https://imaddabbura.github.io/posts/mlsys/dl-systems.html</guid>
  <pubDate>Wed, 20 Dec 2023 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/mlsys/images/dl-system-image.jpeg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>Breaking Text Apart (The Smart Way)</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/nlp/Tokenization-Strategies.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p>Tokenization sits at the foundation of every NLP system — and it’s where more bugs, performance failures, and cross-lingual headaches originate than most practitioners expect.</p>
<p>The core problem: neural networks can’t consume raw text. They need numbers. Tokenization is the bridge — converting a string into a sequence of integer IDs that the model can embed and process. But <em>how</em> you make that conversion has enormous downstream consequences: for vocabulary size, sequence length, out-of-vocabulary handling, and multilingual generalization.</p>
<p>There are three fundamental strategies, sitting on a spectrum from fine-grained to coarse:</p>
<ul>
<li><strong>Character tokenization</strong>: split at every character — maximum granularity, minimum vocabulary</li>
<li><strong>Word tokenization</strong>: split at word boundaries — minimum granularity, maximum vocabulary</li>
<li><strong>Subword tokenization</strong>: split rules learned from corpus statistics — the practical sweet spot used by every modern LLM</li>
</ul>
<p>We’ll work through each in turn with concrete code, then zoom in on the two subword algorithms that dominate modern NLP: <strong>WordPiece</strong> (BERT, DistilBERT) and <strong>BPE via SentencePiece</strong> (XLM-R, LLaMA, GPT-family models).</p>
</section>
<section id="tokenization-process" class="level2">
<h2 class="anchored" data-anchor-id="tokenization-process">Tokenization Process</h2>
<p><a href="images/tokenization-pipeline.png" class="lightbox" data-gallery="quarto-lightbox-gallery-1"><img src="https://imaddabbura.github.io/posts/nlp/images/tokenization-pipeline.png" class="img-fluid"></a></p>
<p>The tokenization pipeline has four stages, each with a distinct job:</p>
<ul>
<li><p><strong>Normalization</strong>: Clean the raw text before any splitting. Common operations include Unicode normalization (collapsing different byte representations of the same character), lowercasing, and accent stripping. Critically, what gets normalized here is permanent — the model never sees the original form.</p></li>
<li><p><strong>Pretokenization</strong>: Split the normalized text into coarse units, typically words or word-like chunks. For English and German, splitting on whitespace and punctuation works well. For languages like Japanese or Chinese — which have no whitespace — language-specific rules or character-level splits are used instead.</p></li>
<li><p><strong>Tokenizer model</strong>: Apply the learned subword splitting algorithm (WordPiece, BPE, Unigram, etc.) to each pretokenized chunk. This is the only <em>trained</em> stage — everything else is rule-based. The vocabulary and merge rules come from the pretraining corpus.</p></li>
<li><p><strong>Postprocessing</strong>: Wrap the token sequence with any model-specific special tokens. BERT prepends <code>[CLS]</code> and inserts <code>[SEP]</code> between sequences. XLM-R uses <code>&lt;s&gt;</code> and <code>&lt;/s&gt;</code>. These tokens have specific learned representations and must be consistent between pretraining and fine-tuning.</p></li>
</ul>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Pipeline Is Framework-Agnostic
</div>
</div>
<div class="callout-body-container callout-body">
<p>This four-stage structure underpins Hugging Face <code>tokenizers</code>, SentencePiece, and most production tokenizer implementations. Most unexpected token outputs trace back to either normalization (e.g., surprise lowercasing or accent stripping) or postprocessing (missing or double-added special tokens).</p>
</div>
</div>
</section>
<section id="tokenization-strategies" class="level2">
<h2 class="anchored" data-anchor-id="tokenization-strategies">Tokenization Strategies</h2>
<p>There are three core tokenization schemes. Before diving in, here’s a preview of the trade-offs that motivate the progression from characters to subwords:</p>
<table class="table">
<colgroup>
<col style="width: 20%">
<col style="width: 20%">
<col style="width: 20%">
<col style="width: 20%">
<col style="width: 20%">
</colgroup>
<thead>
<tr class="header">
<th>Strategy</th>
<th>Vocab size</th>
<th>Sequence length</th>
<th>OOV handling</th>
<th>Multilingual</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Character</td>
<td>Tiny (~100s)</td>
<td>Very long</td>
<td>✅ None</td>
<td>✅ Natural</td>
</tr>
<tr class="even">
<td>Word</td>
<td>Huge (millions)</td>
<td>Short</td>
<td>❌ UNK collapse</td>
<td>⚠️ Poor</td>
</tr>
<tr class="odd">
<td>Subword</td>
<td>Medium (10K–100K)</td>
<td>Medium</td>
<td>✅ Decompose</td>
<td>✅ Good</td>
</tr>
</tbody>
</table>
<p>The pattern is clear: characters and words are opposite extremes, each with a disqualifying flaw. Subword tokenization is the engineered middle ground — and why every modern LLM uses it.</p>
<section id="character-tokenization" class="level3">
<h3 class="anchored" data-anchor-id="character-tokenization">Character Tokenization</h3>
<p>Character tokenization is the simplest possible approach: split the input string into individual characters and treat each one as a token. No learned vocabulary, no language-specific rules — just <code>list(text)</code>. It’s the floor of the granularity spectrum.</p>
<div id="92ba1e07-4dff-4f3a-9b31-0daf88028379" class="cell" data-execution_count="3">
<div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1">text <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"I love NLP!"</span></span>
<span id="cb1-2"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(text)</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="3">
<pre><code>['I', ' ', 'l', 'o', 'v', 'e', ' ', 'N', 'L', 'P', '!']</code></pre>
</div>
</div>
<p>From here, it is easy to convert each character into integers that would be fed to the model. This step is called <em>numericalization</em>. We can numericalize the above text by first building the vocabulary, and then convert each character to its corresponding index as follows:</p>
<div id="88aa86ee-95da-4173-82e0-20dd5f5c3fc6" class="cell" data-execution_count="4">
<div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1">vocab <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {char: idx <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> idx, char <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sorted</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">set</span>(text)))}</span>
<span id="cb3-2"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(vocab)</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>{' ': 0, '!': 1, 'I': 2, 'L': 3, 'N': 4, 'P': 5, 'e': 6, 'l': 7, 'o': 8, 'v': 9}</code></pre>
</div>
</div>
<p>Now we can simply map each token (character in this case) to its own corresponding index:</p>
<div id="cabb3953-a7e5-4bab-9ae6-5f3f3c319e55" class="cell" data-execution_count="5">
<div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">[vocab[char] <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> char <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> text]</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="5">
<pre><code>[2, 0, 7, 8, 9, 6, 0, 4, 3, 5, 1]</code></pre>
</div>
</div>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Why Character Tokenization Is Appealing
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>No out-of-vocabulary problem</strong>: every possible input — misspellings, code, emojis, neologisms — is representable from the same small fixed alphabet</li>
<li><strong>Tiny vocabulary</strong>: ~100 characters for English. The embedding matrix and output projection stay small, which reduces parameter count and memory</li>
</ul>
</div>
</div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Why Character Tokenization Fails in Practice
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Sequences become extremely long</strong>: “I love NLP!” becomes 11 tokens. A typical 512-word document becomes several thousand characters. For Transformers with quadratic attention cost, this is prohibitively expensive</li>
<li><strong>No free linguistic priors</strong>: the model has no prior knowledge that <code>l</code>, <code>o</code>, <code>v</code>, <code>e</code> together constitute a meaningful unit. Recovering word-level and phrase-level structure from raw characters requires far more data, compute, and model depth than most tasks justify</li>
<li><strong>Context window exhaustion</strong>: with fixed-length context windows, very long character sequences mean the model can attend to only a small slice of a document at a time, losing long-range dependencies that often carry the most important signal</li>
</ul>
</div>
</div>
</section>
<section id="word-tokenization" class="level3">
<h3 class="anchored" data-anchor-id="word-tokenization">Word Tokenization</h3>
<p>Word tokenization takes the opposite approach: split on whitespace (and often punctuation) and treat each word as an atomic token. Sequences stay short and tokens carry recognizable meaning — but the vocabulary problem quickly becomes unmanageable at scale.</p>
<div id="1b6e7b4b-99d2-4ce4-a3d7-3176cdf65250" class="cell" data-execution_count="6">
<div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1">text.split()</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="6">
<pre><code>['I', 'love', 'NLP!']</code></pre>
</div>
</div>
<div id="c5daf75d-a025-42d4-a21a-35b3b96391b3" class="cell" data-execution_count="8">
<div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1">vocab <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> {char: idx <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> idx, char <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sorted</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">set</span>(text.split())))}</span>
<span id="cb9-2"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(vocab)</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>{'I': 0, 'NLP!': 1, 'love': 2}</code></pre>
</div>
</div>
<div id="7af9c548-5537-4a74-aee4-904ddfd45184" class="cell" data-execution_count="9">
<div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1">[vocab[word] <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> word <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> text.split()]</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="9">
<pre><code>[0, 2, 1]</code></pre>
</div>
</div>
<p>Most production word tokenizers go beyond whitespace splitting and include language-specific heuristics — for example, separating contractions like “doesn’t” into “does” and “n’t”, or splitting punctuation from adjacent words. These rules improve coverage but don’t solve the fundamental vocabulary size and OOV problems.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Why Word Tokenization Seems Appealing
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Short sequences</strong>: “I love NLP!” is 3 tokens. The model attends to far more context within the same fixed-length window</li>
<li><strong>Tokens carry meaning directly</strong>: each token maps to a recognizable linguistic unit, giving the model useful priors without learning from scratch</li>
</ul>
</div>
</div>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Why Word Tokenization Breaks Down
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li><strong>Vocabulary explosion</strong>: a large corpus contains millions of distinct word forms — declinations, misspellings, punctuation variants, domain-specific terms. An embedding table with 1M entries at dimension 512 requires ~500M parameters for the embedding layer alone. Truncating to the top-N words forces everything else to <code>[UNK]</code>, which destroys information silently — the model has no way to recover what word was there</li>
<li><strong>Under-trained embeddings</strong>: rare words appear too infrequently to accumulate meaningful gradient signal. They occupy slots in the vocabulary without learning useful representations — wasted capacity</li>
<li><strong>Language boundary failures</strong>: languages without clear word boundaries (Japanese, Chinese, Thai) have no natural whitespace to split on. Word tokenization either silently fails or requires expensive language-specific preprocessing at training and inference time</li>
</ul>
</div>
</div>
</section>
<section id="subword-tokenization" class="level3">
<h3 class="anchored" data-anchor-id="subword-tokenization">Subword Tokenization</h3>
<p>Subword tokenization is the engineered middle ground between the two extremes. The core insight: <strong>most words in any language are built from a small set of recurring morphemes</strong> — prefixes, roots, suffixes. “tokenization”, “tokenizer”, “tokenized” all share the root “token”. Word tokenization throws that structure away by treating each form as an unrelated atomic entry. Character tokenization preserves the raw signal but forces the model to discover linguistic structure from scratch, without any priors.</p>
<p>Subword algorithms exploit this structure directly. They learn a vocabulary of high-frequency subword units from a large pretraining corpus. Common words like “love” stay as single tokens. Rare or novel words get decomposed into familiar pieces: “tokenization” → <code>["token", "##ization"]</code> in WordPiece, or <code>["▁token", "ization"]</code> in SentencePiece. The model has seen “token” thousands of times and has a rich representation for it — that representation is now available even when encountering “detokenization” for the first time.</p>
<p>This also handles misspellings and out-of-domain terms gracefully. “GPT-4o” doesn’t need to be in the vocabulary — it gets decomposed into known subwords rather than collapsing to <code>[UNK]</code>.</p>
<p>Two algorithms dominate modern NLP: <strong>WordPiece</strong> (BERT, DistilBERT) and <strong>BPE via SentencePiece</strong> (XLM-R, LLaMA, GPT-family models). Both learn subword vocabularies from corpus statistics, but they use different objectives and produce different tokenization behavior — differences that matter when debugging cross-lingual failures or unexpected token splits.</p>
<section id="wordpiece" class="level4">
<h4 class="anchored" data-anchor-id="wordpiece">WordPiece</h4>
<p><a href="https://arxiv.org/abs/1609.08144v2">WordPiece</a> is the subword algorithm behind BERT and DistilBERT. Like BPE, it starts with a character-level vocabulary and iteratively merges pairs — but the key difference is in <em>how</em> it chooses which pair to merge next.</p>
<p>BPE picks the most frequent pair. WordPiece picks the pair that <strong>maximizes the likelihood of the training corpus</strong> when merged. Concretely, for a candidate pair <img src="https://latex.codecogs.com/png.latex?(u,%20v)">, it evaluates:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Ctext%7Bscore%7D(u,%20v)%20=%20%5Cfrac%7B%5Ctext%7Bcount%7D(uv)%7D%7B%5Ctext%7Bcount%7D(u)%20%5Ctimes%20%5Ctext%7Bcount%7D(v)%7D"></p>
<p>This is a pointwise mutual information criterion: it rewards pairs that appear together more than their individual frequencies would predict. Merging “##iz” with “##ation” scores high not just because the bigram is frequent, but because seeing “##iz” almost always predicts “##ation” — the merge buys maximum information.</p>
<p>The training process:</p>
<ol type="1">
<li>Initialize the vocabulary with all characters in the corpus, prepending <code>##</code> to all characters that don’t start a word</li>
<li>Score every adjacent pair using the PMI formula above</li>
<li>Merge the highest-scoring pair and add it to the vocabulary</li>
<li>Repeat until the vocabulary reaches the target size (BERT uses 30,000)</li>
</ol>
<p>The <code>##</code> prefix is the signature of WordPiece. It marks continuation subwords — pieces that are <em>not</em> at the start of a word boundary. So <code>["nl", "##p"]</code> means: “nl” starts a word, “##p” continues it. Reconstructing the original word means stripping <code>##</code> and concatenating.</p>
<div id="99b30f6a-2857-4850-9b6d-7ccfbd2ec75b" class="cell" data-execution_count="8">
<div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> transformers <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> DistilBertTokenizer</span>
<span id="cb13-2"></span>
<span id="cb13-3">tokenizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> DistilBertTokenizer.from_pretrained(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"distilbert-base-uncased"</span>)</span>
<span id="cb13-4">encoded_text <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> tokenizer(text)</span>
<span id="cb13-5">encoded_text</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="8">
<pre><code>{'input_ids': [101, 1045, 2293, 17953, 2361, 999, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}</code></pre>
</div>
</div>
<div id="f0dbfc78-fdec-485e-84ac-795cb9ea3be3" class="cell" data-execution_count="9">
<div class="sourceCode cell-code" id="cb15" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1">tokenizer.convert_ids_to_tokens(encoded_text[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"input_ids"</span>])</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="9">
<pre><code>['[CLS]', 'i', 'love', 'nl', '##p', '!', '[SEP]']</code></pre>
</div>
</div>
<p>Reading the DistilBERT output token by token:</p>
<ul>
<li><code>[CLS]</code> — a special classification token prepended to every sequence. Its final hidden state is used as the aggregate sequence representation for classification tasks</li>
<li><code>i</code> — “I” was lowercased (DistilBERT uses <code>distilbert-base-**uncased**</code>)</li>
<li><code>love</code> — a common English word; gets its own token</li>
<li><code>nl</code> — the first subword of “NLP”. “NLP” is rare enough in BERT’s training corpus that it was never merged into a single token</li>
<li><code>##p</code> — continues from “nl”. The <code>##</code> prefix signals “this piece is not at a word boundary — attach it to the previous token”</li>
<li><code>!</code> — punctuation gets its own token</li>
<li><code>[SEP]</code> — marks the end of a sequence (or the boundary between two sequences in sentence-pair tasks)</li>
</ul>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Decoding the <code>##</code> Prefix
</div>
</div>
<div class="callout-body-container callout-body">
<p>When you see <code>##</code> in WordPiece output, it means: strip the <code>##</code> and concatenate directly to the previous token. <code>["nl", "##p"]</code> → <code>"nlp"</code>. <code>["un", "##believ", "##able"]</code> → <code>"unbelievable"</code>. The <code>##</code> is how WordPiece encodes which subwords are word-internal vs.&nbsp;word-initial — critical for reconstructing the original string.</p>
</div>
</div>
<div id="211d4098-2821-4fd3-9a0c-c404f4ac3ec9" class="cell" data-execution_count="12">
<div class="sourceCode cell-code" id="cb17" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb17-1">tokenizer.convert_tokens_to_string(</span>
<span id="cb17-2">    tokenizer.convert_ids_to_tokens(encoded_text[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"input_ids"</span>])</span>
<span id="cb17-3">)</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="12">
<pre><code>'[CLS] i love nlp ! [SEP]'</code></pre>
</div>
</div>
</section>
<section id="sentencepiece" class="level4">
<h4 class="anchored" data-anchor-id="sentencepiece">SentencePiece</h4>
<p><a href="https://arxiv.org/abs/1808.06226">SentencePiece</a> is a language-agnostic tokenization library that implements both BPE and unigram language model algorithms. Two properties make it the dominant choice for multilingual models.</p>
<p><strong>First: it treats the input as a raw Unicode character stream</strong> — no language-specific pretokenization required. It never assumes whitespace marks word boundaries, which means it works equally well on English, Chinese, Japanese, Arabic, and any language mixture. This is why XLM-R, mT5, and LLaMA all use SentencePiece.</p>
<p><strong>Second: it uses <code>▁</code> (U+2581, lower one-eighth block) to encode the start of a new word.</strong> Rather than marking continuation pieces like WordPiece does with <code>##</code>, SentencePiece marks word-<em>starts</em>. A <code>▁</code> at the beginning of a token means “there was a space before this character in the original text.” Absence of <code>▁</code> means “this token is a continuation.”</p>
<p>The BPE algorithm it implements:</p>
<ol type="1">
<li>Initialize the vocabulary with individual Unicode characters plus an end-of-word marker</li>
<li>Count all adjacent character pairs across the corpus</li>
<li>Merge the most frequent pair into a new subword unit</li>
<li>Repeat until the vocabulary reaches the target size</li>
</ol>
<p>Unlike WordPiece’s PMI-based selection, BPE uses raw frequency. It’s simpler but produces similar results in practice — both algorithms converge on vocabularies dominated by common morphemes.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
BPE vs.&nbsp;Unigram in SentencePiece
</div>
</div>
<div class="callout-body-container callout-body">
<p>SentencePiece supports two algorithms. BPE builds the vocabulary bottom-up by merging. Unigram starts with a large candidate vocabulary and prunes it by removing tokens that minimally reduce the likelihood of the training corpus — a top-down approach. Unigram is used by XLNet and some multilingual models; BPE is more common. Both are interchangeable in the SentencePiece API.</p>
</div>
</div>
<div id="12596fb8-0a5a-439c-aa79-5bc0ab181f9c" class="cell" data-execution_count="13">
<div class="sourceCode cell-code" id="cb19" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb19-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> transformers <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> XLMRobertaTokenizer</span>
<span id="cb19-2"></span>
<span id="cb19-3">tokenizer <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> XLMRobertaTokenizer.from_pretrained(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"xlm-roberta-base"</span>)</span>
<span id="cb19-4">encoded_text <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> tokenizer(text)</span>
<span id="cb19-5">encoded_text</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="13">
<pre><code>{'input_ids': [0, 87, 5161, 541, 37352, 38, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}</code></pre>
</div>
</div>
<div id="d1d33cbc-139d-49fd-9b80-d1815ff7d60e" class="cell" data-execution_count="14">
<div class="sourceCode cell-code" id="cb21" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb21-1">tokenizer.convert_ids_to_tokens(encoded_text[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"input_ids"</span>])</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="14">
<pre><code>['&lt;s&gt;', '▁I', '▁love', '▁N', 'LP', '!', '&lt;/s&gt;']</code></pre>
</div>
</div>
<p>Reading the XLM-R output token by token:</p>
<ul>
<li><code>&lt;s&gt;</code> — sequence start token (XLM-R’s equivalent of <code>[CLS]</code>)</li>
<li><code>▁I</code> — the <code>▁</code> prefix means “there was a space before this character.” Since “I” starts the sentence (treated as if preceded by whitespace), it gets <code>▁</code></li>
<li><code>▁love</code> — common word, single token; <code>▁</code> marks it as word-initial</li>
<li><code>▁N</code> — “NLP” is split; <code>▁N</code> is the word-initial piece</li>
<li><code>LP</code> — continues from <code>▁N</code>, no <code>▁</code> prefix (it’s a word-internal continuation)</li>
<li><code>!</code> — punctuation token</li>
<li><code>&lt;/s&gt;</code> — sequence end token (XLM-R’s equivalent of <code>[SEP]</code>)</li>
</ul>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
WordPiece <code>##</code> vs.&nbsp;SentencePiece <code>▁</code> — Two Sides of the Same Coin
</div>
</div>
<div class="callout-body-container callout-body">
<p>These two prefixes encode word boundary information in opposite ways:</p>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th>Tokenizer</th>
<th>Marker</th>
<th>Meaning</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>WordPiece (BERT)</td>
<td><code>##token</code></td>
<td>This piece continues the previous word</td>
</tr>
<tr class="even">
<td>SentencePiece (XLM-R, LLaMA)</td>
<td><code>▁token</code></td>
<td>A space preceded this character — new word starts here</td>
</tr>
</tbody>
</table>
<p>Both fully encode the original whitespace and allow perfect string reconstruction. The difference is convention, not capability. But you need to know which convention a tokenizer uses when writing postprocessing code to detokenize outputs.</p>
</div>
</div>
<div id="75c7acd6-1fc1-4b12-9e38-a39ca5ceeca8" class="cell" data-execution_count="15">
<div class="sourceCode cell-code" id="cb23" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb23-1">tokenizer.convert_tokens_to_string(</span>
<span id="cb23-2">    tokenizer.convert_ids_to_tokens(encoded_text[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"input_ids"</span>])</span>
<span id="cb23-3">)</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="15">
<pre><code>'&lt;s&gt; I love NLP!&lt;/s&gt;'</code></pre>
</div>
</div>
</section>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion">Conclusion</h2>
<p>The three tokenization strategies form a clear hierarchy in practice:</p>
<ul>
<li><p><strong>Character tokenization</strong> is essentially unused in production NLP. Sequence lengths become prohibitively long for Transformer attention, and the model must learn linguistic structure entirely from scratch. It survives in niche applications: character-level language models, certain byte-level models (GPT-2 uses byte-level BPE as a starting point), and as a fallback for extremely small vocabularies.</p></li>
<li><p><strong>Word tokenization</strong> appears in legacy systems and simple bag-of-words pipelines, but fails at scale. Vocabulary explosion, <code>[UNK]</code> collapse, and multilingual brittleness make it unsuitable for anything pretrained on broad corpora.</p></li>
<li><p><strong>Subword tokenization</strong> is the universal standard for pretrained language models. WordPiece and SentencePiece BPE both solve the core trade-offs: bounded vocabulary, graceful OOV handling, multilingual coverage, and sequences short enough for Transformer attention.</p></li>
</ul>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Always Use the Tokenizer the Model Was Trained With
</div>
</div>
<div class="callout-body-container callout-body">
<p>When fine-tuning a pretrained model, you must use the <strong>exact same tokenizer</strong> — not just the same algorithm, but the same vocabulary file. The model’s embedding matrix maps token ID 1045 to a learned vector for the word “i” (in DistilBERT). Swap in a different tokenizer and ID 1045 now refers to something else entirely. The embeddings become noise, the model is unrecoverable, and fine-tuning won’t fix it. This applies to vocabulary size, normalization rules, and special token placements — all of it must match pretraining exactly.</p>
</div>
</div>
<p>Most practical work doesn’t require building tokenizers from scratch — Hugging Face <code>tokenizers</code> and SentencePiece handle it. What matters operationally is understanding the output: recognizing <code>##</code> vs <code>▁</code> markers, knowing which special tokens a model expects and in what order, and catching normalization surprises (casing, accent stripping) before they cause silent failures downstream.</p>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>NLP</category>
  <guid>https://imaddabbura.github.io/posts/nlp/Tokenization-Strategies.html</guid>
  <pubDate>Sat, 14 Jan 2023 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/nlp/images/tokenization.png" medium="image" type="image/png" height="65" width="144"/>
</item>
<item>
  <title>C Program Startup</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/c/program-startup-notes.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge growing">growing</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p>In this post, I will try to write down the steps of C program execution on x86. I used to believe that all C programs start execution at <code>main</code>, or at least this was my understanding from different books/courses until my best friend <code>gdb</code> debugger showed the symbol for <code>_start</code>. This is how I got curious until I got to the bottom of it. Below are my notes that I took during my learning.</p>
</section>
<section id="execution-steps" class="level2">
<h2 class="anchored" data-anchor-id="execution-steps">Execution Steps</h2>
<ol type="1">
<li>The linker inject <code>_start</code> which is called in the process of loading.
<ul>
<li>It is written in assembly language</li>
<li>Always placed at the beginning of the <code>.text</code> section -&gt; Always guaranteed to run before anything else</li>
<li>It sets up some registers and arguments and calls <code>__start</code> which is called <code>__libc_start_main</code></li>
</ul></li>
<li><code>__libc_start_main</code> is written in C that:
<ul>
<li>function prototype:</li>
</ul>
<div class="sourceCode" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode c code-with-copy"><code class="sourceCode c"><span id="cb1-1">__libc_start_main <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(</span><span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(*</span>main<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">)</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(</span><span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">int</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">,</span> <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">char</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**,</span> <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">char</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**),</span></span>
<span id="cb1-2">                   <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">int</span> argc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">,</span></span>
<span id="cb1-3">                   <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">char</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>argv<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">,</span></span>
<span id="cb1-4">                   <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">int</span>  <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(*</span>init<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">)</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(</span><span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">int</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">,</span> <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">char</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**,</span> <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">char</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**),</span></span>
<span id="cb1-5">                   <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">void</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(*</span>fini<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">)</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(</span><span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">void</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">),</span></span>
<span id="cb1-6">                   <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">void</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(*</span>rtld_fini<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">)</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(</span><span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">void</span><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">),</span></span>
<span id="cb1-7">                   <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">void</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>stack_end</span>
<span id="cb1-8">                  <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">)</span></span></code></pre></div>
<ul>
<li>Define <code>environ</code> global variable using <code>ps_string</code>: <code>environ = ps_strings-&gt;ps_envstr</code>
<ul>
<li>Below are some details about <code>ps_strings</code> structure:</li>
</ul></li>
</ul>
<div class="sourceCode" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode c code-with-copy"><code class="sourceCode c"><span id="cb2-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/*</span></span>
<span id="cb2-2"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"> * The following structure is found at the top of the user stack of each</span></span>
<span id="cb2-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"> * user process. The ps program uses it to locate argv and environment</span></span>
<span id="cb2-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"> * strings. Programs that wish ps to display other information may modify</span></span>
<span id="cb2-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"> * it; normally ps_argvstr points to argv[0], and ps_nargvstr is the same</span></span>
<span id="cb2-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"> * as the program's argc. The fields ps_envstr and ps_nenvstr are the</span></span>
<span id="cb2-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"> * equivalent for the environment.</span></span>
<span id="cb2-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"> */</span></span>
<span id="cb2-9"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">struct</span> ps_strings <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span></span>
<span id="cb2-10">    <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">char</span>    <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>ps_argvstr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span>       <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/* first of 0 or more argument strings */</span></span>
<span id="cb2-11">    <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">int</span>       ps_nargvstr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span>      <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/* the number of argument strings */</span></span>
<span id="cb2-12">    <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">char</span>    <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span>ps_envstr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span>        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/* first of 0 or more environment strings */</span></span>
<span id="cb2-13">    <span class="dt" style="color: #AD0000;
background-color: null;
font-style: inherit;">int</span>       ps_nenvstr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span>       <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/* the number of environment strings */</span></span>
<span id="cb2-14"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">};</span></span></code></pre></div>
<ul>
<li>It is typically defined as char <code>envp = argv[argc + 1]</code> in <code>libc_init_first</code></li>
<li>It also registers cleanup and exit handlers</li>
<li>It define <code>init</code> &amp; <code>fini</code> that defines function prolog and epilogue which means defining what happens when calling a function and when returning from a function. They also align the stack to be multiple of 16 bytes so it is more efficient and cache friendly. They are written in assembly language</li>
<li>It sets %rbp to zero because <code>main</code> would be the outermost frame</li>
<li>Finally it calls:</li>
</ul>
<div class="sourceCode" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode c code-with-copy"><code class="sourceCode c"><span id="cb3-1">    exit<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(</span>main<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">(</span>ps_strings<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span>ps_nargvstr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">,</span> ps_strings<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span>ps_argvstr<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">,</span> environ<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">));</span></span></code></pre></div>
<ul>
<li>After the NULL of <code>envp</code>, there is ELF auxiliary vector that the loader uses to provide information to the process such as user id and page size etc.</li>
<li>Therefore, <code>__libc_start_main</code> in general does the following:
<ul>
<li>Set up argv and envp</li>
<li>Initialize the thread local storage by calling <code>__pthread_initialize_minimal</code> (which only calls <code>__libc_setup_tls</code>). <code>__libc_setup_tls</code> will initialize Thread Control Block and Dynamic Thread Vector.</li>
<li>Set up the thread stack guard</li>
<li>Register the destructor (i.e.&nbsp;the rtld_fini argument passed to <code>__libc_start_main</code>) of the dynamic linker (by calling <code>__cxa_atexit</code>) if there is any</li>
<li>Initialize Glibc itself by calling <code>__libc_init_first</code></li>
<li>Register <code>__libc_csu_fini</code> (i.e.&nbsp;the fini argument passed to <code>__libc_start_main</code>) using <code>__cxa_atexit</code></li>
<li>Call <code>__libc_csu_init</code> (i.e.&nbsp;the init argument passed to <code>__libc_start_main</code>). <code>__libc_csu_init</code> execute them in the following order:
<ul>
<li>Function pointers in .preinit_array section</li>
<li>Functions marked as <code>__attribute__ ((constructor))</code>, via <code>_init</code></li>
<li>Function pointers in <code>.init_array</code> section</li>
</ul></li>
<li>Set up data structures needed for thread unwinding/cancellation</li>
<li>Call main of user’s program.</li>
<li>Call <code>exit</code>
<ul>
<li>In reverse order, functions registered via <code>atexit</code> or <code>on_exit</code></li>
<li>Function pointers in <code>.fini_array</code> section, via <code>__libc_csu_fini</code></li>
<li>Functions marked as <code>__attribute__ ((destructor))</code>, via <code>__libc_csu_fini</code> (which calls <code>_fini</code> after Step 2)</li>
<li>stdio cleanup functions</li>
<li>The <code>.fini_array</code> section must also contain function pointers and the prototype is like the destructor, i.e.&nbsp;taking no arguments and returning void. If the program exits normally, then the exit function (Glibc source file stdlib/exit.c)</li>
</ul></li>
</ul></li>
</ul></li>
</ol>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion">Conclusion</h2>
<p>So starting program will call <code>execve</code> that starts the loader that at some point pass control to <code>_start</code>, which calls <code>__libc_start_main</code> which calls <code>__libc_csu_init</code> which calls <code>_init</code>.</p>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Software Engineering</category>
  <guid>https://imaddabbura.github.io/posts/c/program-startup-notes.html</guid>
  <pubDate>Fri, 21 Oct 2022 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/c/linux-prog-startup.png" medium="image" type="image/png" height="123" width="144"/>
</item>
<item>
  <title>The Transformer Architecture: A Deep Dive</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/nlp/Transformer-Architecture-Explained.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p>If you’ve called <code>from transformers import BertModel</code> or prompted GPT-4, you’ve used a Transformer. But <em>what actually happens</em> when it processes text? Why does attention use three separate projections — Q, K, and V? Why does the decoder need a causal mask?</p>
<p>The Transformer displaced a decade of sequence modelling research not because it was more complex, but because it was more general: the same architecture, with minimal changes, now handles text, images, protein structures, and audio. Understanding <em>why</em> it generalises is what separates someone who can use these models from someone who can reason about them.</p>
<p>This post builds one from scratch — understanding the motivation behind every design choice before any code. By the end, you will be able to:</p>
<ol type="1">
<li>Explain <em>why</em> each component exists, not just what it does</li>
<li>Trace a forward pass through the full encoder-decoder architecture, step by step</li>
<li>Understand the three architecture variants (encoder-only, decoder-only, encoder-decoder) and when to use each</li>
<li>Read modern Transformer papers and recognise the improvements they describe</li>
</ol>
<p>We start with the problem that motivated the Transformer (sequential bottlenecks in RNNs), build the attention mechanism from scratch, implement each component in PyTorch with annotated shapes, and assemble all three architecture variants — using the architecture diagram below as our map throughout.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/transformer-arch.png" class="lightbox" data-glightbox="description: .lightbox-desc-1" data-gallery="quarto-lightbox-gallery-1" title="Figure 1: The encoder-decoder Transformer (Vaswani et al., 2017) — the architecture we’ll build in this post. Left stack (Encoder): reads the full source sequence; every token attends to every other token with no masking. Right stack (Decoder): generates the target sequence one token at a time; each layer has three sublayers: ① masked self-attention (tokens attend only to past positions), ② cross-attention (Q from decoder, K and V from the encoder output — the arrow connecting the two stacks), and ③ a feed-forward network. “N×” means the layer block repeats N times (typically 6–12). Add &amp; Norm is a residual connection followed by LayerNorm. The Linear + Softmax at the top projects the decoder’s final representation to a probability distribution over the vocabulary. Every component labelled here has its own section below. (source)"><img src="https://imaddabbura.github.io/posts/nlp/images/transformer-arch.png" class="img-fluid quarto-figure quarto-figure-center figure-img" width="600" alt="Figure 1: The encoder-decoder Transformer (Vaswani et al., 2017) — the architecture we’ll build in this post. Left stack (Encoder): reads the full source sequence; every token attends to every other token with no masking. Right stack (Decoder): generates the target sequence one token at a time; each layer has three sublayers: ① masked self-attention (tokens attend only to past positions), ② cross-attention (Q from decoder, K and V from the encoder output — the arrow connecting the two stacks), and ③ a feed-forward network. “N×” means the layer block repeats N times (typically 6–12). Add &amp; Norm is a residual connection followed by LayerNorm. The Linear + Softmax at the top projects the decoder’s final representation to a probability distribution over the vocabulary. Every component labelled here has its own section below. (source)"></a></p>
</figure>
</div>
<figcaption><strong>Figure 1:</strong> The encoder-decoder Transformer (Vaswani et al., 2017) — the architecture we’ll build in this post. <strong>Left stack (Encoder):</strong> reads the full source sequence; every token attends to every other token with no masking. <strong>Right stack (Decoder):</strong> generates the target sequence one token at a time; each layer has three sublayers: ① masked self-attention (tokens attend only to past positions), ② <strong>cross-attention</strong> (Q from decoder, K and V from the encoder output — the arrow connecting the two stacks), and ③ a feed-forward network. <strong>“N×”</strong> means the layer block repeats N times (typically 6–12). <strong>Add &amp; Norm</strong> is a residual connection followed by LayerNorm. The <strong>Linear + Softmax</strong> at the top projects the decoder’s final representation to a probability distribution over the vocabulary. Every component labelled here has its own section below. (<a href="https://arxiv.org/abs/1706.03762">source</a>)</figcaption>
</figure>
</div>
</section>
<section id="the-problem-why-not-rnns" class="level2">
<h2 class="anchored" data-anchor-id="the-problem-why-not-rnns">1. The Problem: Why Not RNNs?</h2>
<p>To understand <em>why</em> the Transformer is designed the way it is, you first need to understand what it replaced — and what was fundamentally broken about it.</p>
<section id="the-rnn-mental-model" class="level3">
<h3 class="anchored" data-anchor-id="the-rnn-mental-model">1.1 The RNN Mental Model</h3>
<p>A Recurrent Neural Network processes a sequence one token at a time. After seeing each token <img src="https://latex.codecogs.com/png.latex?w_t">, it updates a fixed-size <strong>hidden state</strong> <img src="https://latex.codecogs.com/png.latex?h_t"> that is supposed to summarize everything the model has seen so far:</p>
<p><img src="https://latex.codecogs.com/png.latex?h_t%20=%20f(h_%7Bt-1%7D,%5C,%20w_t)"></p>
<p>The hidden state is then passed to the next step. Think of it as a single notepad that a reader carries through a book, rewriting one paragraph of notes after each page. By the time they reach page 500, that notepad contains almost nothing from page 1 — there simply wasn’t room to preserve it through 499 rewrites.</p>
<p>This is not a metaphor for a failure mode; it is the fundamental architectural constraint. The RNN must compress all prior context into a fixed-size vector, and that compression is lossy by design.</p>
</section>
<section id="the-long-range-dependency-problem" class="level3">
<h3 class="anchored" data-anchor-id="the-long-range-dependency-problem">1.2 The Long-Range Dependency Problem</h3>
<p>Language is full of dependencies that span many tokens. Consider:</p>
<blockquote class="blockquote">
<p><em>“The trophy didn’t fit in the suitcase because <strong>it</strong> was too large.”</em></p>
</blockquote>
<p>To resolve what “it” refers to, a model must connect a pronoun near the end of the sentence back to a noun near the beginning. In an RNN, that connection must survive through every intermediate hidden state update. Each update potentially overwrites or dilutes the earlier information. The longer the sequence, the worse this gets.</p>
</section>
<section id="the-vanishing-gradient-problem" class="level3">
<h3 class="anchored" data-anchor-id="the-vanishing-gradient-problem">1.3 The Vanishing Gradient Problem</h3>
<p>The training-time failure mirrors the inference-time failure. When we backpropagate through an RNN, the gradient of the loss with respect to early hidden states is a product of Jacobians — one per time step:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20h_0%7D%20=%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20h_T%7D%20%5Cprod_%7Bt=1%7D%5E%7BT%7D%20%5Cfrac%7B%5Cpartial%20h_t%7D%7B%5Cpartial%20h_%7Bt-1%7D%7D"></p>
<p>If the entries of those Jacobians are consistently less than 1 (common with bounded activations like tanh), the product shrinks exponentially with <img src="https://latex.codecogs.com/png.latex?T">. Gradients from the loss signal barely reach the early time steps, so the model cannot learn from long-range dependencies even if it wanted to.</p>
<p>LSTMs and GRUs mitigate this with gating mechanisms, but they don’t eliminate it — they just slow the decay. (For a full treatment of LSTMs and their gating solution, see the <a href="../../posts/nlp/LSTM-Annotated-Implementation.html">Inside LSTMs</a> post.)</p>
</section>
<section id="the-sequential-processing-bottleneck" class="level3">
<h3 class="anchored" data-anchor-id="the-sequential-processing-bottleneck">1.4 The Sequential Processing Bottleneck</h3>
<p>RNNs are inherently sequential: you cannot compute <img src="https://latex.codecogs.com/png.latex?h_t"> until you have <img src="https://latex.codecogs.com/png.latex?h_%7Bt-1%7D">. This makes it impossible to parallelize across the time dimension. For a sequence of length <img src="https://latex.codecogs.com/png.latex?T">, the forward pass requires <img src="https://latex.codecogs.com/png.latex?T"> sequential steps regardless of how many GPUs you have.</p>
<p>Modern GPUs are massively parallel processors — they shine on matrix multiplications that can be batched across thousands of operations simultaneously. RNNs waste almost all of that capacity.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Three Failure Modes
</div>
</div>
<div class="callout-body-container callout-body">
<p>RNNs fail in three compounding ways: (1) the hidden state <strong>bottleneck</strong> loses information over long sequences; (2) <strong>vanishing gradients</strong> prevent learning long-range relationships from the training signal; (3) <strong>sequential computation</strong> prevents parallelization, making training slow regardless of hardware. The Transformer addresses all three — not with patches, but by replacing sequential recurrence with a fundamentally different mechanism.</p>
</div>
</div>
</section>
</section>
<section id="the-big-idea-attention-as-direct-communication" class="level2">
<h2 class="anchored" data-anchor-id="the-big-idea-attention-as-direct-communication">2. The Big Idea: Attention as Direct Communication</h2>
<p>The central insight of the Transformer is deceptively simple: throw out sequential processing entirely and let every token communicate directly with every other token, in a single parallel operation.</p>
<section id="from-sequential-relay-to-direct-access" class="level3">
<h3 class="anchored" data-anchor-id="from-sequential-relay-to-direct-access">2.1 From Sequential Relay to Direct Access</h3>
<p>With an RNN, every relationship between tokens must be mediated through the hidden state — information travels through a long chain before it reaches its destination. With attention, every token asks every other token directly: <em>“How relevant are you to me?”</em> The answer shapes what information each token receives.</p>
<p>This is a fundamentally different computational paradigm: instead of routing information through a bottleneck, we create a <strong>direct, differentiable communication channel</strong> between all pairs of tokens simultaneously. The attention matrix for a sequence of length <img src="https://latex.codecogs.com/png.latex?T"> is <img src="https://latex.codecogs.com/png.latex?T%20%5Ctimes%20T"> — every pair gets its own weight.</p>
</section>
<section id="the-library-analogy-query-key-value" class="level3">
<h3 class="anchored" data-anchor-id="the-library-analogy-query-key-value">2.2 The Library Analogy: Query, Key, Value</h3>
<p>The attention mechanism is most naturally understood as a <strong>soft database lookup</strong>.</p>
<p>Imagine walking into a library. You have a <strong>query</strong> in mind — say, you’re looking for books about long-range dependencies in sequences. Every book in the library has a <strong>key</strong> on its spine: a short descriptor of what’s inside. You compare your query against every key, computing a relevance score for each book. Then you retrieve the <strong>values</strong> — the actual content — weighted by those relevance scores. The most relevant books contribute the most to what you walk away knowing.</p>
<p>This is exactly what the Transformer’s attention mechanism does at every layer, for every token:</p>
<ul>
<li><strong>Query (<img src="https://latex.codecogs.com/png.latex?Q">)</strong>: what this token is looking for</li>
<li><strong>Key (<img src="https://latex.codecogs.com/png.latex?K">)</strong>: what this token offers to match against</li>
<li><strong>Value (<img src="https://latex.codecogs.com/png.latex?V">)</strong>: what this token actually communicates if attended to</li>
</ul>
<p>The attended output for each token is a weighted mixture of all value vectors, where the weights are determined by the similarity between that token’s query and all other tokens’ keys.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Key Insight
</div>
</div>
<div class="callout-body-container callout-body">
<p>Attention is not a neural network layer in the traditional sense — it is a <strong>soft, differentiable database lookup</strong>. It is differentiable because the retrieval weights are produced by a smooth function (softmax), so gradients flow through the lookup operation during backpropagation. The queries, keys, and values are all learned — the model learns <em>what to look for</em>, <em>what to advertise</em>, and <em>what to say</em>.</p>
</div>
</div>
<p>We’ll return to this library analogy throughout — it explains why Q, K, and V need to be separate projections, and what the attention weights actually represent numerically.</p>
</section>
<section id="why-this-architecture-generalizes-beyond-language" class="level3">
<h3 class="anchored" data-anchor-id="why-this-architecture-generalizes-beyond-language">2.3 Why This Architecture Generalizes Beyond Language</h3>
<p>Here is the deeper insight that explains why Vision Transformers, AlphaFold2, audio Transformers, and point cloud Transformers all use the same architecture as BERT and GPT — often with almost no modification.</p>
<p><strong>The Transformer has almost no structural inductive bias.</strong> CNNs assume that nearby pixels are related — they bake in locality and translation equivariance as a prior. RNNs assume sequential order — they process left-to-right by construction. The Transformer assumes <em>nothing</em> about the structure of its input beyond what the positional encoding tells it. Every pair of positions is treated symmetrically by the attention mechanism until the training data says otherwise.</p>
<p>This is simultaneously the weakness and the superpower:</p>
<ul>
<li><strong>Weakness</strong>: Without structural priors, the model needs more data to learn relationships that CNNs or RNNs would pick up for free. A CNN learns “adjacent pixels tend to be related” from very few examples; a Transformer must discover this from data.</li>
<li><strong>Superpower</strong>: Any domain with a set of elements you want to relate to each other can be modeled by a Transformer. Images? Treat patches as tokens, inject 2D positional encodings (ViT). Proteins? Treat amino acids as tokens, use pairwise distances as positional information (AlphaFold2). Audio? Treat spectrogram frames as tokens. Graphs? Treat nodes as tokens.</li>
</ul>
<p>The key insight: <strong>positional encoding is the only thing that changes across domains.</strong> The attention mechanism, FFN, LayerNorm, and residual connections are entirely domain-agnostic. Swap the positional encoding and the same architecture processes any structured data. This is why the Transformer became the universal architecture — not because it is uniquely suited to language, but because it is uniquely <em>generic</em>.</p>
</section>
</section>
<section id="tokenization-from-text-to-numbers" class="level2">
<h2 class="anchored" data-anchor-id="tokenization-from-text-to-numbers">3. Tokenization: From Text to Numbers</h2>
<p>Before anything else, raw text must be converted into numbers that the model can process. This conversion — <strong>tokenization</strong> — splits text into a vocabulary of subword units and maps each unit to an integer ID. The Transformer receives a <code>B × T</code> matrix of integers as input, where <code>B</code> is the batch size and <code>T</code> is the sequence length.</p>
<p>There are three families of tokenization strategy — character-level, word-level, and subword — each with distinct tradeoffs in vocabulary size, sequence length, and out-of-vocabulary handling. Modern language models universally use <strong>subword tokenization</strong> (BPE or WordPiece), which offers a vocabulary of tens of thousands of tokens while gracefully handling rare and novel words by decomposing them into known pieces.</p>
<p>This post focuses on the Transformer architecture that consumes tokenized sequences, not on tokenization itself. For a deep dive into how tokenization works:</p>
<ul>
<li><a href="../../posts/nlp/Tokenization-Strategies.html"><strong>Breaking Text Apart (The Smart Way)</strong></a> — covers all three strategies, the four-stage tokenization pipeline (normalization, pretokenization, subword model, postprocessing), WordPiece (BERT), and SentencePiece (LLaMA, XLM-R)</li>
<li><a href="../../posts/nlp/BPE-Tokenizer.html"><strong>Byte Pair Encoding from Scratch</strong></a> — builds a BPE tokenizer from scratch, explains the training vs.&nbsp;encoding asymmetry, vocabulary size tradeoffs, and GPT-2’s regex pre-tokenization refinement</li>
</ul>
</section>
<section id="embedding-layer" class="level2">
<h2 class="anchored" data-anchor-id="embedding-layer">4. Embedding Layer</h2>
<p>The embedding layer is the first thing the model does with the token IDs it receives. It has two jobs: turn integers into meaningful vectors, and inject positional information so the model knows where each token sits in the sequence.</p>
<section id="token-embeddings" class="level3">
<h3 class="anchored" data-anchor-id="token-embeddings">4.1 Token Embeddings</h3>
<p>An integer ID has no geometric structure. The number 42 is not “close to” 41 in any meaningful sense for language — the token at position 42 in the vocabulary might be completely unrelated to token 41. Neural networks need continuous-valued vectors they can do math on: compute dot products, measure distances, apply linear transformations.</p>
<p>A token embedding is a <strong>lookup table</strong>: a matrix of shape <code>vocab_sz × embed_dim</code> where each row is a learnable vector associated with one token. When the model sees token ID <img src="https://latex.codecogs.com/png.latex?i">, it looks up row <img src="https://latex.codecogs.com/png.latex?i"> and uses that vector downstream.</p>
<p>What makes embeddings powerful is that training forces semantically similar tokens into nearby regions of this vector space. After training on enough text, the embedding for “king” minus the embedding for “man” plus the embedding for “woman” lands close to “queen” — not because we encoded this relationship by hand, but because the training signal shaped the space that way.</p>
<blockquote class="blockquote">
<p><em>An embedding turns a name tag into a GPS coordinate — suddenly you can measure distance, find neighbors, and do arithmetic.</em></p>
</blockquote>
<p><strong>Shape:</strong> <code>B × T</code> (integer IDs) → <code>B × T × embed_dim</code> (float vectors)</p>
<p><strong>Weight tying.</strong> In most language models, the embedding matrix is <em>reused</em> as the output projection at the end of the network — the final linear layer that maps from <code>d_model</code> back to <code>vocab_sz</code> uses the same weights, transposed. This is called <strong>weight tying</strong>. The core intuition: if two tokens have similar embeddings (i.e., they are semantically close), they should also receive similar probabilities when the model generates the next-token distribution. Since the LM head scores each candidate token by taking the dot product of the model’s output vector with that token’s embedding row, tokens whose embedding vectors are close to the output vector will score similarly — producing nearby probabilities. Weight tying enforces this consistency directly: the same geometry that groups similar tokens together in the input space also determines their relative scores in the output distribution. As a bonus, it halves the parameter count for the vocabulary components (often 30–100K tokens × 768 dims = a significant share of total parameters) and keeps input and output representations aligned throughout training.</p>
</section>
<section id="positional-encodings" class="level3">
<h3 class="anchored" data-anchor-id="positional-encodings">4.2 Positional Encodings</h3>
<section id="why-position-carries-meaning" class="level4">
<h4 class="anchored" data-anchor-id="why-position-carries-meaning">Why position carries meaning</h4>
<p>Word order is one of the primary mechanisms through which human languages encode meaning. Consider how much information is carried purely by where a word sits in a sentence:</p>
<p><strong>Order determines who does what to whom.</strong> “The dog bit the man” and “The man bit the dog” contain identical tokens. The meaning is completely reversed. Without positional information, a model sees the same set of embeddings for both — it cannot distinguish them.</p>
<p><strong>Agreement and dependency span long distances.</strong> In <em>“The cats that live in the house <strong>are</strong> noisy”</em>, the verb “are” must agree with “cats” (plural), not “house” (singular). Correctly resolving this requires knowing that “cats” appears before the relative clause, and “house” is inside it — a structural relationship determined entirely by position.</p>
<p><strong>Negation scope is positional.</strong> <em>“I <strong>never</strong> said she stole the money”</em> and <em>“I said she <strong>never</strong> stole the money”</em> contain the same words. The position of “never” determines the scope of negation — whether the speaker denies making the claim or denies the theft itself.</p>
<p><strong>Modifier attachment is determined by proximity.</strong> <em>“I photographed the man with a telescope”</em> is ambiguous in isolation. In context, positional proximity to either “man” or “photographed” is the primary cue for whether the telescope was used for photographing or was held by the man.</p>
<p>In short: token embeddings capture <em>what</em> each word means in isolation; positional encodings capture <em>where</em> each word sits, which encodes its grammatical role, its relationships to surrounding words, and its structural function in the sentence.</p>
</section>
<section id="the-permutation-equivariance-problem" class="level4">
<h4 class="anchored" data-anchor-id="the-permutation-equivariance-problem">The permutation equivariance problem</h4>
<p>Here is the technical issue: attention is <strong>permutation equivariant</strong>. If you reorder the input tokens, the output tokens reorder identically — the attention mechanism has no internal sense of sequence order. From the model’s perspective, “the cat sat on the mat” and “the mat sat on the cat” produce the same set of output vectors (just shuffled). Position is invisible.</p>
<p>To fix this, we add <strong>positional encodings</strong> to the token embeddings before feeding them into the Transformer. The result: two otherwise identical tokens at different positions get different combined representations, making order visible to every downstream layer.</p>
<p>There are three main strategies:</p>
</section>
<section id="strategy-1-sinusoidal-encodings-original-paper" class="level4">
<h4 class="anchored" data-anchor-id="strategy-1-sinusoidal-encodings-original-paper">Strategy 1: Sinusoidal Encodings (Original Paper)</h4>
<p>The original Transformer paper uses fixed, non-learned positional encodings based on sine and cosine functions at different frequencies:</p>
<p><img src="https://latex.codecogs.com/png.latex?PE_%7B(pos,%5C,%202i)%7D%20=%20%5Csin%5C!%5Cleft(%5Cfrac%7Bpos%7D%7B10000%5E%7B2i/d%7D%7D%5Cright)"> <img src="https://latex.codecogs.com/png.latex?PE_%7B(pos,%5C,%202i+1)%7D%20=%20%5Ccos%5C!%5Cleft(%5Cfrac%7Bpos%7D%7B10000%5E%7B2i/d%7D%7D%5Cright)"></p>
<p><strong>Why low dimensions oscillate fast and high dimensions oscillate slowly</strong> comes directly from the formula. The denominator <img src="https://latex.codecogs.com/png.latex?10000%5E%7B2i/d%7D"> is the key — it grows exponentially with the dimension index <img src="https://latex.codecogs.com/png.latex?i">. Dividing by a larger number slows the wave down:</p>
<table class="table">
<colgroup>
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th>Dimension <img src="https://latex.codecogs.com/png.latex?i"></th>
<th>Denominator <img src="https://latex.codecogs.com/png.latex?10000%5E%7B2i/d%7D"></th>
<th>Wave period (positions for one full cycle)</th>
<th>What it encodes</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?i%20=%200"></td>
<td><img src="https://latex.codecogs.com/png.latex?1"></td>
<td>~6 positions</td>
<td>Very fine — distinguishes adjacent tokens</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?i%20=%20d/8"></td>
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%2018"></td>
<td>~110 positions</td>
<td>Phrase-level distance</td>
</tr>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?i%20=%20d/4"></td>
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%20316"></td>
<td>~2,000 positions</td>
<td>Sentence-level distance</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?i%20=%20d/2"></td>
<td><img src="https://latex.codecogs.com/png.latex?10%7B,%7D000"></td>
<td>~62,800 positions</td>
<td>Barely changes — encodes very coarse, document-level position</td>
</tr>
</tbody>
</table>
<p>Think of it as an odometer: the rightmost digit (low dimension) flips every meter, the leftmost digit (high dimension) barely moves over a typical journey. Each digit alone is ambiguous — the rightmost digit of “7” could be position 7, 17, 27, or 107. But all digits together uniquely identify every position.</p>
<p>This multi-scale design is intentional. <strong>Low dimensions</strong> give the model a fine-grained signal that changes every few positions — useful for detecting whether two tokens are immediate neighbors. <strong>High dimensions</strong> give a coarse signal that changes only over long distances — useful for detecting whether two tokens are in the same half of the document. Together, the full vector is unique for every position from 0 to the maximum.</p>
<p>The advantage of sinusoidal encodings is that they can generalize to sequence lengths longer than those seen during training — the functions extend naturally to any position.</p>
</section>
<section id="strategy-2-learned-absolute-encodings" class="level4">
<h4 class="anchored" data-anchor-id="strategy-2-learned-absolute-encodings">Strategy 2: Learned Absolute Encodings</h4>
<p>Instead of fixing the positional encoding by formula, we can make it a learned parameter — another <code>nn.Embedding</code> table of shape <code>max_seq_len × embed_dim</code>. Each position from 0 to <code>max_seq_len-1</code> gets its own learnable row, updated via backpropagation just like token embeddings.</p>
<p>This is what BERT and GPT use. The model learns what positional fingerprints work best for its task. The downside: sequences longer than <code>max_seq_len</code> seen during training have no positional encoding — the model has never learned what those positions mean.</p>
</section>
<section id="strategy-3-rotary-positional-encoding-rope" class="level4">
<h4 class="anchored" data-anchor-id="strategy-3-rotary-positional-encoding-rope">Strategy 3: Rotary Positional Encoding (RoPE)</h4>
<p>RoPE, introduced by Su et al.&nbsp;(2021) and used in LLaMA, Mistral, and GPT-NeoX, takes a fundamentally different approach: instead of <em>adding</em> a fixed vector to the embeddings, it <em>rotates</em> the query and key vectors by an angle proportional to their absolute position before computing the attention dot product.</p>
<p>The key property: when you rotate <img src="https://latex.codecogs.com/png.latex?Q"> at position <img src="https://latex.codecogs.com/png.latex?m"> and <img src="https://latex.codecogs.com/png.latex?K"> at position <img src="https://latex.codecogs.com/png.latex?n">, their dot product becomes a function of only the <em>relative distance</em> <img src="https://latex.codecogs.com/png.latex?m%20-%20n">:</p>
<p><img src="https://latex.codecogs.com/png.latex?Q_m%20%5Ccdot%20K_n%20=%20f(m%20-%20n)"></p>
<p>This is highly desirable. Relative position — how far apart two tokens are — is often more informative than absolute position. Whether “cat” is token 5 or token 50 in the sentence matters less than how far it sits from the verb it modifies. Syntactic dependencies (subject → verb, adjective → noun) tend to hold over short distances regardless of where the sentence begins. RoPE bakes this directly into the attention computation at every layer, without requiring separate positional embedding vectors.</p>
<p>RoPE also generalizes better to longer sequences than the model was trained on, making it the dominant choice in modern open-source LLMs.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Implementation Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>The code below uses learned absolute positional embeddings — the simplest approach and standard for BERT-style encoder models. The embedding layer adds the token embedding and positional embedding, normalizes with LayerNorm, and applies dropout.</p>
</div>
</div>
<div id="9a092f68-2275-4456-acfa-39f02a2ffe24" class="cell">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> dataclasses <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> dataclass</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.nn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> nn</span>
<span id="cb1-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.nn.functional <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> F</span></code></pre></div>
</details>
</div>
<div id="5a6d2afd-e157-421f-bfe6-d4dc0f73af5f" class="cell">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="at" style="color: #657422;
background-color: null;
font-style: inherit;">@dataclass</span></span>
<span id="cb2-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> TransformerConfig:</span>
<span id="cb2-3">    vocab_sz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span></span>
<span id="cb2-4">    block_sz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span></span>
<span id="cb2-5">    hidden_dropout_prob: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.2</span></span>
<span id="cb2-6">    num_attention_heads: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span></span>
<span id="cb2-7">    num_hidden_layers: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span></span>
<span id="cb2-8">    embed_dim: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">768</span></span>
<span id="cb2-9">    num_classes: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span></span>
<span id="cb2-10">    layer_norm_eps: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1e-12</span></span>
<span id="cb2-11">    intermediate_sz: <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">int</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># set to 4 * embed_dim in __post_init__</span></span>
<span id="cb2-12"></span>
<span id="cb2-13">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> __post_init__(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>):</span>
<span id="cb2-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.intermediate_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>:</span>
<span id="cb2-15">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.intermediate_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.embed_dim</span>
<span id="cb2-16"></span>
<span id="cb2-17">config <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> TransformerConfig()</span></code></pre></div>
</details>
</div>
<div id="51562c60-9883-4e62-b7c0-6dd0091ef203" class="cell">
<div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> Embeddings(nn.Module):</span>
<span id="cb3-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config):</span>
<span id="cb3-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb3-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.token_embedding <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Embedding(config.vocab_sz, config.embed_dim)</span>
<span id="cb3-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.position_embedding <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Embedding(config.block_sz, config.embed_dim)</span>
<span id="cb3-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.embed_dim, eps<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>config.layer_norm_eps)</span>
<span id="cb3-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Dropout(p<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>)</span>
<span id="cb3-8"></span>
<span id="cb3-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb3-10">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x:              B x T  (integer token IDs)</span></span>
<span id="cb3-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## token_emb:      B x T x embed_dim</span></span>
<span id="cb3-12">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## position_emb:   T x embed_dim  (broadcast over batch)</span></span>
<span id="cb3-13">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## output:         B x T x embed_dim</span></span>
<span id="cb3-14">        seq_len <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]</span>
<span id="cb3-15">        positions <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.arange(seq_len, device<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>x.device)</span>
<span id="cb3-16">        embeddings <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.token_embedding(x) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.position_embedding(positions)</span>
<span id="cb3-17">        embeddings <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm(embeddings)</span>
<span id="cb3-18">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout(embeddings)</span></code></pre></div>
</div>
<div id="aa000001-pe00-heat-map0-000000000001" class="cell" data-execution_count="1">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb4-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb4-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib</span>
<span id="cb4-4"></span>
<span id="cb4-5">matplotlib.rcParams[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"figure.dpi"</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">150</span></span>
<span id="cb4-6"></span>
<span id="cb4-7"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> sinusoidal_pe(seq_len, d_model):</span>
<span id="cb4-8">    pe  <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros((seq_len, d_model))</span>
<span id="cb4-9">    pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.arange(seq_len)[:, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>]</span>
<span id="cb4-10">    i   <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.arange(d_model)[<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>, :]</span>
<span id="cb4-11">    div <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10000</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> (i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> d_model)</span>
<span id="cb4-12">    pe[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>::<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.sin(pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> div[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>::<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb4-13">    pe[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>::<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.cos(pos <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> div[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>::<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb4-14">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> pe</span>
<span id="cb4-15"></span>
<span id="cb4-16">pe <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> sinusoidal_pe(seq_len<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">64</span>, d_model<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">128</span>)</span>
<span id="cb4-17"></span>
<span id="cb4-18">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">3.5</span>))</span>
<span id="cb4-19">img <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ax.imshow(pe.T, cmap<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"RdBu_r"</span>, aspect<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"auto"</span>, vmin<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, vmax<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb4-20">ax.set_xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Position in sequence"</span>, fontsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span>)</span>
<span id="cb4-21">ax.set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Embedding dimension"</span>, fontsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span>)</span>
<span id="cb4-22">ax.set_title(</span>
<span id="cb4-23">    <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Sinusoidal Positional Encoding  —  each column is a unique position fingerprint"</span>,</span>
<span id="cb4-24">    fontsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">11</span>, pad<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span></span>
<span id="cb4-25">)</span>
<span id="cb4-26">plt.colorbar(img, ax<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>ax, fraction<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.015</span>, pad<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.02</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"Encoding value"</span>)</span>
<span id="cb4-27">plt.tight_layout()</span>
<span id="cb4-28">plt.savefig(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"positional-encoding-heatmap.png"</span>, dpi<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">150</span>, bbox_inches<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"tight"</span>)</span>
<span id="cb4-29">plt.show()</span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="Transformer-Architecture-Explained_files/figure-html/cell-5-output-1.png" class="lightbox" data-glightbox="description: .lightbox-desc-2" data-gallery="quarto-lightbox-gallery-2" title="Figure 2: Sinusoidal positional encoding. Each column is a unique fingerprint for one position. Low-frequency components (bottom rows) vary slowly — encoding coarse, sentence-level position. High-frequency components (top rows) vary quickly — encoding fine, word-level position. Together, every position from 0 to max_seq_len gets a unique vector."><img src="https://imaddabbura.github.io/posts/nlp/Transformer-Architecture-Explained_files/figure-html/cell-5-output-1.png" class="img-fluid figure-img" alt="Figure 2: Sinusoidal positional encoding. Each column is a unique fingerprint for one position. Low-frequency components (bottom rows) vary slowly — encoding coarse, sentence-level position. High-frequency components (top rows) vary quickly — encoding fine, word-level position. Together, every position from 0 to max_seq_len gets a unique vector."></a></p>
<figcaption><strong>Figure 2:</strong> Sinusoidal positional encoding. Each column is a unique fingerprint for one position. Low-frequency components (bottom rows) vary slowly — encoding coarse, sentence-level position. High-frequency components (top rows) vary quickly — encoding fine, word-level position. Together, every position from 0 to max_seq_len gets a unique vector.</figcaption>
</figure>
</div>
</div>
</div>
</section>
</section>
</section>
<section id="scaled-dot-product-attention" class="level2">
<h2 class="anchored" data-anchor-id="scaled-dot-product-attention">5. Scaled Dot-Product Attention</h2>
<p>Attention is the core computation that makes everything else in the Transformer work. Everything up to this point — embeddings, positional encodings — has been preprocessing. <em>This</em> is the operation that enables direct token-to-token communication.</p>
<p>This section builds it up step by step: from the Q/K/V projections, through the dot-product similarity, scaling, softmax, and masking. By the end, the library analogy from Section 2 will have a precise mathematical form.</p>
<section id="the-three-projections-query-key-value" class="level3">
<h3 class="anchored" data-anchor-id="the-three-projections-query-key-value">5.1 The Three Projections: Query, Key, Value</h3>
<p>Given an input sequence <img src="https://latex.codecogs.com/png.latex?x"> of shape <code>B × T × embed_dim</code>, we produce three separate linear projections:</p>
<p><img src="https://latex.codecogs.com/png.latex?Q%20=%20xW_Q,%20%5Cquad%20K%20=%20xW_K,%20%5Cquad%20V%20=%20xW_V"></p>
<p>Each weight matrix (<img src="https://latex.codecogs.com/png.latex?W_Q">, <img src="https://latex.codecogs.com/png.latex?W_K">, <img src="https://latex.codecogs.com/png.latex?W_V">) has shape <code>embed_dim × head_dim</code>. These are learned parameters — different projection matrices produce different “perspectives” on the same input.</p>
<p><strong>Why three separate projections instead of one?</strong> Because what a token <em>wants</em> (its query), what it <em>offers to match against</em> (its key), and what it <em>actually communicates</em> (its value) are three genuinely different things. Consider how a search engine works: your search query text (Q) is compared against the indexed keywords of a web page (K), but what you actually receive when you click is the full page content (V) — which may be organized completely differently from the index terms. Separating these three roles gives the model the flexibility to learn very different relationships for each.</p>
</section>
<section id="computing-attention-weights-a-worked-example" class="level3">
<h3 class="anchored" data-anchor-id="computing-attention-weights-a-worked-example">5.2 Computing Attention Weights: A Worked Example</h3>
<p>With Q, K, V in hand, the attention weights are computed as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Ctext%7Bweights%7D%20=%20%5Ctext%7Bsoftmax%7D%5C!%5Cleft(%5Cfrac%7BQK%5ET%7D%7B%5Csqrt%7Bd_k%7D%7D%5Cright)"></p>
<p>Let’s trace through this with a concrete 3-token example. Suppose our sequence is [“the”, “cat”, “sat”], with <code>head_dim = 4</code>. After the Q and K projections, imagine we have:</p>
<p><img src="https://latex.codecogs.com/png.latex?Q%20=%20%5Cbegin%7Bbmatrix%7D%201%20&amp;%200%20&amp;%201%20&amp;%200%20%5C%5C%200%20&amp;%201%20&amp;%200%20&amp;%201%20%5C%5C%201%20&amp;%201%20&amp;%200%20&amp;%200%20%5Cend%7Bbmatrix%7D,%5Cquad%20K%20=%20%5Cbegin%7Bbmatrix%7D%201%20&amp;%200%20&amp;%200%20&amp;%201%20%5C%5C%200%20&amp;%201%20&amp;%201%20&amp;%200%20%5C%5C%201%20&amp;%201%20&amp;%200%20&amp;%200%20%5Cend%7Bbmatrix%7D"></p>
<p><strong>Step 1 — Dot products <img src="https://latex.codecogs.com/png.latex?QK%5ET"></strong> (shape <code>3 × 3</code>): Every token’s query is dotted with every token’s key. The <img src="https://latex.codecogs.com/png.latex?(i,j)"> entry measures how much token <img src="https://latex.codecogs.com/png.latex?i"> “wants” to attend to token <img src="https://latex.codecogs.com/png.latex?j">.</p>
<p><img src="https://latex.codecogs.com/png.latex?QK%5ET%20=%20%5Cbegin%7Bbmatrix%7D%201%20&amp;%201%20&amp;%202%20%5C%5C%201%20&amp;%201%20&amp;%200%20%5C%5C%201%20&amp;%202%20&amp;%202%20%5Cend%7Bbmatrix%7D"></p>
<p>Why dot products? Geometrically, the dot product of two vectors is large when they point in similar directions (small angle) and small when they are orthogonal. If a query and key are aligned — the token is “looking for” exactly what the other token “offers” — the dot product is high, and that token will receive a large attention weight.</p>
<p><strong>Step 2 — Scale by <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B1%7D%7B%5Csqrt%7Bd_k%7D%7D%20=%20%5Cfrac%7B1%7D%7B2%7D">:</strong></p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cfrac%7BQK%5ET%7D%7B%5Csqrt%7Bd_k%7D%7D%20=%20%5Cbegin%7Bbmatrix%7D%200.5%20&amp;%200.5%20&amp;%201.0%20%5C%5C%200.5%20&amp;%200.5%20&amp;%200.0%20%5C%5C%200.5%20&amp;%201.0%20&amp;%201.0%20%5Cend%7Bbmatrix%7D"></p>
<p><strong>Step 3 — Softmax row-wise</strong> (each row sums to 1):</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Ctext%7Bweights%7D%20=%20%5Cbegin%7Bbmatrix%7D%200.27%20&amp;%200.27%20&amp;%200.46%20%5C%5C%200.33%20&amp;%200.33%20&amp;%200.33%20%5C%5C%200.21%20&amp;%200.39%20&amp;%200.39%20%5Cend%7Bbmatrix%7D"></p>
<p>Row 1 (token “the”): attends most strongly to “sat” (0.46). Row 2 (“cat”): distributes evenly. Row 3 (“sat”): attends most to “cat” and itself.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/attention-weights-matrix.png" class="lightbox" data-glightbox="description: .lightbox-desc-3" data-gallery="quarto-lightbox-gallery-3" title="Figure 3: Attention weights for the [‘the’, ‘cat’, ‘sat’] example. Rows are query tokens; columns are keys. Each row is a probability distribution — how much each token attends to every other token."><img src="https://imaddabbura.github.io/posts/nlp/images/attention-weights-matrix.png" class="quarto-figure quarto-figure-center figure-img" width="400" height="400" alt="Figure 3: Attention weights for the [‘the’, ‘cat’, ‘sat’] example. Rows are query tokens; columns are keys. Each row is a probability distribution — how much each token attends to every other token."></a></p>
</figure>
</div>
<figcaption><strong>Figure 3:</strong> Attention weights for the [‘the’, ‘cat’, ‘sat’] example. Rows are query tokens; columns are keys. Each row is a probability distribution — how much each token attends to every other token.</figcaption>
</figure>
</div>
<p><strong>Step 4 — Multiply by V</strong> (shape <code>3 × head_dim</code>): The output for each token is a weighted combination of all value vectors, with weights from the softmax step. Token “the” will receive a mix of all three value vectors, weighted 27%/27%/46%. The output is a <strong>contextualized representation</strong> — the same token in a different sentence would produce different weights and therefore a different output vector.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Departure from Static Embeddings
</div>
</div>
<div class="callout-body-container callout-body">
<p>Notice what just happened: the token “the” — which starts with a fixed embedding vector identical in every sentence — now has a representation shaped by the presence of “cat” and “sat.” Run the same token through a different sentence (“the table broke”), and it emerges from attention with a different output vector.</p>
<p>This is the fundamental departure from static word embeddings like word2vec or GloVe: those give every token a single, context-free vector that never changes. A Transformer gives every token a <em>contextual</em> representation — numerically different depending on what surrounds it. “bank” in “river bank” and “bank” in “bank account” start with the same embedding but diverge after attention. This is why Transformer-based representations are so dramatically better at tasks requiring word sense disambiguation, coreference resolution, and syntactic parsing.</p>
</div>
</div>
</section>
<section id="why-scale-by-sqrtd_k" class="level3">
<h3 class="anchored" data-anchor-id="why-scale-by-sqrtd_k">5.3 Why Scale by <img src="https://latex.codecogs.com/png.latex?%5Csqrt%7Bd_k%7D">?</h3>
<p>This is one of those design choices that looks arbitrary until you understand the numerical reason behind it.</p>
<p>Q and K are both initialized as approximately unit-variance random vectors. When you compute their dot product across <code>d_k</code> dimensions, the result has <strong>variance equal to <img src="https://latex.codecogs.com/png.latex?d_k"></strong> (sum of <code>d_k</code> independent unit-variance terms). For a typical <code>head_dim</code> of 64, the raw dot products have standard deviation 8. For <code>head_dim = 768</code>, standard deviation 27.</p>
<p>Large-magnitude inputs to softmax cause a saturation problem. When one logit is much larger than the others, softmax approaches a one-hot distribution — almost all weight goes to one token, and gradients for every other position become negligibly small. The model can only learn from the one token it attends to, and ignores all the rest.</p>
<p>Dividing by <img src="https://latex.codecogs.com/png.latex?%5Csqrt%7Bd_k%7D"> rescales the dot products back to approximately unit variance, regardless of <code>head_dim</code>. Softmax then produces a diffuse distribution — not too concentrated, not too uniform — and gradients flow to all positions during training.</p>
<blockquote class="blockquote">
<p><em>Without scaling, attention becomes a dictatorship: one token captures all the weight and the rest are ignored. Scaling preserves the democracy: every token can contribute to the output.</em></p>
</blockquote>
</section>
<section id="softmax-competition-not-independence" class="level3">
<h3 class="anchored" data-anchor-id="softmax-competition-not-independence">5.4 Softmax: Competition, Not Independence</h3>
<p>Why use softmax and not sigmoid (or any other normalization)?</p>
<p>Sigmoid applied to each attention logit independently would allow a token to “attend highly to everyone” at the same time, with no trade-off. But attention should be selective: attending more to one token means attending less to others.</p>
<p>Softmax is a <strong>competitive normalization</strong> — its outputs sum to 1, so the weights form a probability distribution over the context window. Increasing attention to one token necessarily decreases attention to all others. This forces the model to make decisions about what is relevant rather than attending indiscriminately to everything.</p>
<p><strong>The exponential creates sparsity, not just competition.</strong> Softmax uses <img src="https://latex.codecogs.com/png.latex?e%5Ex">, not a simple linear normalization like dividing by the sum. The exponential amplifies differences: if one logit is 2 points higher than another, it receives <img src="https://latex.codecogs.com/png.latex?e%5E2%20%5Capprox%207%5Ctimes"> more weight — not just 2× more. In practice this means attention patterns are often <em>peaky</em>: a small number of tokens receive the vast majority of the weight, and the rest are nearly zero. This emergent sparsity is what makes attention heads interpretable — a head that attends sharply to the syntactic subject has learned a crisp, readable pattern, not a diffuse smear. It also means that a single highly-relevant token can dominate the output almost entirely, which is the mechanism behind induction heads and other sharp attention circuits found in mechanistic interpretability research.</p>
</section>
<section id="causal-masking-decoder-only" class="level3">
<h3 class="anchored" data-anchor-id="causal-masking-decoder-only">5.5 Causal Masking (Decoder Only)</h3>
<p>In a language model, the task is to predict the next token from all previous tokens. If the model can see token <img src="https://latex.codecogs.com/png.latex?t+1"> while predicting token <img src="https://latex.codecogs.com/png.latex?t">, that is data leakage — the model would just copy the future token rather than learning to predict it.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/causal-mask-matrix.png" class="lightbox" data-glightbox="description: .lightbox-desc-4" data-gallery="quarto-lightbox-gallery-4" title="Figure 4: Causal mask for a 5-token sequence. Green cells (✓) are positions the query token is allowed to attend to; red cells (−∞) are masked out and become 0 after softmax. Each row is a query token; each column is a key token."><img src="https://imaddabbura.github.io/posts/nlp/images/causal-mask-matrix.png" class="img-fluid quarto-figure quarto-figure-center figure-img" alt="Figure 4: Causal mask for a 5-token sequence. Green cells (✓) are positions the query token is allowed to attend to; red cells (−∞) are masked out and become 0 after softmax. Each row is a query token; each column is a key token."></a></p>
</figure>
</div>
<figcaption><strong>Figure 4:</strong> Causal mask for a 5-token sequence. Green cells (✓) are positions the query token is allowed to attend to; red cells (−∞) are masked out and become 0 after softmax. Each row is a query token; each column is a key token.</figcaption>
</figure>
</div>
<p>After softmax, <img src="https://latex.codecogs.com/png.latex?-%5Cinfty"> becomes exactly 0. Token 1 can only attend to itself. Token 3 can attend to tokens 1, 2, and 3 but not 4. The mask enforces a strict information asymmetry: <strong>you can read anything in the past, but nothing in the future</strong>.</p>
<p>This is implemented by registering a lower-triangular buffer in the <code>AttentionHead</code> and calling <code>masked_fill</code> before softmax.</p>
</section>
<section id="self-attention-vs.-cross-attention" class="level3">
<h3 class="anchored" data-anchor-id="self-attention-vs.-cross-attention">5.6 Self-Attention vs.&nbsp;Cross-Attention</h3>
<p><strong>Self-attention</strong>: Q, K, and V all come from the same input sequence <img src="https://latex.codecogs.com/png.latex?x">. Every token attends to every other token within the same sequence. This is what the encoder uses (bidirectional) and the decoder uses for its first sublayer (causal).</p>
<p><strong>Cross-attention</strong>: Q comes from one sequence (the decoder’s hidden state), while K and V come from a different sequence (the encoder’s output). The decoder “reads” the encoder’s representation of the source sequence. This is the mechanism that connects the two halves of an encoder-decoder model.</p>
<p>The generalization is worth stating explicitly: <strong>any two sequences can be related through cross-attention</strong>, simply by using one as the source of Q and the other as the source of K and V. This is the same operation that connects modalities in vision-language models (text queries attend to image patch keys/values), that lets perceiver architectures compress long inputs (a small set of learned query vectors attends to a large input), and that underlies virtually all multi-modal conditioning. Cross-attention is not a feature of encoder-decoder models — it is a universal conditioning primitive.</p>
<p>The full attention equation:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Ctext%7BAttention%7D(Q,%20K,%20V)%20=%20%5Ctext%7Bsoftmax%7D%5C!%5Cleft(%5Cfrac%7BQK%5ET%7D%7B%5Csqrt%7Bd_k%7D%7D%5Cright)V"></p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/scaled-dot-product-attention.png" class="lightbox" data-glightbox="description: .lightbox-desc-5" data-gallery="quarto-lightbox-gallery-5" title="Figure 5: Scaled Dot-Product Attention (source)"><img src="https://imaddabbura.github.io/posts/nlp/images/scaled-dot-product-attention.png" class="quarto-figure quarto-figure-center figure-img" height="400" alt="Figure 5: Scaled Dot-Product Attention (source)"></a></p>
</figure>
</div>
<figcaption><strong>Figure 5:</strong> Scaled Dot-Product Attention (<a href="https://arxiv.org/abs/1706.03762">source</a>)</figcaption>
</figure>
</div>
</section>
<section id="the-quadratic-cost-attentions-fundamental-bottleneck" class="level3">
<h3 class="anchored" data-anchor-id="the-quadratic-cost-attentions-fundamental-bottleneck">5.7 The Quadratic Cost: Attention’s Fundamental Bottleneck</h3>
<p>Computing attention requires forming the full <img src="https://latex.codecogs.com/png.latex?T%20%5Ctimes%20T"> weight matrix — every token’s query dotted against every token’s key. This is <img src="https://latex.codecogs.com/png.latex?O(T%5E2%20%5Ccdot%20d_k)"> time and <img src="https://latex.codecogs.com/png.latex?O(T%5E2)"> memory. For most sentences this is fine. For long documents, it becomes the dominant constraint:</p>
<table class="table">
<thead>
<tr class="header">
<th>Sequence length</th>
<th>Attention matrix</th>
<th>Memory (fp16, 1 head)</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>512 tokens</td>
<td>512 × 512 = 262K</td>
<td>~0.5 MB</td>
</tr>
<tr class="even">
<td>4,096 tokens</td>
<td>4K × 4K = 16.8M</td>
<td>~32 MB</td>
</tr>
<tr class="odd">
<td>128K tokens</td>
<td>128K × 128K = 16.4B</td>
<td>~31 GB</td>
</tr>
</tbody>
</table>
<p>This quadratic growth is why early BERT was capped at 512 tokens, why getting GPT-3 to handle long documents required tricks, and why an entire subfield of <strong>efficient attention</strong> exists — sliding-window attention (Longformer), linear attention, sparse attention (BigBird), and state-space models like Mamba are all attempts to approximate or restructure the <img src="https://latex.codecogs.com/png.latex?T%20%5Ctimes%20T"> computation to grow linearly with sequence length.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
FlashAttention Changes the Hardware Utilization, Not the Complexity
</div>
</div>
<div class="callout-body-container callout-body">
<p>FlashAttention (Dao et al., 2022) is often described as “making attention faster.” What it actually does: reorders the computation to tile through the attention matrix in blocks that fit in GPU SRAM (fast memory), avoiding slow round-trips to HBM (GPU global memory). The FLOPs are identical to standard attention; the memory bandwidth cost drops dramatically — 2–4× wall-clock speedup with numerically identical outputs. It also reduces peak memory from <img src="https://latex.codecogs.com/png.latex?O(T%5E2)"> to <img src="https://latex.codecogs.com/png.latex?O(T)"> by never materializing the full attention matrix. This is why FlashAttention is the standard in every modern training stack, but it does not fix the fundamental quadratic scaling problem for very long contexts.</p>
</div>
</div>
<div id="0540e571-f6de-49d2-9535-e91755a7a78f" class="cell">
<div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> AttentionHead(nn.Module):</span>
<span id="cb5-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config, head_dim, is_decoder<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb5-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb5-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.q <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.embed_dim, head_dim, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb5-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.k <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.embed_dim, head_dim, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb5-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.embed_dim, head_dim, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb5-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.is_decoder <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> is_decoder</span>
<span id="cb5-8">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.is_decoder:</span>
<span id="cb5-9">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.register_buffer(</span>
<span id="cb5-10">                <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"mask"</span>, torch.tril(torch.ones(config.block_sz, config.block_sz))</span>
<span id="cb5-11">            )</span>
<span id="cb5-12"></span>
<span id="cb5-13">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, query, key, value):</span>
<span id="cb5-14">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## query: B x T_q x embed_dim  (source of queries)</span></span>
<span id="cb5-15">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## key:   B x T_k x embed_dim  (source of keys)</span></span>
<span id="cb5-16">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## value: B x T_k x embed_dim  (source of values)</span></span>
<span id="cb5-17">        q <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.q(query)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_q x head_dim</span></span>
<span id="cb5-18">        k <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.k(key)    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_k x head_dim</span></span>
<span id="cb5-19">        v <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.v(value)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_k x head_dim</span></span>
<span id="cb5-20">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## w: B x T_q x T_k  — pairwise similarity between every query and every key</span></span>
<span id="cb5-21">        w <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> q <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> k.transpose(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (k.shape[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>)</span>
<span id="cb5-22">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.is_decoder:</span>
<span id="cb5-23">            T <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> w.shape[<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]</span>
<span id="cb5-24">            w <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> w.masked_fill(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.mask[:T, :T] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">float</span>(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"inf"</span>))</span>
<span id="cb5-25">        w <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> F.softmax(w, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb5-26">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## output: B x T_q x head_dim</span></span>
<span id="cb5-27">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> w <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> v</span></code></pre></div>
</div>
</section>
</section>
<section id="multi-head-attention" class="level2">
<h2 class="anchored" data-anchor-id="multi-head-attention">6. Multi-Head Attention</h2>
<section id="why-multiple-heads" class="level3">
<h3 class="anchored" data-anchor-id="why-multiple-heads">6.1 Why Multiple Heads?</h3>
<p>A single attention head learns one type of relationship between tokens. For example, it might learn to focus on the syntactic subject of a sentence whenever any token is processed — a subject-finding head. But language has many simultaneous relationship types that are all relevant at once:</p>
<ul>
<li><em>Syntactic</em>: subject-verb agreement, noun-adjective agreement</li>
<li><em>Semantic</em>: coreference (“it” → “the trophy”), negation scope</li>
<li><em>Structural</em>: attending to nearby tokens for local context</li>
<li><em>Task-specific</em>: attending to sentiment-bearing words for classification</li>
</ul>
<p>Multiple heads allow the model to learn all of these in parallel. Each head has its own independent weight matrices <img src="https://latex.codecogs.com/png.latex?W_Q%5Eh">, <img src="https://latex.codecogs.com/png.latex?W_K%5Eh">, <img src="https://latex.codecogs.com/png.latex?W_V%5Eh"> that project the same input <img src="https://latex.codecogs.com/png.latex?x"> into a different lower-dimensional subspace:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Ctext%7Bhead%5C_dim%7D%20=%20%5Cfrac%7B%5Ctext%7Bembed%5C_dim%7D%7D%7B%5Ctext%7Bnum%5C_heads%7D%7D"></p>
<p>This subspace separation is the mechanism that makes specialization both possible and stable. Head 3’s attention weights are determined by <img src="https://latex.codecogs.com/png.latex?W_Q%5E3%20%5Ccdot%20W_K%5E3"> inner products, which have nothing to do with what <img src="https://latex.codecogs.com/png.latex?W_Q%5E7%20%5Ccdot%20W_K%5E7"> computes for head 7. Because they project into orthogonal subspaces of the embedding, heads don’t interfere with each other — a coreference head and a subject-finding head can coexist without one corrupting the other.</p>
<p>Empirical findings from BERTology (Clark et al., 2019) confirm that this specialization emerges after training: some heads consistently track syntactic dependencies across the entire network; others attend primarily to adjacent tokens, effectively implementing a local sliding window; some heads in BERT-style models attend heavily to the <code>[SEP]</code> token — a kind of “no-op” head that routes excess attention somewhere harmless when no strong relationship exists.</p>
<p>Importantly, this specialization is <strong>not designed in</strong>. It arises entirely from the training signal. The architecture only provides the capacity for parallel, independent subspace projections; training discovers what each subspace should track.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/multi-head-attention.png" class="lightbox" data-glightbox="description: .lightbox-desc-6" data-gallery="quarto-lightbox-gallery-6" title="Figure 6: Multi-Head Attention with several attention layers running in parallel (source)"><img src="https://imaddabbura.github.io/posts/nlp/images/multi-head-attention.png" class="quarto-figure quarto-figure-center figure-img" height="400" alt="Figure 6: Multi-Head Attention with several attention layers running in parallel (source)"></a></p>
</figure>
</div>
<figcaption><strong>Figure 6:</strong> Multi-Head Attention with several attention layers running in parallel (<a href="https://arxiv.org/abs/1706.03762">source</a>)</figcaption>
</figure>
</div>
</section>
<section id="implementation-parallel-heads-final-projection" class="level3">
<h3 class="anchored" data-anchor-id="implementation-parallel-heads-final-projection">6.2 Implementation: Parallel Heads, Final Projection</h3>
<p>Each head produces an output of shape <code>B × T × head_dim</code>. All heads run entirely in parallel — there is <strong>no communication between heads</strong> during the forward pass. The outputs of all heads are concatenated along the last dimension: <code>num_heads × head_dim = embed_dim</code>. The concatenated tensor then passes through a final linear projection <img src="https://latex.codecogs.com/png.latex?W_O"> of shape <code>embed_dim × embed_dim</code>.</p>
<p><strong>Why the final projection?</strong> The heads operated in isolation — each found something different in its own subspace. The <img src="https://latex.codecogs.com/png.latex?W_O"> projection is the first opportunity for the model to mix information <em>across</em> heads: to combine what the coreference head found with what the subject-finding head found into a single coherent output vector. But <img src="https://latex.codecogs.com/png.latex?W_O"> does more than concatenate — it <em>filters and compresses</em>. The 12 concatenated head outputs may contain redundant information, conflicting signals, or noise from heads that found nothing relevant. <img src="https://latex.codecogs.com/png.latex?W_O"> is a learned projection that selects which cross-head combinations to amplify and which to suppress. Think of it as the editor who takes 12 reporters’ raw notes and synthesises them into a single coherent paragraph — not every detail makes it through.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Pedagogical vs.&nbsp;Efficient Implementation
</div>
</div>
<div class="callout-body-container callout-body">
<p>The implementation below uses a Python loop over heads for clarity. In practice, all heads are computed in a single batched matrix multiply by reshaping the input to <code>B × T × num_heads × head_dim</code> and transposing — this is the approach used in production (and in <code>torch.nn.MultiheadAttention</code>). The pedagogical loop is equivalent but slower.</p>
</div>
</div>
<p>However, there is still a problem: multi-head attention is a weighted <em>averaging</em> operation — it is linear in V. Stacking multiple attention layers with nothing in between collapses to a single linear transformation. The network needs nonlinearity. That is the feed-forward network’s job.</p>
<div id="cf91afcc-af53-4bd9-84c6-bd3dd69bd49f" class="cell">
<div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> MultiHeadAttention(nn.Module):</span>
<span id="cb6-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config, is_decoder<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb6-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb6-4">        head_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> config.embed_dim <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">//</span> config.num_attention_heads</span>
<span id="cb6-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.heads <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.ModuleList(</span>
<span id="cb6-6">            [</span>
<span id="cb6-7">                AttentionHead(config, head_dim, is_decoder)</span>
<span id="cb6-8">                <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(config.num_attention_heads)</span>
<span id="cb6-9">            ]</span>
<span id="cb6-10">        )</span>
<span id="cb6-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Final projection mixes information across heads: embed_dim -&gt; embed_dim</span></span>
<span id="cb6-12">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.output_proj <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.embed_dim, config.embed_dim)</span>
<span id="cb6-13"></span>
<span id="cb6-14">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, query, key, value):</span>
<span id="cb6-15">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## query: B x T_q x embed_dim</span></span>
<span id="cb6-16">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## key:   B x T_k x embed_dim</span></span>
<span id="cb6-17">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## value: B x T_k x embed_dim</span></span>
<span id="cb6-18">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Each head produces B x T_q x head_dim; cat gives B x T_q x embed_dim</span></span>
<span id="cb6-19">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.cat([head(query, key, value) <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> head <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.heads], dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb6-20">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.output_proj(x)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_q x embed_dim</span></span></code></pre></div>
</div>
</section>
</section>
<section id="feed-forward-network" class="level2">
<h2 class="anchored" data-anchor-id="feed-forward-network">7. Feed-Forward Network</h2>
<section id="why-is-it-needed" class="level3">
<h3 class="anchored" data-anchor-id="why-is-it-needed">7.1 Why Is It Needed?</h3>
<p>Attention is a weighted averaging operation. It is <strong>linear in V</strong>: the output for each position is a linear combination of value vectors, where the combination weights come from the attention scores. If we stacked multiple attention layers with no nonlinearity in between, the composition of linear operations would remain linear — effectively equivalent to a single layer.</p>
<p>This is the same reason we use activation functions between layers in any neural network: without them, depth buys us nothing.</p>
<p>The feed-forward network (FFN) adds the essential nonlinearity. It processes each token’s representation <strong>independently</strong> after the attention layer. There is no mixing of tokens in the FFN — that is attention’s job. The clean separation of concerns is intentional:</p>
<ul>
<li><strong>Attention</strong>: mixes information across positions (who talks to whom)</li>
<li><strong>FFN</strong>: transforms each position’s representation non-linearly (what to say)</li>
</ul>
<p><strong>The FFN as a knowledge store.</strong> Research by Geva et al.&nbsp;(2021) provides a compelling interpretation: FFN layers function as associative memories. The first linear layer acts as a set of keys that pattern-match against the input; the second linear layer acts as the corresponding values that are retrieved and output. Most of a Transformer’s factual knowledge — the associations between entities, relations, and attributes — is hypothesized to live in FFN weights, not in the attention matrices.</p>
<blockquote class="blockquote">
<p><em>Attention is the routing system. The FFN is the knowledge store.</em></p>
</blockquote>
</section>
<section id="architecture-details" class="level3">
<h3 class="anchored" data-anchor-id="architecture-details">7.2 Architecture Details</h3>
<p>The FFN has a characteristic structure: expand, activate, contract.</p>
<ol type="1">
<li><strong>Expand</strong>: Linear projection from <code>embed_dim</code> → <code>4 × embed_dim</code>. The 4x factor is empirical — found to work well across a range of model sizes. The expanded intermediate dimension is where most of the model’s representational capacity lives, and it is the dimension that is typically scaled up when making larger models.</li>
<li><strong>Activate</strong>: GELU (Gaussian Error Linear Unit) nonlinearity. Unlike ReLU, GELU applies a smooth, probabilistic gate proportional to the Gaussian CDF. Empirically, GELU consistently outperforms ReLU in Transformer training. Modern models (LLaMA, PaLM) use SwiGLU — a gated variant — which further improves performance.</li>
<li><strong>Contract</strong>: Linear projection from <code>4 × embed_dim</code> → <code>embed_dim</code>, restoring the original dimension for the residual connection.</li>
</ol>
<p><strong>Why position-wise?</strong> The FFN applies the same learned transformation to every position independently and in parallel. There is no weight-sharing across positions within a layer, but the same weight matrices process every position. This is sometimes called a “position-wise” or “point-wise” feed-forward layer.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p><strong>Most of the parameters live here.</strong> Each attention layer has four weight matrices (Q, K, V, O), each of size <img src="https://latex.codecogs.com/png.latex?d_%7B%5Ctext%7Bmodel%7D%7D%20%5Ctimes%20d_%7B%5Ctext%7Bmodel%7D%7D">, totalling <img src="https://latex.codecogs.com/png.latex?4d_%7B%5Ctext%7Bmodel%7D%7D%5E2"> parameters. The FFN has two matrices of size <img src="https://latex.codecogs.com/png.latex?d_%7B%5Ctext%7Bmodel%7D%7D%20%5Ctimes%204d_%7B%5Ctext%7Bmodel%7D%7D">, totalling <img src="https://latex.codecogs.com/png.latex?8d_%7B%5Ctext%7Bmodel%7D%7D%5E2"> parameters — <strong>twice as many as attention</strong>. Across a full model, the FFN accounts for roughly two-thirds of all trainable parameters. When people talk about “scaling” a Transformer, they mostly mean growing <img src="https://latex.codecogs.com/png.latex?d_%7B%5Ctext%7Bmodel%7D%7D"> and <img src="https://latex.codecogs.com/png.latex?d_%7B%5Ctext%7Bff%7D%7D">, which expands this majority share.</p>
</div>
</div>
<div id="c7ceda3b-1880-495c-8a32-04005afe4260" class="cell">
<div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> FeedForwardNN(nn.Module):</span>
<span id="cb7-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config):</span>
<span id="cb7-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb7-4">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Expand to 4x hidden dim, then contract back — most capacity lives here</span></span>
<span id="cb7-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.l1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.embed_dim, config.intermediate_sz)</span>
<span id="cb7-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.l2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.intermediate_sz, config.embed_dim)</span>
<span id="cb7-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Dropout(config.hidden_dropout_prob)</span>
<span id="cb7-8"></span>
<span id="cb7-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb7-10">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x:        B x T x embed_dim</span></span>
<span id="cb7-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## after l1: B x T x intermediate_sz  (expand)</span></span>
<span id="cb7-12">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## after l2: B x T x embed_dim        (contract)</span></span>
<span id="cb7-13">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.l2(F.gelu(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.l1(x))))</span></code></pre></div>
</div>
</section>
</section>
<section id="layer-normalization" class="level2">
<h2 class="anchored" data-anchor-id="layer-normalization">8. Layer Normalization</h2>
<section id="why-normalize-at-all" class="level3">
<h3 class="anchored" data-anchor-id="why-normalize-at-all">8.1 Why Normalize at All?</h3>
<p>Deep networks have a training stability problem: as signals propagate through many layers, the distribution of activations tends to shift and grow — a phenomenon called <strong>internal covariate shift</strong>. Layers that receive wildly varying input distributions must constantly adjust their weights just to track the shifting scale, not to learn meaningful transformations. This wastes capacity and slows training.</p>
<blockquote class="blockquote">
<p><em>Think of it as keeping the working range of each layer consistent. Without normalization, earlier layers can produce outputs 100x larger than what later layers expect — the later layers waste capacity on a bookkeeping problem rather than learning anything about language.</em></p>
</blockquote>
<p>Normalization is the engineering fix: explicitly constrain activation distributions to zero mean and unit variance at key points in the network, keeping signals in a regime where gradients are well-behaved throughout training.</p>
</section>
<section id="batch-normalization-vs.-layer-normalization" class="level3">
<h3 class="anchored" data-anchor-id="batch-normalization-vs.-layer-normalization">8.2 Batch Normalization vs.&nbsp;Layer Normalization</h3>
<p>Batch Normalization (Ioffe &amp; Szegedy, 2015) normalizes each feature across the batch dimension. This works well for CNNs on images but has two critical failure modes for sequence models:</p>
<ol type="1">
<li><strong>Small batches</strong>: with batch size 1, the batch mean and variance are undefined (or estimated from a single sample). Transformers are often trained with small batch sizes per GPU.</li>
<li><strong>Variable-length sequences</strong>: different positions in a batch may have very different activation statistics. Normalizing across a mixed batch conflates these.</li>
</ol>
<p>Layer Normalization (Ba et al., 2016) normalizes across the <strong>feature dimension</strong> instead of the batch dimension:</p>
<p><img src="https://latex.codecogs.com/png.latex?y%20=%20%5Cfrac%7Bx%20-%20%5Cmathbb%7BE%7D%5Bx%5D%7D%7B%5Csqrt%7B%5Ctext%7BVar%7D%5Bx%5D%20+%20%5Cepsilon%7D%7D%20%5Ccdot%20%5Cgamma%20+%20%5Cbeta"></p>
<p>The mean and variance are computed independently for each example, over all features of that example. This makes LayerNorm completely independent of batch size — it works identically whether batch size is 1 or 1000.</p>
<table class="table">
<thead>
<tr class="header">
<th></th>
<th>Batch Norm</th>
<th>Layer Norm</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Normalizes over</td>
<td>Batch dimension</td>
<td>Feature dimension</td>
</tr>
<tr class="even">
<td>Running statistics for inference</td>
<td>Yes</td>
<td>No</td>
</tr>
<tr class="odd">
<td>Breaks for batch_size = 1</td>
<td>Yes</td>
<td>No</td>
</tr>
<tr class="even">
<td>Variable-length sequences</td>
<td>Awkward</td>
<td>Natural</td>
</tr>
<tr class="odd">
<td>Common in</td>
<td>CNNs, image models</td>
<td>Transformers, RNNs</td>
</tr>
</tbody>
</table>
<p><strong>The learnable parameters <img src="https://latex.codecogs.com/png.latex?%5Cgamma"> and <img src="https://latex.codecogs.com/png.latex?%5Cbeta"></strong>: After normalization, every layer’s output would have zero mean and unit variance — too rigid. The learned scale (<img src="https://latex.codecogs.com/png.latex?%5Cgamma">) and shift (<img src="https://latex.codecogs.com/png.latex?%5Cbeta">) let each layer restore whatever distribution works best for its downstream computation. Without them, normalization would over-constrain the model.</p>
</section>
<section id="pre-norm-vs.-post-norm-a-critical-implementation-choice" class="level3">
<h3 class="anchored" data-anchor-id="pre-norm-vs.-post-norm-a-critical-implementation-choice">8.3 Pre-Norm vs.&nbsp;Post-Norm: A Critical Implementation Choice</h3>
<p>The original Transformer paper placed LayerNorm <em>after</em> the residual addition (Post-LayerNorm). GPT-2 and virtually every modern large model places it <em>before</em> (Pre-LayerNorm). This seemingly minor change has significant consequences for training stability.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph LR
    subgraph PostLN["Post-LN (original paper)"]
        A1[x] --&gt; B1[Sublayer]
        A1 --&gt; C1[+]
        B1 --&gt; C1
        C1 --&gt; D1[LayerNorm]
        D1 --&gt; E1[output]
    end
    subgraph PreLN["Pre-LN (GPT-2, modern default)"]
        A2[x] --&gt; B2[LayerNorm]
        B2 --&gt; C2[Sublayer]
        A2 --&gt; D2[+]
        C2 --&gt; D2
        D2 --&gt; E2[output]
    end
</pre>
</div>
<p></p><figcaption> <strong>Figure 7:</strong> Post-LayerNorm (left) vs Pre-LayerNorm (right). Modern models use Pre-LN.</figcaption> </figure><p></p>
</div>
</div>
</div>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th></th>
<th>Post-LN: <code>LN(x + sublayer(x))</code></th>
<th>Pre-LN: <code>x + sublayer(LN(x))</code></th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Gradient path</td>
<td>Normalization sits outside the residual — gradients must pass through it</td>
<td>Normalization is inside — clean gradient highway through the residual</td>
</tr>
<tr class="even">
<td>Training stability</td>
<td>Sensitive; requires careful learning rate warm-up; can diverge</td>
<td>More stable; trains without warm-up</td>
</tr>
<tr class="odd">
<td>Final performance</td>
<td>Marginally better with enough tuning</td>
<td>Slightly lower ceiling, but much easier to train</td>
</tr>
</tbody>
</table>
<p>Modern practice defaults to Pre-LN: training stability at scale is worth more than marginal final performance differences. If you are building a new model, use Pre-LN.</p>
</section>
</section>
<section id="skip-residual-connections" class="level2">
<h2 class="anchored" data-anchor-id="skip-residual-connections">9. Skip (Residual) Connections</h2>
<section id="the-residual-stream-mental-model" class="level3">
<h3 class="anchored" data-anchor-id="the-residual-stream-mental-model">9.1 The Residual Stream Mental Model</h3>
<p>Think of a Transformer as a <strong>residual stream</strong> — a river of information that flows from the input through all the layers to the output. Each layer (attention + FFN) reads from the stream and writes a correction back to it via addition:</p>
<p><img src="https://latex.codecogs.com/png.latex?x_%7B%5Ctext%7Bout%7D%7D%20=%20x_%7B%5Ctext%7Bin%7D%7D%20+%20%5Ctext%7Bsublayer%7D(x_%7B%5Ctext%7Bin%7D%7D)"></p>
<p>No single layer “owns” the representation. Each layer adds its contribution to a shared river. The residual stream at any point contains the sum of everything all previous layers have written.</p>
<p>This framing — developed in mechanistic interpretability research — makes it immediately clear why attention heads can specialize: each head contributes independently and additively to the stream. They don’t compete or overwrite each other; they contribute independently, and the stream accumulates all contributions.</p>
</section>
<section id="why-residual-connections-work" class="level3">
<h3 class="anchored" data-anchor-id="why-residual-connections-work">9.2 Why Residual Connections Work</h3>
<p><strong>Gradient highways.</strong> When backpropagating through <img src="https://latex.codecogs.com/png.latex?y%20=%20x%20+%20F(x)">:</p>
<p><img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20x%7D%20=%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20y%7D%20%5Ccdot%20%5Cleft(1%20+%20%5Cfrac%7B%5Cpartial%20F%7D%7B%5Cpartial%20x%7D%5Cright)"></p>
<p>The term <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20y%7D"> reaches <img src="https://latex.codecogs.com/png.latex?x"> directly, through the identity path — regardless of what <img src="https://latex.codecogs.com/png.latex?F(x)"> does. Even if <img src="https://latex.codecogs.com/png.latex?F"> has saturated activations or near-zero gradients, the loss signal still flows back to earlier layers. This is why ResNets with skip connections can be trained to hundreds of layers while the same architecture without them fails beyond a dozen.</p>
<p><strong>Loss landscape smoothing.</strong> He et al.&nbsp;(2016) visualized the loss surfaces of deep networks with and without skip connections. Without them: chaotic, sharp, with many high-curvature local minima that trap gradient descent. With them: smooth, convex, much more navigable.</p>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/loss-landscape-with-skip-connections.png" class="lightbox" data-glightbox="description: .lightbox-desc-7" data-gallery="quarto-lightbox-gallery-7" title="Figure 8: Loss surfaces of ResNet-56 with/without skip connections (source)"><img src="https://imaddabbura.github.io/posts/nlp/images/loss-landscape-with-skip-connections.png" class="img-fluid figure-img" alt="Figure 8: Loss surfaces of ResNet-56 with/without skip connections (source)"></a></p>
<figcaption><strong>Figure 8:</strong> Loss surfaces of ResNet-56 with/without skip connections (<a href="https://arxiv.labs.arxiv.org/html/1712.09913">source</a>)</figcaption>
</figure>
</div>
<p><strong>The forgetting argument.</strong> Without skip connections, each layer must preserve all useful information from its input in its output — if the layer wants to pass something unchanged, it must learn to do so explicitly. With skip connections, the <strong>default is identity</strong> — the layer only needs to learn what to <em>add</em>, not what to keep. This dramatically reduces the effective depth that the gradient must overcome.</p>
<p>However, training deep networks reliably requires one more ingredient beyond gradient highways — preventing the network from memorizing noise. That is dropout’s job.</p>
</section>
</section>
<section id="dropout" class="level2">
<h2 class="anchored" data-anchor-id="dropout">10. Dropout</h2>
<p>Dropout (Srivastava et al., 2014) randomly zeros a fraction <code>p</code> of activations during training. Each training step uses a different random mask, forcing the model not to rely on any particular activation path — a phenomenon called <strong>co-adaptation prevention</strong>.</p>
<p>The regularization effect comes from two mechanisms:</p>
<ol type="1">
<li><strong>Network size reduction</strong>: Dropping units creates a smaller effective network per step. A smaller network has fewer parameters to overfit.</li>
<li><strong>Implicit ensembling</strong>: Each step trains a different subnetwork. At inference, the full network approximates averaging over all these subnetworks — equivalent to a cheap bagging ensemble.</li>
</ol>
<p>In Transformers, dropout is applied after the embedding layer (after adding token + positional embeddings), after each attention sublayer, and after each FFN sublayer.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Modern Large Models Often Skip Dropout
</div>
</div>
<div class="callout-body-container callout-body">
<p>LLaMA, Mistral, and other recent large models use no dropout at all. At sufficient scale with enough data, the regularization effect of dropout is less necessary, and it slows training. Dropout remains important for smaller models trained on limited data, and for fine-tuning where overfitting is a risk.</p>
</div>
</div>
<div class="quarto-figure quarto-figure-center">
<figure class="figure">
<p><a href="images/dropout.png" class="lightbox" data-glightbox="description: .lightbox-desc-8" data-gallery="quarto-lightbox-gallery-8" title="Figure 9: Left: standard neural net. Right: thinned net after applying dropout — crossed units are dropped. (source)"><img src="https://imaddabbura.github.io/posts/nlp/images/dropout.png" class="img-fluid figure-img" alt="Figure 9: Left: standard neural net. Right: thinned net after applying dropout — crossed units are dropped. (source)"></a></p>
<figcaption><strong>Figure 9:</strong> Left: standard neural net. Right: thinned net after applying dropout — crossed units are dropped. (<a href="https://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf">source</a>)</figcaption>
</figure>
</div>
<p>With all the individual components understood — attention, FFN, LayerNorm, skip connections, dropout — it’s time to see how they snap together into a complete layer.</p>
</section>
<section id="assembling-the-encoder-layer" class="level2">
<h2 class="anchored" data-anchor-id="assembling-the-encoder-layer">11. Assembling the Encoder Layer</h2>
<p>Now that we have all the building blocks, let us see how they snap together into a single encoder layer — the repeated unit that makes up the encoder stack.</p>
<p>An encoder layer applies two sublayers in sequence, each wrapped in a residual connection and LayerNorm. Tracing the shapes at every step (using Pre-LN convention):</p>
<table class="table">
<colgroup>
<col style="width: 15%">
<col style="width: 28%">
<col style="width: 18%">
<col style="width: 36%">
</colgroup>
<thead>
<tr class="header">
<th>Step</th>
<th>Operation</th>
<th>Shape</th>
<th>What it does</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>1</td>
<td>Input Embeddings</td>
<td><code>(B, T, d_model)</code></td>
<td>Token IDs → dense vectors</td>
</tr>
<tr class="even">
<td>2</td>
<td>+ Positional Encoding</td>
<td><code>(B, T, d_model)</code></td>
<td>Inject position information</td>
</tr>
<tr class="odd">
<td>3</td>
<td>Self-Attention (xh heads)</td>
<td><code>(B, T, d_model)</code></td>
<td>Each token attends to all others</td>
</tr>
<tr class="even">
<td>4</td>
<td>Add &amp; Norm</td>
<td><code>(B, T, d_model)</code></td>
<td>Residual connection + layer norm</td>
</tr>
<tr class="odd">
<td>5</td>
<td>Feed-Forward</td>
<td><code>(B, T, d_ff)</code> → <code>(B, T, d_model)</code></td>
<td>Non-linear transformation</td>
</tr>
<tr class="even">
<td>6</td>
<td>Add &amp; Norm</td>
<td><code>(B, T, d_model)</code></td>
<td>Residual connection + layer norm</td>
</tr>
<tr class="odd">
<td>7</td>
<td>[Repeat × N layers]</td>
<td><code>(B, T, d_model)</code></td>
<td>Stack N encoder layers</td>
</tr>
<tr class="even">
<td>8</td>
<td>Encoder Output</td>
<td><code>(B, T, d_model)</code></td>
<td>Rich contextual representations</td>
</tr>
</tbody>
</table>
<p><code>B = batch size, T = sequence length, d_model = model dimension</code></p>
<p><strong>Why every row says <code>d_model</code>.</strong> Residual connections require that the sublayer output has exactly the same shape as its input — otherwise you cannot add them together. This is a hard architectural constraint: every sublayer (attention, FFN, LayerNorm) must consume and produce tensors of shape <code>(B, T, d_model)</code>. It is the reason <code>head_dim = d_model / num_heads</code> (the concatenation of all heads must restore <code>d_model</code>), and why the FFN contracts back from <code>d_ff</code> → <code>d_model</code> at the end. The entire Transformer is shaped around this single number.</p>
<p>Every token’s representation enters with shape <code>d_model</code>. After the attention sublayer, it has been updated by attending to all other tokens — information has been mixed across positions. After the FFN sublayer, each position’s representation has been transformed nonlinearly — independently from all other positions.</p>
<blockquote class="blockquote">
<p><em>An encoder layer does two things: (1) let tokens talk to each other via attention, then (2) let each token digest what it heard via the FFN.</em></p>
</blockquote>
<p>A full encoder stacks <img src="https://latex.codecogs.com/png.latex?N"> of these layers (typically 6–24). Each layer refines the representations further — early layers tend to capture surface-level patterns, later layers capture increasingly abstract semantic relationships.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Post-LN in the Code
</div>
</div>
<div class="callout-body-container callout-body">
<p>The implementation below uses the Post-LayerNorm arrangement from the original paper: <code>LN(x + sublayer(x))</code>. The Pre-LN alternative is shown in comments. For new models, prefer Pre-LN.</p>
</div>
</div>
<div id="b25d08fb-7099-40a1-9175-f53442e43b89" class="cell">
<div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> EncoderLayer(nn.Module):</span>
<span id="cb8-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config):</span>
<span id="cb8-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb8-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.attn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> MultiHeadAttention(config)</span>
<span id="cb8-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ff <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> FeedForwardNN(config)</span>
<span id="cb8-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.embed_dim)</span>
<span id="cb8-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.embed_dim)</span>
<span id="cb8-8"></span>
<span id="cb8-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb8-10">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x: B x T x embed_dim  (input and output shape are identical)</span></span>
<span id="cb8-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##</span></span>
<span id="cb8-12">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Post-LayerNorm arrangement (original Transformer paper):</span></span>
<span id="cb8-13">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_1(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.attn(x, x, x))  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## bidirectional self-attention</span></span>
<span id="cb8-14">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_2(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ff(x))</span>
<span id="cb8-15">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##</span></span>
<span id="cb8-16">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Pre-LayerNorm alternative (GPT-2+, more stable — recommended for new models):</span></span>
<span id="cb8-17">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x = x + self.attn(self.layer_norm_1(x), self.layer_norm_1(x), self.layer_norm_1(x))</span></span>
<span id="cb8-18">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x = x + self.ff(self.layer_norm_2(x))</span></span>
<span id="cb8-19">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> x  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T x embed_dim</span></span></code></pre></div>
</div>
<div id="af565905-5b13-4773-b44f-d0a2e78a3831" class="cell">
<div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> TransformerEncoder(nn.Module):</span>
<span id="cb9-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb9-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb9-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.embeddings <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Embeddings(config)</span>
<span id="cb9-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.encoder_blocks <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Sequential(</span>
<span id="cb9-6">            <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>[EncoderLayer(config) <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(config.num_hidden_layers)]</span>
<span id="cb9-7">        )</span>
<span id="cb9-8"></span>
<span id="cb9-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb9-10">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x:    B x T  (integer token IDs)</span></span>
<span id="cb9-11">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.embeddings(x)        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T x embed_dim</span></span>
<span id="cb9-12">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.encoder_blocks(x)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T x embed_dim</span></span></code></pre></div>
</div>
</section>
<section id="assembling-the-decoder-layer" class="level2">
<h2 class="anchored" data-anchor-id="assembling-the-decoder-layer">12. Assembling the Decoder Layer</h2>
<p>The decoder layer differs from the encoder in one critical way: it adds a <strong>cross-attention sublayer</strong> between the masked self-attention and the FFN. This is the mechanism that lets the decoder read the encoder’s output.</p>
<p>A decoder layer applies three sublayers:</p>
<table class="table">
<colgroup>
<col style="width: 15%">
<col style="width: 28%">
<col style="width: 18%">
<col style="width: 36%">
</colgroup>
<thead>
<tr class="header">
<th>Step</th>
<th>Operation</th>
<th>Shape</th>
<th>What it does</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>1</td>
<td>Target Embeddings</td>
<td><code>(B, T_tgt, d_model)</code></td>
<td>Target token IDs → dense vectors</td>
</tr>
<tr class="even">
<td>2</td>
<td>+ Positional Encoding</td>
<td><code>(B, T_tgt, d_model)</code></td>
<td>Inject position information</td>
</tr>
<tr class="odd">
<td>3</td>
<td>Masked Self-Attention (xh)</td>
<td><code>(B, T_tgt, d_model)</code></td>
<td>Attends only to past positions (causal mask)</td>
</tr>
<tr class="even">
<td>4</td>
<td>Add &amp; Norm</td>
<td><code>(B, T_tgt, d_model)</code></td>
<td>Residual connection + layer norm</td>
</tr>
<tr class="odd">
<td>5</td>
<td>Cross-Attention (×h)</td>
<td><code>(B, T_tgt, d_model)</code></td>
<td>Q from decoder; K, V from encoder output</td>
</tr>
<tr class="even">
<td>6</td>
<td>Add &amp; Norm</td>
<td><code>(B, T_tgt, d_model)</code></td>
<td>Residual connection + layer norm</td>
</tr>
<tr class="odd">
<td>7</td>
<td>Feed-Forward</td>
<td><code>(B, T_tgt, d_ff)</code> → <code>(B, T_tgt, d_model)</code></td>
<td>Non-linear transformation</td>
</tr>
<tr class="even">
<td>8</td>
<td>Add &amp; Norm</td>
<td><code>(B, T_tgt, d_model)</code></td>
<td>Residual connection + layer norm</td>
</tr>
<tr class="odd">
<td>9</td>
<td>[Repeat × N layers]</td>
<td><code>(B, T_tgt, d_model)</code></td>
<td>Stack N decoder layers</td>
</tr>
<tr class="even">
<td>10</td>
<td>Linear + Softmax</td>
<td><code>(B, T_tgt, vocab_size)</code></td>
<td>Project to vocabulary probabilities</td>
</tr>
</tbody>
</table>
<p><code>B = batch size, T_tgt = target sequence length</code></p>
<p><strong>Sublayer 1 — Masked self-attention</strong>: Decoder tokens attend to each other, but only to past and current positions (causal mask). This builds a contextualized representation of the target sequence generated so far.</p>
<p><strong>Sublayer 2 — Cross-attention</strong>: The decoder’s hidden state becomes the query. The encoder’s final output provides the keys and values. Every decoder position can attend to all encoder positions — this is how the decoder “reads” the full source sequence at every generation step.</p>
<p><strong>Sublayer 3 — FFN</strong>: Same position-wise transformation as in the encoder.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note on the Code Below
</div>
</div>
<div class="callout-body-container callout-body">
<p>The DecoderLayer shown uses only masked self-attention (no cross-attention sublayer). It is therefore suited for the decoder-only (GPT-style) architecture. Cross-attention is addressed in the Encoder-Decoder section.</p>
</div>
</div>
<div id="d684880a-3132-487a-ab51-fc3b20978e5f" class="cell">
<div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> DecoderLayer(nn.Module):</span>
<span id="cb10-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config):</span>
<span id="cb10-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb10-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.attn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> MultiHeadAttention(config, is_decoder<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb10-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ff <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> FeedForwardNN(config)</span>
<span id="cb10-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.embed_dim)</span>
<span id="cb10-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.embed_dim)</span>
<span id="cb10-8"></span>
<span id="cb10-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb10-10">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x: B x T x embed_dim</span></span>
<span id="cb10-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Masked self-attention: each token only attends to past and current positions</span></span>
<span id="cb10-12">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_1(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.attn(x, x, x))</span>
<span id="cb10-13">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_2(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ff(x))</span>
<span id="cb10-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> x  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T x embed_dim</span></span></code></pre></div>
</div>
<div id="89f920e8-82b9-4975-841d-94613c5bfa7d" class="cell">
<div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> TransformerDecoder(nn.Module):</span>
<span id="cb11-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-&gt;</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">None</span>:</span>
<span id="cb11-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb11-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.embeddings <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Embeddings(config)</span>
<span id="cb11-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.decoder_blocks <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Sequential(</span>
<span id="cb11-6">            <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span>[DecoderLayer(config) <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(config.num_hidden_layers)]</span>
<span id="cb11-7">        )</span>
<span id="cb11-8"></span>
<span id="cb11-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb11-10">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x:    B x T  (integer token IDs)</span></span>
<span id="cb11-11">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.embeddings(x)         <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T x embed_dim</span></span>
<span id="cb11-12">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.decoder_blocks(x)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T x embed_dim</span></span></code></pre></div>
</div>
</section>
<section id="architecture-variants" class="level2">
<h2 class="anchored" data-anchor-id="architecture-variants">13. Architecture Variants</h2>
<p>The same building blocks support three distinct architectures, differing only in which sublayers are present and whether attention is masked. Here is the full comparison before diving into each:</p>
<table class="table">
<colgroup>
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th></th>
<th>Encoder-Only</th>
<th>Decoder-Only</th>
<th>Encoder-Decoder</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Attention masking</strong></td>
<td>Bidirectional</td>
<td>Causal</td>
<td>Causal in decoder; bidirectional in encoder</td>
</tr>
<tr class="even">
<td><strong>Cross-attention</strong></td>
<td>No</td>
<td>No</td>
<td>Yes</td>
</tr>
<tr class="odd">
<td><strong>Input → Output</strong></td>
<td>Text → hidden states</td>
<td>Text → next token</td>
<td>Source text → target text</td>
</tr>
<tr class="even">
<td><strong>Canonical task</strong></td>
<td>Classification, NER, embeddings</td>
<td>Text generation, LM</td>
<td>Translation, summarization</td>
</tr>
<tr class="odd">
<td><strong>Examples</strong></td>
<td>BERT, RoBERTa, DistilBERT</td>
<td>GPT, LLaMA, Mistral</td>
<td>T5, BART, mT5</td>
</tr>
</tbody>
</table>
<section id="encoder-only-architecture" class="level3">
<h3 class="anchored" data-anchor-id="encoder-only-architecture">13.1 Encoder-Only Architecture</h3>
<p>Encoder-only models use bidirectional self-attention — every token attends to every other token with no masking. This means the representation of each token is conditioned on the full context: tokens to the left <em>and</em> the right. Bidirectional context makes encoder-only models excellent at understanding tasks: text classification, named entity recognition, extractive question answering, and computing sentence embeddings.</p>
<p><strong>Why bidirectional?</strong> Classification does not require generating new tokens — it requires understanding the full input. A model that sees the entire sentence simultaneously can build richer representations than one forced to read left-to-right.</p>
<p><strong>How is it trained?</strong> BERT-style models are trained with <strong>Masked Language Modeling (MLM)</strong>: 15% of tokens are randomly masked (<code>[MASK]</code>), and the model must predict the original token at each masked position. Because the model can see all tokens to the left <em>and</em> right of the mask, this forces it to build bidirectional representations.</p>
<p><strong>Classification head.</strong> A special <code>[CLS]</code> token is prepended to every sequence before the encoder. The encoder’s output at the <code>[CLS]</code> position — <code>encoder_output[:, 0, :]</code> — serves as an aggregate representation of the full sequence. This vector is passed through a linear classification head to produce logits.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Why [CLS] and Not Mean Pooling?
</div>
</div>
<div class="callout-body-container callout-body">
<p>BERT uses <code>[CLS]</code> because it is trained to aggregate sequence-level information during pretraining (next sentence prediction task). In practice, mean pooling over all token representations often performs equally well or better for downstream tasks. Modern models trained without NSP use mean pooling as the default.</p>
</div>
</div>
<div id="df92c42f-8650-4ef6-9865-686d183d61cc" class="cell">
<div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> TransformerForSequenceClassification(nn.Module):</span>
<span id="cb12-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config):</span>
<span id="cb12-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb12-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.encoder <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> TransformerEncoder(config)</span>
<span id="cb12-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Dropout(config.hidden_dropout_prob)</span>
<span id="cb12-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.classifier <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.embed_dim, config.num_classes)</span>
<span id="cb12-7"></span>
<span id="cb12-8">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb12-9">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x:              B x T  (integer token IDs)</span></span>
<span id="cb12-10">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## encoder output: B x T x embed_dim</span></span>
<span id="cb12-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## [CLS] vector:   B x embed_dim  (position 0 aggregates sequence meaning)</span></span>
<span id="cb12-12">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## logits:         B x num_classes</span></span>
<span id="cb12-13">        cls_output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.encoder(x)[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, :]</span>
<span id="cb12-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.classifier(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout(cls_output))</span></code></pre></div>
</div>
</section>
<section id="decoder-only-architecture" class="level3">
<h3 class="anchored" data-anchor-id="decoder-only-architecture">13.2 Decoder-Only Architecture</h3>
<p>Decoder-only models use causal self-attention — each token can only attend to itself and previous tokens. This is the natural architecture for <strong>language modeling</strong>: predicting the next token given all previous tokens.</p>
<p><strong>Why causal?</strong> Generating text requires predicting one token at a time. If the model could see future tokens while predicting token <img src="https://latex.codecogs.com/png.latex?t">, it would simply copy them. The causal mask enforces the constraint that prediction at position <img src="https://latex.codecogs.com/png.latex?t"> uses only information from positions <img src="https://latex.codecogs.com/png.latex?0,%201,%20%5Cldots,%20t">.</p>
<p><strong>The training objective: Causal Language Modeling (CLM).</strong> Decoder-only models are trained by next-token prediction: given a sequence of tokens, predict the next one at every position simultaneously. The loss is the average cross-entropy over all positions. Because the causal mask prevents each position from seeing future tokens, a single forward pass generates <img src="https://latex.codecogs.com/png.latex?T"> training examples from one sequence — every position is simultaneously a training target. This is why CLM scales so efficiently: a 2048-token document yields 2048 gradient signals per forward pass. The training objective directly shapes what the model learns: because it must predict the next token from all preceding context, the model is forced to compress everything useful about the past into each position’s representation — which is why later layers hold increasingly abstract, predictive features.</p>
<p><strong>Autoregressive generation.</strong> At inference, the decoder generates text by repeating:</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">graph LR
    A["Input tokens
[BOS, t₁, t₂]"] --&gt; B["Decoder
(causal attention)"]
    B --&gt; C["LM head
(linear + softmax)"]
    C --&gt; D["Next token
t₃"]
    D --&gt; A
</pre>
</div>
<p></p><figcaption> <strong>Figure 10:</strong> Autoregressive generation loop in decoder-only models.</figcaption> </figure><p></p>
</div>
</div>
</div>
<ol type="1">
<li>Feed current token sequence through the decoder</li>
<li>Take the output at the last position → pass through the LM head (linear projection to <code>vocab_sz</code>, then softmax)</li>
<li>Sample the next token from the resulting distribution</li>
<li>Append the sampled token to the sequence and repeat</li>
</ol>
<p><strong>Sampling strategies</strong> control how token <img src="https://latex.codecogs.com/png.latex?t+1"> is chosen from the distribution:</p>
<ul>
<li><strong>Greedy</strong>: always pick the highest-probability token. Fast but repetitive.</li>
<li><strong>Top-k</strong>: sample from the top-<img src="https://latex.codecogs.com/png.latex?k"> tokens by probability. Controls diversity.</li>
<li><strong>Top-p (nucleus)</strong>: sample from the smallest set of tokens whose cumulative probability exceeds <img src="https://latex.codecogs.com/png.latex?p">. Adaptive — uses fewer options when one token is dominant.</li>
<li><strong>Temperature</strong>: divide all logits by temperature <img src="https://latex.codecogs.com/png.latex?%5Ctau"> before softmax. <img src="https://latex.codecogs.com/png.latex?%5Ctau%20%3C%201"> sharpens the distribution (more confident); <img src="https://latex.codecogs.com/png.latex?%5Ctau%20%3E%201"> flattens it (more random).</li>
</ul>
<p><strong>KV caching: why inference is efficient.</strong> The loop as described implies re-computing attention over the full growing sequence at every step — which would scale as <img src="https://latex.codecogs.com/png.latex?O(T%5E2)"> for a <img src="https://latex.codecogs.com/png.latex?T">-token generation. Production systems avoid this with a <strong>KV cache</strong>: the K and V tensors for all past positions are stored after their first computation and reused on every subsequent step. Only the new token’s Q needs to be computed; it attends to the cached K/V from all prior positions. Each generation step then costs <img src="https://latex.codecogs.com/png.latex?O(T%20%5Ccdot%20d)"> instead of <img src="https://latex.codecogs.com/png.latex?O(T%5E2%20%5Ccdot%20d)">.</p>
<p>The KV cache is a first-class engineering constraint in LLM deployment. For a model with <img src="https://latex.codecogs.com/png.latex?L"> layers, <img src="https://latex.codecogs.com/png.latex?H"> heads, head dimension <img src="https://latex.codecogs.com/png.latex?d_k">, and current sequence length <img src="https://latex.codecogs.com/png.latex?T">, the cache requires <img src="https://latex.codecogs.com/png.latex?2%20%5Ccdot%20L%20%5Ccdot%20H%20%5Ccdot%20d_k%20%5Ccdot%20T"> values — for LLaMA-3 70B at 4K context in fp16, that is roughly 5 GB. This is precisely why <strong>Grouped Query Attention (GQA)</strong> exists: by sharing a single K/V head across multiple Q heads, the cache shrinks by a factor of <code>num_heads / num_kv_heads</code> — often 8x. Every major modern model (LLaMA 2/3, Mistral, Gemma) uses GQA for exactly this reason.</p>
<div id="8d2737a7-c025-4578-882e-0352341d7e95" class="cell">
<div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> GPT(nn.Module):</span>
<span id="cb13-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config):</span>
<span id="cb13-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb13-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.decoder <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> TransformerDecoder(config)</span>
<span id="cb13-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Dropout(config.hidden_dropout_prob)</span>
<span id="cb13-6">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Project from embed_dim to vocab_sz to get next-token logits</span></span>
<span id="cb13-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.lm_head <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.embed_dim, config.vocab_sz, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb13-8"></span>
<span id="cb13-9">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x):</span>
<span id="cb13-10">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x:       B x T  (integer token IDs)</span></span>
<span id="cb13-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## decoded: B x T x embed_dim</span></span>
<span id="cb13-12">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## logits:  B x T x vocab_sz  (next-token distribution at every position)</span></span>
<span id="cb13-13">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.dropout(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.decoder(x))</span>
<span id="cb13-14">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.lm_head(x)</span></code></pre></div>
</div>
<div id="b3e1f2a0-seq2-seq0-0000-000000000001" class="cell">
<div class="sourceCode cell-code" id="cb14" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> CrossAttentionDecoderLayer(nn.Module):</span>
<span id="cb14-2">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Decoder layer with three sublayers:</span></span>
<span id="cb14-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    (1) masked causal self-attention, (2) cross-attention to encoder, (3) FFN.</span></span>
<span id="cb14-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    """</span></span>
<span id="cb14-5">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config):</span>
<span id="cb14-6">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb14-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.self_attn    <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> MultiHeadAttention(config, is_decoder<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb14-8">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.cross_attn   <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> MultiHeadAttention(config, is_decoder<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb14-9">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ff           <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> FeedForwardNN(config)</span>
<span id="cb14-10">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.embed_dim)</span>
<span id="cb14-11">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.embed_dim)</span>
<span id="cb14-12">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_3 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LayerNorm(config.embed_dim)</span>
<span id="cb14-13"></span>
<span id="cb14-14">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x, encoder_output):</span>
<span id="cb14-15">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x:              B x T_dec x embed_dim</span></span>
<span id="cb14-16">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## encoder_output: B x T_enc x embed_dim</span></span>
<span id="cb14-17">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##</span></span>
<span id="cb14-18">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## 1. Masked self-attention — decoder tokens attend to each other causally</span></span>
<span id="cb14-19">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_1(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.self_attn(x, x, x))</span>
<span id="cb14-20">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## 2. Cross-attention — Q from decoder, K and V from encoder</span></span>
<span id="cb14-21">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">##    Every decoder position can attend to all encoder positions</span></span>
<span id="cb14-22">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_2(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.cross_attn(x, encoder_output, encoder_output))</span>
<span id="cb14-23">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## 3. Position-wise FFN</span></span>
<span id="cb14-24">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.layer_norm_3(x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ff(x))</span>
<span id="cb14-25">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> x  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_dec x embed_dim</span></span>
<span id="cb14-26"></span>
<span id="cb14-27"></span>
<span id="cb14-28"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> Seq2SeqTransformer(nn.Module):</span>
<span id="cb14-29">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">"""Encoder-decoder Transformer for sequence-to-sequence tasks</span></span>
<span id="cb14-30"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    such as machine translation and summarization.</span></span>
<span id="cb14-31"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    """</span></span>
<span id="cb14-32">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, config):</span>
<span id="cb14-33">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb14-34">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.encoder_embeddings <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Embeddings(config)</span>
<span id="cb14-35">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.decoder_embeddings <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Embeddings(config)</span>
<span id="cb14-36">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.encoder_blocks <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.ModuleList(</span>
<span id="cb14-37">            [EncoderLayer(config) <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(config.num_hidden_layers)]</span>
<span id="cb14-38">        )</span>
<span id="cb14-39">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.decoder_blocks <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.ModuleList(</span>
<span id="cb14-40">            [CrossAttentionDecoderLayer(config) <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> _ <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(config.num_hidden_layers)]</span>
<span id="cb14-41">        )</span>
<span id="cb14-42">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.lm_head <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(config.embed_dim, config.vocab_sz, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">False</span>)</span>
<span id="cb14-43"></span>
<span id="cb14-44">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> encode(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, src):</span>
<span id="cb14-45">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## src: B x T_enc  →  B x T_enc x embed_dim</span></span>
<span id="cb14-46">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.encoder_embeddings(src)</span>
<span id="cb14-47">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> block <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.encoder_blocks:</span>
<span id="cb14-48">            x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> block(x)</span>
<span id="cb14-49">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> x  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_enc x embed_dim</span></span>
<span id="cb14-50"></span>
<span id="cb14-51">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> decode(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, tgt, encoder_output):</span>
<span id="cb14-52">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## tgt:            B x T_dec</span></span>
<span id="cb14-53">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## encoder_output: B x T_enc x embed_dim</span></span>
<span id="cb14-54">        x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.decoder_embeddings(tgt)  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_dec x embed_dim</span></span>
<span id="cb14-55">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> block <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.decoder_blocks:</span>
<span id="cb14-56">            x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> block(x, encoder_output)</span>
<span id="cb14-57">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> x  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_dec x embed_dim</span></span>
<span id="cb14-58"></span>
<span id="cb14-59">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, src, tgt):</span>
<span id="cb14-60">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## src: B x T_enc  (source token IDs, e.g. English)</span></span>
<span id="cb14-61">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## tgt: B x T_dec  (target token IDs, e.g. German — teacher-forced during training)</span></span>
<span id="cb14-62">        encoder_output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.encode(src)                    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_enc x embed_dim</span></span>
<span id="cb14-63">        decoder_output <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.decode(tgt, encoder_output)    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_dec x embed_dim</span></span>
<span id="cb14-64">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.lm_head(decoder_output)                  <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x T_dec x vocab_sz</span></span></code></pre></div>
</div>
</section>
<section id="encoder-decoder-architecture" class="level3">
<h3 class="anchored" data-anchor-id="encoder-decoder-architecture">13.3 Encoder-Decoder Architecture</h3>
<p>The encoder-decoder (or “sequence-to-sequence”) architecture is the original Transformer from Vaswani et al.&nbsp;(2017). It is designed for tasks where both input and output are text sequences — particularly tasks where the input and output are structurally different, like machine translation or summarization.</p>
<p><strong>The two-phase interpretation:</strong></p>
<ul>
<li><strong>Encoder</strong>: reads the full source sequence with bidirectional attention and produces a rich, contextualized representation. Think of this as “understanding the source.”</li>
<li><strong>Decoder</strong>: generates the target sequence token by token, conditioned on the encoder’s representation at every step. Think of this as “generating the target given the understanding.”</li>
</ul>
<p><strong>How cross-attention implements conditioning.</strong> At every decoder step, the cross-attention sublayer takes:</p>
<ul>
<li>Queries from the decoder’s current hidden state: <em>“What do I need from the source?”</em></li>
<li>Keys and Values from the encoder’s final output: <em>“Here is everything in the source.”</em></li>
</ul>
<p>Every decoder position attends to all encoder positions simultaneously. The model learns which parts of the source to focus on when generating each target token — the alignment between source and target.</p>
<p>Note that unlike self-attention (where the <img src="https://latex.codecogs.com/png.latex?T%20%5Ctimes%20T"> weight matrix is square), cross-attention produces a <strong>rectangular</strong> weight matrix of shape <img src="https://latex.codecogs.com/png.latex?T_%7B%5Ctext%7Btgt%7D%7D%20%5Ctimes%20T_%7B%5Ctext%7Bsrc%7D%7D">: one row per decoder query position, one column per encoder key position. During early generation when only a few target tokens exist, this matrix might be <img src="https://latex.codecogs.com/png.latex?3%20%5Ctimes%2050"> — three decoder positions each attending over fifty source positions. The asymmetry is intentional: the decoder decides what to ask (Q), the encoder provides the full library of keys and values (K, V), and the weight matrix records what each decoder step borrows from each source position.</p>
<p><strong>When encoder-decoder vs.&nbsp;decoder-only?</strong> Encoder-decoder models are preferred when source and target are structurally different (e.g., English → German, document → summary). For tasks where both input and output are similar in format (e.g., open-domain conversation, code completion), decoder-only models have largely taken over — they are simpler to train and scale, and can handle both input and output within a single sequence by formatting the task as a text completion problem.</p>
<p>Notable encoder-decoder models: <strong>T5</strong> (Text-to-Text Transfer Transformer), <strong>BART</strong>, <strong>mT5</strong>, <strong>NLLB</strong>.</p>
</section>
</section>
<section id="end-to-end-forward-pass-walkthrough" class="level2">
<h2 class="anchored" data-anchor-id="end-to-end-forward-pass-walkthrough">14. End-to-End Forward Pass Walkthrough</h2>
<p>Let’s trace a complete forward pass through an encoder-decoder Transformer to see how all the pieces compose. We’ll use a small example: translating the English sentence “The cat sat” into German.</p>
<p><strong>Setup:</strong> batch size <img src="https://latex.codecogs.com/png.latex?B%20=%201">, source length <img src="https://latex.codecogs.com/png.latex?T_%7Benc%7D%20=%203">, <code>embed_dim = 768</code>, <code>num_heads = 12</code>, <code>head_dim = 64</code>.</p>
<hr>
<p><strong>Step 1 — Tokenize the source.</strong></p>
<p>“The cat sat” → subword tokenizer → <code>[2, 47, 193]</code> (integer IDs)</p>
<p>Shape: <code>1 × 3</code> (integers)</p>
<hr>
<p><strong>Step 2 — Token embedding lookup.</strong></p>
<p>Each integer is mapped to a 768-dimensional vector via the embedding table.</p>
<p>Shape: <code>1 × 3</code> → <code>1 × 3 × 768</code></p>
<hr>
<p><strong>Step 3 — Add positional encodings.</strong></p>
<p>A positional encoding vector is added to each token’s embedding. The result encodes both <em>what</em> the token is (token embedding) and <em>where</em> it sits (positional encoding).</p>
<p>Shape: <code>1 × 3 × 768</code> (unchanged)</p>
<hr>
<p><strong>Step 4 — N encoder layers.</strong></p>
<p>Each encoder layer applies two sublayers:</p>
<ul>
<li><strong>Multi-head self-attention</strong>: All 3 tokens attend to all 3 tokens. The <img src="https://latex.codecogs.com/png.latex?3%20%C3%97%203"> attention weight matrix (12 heads, each with its own <img src="https://latex.codecogs.com/png.latex?3%20%C3%97%203"> weights) is computed, and each token’s representation is updated as a weighted mix of all token values.</li>
<li><strong>FFN</strong>: Each token’s updated representation passes through the 2-layer FFN independently.</li>
</ul>
<p>Shape at every encoder layer: <code>1 × 3 × 768</code> (unchanged throughout)</p>
<p>After <img src="https://latex.codecogs.com/png.latex?N"> encoder layers, each of the 3 token positions holds a deeply <strong>contextualized representation</strong> — the meaning of “cat” is now informed by the presence of “The” and “sat” in context.</p>
<p><strong>Encoder output:</strong> <code>1 × 3 × 768</code> — this is what the decoder will attend to.</p>
<hr>
<p><strong>Step 5 — Decoder receives the start token.</strong></p>
<p>Decoder input starts with a start-of-sequence token <code>[BOS]</code>.</p>
<p>Shape: <code>1 × 1</code> → (after embedding) <code>1 × 1 × 768</code></p>
<hr>
<p><strong>Step 6 — N decoder layers.</strong></p>
<p>Each decoder layer applies three sublayers:</p>
<ol type="1">
<li><p><strong>Masked self-attention</strong>: Only 1 token so far, so the <img src="https://latex.codecogs.com/png.latex?1%20%C3%97%201"> causal attention matrix is trivially “attend to self.” Shape: <code>1 × 1 × 768</code>.</p></li>
<li><p><strong>Cross-attention</strong>: Q comes from the decoder hidden state (<code>1 × 1 × 768</code>). K and V come from the encoder output (<code>1 × 3 × 768</code>). Attention weights have shape <code>1 × 1 × 3</code> — the single decoder position attends to all 3 encoder positions. Output: <code>1 × 1 × 768</code>.</p></li>
<li><p><strong>FFN</strong>: <code>1 × 1 × 768</code> processed position-wise.</p></li>
</ol>
<hr>
<p><strong>Step 7 — LM head.</strong></p>
<p>The decoder output at the final position (<code>1 × 1 × 768</code>) is projected to <code>vocab_sz</code> via a linear layer, then softmax gives a probability distribution over the vocabulary.</p>
<p>Shape: <code>1 × 1 × 768</code> → <code>1 × 1 × vocab_sz</code> → sample token → e.g., <code>"Die"</code> (German “The”)</p>
<hr>
<p><strong>Step 8 — Autoregressive loop.</strong></p>
<p>Append <code>"Die"</code> to the decoder input. Repeat Steps 6–7 with decoder input <code>[BOS, "Die"]</code> to generate the next token. Continue until <code>[EOS]</code> is sampled or the maximum length is reached.</p>
<hr>
<p>The key insight from this walkthrough: <strong>the encoder runs once</strong> for the full source sequence. The <strong>decoder runs once per generated token</strong>, attending to the full encoder output (which never changes) at every step via cross-attention.</p>
</section>
<section id="what-transformers-actually-learn" class="level2">
<h2 class="anchored" data-anchor-id="what-transformers-actually-learn">15. What Transformers Actually Learn</h2>
<p>Understanding the architecture is one thing; understanding what trained Transformers actually compute is another. Here is a brief map of empirical findings.</p>
<section id="attention-head-specialization" class="level3">
<h3 class="anchored" data-anchor-id="attention-head-specialization">15.1 Attention Head Specialization</h3>
<p>Clark et al.&nbsp;(2019) systematically analyzed BERT’s attention patterns across all layers and heads and found striking specialization:</p>
<ul>
<li><strong>Syntactic dependency heads</strong>: Certain heads consistently attend from a token to its syntactic governor (the word it depends on), recovering dependency parse relationships with high accuracy — without ever being trained on parse labels.</li>
<li><strong>Positional heads</strong>: Some heads attend predominantly to adjacent tokens (the previous or next token), implementing local sliding-window attention.</li>
<li><strong><code>[SEP]</code> heads</strong>: Many heads in middle layers attend heavily to <code>[SEP]</code> tokens. The interpretation: when no strong relationship exists, these heads use <code>[SEP]</code> as a “garbage collector” — routing excess attention somewhere harmless.</li>
</ul>
<p>This specialization is <strong>emergent</strong>, not designed. It arises purely from the training signal on downstream tasks.</p>
</section>
<section id="ffn-layers-as-factual-memories" class="level3">
<h3 class="anchored" data-anchor-id="ffn-layers-as-factual-memories">15.2 FFN Layers as Factual Memories</h3>
<p>Geva et al.&nbsp;(2021) showed that FFN sublayers act as key-value memories. The first linear layer’s weight rows act as “keys” that activate on specific input patterns; the second linear layer’s corresponding columns act as “values” that are retrieved and output.</p>
<p>This framing explains where factual knowledge lives in a language model. When a model correctly completes “The Eiffel Tower is located in ___“, the relevant association (Eiffel Tower → Paris) is likely stored as a key-value pair in the FFN weights of one or more layers — not in the attention matrices.</p>
</section>
<section id="layer-depth-and-abstraction" class="level3">
<h3 class="anchored" data-anchor-id="layer-depth-and-abstraction">15.3 Layer Depth and Abstraction</h3>
<p>Probing classifiers — small models trained to predict linguistic properties from internal representations — consistently find that:</p>
<ul>
<li><strong>Early layers</strong> (1–4): Surface-level features — part-of-speech tags, token identity, local syntax.</li>
<li><strong>Middle layers</strong> (5–12): Syntactic structure, phrase-level groupings, coreference.</li>
<li><strong>Later layers</strong>: Task-specific, abstract semantic features.</li>
</ul>
<p>The architecture explains <em>why</em> this gradient exists. Early layers receive representations that have undergone very little contextualization — essentially just the token and positional embeddings. They can only access local, surface-level patterns. Later layers, on the other hand, are reading from a residual stream that has already accumulated many rounds of attention and FFN processing. Each layer builds on the contextualized representations produced by all previous layers, enabling increasingly abstract structures to emerge. The depth gradient is not a design choice — it is a direct consequence of how information accumulates through residual connections.</p>
</section>
</section>
<section id="modern-improvements" class="level2">
<h2 class="anchored" data-anchor-id="modern-improvements">16. Modern Improvements</h2>
<p>The original Transformer (2017) has been refined substantially. Here are the key improvements that appear in modern LLMs, with brief explanations of why each was adopted:</p>
<table class="table">
<colgroup>
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th>Improvement</th>
<th>What changes</th>
<th>Why</th>
<th>Used in</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Pre-LayerNorm</strong></td>
<td>LN moves inside the residual branch</td>
<td>Training stability at scale; no warm-up required</td>
<td>GPT-2, LLaMA, Mistral</td>
</tr>
<tr class="even">
<td><strong>Rotary Position Embedding (RoPE)</strong></td>
<td>Replaces absolute pos. embeddings with rotation of Q and K</td>
<td>Better length generalization; relative position naturally encoded at every layer</td>
<td>LLaMA, Mistral, GPT-NeoX, Qwen</td>
</tr>
<tr class="odd">
<td><strong>Grouped Query Attention (GQA)</strong></td>
<td>Multiple Q heads share a single K and V head</td>
<td>Reduces KV cache memory at inference without meaningful accuracy loss</td>
<td>LLaMA 2/3, Mistral</td>
</tr>
<tr class="even">
<td><strong>SwiGLU activation</strong></td>
<td>Replaces GELU in FFN with a gated linear unit: <img src="https://latex.codecogs.com/png.latex?%5Ctext%7BSwiGLU%7D(x)%20=%20%5Ctext%7BSwish%7D(xW_1)%20%5Codot%20xW_2"></td>
<td>Consistently higher benchmark performance at equivalent parameter counts</td>
<td>LLaMA, PaLM, Gemma</td>
</tr>
<tr class="odd">
<td><strong>FlashAttention</strong></td>
<td>Reorders attention computation to minimize memory bandwidth</td>
<td><img src="https://latex.codecogs.com/png.latex?O(N)"> memory instead of <img src="https://latex.codecogs.com/png.latex?O(N%5E2)">; 2–4x faster; identical numerical outputs</td>
<td>Used in most modern training stacks</td>
</tr>
<tr class="even">
<td><strong>RMSNorm</strong></td>
<td>Replaces LayerNorm with root-mean-square normalization (no mean subtraction)</td>
<td>Simpler, ~10% faster, equivalent quality</td>
<td>LLaMA, Mistral, Gemma</td>
</tr>
</tbody>
</table>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion">17. Conclusion</h2>
<p>In this post, we built the Transformer architecture from scratch — starting from the failure modes of RNNs, building the attention mechanism step by step, implementing each component in PyTorch with annotated shapes, and assembling the encoder-only, decoder-only, and encoder-decoder variants. We also traced a complete end-to-end forward pass and surveyed what trained Transformers empirically learn.</p>
<p>The architecture’s dominance across language, vision, speech, and biology stems from a coherent set of design choices — each solving a specific problem with a specific mechanism.</p>
<section id="key-takeaways" class="level3">
<h3 class="anchored" data-anchor-id="key-takeaways">Key Takeaways</h3>
<ol type="1">
<li><p><strong>Attention replaces sequential recurrence with parallel direct communication.</strong> Every token attends to every other token in a single matrix operation. No hidden state bottleneck, no sequential dependency, no vanishing gradient through time — the fundamental failures of RNNs are eliminated at the architectural level, not patched over.</p></li>
<li><p><strong>Q, K, V separation is intentional, not arbitrary.</strong> What a token <em>wants</em> (query), what it <em>offers</em> (key), and what it <em>says</em> (value) are three genuinely different roles. Separating them — as in the library lookup analogy — gives the model the flexibility to learn very different relationships for each. A single projection would conflate all three.</p></li>
<li><p><strong>Multi-head attention gives the model multiple simultaneous perspectives.</strong> Each head operates in its own lower-dimensional subspace and learns to track different relationship types: one head for syntax, one for coreference, one for local context. This specialization is emergent — it arises from the training signal, not from any explicit design constraint.</p></li>
<li><p><strong>The FFN is the knowledge store; attention is the routing system.</strong> Attention decides which tokens talk to which and mixes their representations. The FFN then transforms each token’s representation independently and nonlinearly — this is where factual associations are stored. Without the FFN, stacked attention layers collapse to a single linear transformation.</p></li>
<li><p><strong>Skip connections and LayerNorm make depth trainable.</strong> Residual connections create gradient highways that bypass each sublayer entirely, making it possible to train networks dozens of layers deep. Pre-LayerNorm (inside the residual branch) stabilizes training at scale without requiring learning rate warm-up.</p></li>
<li><p><strong>Architecture determines what tokens can see; everything else is shared.</strong> The only fundamental difference between an encoder and a decoder is the causal mask. The same attention mechanism, FFN, LayerNorm, and residual structure underlies all three variants — encoder-only, decoder-only, and encoder-decoder — differing only in which tokens each position is allowed to attend to.</p></li>
</ol>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Core Architecture Is Stable
</div>
</div>
<div class="callout-body-container callout-body">
<p>Despite years of improvements — RoPE, GQA, SwiGLU, FlashAttention, RMSNorm — the fundamental architecture described in this post has not changed since 2017. The overall structure (attention + FFN + residual + norm, stacked <img src="https://latex.codecogs.com/png.latex?N"> times) is the same in GPT-4, LLaMA 3, and Gemini as it was in the original “Attention Is All You Need.” If you understand this post, you understand the backbone of essentially all modern AI.</p>
</div>
</div>
<p><strong>What to explore next:</strong></p>
<ul>
<li><a href="../../posts/nlp/GPT2-From-Scratch.html"><strong>Building GPT-2 from Scratch</strong></a> — takes the decoder-only architecture from this post and implements a full GPT-2 training run, including mixed precision, Flash Attention, and distributed training</li>
<li><a href="../../posts/nlp/BPE-Tokenizer.html"><strong>BPE Tokenizer from Scratch</strong></a> — implements the tokenizer that sits upstream of everything in this post</li>
<li><a href="../../posts/nlp/Tokenization-Strategies.html"><strong>Tokenization Strategies</strong></a> — compares character, word, and subword tokenization with code examples and real model outputs</li>
</ul>
</section>
</section>
<section id="references-resources" class="level2">
<h2 class="anchored" data-anchor-id="references-resources">References &amp; Resources</h2>
<ul>
<li>Vaswani et al.&nbsp;(2017). <a href="https://arxiv.org/abs/1706.03762">Attention Is All You Need</a>. — The original Transformer paper.</li>
<li>Devlin et al.&nbsp;(2018). <a href="https://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a>.</li>
<li>Ba et al.&nbsp;(2016). <a href="https://arxiv.org/abs/1607.06450">Layer Normalization</a>. — The original LayerNorm paper.</li>
<li>He et al.&nbsp;(2016). <a href="https://arxiv.org/abs/1512.03385">Deep Residual Learning for Image Recognition</a>. — Skip connections and loss landscape analysis.</li>
<li>Clark et al.&nbsp;(2019). <a href="https://arxiv.org/abs/1906.04341">What Does BERT Look At? An Analysis of BERT’s Attention</a>. — BERTology: what different attention heads learn.</li>
<li>Geva et al.&nbsp;(2021). <a href="https://arxiv.org/abs/2012.14913">Transformer Feed-Forward Layers Are Key-Value Memories</a>. — FFN layers as associative memory stores.</li>
<li>Su et al.&nbsp;(2021). <a href="https://arxiv.org/abs/2104.09864">RoFormer: Enhanced Transformer with Rotary Position Embedding</a>. — The RoPE paper.</li>
<li>Dao et al.&nbsp;(2022). <a href="https://arxiv.org/abs/2205.14135">FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness</a>.</li>
<li>Srivastava et al.&nbsp;(2014). <a href="https://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf">Dropout: A Simple Way to Prevent Neural Networks from Overfitting</a>.</li>
<li><a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html">The Annotated Transformer</a> — line-by-line walkthrough of the original paper’s code.</li>
<li><a href="https://github.com/karpathy/nanoGPT">Andrej Karpathy’s NanoGPT</a> — minimal, readable GPT implementation.</li>
<li><a href="https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/">Lilian Weng’s The Transformer Family v2.0</a> — comprehensive survey of Transformer variants.</li>
</ul>



</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>NLP</category>
  <guid>https://imaddabbura.github.io/posts/nlp/Transformer-Architecture-Explained.html</guid>
  <pubDate>Mon, 14 Feb 2022 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/nlp/images/transformer-arch.png" medium="image" type="image/png" height="96" width="144"/>
</item>
<item>
  <title>Inside LSTMs: Implementing and Optimizing Sequential Models from First Principles</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/nlp/LSTM-Annotated-Implementation.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="why-implement-an-lstm-from-scratch" class="level2">
<h2 class="anchored" data-anchor-id="why-implement-an-lstm-from-scratch">Why Implement an LSTM from Scratch?</h2>
<p>If you’ve used <code>nn.LSTM</code> in PyTorch, you’ve seen it work. But <em>how</em> does it decide what to remember and what to forget? Why does it need four gates instead of one? And why is it so much better than a vanilla RNN at handling long sequences?</p>
<p>The best way to answer these questions is to build one yourself. In this post, we’ll start with the problem that motivated LSTMs (vanishing gradients), build up the intuition for how they solve it, then implement both <code>LSTMCell</code> and a multi-layer <code>LSTM</code> from scratch in PyTorch — verifying each against the official implementation down to floating-point precision.</p>
</section>
<section id="the-vanishing-gradient-problem" class="level2">
<h2 class="anchored" data-anchor-id="the-vanishing-gradient-problem">The Vanishing Gradient Problem</h2>
<p><strong>Long Short-Term Memory (LSTM)</strong> is a recurrent neural network architecture introduced by <a href="https://www.bioinf.jku.at/publications/older/2604.pdf">Hochreiter and Schmidhuber (1997)</a> to solve the <strong>vanishing gradient problem</strong> — the central failure mode of vanilla RNNs on long sequences.</p>
<p>To understand why LSTMs exist, we first need to understand what goes wrong. In a vanilla RNN, the hidden state is <em>completely overwritten</em> at every time step:</p>
<p><img src="https://latex.codecogs.com/png.latex?h_t%20=%20%5Ctanh(W_%7Bhh%7D%20%5Ccdot%20h_%7Bt-1%7D%20+%20W_%7Bxh%7D%20%5Ccdot%20x_t%20+%20b)"></p>
<p>During backpropagation, the gradient of the loss with respect to an early hidden state <img src="https://latex.codecogs.com/png.latex?h_1"> must pass through the <img src="https://latex.codecogs.com/png.latex?%5Ctanh"> nonlinearity and the weight matrix <img src="https://latex.codecogs.com/png.latex?W_%7Bhh%7D"> at <em>every single time step</em> between <img src="https://latex.codecogs.com/png.latex?h_T"> and <img src="https://latex.codecogs.com/png.latex?h_1">. If the sequence has 100 tokens, the gradient is multiplied by <img src="https://latex.codecogs.com/png.latex?W_%7Bhh%7D"> roughly 100 times. If the dominant eigenvalue of <img src="https://latex.codecogs.com/png.latex?W_%7Bhh%7D"> is even slightly less than 1 — say 0.9 — the gradient shrinks by a factor of <img src="https://latex.codecogs.com/png.latex?0.9%5E%7B100%7D%20%5Capprox%200.00003">. The signal from early tokens effectively disappears.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Fundamental Issue
</div>
</div>
<div class="callout-body-container callout-body">
<p>The problem isn’t just mathematical — it has a concrete consequence: <strong>vanilla RNNs can’t learn long-range dependencies</strong>. If the answer to a question depends on a word 50 tokens earlier in the sentence, the gradient signal connecting them is essentially zero. The model can’t learn that relationship, no matter how long you train.</p>
</div>
</div>
</section>
<section id="how-lstms-fix-it" class="level2">
<h2 class="anchored" data-anchor-id="how-lstms-fix-it">How LSTMs Fix It</h2>
<p>The LSTM introduces a <strong>cell state</strong> <img src="https://latex.codecogs.com/png.latex?c_t"> — a separate memory channel that runs parallel to the hidden state. The critical difference is in <em>how</em> it gets updated:</p>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th></th>
<th>Vanilla RNN</th>
<th>LSTM Cell State</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Update rule</strong></td>
<td><img src="https://latex.codecogs.com/png.latex?h_t%20=%20%5Ctanh(W%20%5Ccdot%20h_%7Bt-1%7D%20+%20%5Cldots)"></td>
<td><img src="https://latex.codecogs.com/png.latex?c_t%20=%20f_t%20%5Codot%20c_%7Bt-1%7D%20+%20i_t%20%5Codot%20g_t"></td>
</tr>
<tr class="even">
<td><strong>Mechanism</strong></td>
<td>Complete <em>replacement</em> through nonlinearity</td>
<td>Selective <em>modification</em> via additive gating</td>
</tr>
<tr class="odd">
<td><strong>Gradient flow</strong></td>
<td>Must pass through <img src="https://latex.codecogs.com/png.latex?%5Ctanh"> and <img src="https://latex.codecogs.com/png.latex?W"> at every step</td>
<td>Can flow <em>directly</em> through the forget gate <img src="https://latex.codecogs.com/png.latex?f_t"></td>
</tr>
<tr class="even">
<td><strong>Long-range memory</strong></td>
<td>Exponential decay</td>
<td>Controlled retention</td>
</tr>
</tbody>
</table>
<p>The cell state update is <strong>additive</strong>: when the forget gate <img src="https://latex.codecogs.com/png.latex?f_t"> is close to 1 and the input gate <img src="https://latex.codecogs.com/png.latex?i_t"> is close to 0, the cell state passes through <em>unchanged</em>: <img src="https://latex.codecogs.com/png.latex?c_t%20%5Capprox%20c_%7Bt-1%7D">. Gradients flow backward through time with minimal decay — no weight matrix or nonlinearity in the way.</p>
<p>If this looks familiar, it should — it’s the same principle behind <strong>residual connections</strong> in ResNets. In a ResNet, each layer computes <img src="https://latex.codecogs.com/png.latex?y%20=%20F(x)%20+%20x">: the input passes through unchanged, and the layer only learns the <em>residual</em>. The LSTM cell state works the same way, but across <strong>time instead of depth</strong>: the previous cell state passes through (scaled by <img src="https://latex.codecogs.com/png.latex?f_t">), and the network adds a residual update (<img src="https://latex.codecogs.com/png.latex?i_t%20%5Codot%20g_t">). Both create a gradient highway. ResNets made it possible to train 100+ layer networks; the LSTM cell state makes it possible to learn dependencies across 100+ time steps. Same insight, different axis.</p>
<p align="center">
<img src="https://imaddabbura.github.io/posts/nlp/images/lstm-cell.jpeg" style="width: 500px;"><br>

</p><center>
<u><b><font color="00b7e4">Figure 1:</font></b></u> The LSTM cell. The horizontal line at the top is the cell state — the “highway” through time. The four yellow boxes (<img src="https://latex.codecogs.com/png.latex?%5Csigma,%20%5Csigma,%20%5Ctanh,%20%5Csigma">) are the forget, input, cell, and output gates respectively. The cell state is updated additively (the ⊕ node), while the gates use element-wise multiplication (⊗) to control information flow.
</center>

<p></p>
<section id="why-two-states" class="level3">
<h3 class="anchored" data-anchor-id="why-two-states">Why Two States?</h3>
<p>A vanilla RNN has a single hidden state that must do <em>everything</em>: store long-term memory, carry short-term context, and produce the output that downstream layers consume. That’s too many jobs for one vector — optimizing the hidden state for the current prediction destroys the long-term information stored in it.</p>
<p>LSTMs split this into two specialized roles:</p>
<p><strong>Cell state (<img src="https://latex.codecogs.com/png.latex?c_t">): the long-term internal memory.</strong> The cell state is the LSTM’s private memory — never directly exposed to the rest of the network. Its job is to <em>retain information across long distances</em> without interference. Because it’s updated additively, gradients can flow through it across hundreds of time steps. Think of it as a notebook that the LSTM writes to and reads from, but never shows to anyone directly.</p>
<p><strong>Hidden state (<img src="https://latex.codecogs.com/png.latex?h_t">): the short-term working output.</strong> The hidden state is what the LSTM <em>exposes</em> to the outside world — the input to the next layer, the softmax, or whatever comes next. It’s computed by selectively reading from the cell state via the output gate: <img src="https://latex.codecogs.com/png.latex?h_t%20=%20o_t%20%5Codot%20%5Ctanh(c_t)">. The output gate decides: <em>“Given everything I know and the current context, what’s relevant right now?”</em></p>
<p>This separation is crucial. The cell state can hold information like “the subject is plural” or “we’re inside a quotation” for as long as needed, without being distorted by the demands of predicting intermediate tokens. When it <em>is</em> needed — the output gate reads it out at exactly the right moment.</p>
<table class="table">
<colgroup>
<col style="width: 33%">
<col style="width: 33%">
<col style="width: 33%">
</colgroup>
<thead>
<tr class="header">
<th></th>
<th>Cell State (<img src="https://latex.codecogs.com/png.latex?c_t">)</th>
<th>Hidden State (<img src="https://latex.codecogs.com/png.latex?h_t">)</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><strong>Role</strong></td>
<td>Long-term memory</td>
<td>Short-term working output</td>
</tr>
<tr class="even">
<td><strong>Visible to</strong></td>
<td>Only the LSTM itself (internal)</td>
<td>Next layer, softmax, classifier (external)</td>
</tr>
<tr class="odd">
<td><strong>Updated by</strong></td>
<td>Forget gate (erase) + input gate (write)</td>
<td>Output gate reading from cell state</td>
</tr>
<tr class="even">
<td><strong>Gradient flow</strong></td>
<td>Additive — gradients pass through cleanly</td>
<td>Through tanh and output gate — more lossy</td>
</tr>
<tr class="odd">
<td><strong>Analogy</strong></td>
<td>A notebook you write in privately</td>
<td>The answer you speak aloud when asked</td>
</tr>
</tbody>
</table>
</section>
<section id="a-concrete-example" class="level3">
<h3 class="anchored" data-anchor-id="a-concrete-example">A Concrete Example</h3>
<p>Consider: <em>“The cat, which sat on the mat in the living room near the window overlooking the garden, <strong>was</strong> sleeping.”</em> The verb “was” must agree with “cat” (singular), not “garden” or “window” — a dependency spanning ~15 tokens. A vanilla RNN’s gradient signal from “was” back to “cat” would be multiplied by <img src="https://latex.codecogs.com/png.latex?W_%7Bhh%7D"> fifteen times — likely vanishing. An LSTM can keep “cat = singular noun” in its cell state with the forget gate near 1, preserving the information until it’s needed at “was.”</p>
<p>One important constraint: RNNs and LSTMs are <strong>sequential models</strong> — the output at time <img src="https://latex.codecogs.com/png.latex?t"> depends on the hidden state from <img src="https://latex.codecogs.com/png.latex?t-1">. We cannot parallelize across time steps; we must iterate one token at a time. This is the limitation that the Transformer (<a href="https://arxiv.org/abs/1706.03762">Vaswani et al., 2017</a>) later addressed with self-attention.</p>
</section>
</section>
<section id="inside-the-lstm-cell" class="level2">
<h2 class="anchored" data-anchor-id="inside-the-lstm-cell">Inside the LSTM Cell</h2>
<p>An <code>LSTMCell</code> computes four gates, then uses them to update the cell and hidden states. Each gate has the same dimension as the hidden state:</p>
<img src="https://latex.codecogs.com/png.latex?%5Cbegin%7Barray%7D%7Bll%7D%20%5C%5C%0Ai_t%20=%20%5Csigma(W_%7Bii%7D%20x_t%20+%20b_%7Bii%7D%20+%20W_%7Bih%7D%20h_%7Bt-1%7D%20+%20b_%7Bhi%7D)%20%5C%5C%0Af_t%20=%20%5Csigma(W_%7Bif%7D%20x_t%20+%20b_%7Bif%7D%20+%20W_%7Bhf%7D%20h_%7Bt-1%7D%20+%20b_%7Bhf%7D)%20%5C%5C%0Ag_t%20=%20%5Ctanh(W_%7Big%7D%20x_t%20+%20b_%7Big%7D%20+%20W_%7Bhg%7D%20h_%7Bt-1%7D%20+%20b_%7Bhg%7D)%20%5C%5C%0Ao_t%20=%20%5Csigma(W_%7Bio%7D%20x_t%20+%20b_%7Bio%7D%20+%20W_%7Bho%7D%20h_%7Bt-1%7D%20+%20b_%7Bho%7D)%20%5C%5C%0Ac_t%20=%20f_t%20%5Codot%20c_%7Bt-1%7D%20+%20i_t%20%5Codot%20g_t%20%5C%5C%0Ah_t%20=%20o_t%20%5Codot%20%5Ctanh(c_t)%20%5C%5C%0A%5Cend%7Barray%7D">
<table class="table">
<colgroup>
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
<col style="width: 25%">
</colgroup>
<thead>
<tr class="header">
<th>Gate</th>
<th>Name</th>
<th>Activation</th>
<th>What It Does</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?i_t"></td>
<td><strong>Input gate</strong></td>
<td>Sigmoid (0–1)</td>
<td>How much of the <em>new</em> candidate values to write into the cell</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?f_t"></td>
<td><strong>Forget gate</strong></td>
<td>Sigmoid (0–1)</td>
<td>How much of the <em>old</em> cell state to keep (1 = remember everything, 0 = forget everything)</td>
</tr>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?g_t"></td>
<td><strong>Cell gate</strong></td>
<td>Tanh (-1 to 1)</td>
<td>The candidate new values to potentially add to the cell state</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?o_t"></td>
<td><strong>Output gate</strong></td>
<td>Sigmoid (0–1)</td>
<td>How much of the cell state to expose as the hidden state output</td>
</tr>
</tbody>
</table>
<p>Notice the activation functions: three gates use <strong>sigmoid</strong>, but the cell gate uses <strong>tanh</strong>. This isn’t arbitrary — it reflects their different roles. The sigmoid gates (<img src="https://latex.codecogs.com/png.latex?i_t,%20f_t,%20o_t">) answer <em>“how much?”</em> questions: how much to write, how much to keep, how much to expose. Sigmoid squashes values to (0, 1), making each gate a dimmer switch that scales its input between “fully off” and “fully on.” The cell gate <img src="https://latex.codecogs.com/png.latex?g_t"> answers a different question: <em>“what values?”</em> It proposes candidate content to write into the cell state. Tanh maps to (-1, 1), which is critical — it allows the cell state to both <strong>increase and decrease</strong>. If <img src="https://latex.codecogs.com/png.latex?g_t"> used sigmoid (0, 1), the additive update <img src="https://latex.codecogs.com/png.latex?i_t%20%5Codot%20g_t"> could only ever push the cell state upward, and it would grow without bound. Tanh lets the network write negative corrections, keeping the cell state centered and bounded.</p>
<section id="independent-gates-four-operating-modes" class="level3">
<h3 class="anchored" data-anchor-id="independent-gates-four-operating-modes">Independent Gates: Four Operating Modes</h3>
<p>A critical design choice is that the input gate and forget gate are <strong>completely independent</strong> — computed from separate weight matrices and biases, with nothing constraining them to sum to 1. The network is free to set both high, both low, or any combination.</p>
<p>Contrast this with the GRU (Gated Recurrent Unit), where the equivalent gates <em>are</em> complementary: a single update gate <img src="https://latex.codecogs.com/png.latex?z_t"> weights new content by <img src="https://latex.codecogs.com/png.latex?z_t"> and old content by <img src="https://latex.codecogs.com/png.latex?(1%20-%20z_t)">, forcing a trade-off. The GRU is more parameter-efficient, but less expressive — it can only interpolate between “keep old” and “write new.”</p>
<p>The LSTM’s independence gives it four distinct operating modes:</p>
<table class="table">
<colgroup>
<col style="width: 20%">
<col style="width: 20%">
<col style="width: 20%">
<col style="width: 20%">
<col style="width: 20%">
</colgroup>
<thead>
<tr class="header">
<th>Forget <img src="https://latex.codecogs.com/png.latex?f_t"></th>
<th>Input <img src="https://latex.codecogs.com/png.latex?i_t"></th>
<th>Mode</th>
<th>Effect</th>
<th>When It’s Useful</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%201"></td>
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%201"></td>
<td><strong>Accumulate</strong></td>
<td>Keep old state <em>and</em> write new info</td>
<td>Building up a running representation (e.g., accumulating features of a described entity)</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%200"></td>
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%201"></td>
<td><strong>Replace</strong></td>
<td>Flush old state, write new info</td>
<td>Topic change, sentence boundary — start fresh with new content</td>
</tr>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%201"></td>
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%200"></td>
<td><strong>Preserve</strong></td>
<td>Keep old state, ignore current input</td>
<td>Carrying information across irrelevant tokens (e.g., remembering subject across a parenthetical)</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%200"></td>
<td><img src="https://latex.codecogs.com/png.latex?%5Capprox%200"></td>
<td><strong>Reset</strong></td>
<td>Forget old state <em>and</em> ignore input</td>
<td>Clearing a dimension that’s no longer needed</td>
</tr>
</tbody>
</table>
<p>The GRU can only express the diagonal of this table. This is why LSTMs tend to outperform GRUs on tasks requiring long-range memory: the accumulate mode lets information persist indefinitely while still absorbing new inputs, and the reset mode provides a clean mechanism for freeing capacity.</p>
</section>
<section id="gates-as-learned-pattern-detectors" class="level3">
<h3 class="anchored" data-anchor-id="gates-as-learned-pattern-detectors">Gates as Learned Pattern Detectors</h3>
<p>It’s tempting to think of gates as simple switches, but each gate is a <strong>learned pattern detector</strong> — analogous to how a CNN filter activates on specific visual patterns, a gate’s weight matrix learns to activate on specific <em>contextual patterns</em> in the input and hidden state. A CNN filter produces a high activation when the input patch matches its learned pattern; a gate weight matrix produces a high activation (close to 1 after sigmoid) when the combination of <img src="https://latex.codecogs.com/png.latex?x_t"> and <img src="https://latex.codecogs.com/png.latex?h_%7Bt-1%7D"> matches <em>its</em> learned pattern. CNN filters detect <em>spatial</em> patterns in pixel neighborhoods; gate weights detect <em>contextual</em> patterns across the current token and sequence history.</p>
<p>Consider the forget gate: <img src="https://latex.codecogs.com/png.latex?f_t%20=%20%5Csigma(W_%7Bif%7D%20%5Ccdot%20x_t%20+%20W_%7Bhf%7D%20%5Ccdot%20h_%7Bt-1%7D%20+%20b_f)">. After training, specific rows of these weight matrices become specialized detectors:</p>
<ul>
<li>Some rows might detect <strong>“end of clause”</strong> patterns (a period, “but”) — signaling that old context should be flushed</li>
<li>Other rows might detect <strong>“continuation”</strong> patterns (a comma, “which”) — signaling that existing context should be preserved</li>
<li>Rows in the input gate might detect <strong>“salient new information”</strong> patterns (a named entity, a negation word) — signaling that this input should be written into memory</li>
</ul>
<p>This happens <strong>per dimension</strong> of the hidden state. The gate output is a vector, not a scalar — dimension 42 of the forget gate might be close to 0 (forget) while dimension 73 is close to 1 (keep), because each dimension stores different information and each gate dimension detects different patterns.</p>
<div class="callout callout-style-default callout-tip callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Single-Matrix Trick
</div>
</div>
<div class="callout-body-container callout-body">
<p>Even though we describe four separate gates, in practice we compute them all in <strong>one matrix multiplication</strong> by concatenating the four weight matrices into a single <code>4 * hidden_size</code> matrix. We then split the result into four chunks. This is much faster because it replaces four small matmuls with one large one — better utilizing GPU parallelism and memory bandwidth.</p>
</div>
</div>
</section>
</section>
<section id="implementation" class="level2">
<h2 class="anchored" data-anchor-id="implementation">Implementation</h2>
<p>With the conceptual foundation in place, let’s turn these equations into code. We’ll build two modules — <code>LSTMCell</code> (one time step) and <code>LSTM</code> (full sequences with multiple layers) — verifying each against PyTorch’s official implementation.</p>
<section id="lstmcell" class="level3">
<h3 class="anchored" data-anchor-id="lstmcell"><code>LSTMCell</code></h3>
<p>We implement two versions: a verbose one that makes every operation explicit (separate weight matrices for each gate), and a compact one using <code>nn.Linear</code> with the single-matrix trick. Both produce identical results — the compact version is what you’d use in practice.</p>
<div id="4f325555-fdab-47c3-94a6-38a4a2e6dd11" class="cell" data-execution_count="1">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch</span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.nn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> nn</span>
<span id="cb1-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> torch.nn.functional <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> F</span></code></pre></div>
</details>
</div>
<div id="e7aa6509-79db-4a54-9d47-e8c656fa28d7" class="cell" data-execution_count="2">
<div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Long version</span></span>
<span id="cb2-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> LSTMCellNew(nn.Module):</span>
<span id="cb2-3">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, input_sz, hidden_sz, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>):</span>
<span id="cb2-4">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb2-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.weight_ih <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Parameter(torch.randn((input_sz, hidden_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>)))</span>
<span id="cb2-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.weight_hh <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Parameter(torch.randn((hidden_sz, hidden_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>)))</span>
<span id="cb2-7">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bias_ih <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Parameter(torch.zeros(hidden_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>))</span>
<span id="cb2-8">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bias_hh <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Parameter(torch.zeros(hidden_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>))</span>
<span id="cb2-9"></span>
<span id="cb2-10">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x, h, c):</span>
<span id="cb2-11">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## B x hidden_sz</span></span>
<span id="cb2-12">        out <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.weight_ih <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> h <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">@</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.weight_hh <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bias_ih <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.bias_hh</span>
<span id="cb2-13">        i, f, g, o <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.split(out, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb2-14">        i, f, o <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)</span>
<span id="cb2-15">        g <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tanh(g)</span>
<span id="cb2-16">        c_t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> f <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> g</span>
<span id="cb2-17">        h_t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> o <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> torch.tanh(c_t)</span>
<span id="cb2-18">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> h_t, c_t</span></code></pre></div>
</div>
<div id="d5630f35-d6a3-435a-8a8d-0ae5c2b87601" class="cell" data-execution_count="3">
<div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Short version utilizing linear layer module</span></span>
<span id="cb3-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> LSTMCellNew(nn.Module):</span>
<span id="cb3-3">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, input_sz, hidden_sz, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>):</span>
<span id="cb3-4">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb3-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ih <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(input_sz, hidden_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>bias)</span>
<span id="cb3-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.hh <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.Linear(hidden_sz, hidden_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>bias)</span>
<span id="cb3-7"></span>
<span id="cb3-8">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x, h, c):</span>
<span id="cb3-9">        out <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.ih(x) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.hh(h)</span>
<span id="cb3-10">        i, f, g, o <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.split(out, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, dim<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb3-11">        i, f, o <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)</span>
<span id="cb3-12">        g <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.tanh(g)</span>
<span id="cb3-13">        c_t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> f <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> g</span>
<span id="cb3-14">        h_t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> o <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> torch.tanh(c_t)</span>
<span id="cb3-15">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> h_t, c_t</span></code></pre></div>
</div>
<div id="2c9e3a64-ba74-4c22-84ad-77d212fe2f31" class="cell" data-execution_count="4">
<div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1">batch_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">64</span></span>
<span id="cb4-2">seq_len <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span></span>
<span id="cb4-3">input_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span></span>
<span id="cb4-4">hidden_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span></span>
<span id="cb4-5">num_layers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span></span></code></pre></div>
</div>
<div id="18f4dcb4-49b0-4f90-bf15-961483ed0471" class="cell" data-execution_count="5">
<div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(seq_len, batch_sz, input_sz, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>torch.float32)</span>
<span id="cb5-2">c_0 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(num_layers, batch_sz, hidden_sz, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>torch.float32)</span>
<span id="cb5-3">h_0 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.randn(num_layers, batch_sz, hidden_sz, dtype<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>torch.float32)</span></code></pre></div>
</div>
<div id="13915d7b-7746-480b-8bfe-85d2000c21a1" class="cell" data-execution_count="6">
<div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1">pytorch_cell <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LSTMCell(input_sz, hidden_sz, bias<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">True</span>)</span>
<span id="cb6-2">(</span>
<span id="cb6-3">    pytorch_cell.weight_hh.shape,</span>
<span id="cb6-4">    pytorch_cell.weight_ih.shape,</span>
<span id="cb6-5">    pytorch_cell.bias_ih.shape,</span>
<span id="cb6-6">    pytorch_cell.bias_hh.shape,</span>
<span id="cb6-7">)</span></code></pre></div>
<div class="cell-output cell-output-display" data-execution_count="6">
<pre><code>(torch.Size([400, 100]),
 torch.Size([400, 20]),
 torch.Size([400]),
 torch.Size([400]))</code></pre>
</div>
</div>
<div id="36dbbe52-f7c4-4dff-9361-6757dffbd917" class="cell" data-execution_count="7">
<div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## h: B x hidden_sz</span></span>
<span id="cb8-2"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## c: B x hidden_sz</span></span>
<span id="cb8-3">pytorch_h, pytorch_c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pytorch_cell(X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], (h_0[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], c_0[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]))</span></code></pre></div>
</div>
<div id="9f7dd706-59f6-4162-b00c-8c3610c573d3" class="cell" data-execution_count="8">
<div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1">cell <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> LSTMCellNew(input_sz, hidden_sz)</span>
<span id="cb9-2"></span>
<span id="cb9-3"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## To make sure pytorch and our implementation both</span></span>
<span id="cb9-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## have the same weights so we can compare them</span></span>
<span id="cb9-5">cell.ih.weight.data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pytorch_cell.weight_ih.data</span>
<span id="cb9-6">cell.hh.weight.data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pytorch_cell.weight_hh.data</span>
<span id="cb9-7">cell.ih.bias.data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pytorch_cell.bias_ih.data</span>
<span id="cb9-8">cell.hh.bias.data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pytorch_cell.bias_hh.data</span></code></pre></div>
</div>
<div id="1a03238d-994b-427e-92bd-7fb1abff8b62" class="cell" data-execution_count="9">
<div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1">h_t, c_t <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> cell(X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], h_0[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], c_0[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span></code></pre></div>
</div>
<div id="0bc1356c-0d3e-4e79-a174-0b717407882c" class="cell" data-execution_count="10">
<div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(</span>
<span id="cb11-2">    np.linalg.norm(pytorch_h.detach().numpy() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> h_t.detach().numpy()),</span>
<span id="cb11-3">    np.linalg.norm(pytorch_c.detach().numpy() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> c_t.detach().numpy()),</span>
<span id="cb11-4">)</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>0.0 0.0</code></pre>
</div>
</div>
</section>
<section id="from-cell-to-sequence-the-full-lstm" class="level3">
<h3 class="anchored" data-anchor-id="from-cell-to-sequence-the-full-lstm">From Cell to Sequence: The Full <code>LSTM</code></h3>
<p>With <code>LSTMCell</code> verified, let’s build the full <code>LSTM</code> module that handles entire sequences and optionally stacks multiple layers.</p>
<p>There are several important design decisions in a production LSTM implementation:</p>
<p><strong>Memory layout: sequence-first (<code>T × B × D</code>).</strong> We use the sequence length as the first dimension instead of batch-first. Why? We iterate over time steps in the inner loop, and we want each <code>x[t]</code> to be a contiguous slice of memory. If batch were first, each time step’s data would be non-contiguous, requiring a copy on every iteration.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
The Contiguity Trap
</div>
</div>
<div class="callout-body-container callout-body">
<p>If you pass batch-first tensors (<code>B × T × D</code>) to an LSTM that expects sequence-first, it will still “work” — but each time step access triggers an implicit copy because the memory isn’t contiguous along the time dimension. This can silently slow down training. PyTorch’s <code>nn.LSTM</code> has a <code>batch_first</code> flag that handles the transpose for you, but internally it still processes sequence-first.</p>
</div>
</div>
<p><strong>Truncated Backpropagation Through Time (TBPTT).</strong> Since weights are shared across all time steps within a layer, backpropagating through very long sequences causes severe vanishing/exploding gradients <em>and</em> extreme memory usage (all intermediate activations must be stored). The standard solution: <strong>detach</strong> the hidden and cell states from the computation graph after each batch. Gradients can flow within a batch’s time steps but not across batch boundaries.</p>
<p><strong>Multi-layer stacking.</strong> We can stack LSTMs by feeding the hidden state output of layer <img src="https://latex.codecogs.com/png.latex?l"> as the input to layer <img src="https://latex.codecogs.com/png.latex?l+1">. Each layer has its own <code>LSTMCell</code> with independent weights. The first layer’s cell takes input of size <code>input_sz</code>; all subsequent layers take input of size <code>hidden_sz</code>. This increases model capacity — deeper layers can learn more abstract representations.</p>
<p><strong>Layer iteration order.</strong> With multiple layers, there are two valid iteration orders: (1) iterate all time steps for layer 0, then all time steps for layer 1, etc., or (2) at each time step, iterate through all layers before moving to the next time step. Our implementation uses option (1), which is simpler and matches PyTorch’s behavior.</p>
<p><strong>Handling variable-length sequences.</strong> Not all sequences have the same length. Two approaches:</p>
<ol type="1">
<li><strong>Padding</strong>: pad shorter sequences to the longest length with zeros (pre- or post-padding). Simple but wasteful — the model does unnecessary computation on padding tokens.</li>
<li><strong>Packed sequences</strong>: combine all sequences together with index metadata marking boundaries. More efficient but more complex to implement. PyTorch provides <code>pack_padded_sequence</code> and <code>pad_packed_sequence</code> utilities for this.</li>
</ol>
<div id="d7691b4e-235e-434c-b22f-a5130c6ad864" class="cell" data-execution_count="11">
<div class="sourceCode cell-code" id="cb13" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb13-1"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> LSTMNew(nn.Module):</span>
<span id="cb13-2">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, input_sz, hidden_sz, num_layers<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>):</span>
<span id="cb13-3">        <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">super</span>().<span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>()</span>
<span id="cb13-4">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.num_layers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> num_layers</span>
<span id="cb13-5">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.hidden_sz <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> hidden_sz</span>
<span id="cb13-6">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.cells <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.ModuleList(</span>
<span id="cb13-7">            [</span>
<span id="cb13-8">                LSTMCellNew(input_sz, hidden_sz)</span>
<span id="cb13-9">                <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb13-10">                <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">else</span> LSTMCellNew(hidden_sz, hidden_sz)</span>
<span id="cb13-11">                <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.num_layers)</span>
<span id="cb13-12">            ]</span>
<span id="cb13-13">        )</span>
<span id="cb13-14"></span>
<span id="cb13-15">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> forward(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, x, h_t, c_t):</span>
<span id="cb13-16">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## x  :      T     x B x hidden_sz</span></span>
<span id="cb13-17">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## h_t: num_layers x B x hidden_sz</span></span>
<span id="cb13-18">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## c_t: num_layers x B x hidden_sz</span></span>
<span id="cb13-19">        T, B, _ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> x.shape</span>
<span id="cb13-20">        H <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> torch.zeros(T, B, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.hidden_sz)</span>
<span id="cb13-21">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i, cell <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.cells):</span>
<span id="cb13-22">            h, c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> h_t[i], c_t[i]</span>
<span id="cb13-23">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>:</span>
<span id="cb13-24">                x <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> H</span>
<span id="cb13-25">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> t <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(T):</span>
<span id="cb13-26">                h, c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> cell(x[t], h, c)</span>
<span id="cb13-27">                H[t] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> h</span>
<span id="cb13-28">            <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## last hidden state for each layer</span></span>
<span id="cb13-29">            h_t[i], c_t[i] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> h, c</span>
<span id="cb13-30">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">## Truncated BPTT</span></span>
<span id="cb13-31">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> H, (h_t.detach(), c_t.detach())</span></code></pre></div>
</div>
<div id="d2c608d9-11f0-49a6-bcbf-e8a7387cb205" class="cell" data-execution_count="12">
<div class="sourceCode cell-code" id="cb14" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb14-1">pytorch_lstm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> nn.LSTM(input_sz, hidden_sz, num_layers<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>num_layers)</span>
<span id="cb14-2">pytorch_H, (pytorch_h, pytorch_c) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pytorch_lstm(X, (h_0, c_0))</span></code></pre></div>
</div>
<div id="0966232f-132c-4e2f-b97c-8fa2fab3ba56" class="cell" data-execution_count="13">
<div class="sourceCode cell-code" id="cb15" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb15-1">lstm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> LSTMNew(input_sz, hidden_sz, num_layers<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>num_layers)</span>
<span id="cb15-2"></span>
<span id="cb15-3"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(num_layers):</span>
<span id="cb15-4">    lstm.cells[i].ih.weight.data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">getattr</span>(pytorch_lstm, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"weight_ih_l</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>i<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>).data</span>
<span id="cb15-5">    lstm.cells[i].hh.weight.data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">getattr</span>(pytorch_lstm, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"weight_hh_l</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>i<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>).data</span>
<span id="cb15-6">    lstm.cells[i].ih.bias.data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">getattr</span>(pytorch_lstm, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"bias_ih_l</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>i<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>).data</span>
<span id="cb15-7">    lstm.cells[i].hh.bias.data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">getattr</span>(pytorch_lstm, <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f"bias_hh_l</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>i<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">"</span>).data</span>
<span id="cb15-8"></span>
<span id="cb15-9">H, (h_t, c_t) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> lstm(X, h_0, c_0)</span></code></pre></div>
</div>
<div id="652fc3bb-7371-489b-85c4-8cd2e97fd60f" class="cell" data-execution_count="14">
<div class="sourceCode cell-code" id="cb16" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb16-1"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(</span>
<span id="cb16-2">    np.linalg.norm(pytorch_H.detach().numpy() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> H.detach().numpy()),</span>
<span id="cb16-3">    np.linalg.norm(pytorch_h.detach().numpy() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> h_t.detach().numpy()),</span>
<span id="cb16-4">    np.linalg.norm(pytorch_c.detach().numpy() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> c_t.detach().numpy()),</span>
<span id="cb16-5">)</span></code></pre></div>
<div class="cell-output cell-output-stdout">
<pre><code>0.0 0.0 0.0</code></pre>
</div>
</div>
</section>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion">Conclusion</h2>
<p>LSTMs were the dominant architecture for sequence modeling in NLP for years — powering machine translation, text classification, language modeling, and speech recognition before Transformers took over. In this post, we implemented both <code>LSTMCell</code> and a multi-layer <code>LSTM</code> from scratch, verified them against PyTorch’s official implementation, and discussed the performance decisions that go into a production implementation.</p>
<section id="key-takeaways" class="level3">
<h3 class="anchored" data-anchor-id="key-takeaways">Key Takeaways</h3>
<ol type="1">
<li><p><strong>LSTMs solve vanishing gradients through additive cell state updates.</strong> The forget gate can stay close to 1, allowing gradients to flow through many time steps without exponential decay. This is fundamentally different from vanilla RNNs, where the hidden state is completely overwritten at each step.</p></li>
<li><p><strong>Four gates, one matrix multiplication.</strong> The input, forget, cell, and output gates are computed together in a single fused operation, then split — a practical optimization that significantly improves throughput by better utilizing hardware parallelism.</p></li>
<li><p><strong>Sequential processing is the fundamental bottleneck.</strong> The output at time <img src="https://latex.codecogs.com/png.latex?t"> depends on the hidden state from <img src="https://latex.codecogs.com/png.latex?t-1">, making parallelization across time steps impossible. This is the limitation that motivated the Transformer’s self-attention mechanism.</p></li>
<li><p><strong>Truncated BPTT is essential for long sequences.</strong> Detaching hidden states between batches prevents gradient computation from spanning the entire sequence, reducing both memory usage and gradient instability.</p></li>
<li><p><strong>Memory layout matters.</strong> Using sequence-first tensors (<code>T × B × D</code>) ensures contiguous memory access at each time step, avoiding hidden performance penalties from implicit copies.</p></li>
</ol>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
LSTMs in the Transformer Era
</div>
</div>
<div class="callout-body-container callout-body">
<p>While Transformers have largely replaced LSTMs for most NLP tasks, understanding LSTMs remains valuable. They’re still used in streaming/online settings where you process one token at a time, in resource-constrained environments where the <img src="https://latex.codecogs.com/png.latex?O(n%5E2)"> attention cost is prohibitive, and as components in hybrid architectures. More importantly, the concepts — gating, cell states, truncated BPTT — appear in many modern architectures in different forms.</p>
</div>
</div>
</section>
</section>
<section id="references-resources" class="level2">
<h2 class="anchored" data-anchor-id="references-resources">References &amp; Resources</h2>
<ul>
<li><strong>Hochreiter, S. &amp; Schmidhuber, J.</strong> (1997). <a href="https://www.bioinf.jku.at/publications/older/2604.pdf">Long Short-Term Memory</a>. <em>Neural Computation</em>, 9(8), 1735–1780. The original LSTM paper.</li>
<li><strong>Olah, C.</strong> (2015). <a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">Understanding LSTM Networks</a>. The classic visual explainer — the single best resource for building intuition about LSTM gates.</li>
<li><strong>Greff, K. et al.</strong> (2017). <a href="https://arxiv.org/abs/1503.04069">LSTM: A Search Space Odyssey</a>. <em>IEEE TNNLS</em>. Comprehensive study of LSTM variants — concludes that the forget gate and output activation are the most critical components.</li>
<li><strong>Vaswani, A. et al.</strong> (2017). <a href="https://arxiv.org/abs/1706.03762">Attention Is All You Need</a>. <em>NeurIPS 2017</em>. The Transformer architecture that largely replaced LSTMs for NLP tasks.</li>
<li><strong>Merity, S. et al.</strong> (2018). <a href="https://arxiv.org/abs/1708.02182">Regularizing and Optimizing LSTM Language Models</a>. <em>ICLR 2018</em>. AWD-LSTM — pushed LSTM language models to their limits with careful regularization.</li>
<li><strong>PyTorch Documentation</strong>. <a href="https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html">nn.LSTM</a> and <a href="https://pytorch.org/docs/stable/generated/torch.nn.LSTMCell.html">nn.LSTMCell</a>. Official reference for the implementation we verified against.</li>
</ul>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>NLP</category>
  <guid>https://imaddabbura.github.io/posts/nlp/LSTM-Annotated-Implementation.html</guid>
  <pubDate>Tue, 10 Mar 2020 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/nlp/images/lstm-cell.jpeg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>Anomaly Detection</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/anomaly-detection/Anomaly-Detection.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge growing">growing</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p>Anomaly Detection is the identification of examples or events that don’t confront to an expected pattern or the majority of examples. Roughly speaking, it’s the process of identifying an example that is not <em>normal (outlier)</em> given the distribution of the data. <strong>Outlier</strong> is an example that deviates so much from the other examples that arouse suspicions that it was generated by different data generating process. Mainly, such outliers would have a very low probability (on the very end of both left and right tails of the probability density function) that they belong to the same data generating process.</p>
<p>The algorithm works as follows: 1. Fit a <em>Gaussian Probability Density Function (PDF)</em> for each feature in the training dataset. 1. Calculate the mean and the variance of each feature: <img src="https://latex.codecogs.com/png.latex?%5Cmu_j%20=%20%5Cfrac%7B1%7D%7Bm%7D%5Csum_%7Bi%20=%201%7D%5Emx_j%5Ei%5C%5C%7B%7D"> <img src="https://latex.codecogs.com/png.latex?%5Csigma%5E2_j%20=%20%5Cfrac%7B1%7D%7Bm%7D%5Csum_%7Bi%20=%201%7D%5Em(x_j%5Ei%20-%20%5Cmu_j)%5E2%5C%5C%7B%7D"> Where <img src="https://latex.codecogs.com/png.latex?%5Cmu"> is the mean and <img src="https://latex.codecogs.com/png.latex?%5Csigma%5E2"> is the variance that controls the shape of the density function. 2. Compute the density function for each feature using the following formula:<br>
<img src="https://latex.codecogs.com/png.latex?p(x;%20%5Cmu,%20%5Csigma%5E2)%20=%20%5Cfrac%7B1%7D%7B%5Csqrt%7B2%5Cpi%7D%5Csigma%7De%5E%7B-%5Cfrac%7B(x%20-%20%5Cmu)%5E2%7D%7B2%5Csigma%5E2%7D%7D%5C%5C%7B%7D"> Since the mean and the variance are sensitive to outliers, we use training dataset that has only normal examples to fit the model and calculate both the mean vector and the covariance matrix. 2. Compute the gaussian density by taking the product of all features’ density functions. 3. If <img src="https://latex.codecogs.com/png.latex?p(x)%20%3C%20%5Cepsilon"> then anomaly; otherwise, normal. Epsilon controls how sensitive the detection algorithm is. If <img src="https://latex.codecogs.com/png.latex?%5Cepsilon"> is large <img src="https://latex.codecogs.com/png.latex?%5Crightarrow"> flag a lot of the examples as anomalous and that would increase the <em>False Positives</em>. However, If <img src="https://latex.codecogs.com/png.latex?%5Cepsilon"> is small <img src="https://latex.codecogs.com/png.latex?%5Crightarrow"> very small portion of the examples will be flagged as anomalous and that would increase the <em>False Negatives</em>. 4. Use <em>Cross Validation</em> for tuning the hyper-parameter <img src="https://latex.codecogs.com/png.latex?%5Cepsilon"> that yields the best performance metrics value. F1 score is commonly used: <img src="https://latex.codecogs.com/png.latex?F_1%20=%202%20%5Cfrac%7Bprecision%20*%20recall%7D%7Bprecision%20+%20recall%7D%5C%5C%7B%7D"> Where:<img src="https://latex.codecogs.com/png.latex?precision%20=%20%5Cfrac%7Btp%7D%7Btp%20+%20fp%7D%5C%5C%7B%7D"> <img src="https://latex.codecogs.com/png.latex?recall%20=%20%5Cfrac%7Btp%7D%7Btp%20+%20fn%7D%5C%5C%7B%7D"> <em>tp: True Positive, fp: False Positive, fn: False Negative</em>.</p>
<p>We have two kinds of anomaly detection algorithms: 1. <strong>Univariate Gaussian Density Function</strong> <img src="https://latex.codecogs.com/png.latex?p(x)%20=%20%5Cprod_%7Bj%20=%201%7D%5E%7Bn%7Dp(x_j;%20%5Cmu_j,%20%5Csigma_j%5E2)%5C%5C%7B%7D"> <img src="https://latex.codecogs.com/png.latex?%20=%20p(x_1;%20%5Cmu_1,%20%5Csigma_1%5E2)*p(x_2;%20%5Cmu_2,%20%5Csigma_2%5E2)*%20...%20*%20p(x_n;%20%5Cmu_n,%20%5Csigma_j%5En)%5C%5C%7B%7D"> * It assumes that all features are independent. Therefore, the covariance between all pairs of features is zero. * It’s computationally faster and more efficient. * Use it if we have very large number of features. * Make sure to add features manually that captures unusual values for combination of features; such as <img src="https://latex.codecogs.com/png.latex?x_3%20=%20%5Cfrac%20%7Bx_2%7D%7Bx_1%7D">. Otherwise, the algorithm may fail to detect anomalies that takes values that are considered normal when looked at each feature separately but are unusual when looking at values of all features together such as having high value for feature 2 compared to low value for feature 1.</p>
<ol start="2" type="1">
<li><strong>Multivariate Gaussian Density Function</strong> <img src="https://latex.codecogs.com/png.latex?p(x)%20=%20%5Cprod_%7Bj%20=%201%7D%5E%7Bn%7Dp(x_j;%20%5Cmu_j,%20%5Csigma_j%5E2)%5C%5C%7B%7D"> <img src="https://latex.codecogs.com/png.latex?p(x;%20%5Cmu,%20%5Csigma%5E2)%20=%20%5Cfrac%7B1%7D%7B(2%5Cpi)%5E%7B(n%20/%202)%7D(%5Cdet%5Csum)%5E%7B1%20/%202%7D%7De%5E%7B%5Cfrac%7B-1%7D%7B2%7D(x%20-%20%5Cmu)%5ET%5Csum%5E%7B-1%7D(x%20-%20%5Cmu)%7D%5C%5C%7B%7D"> Where <img src="https://latex.codecogs.com/png.latex?%5Csum"> is n x n covariance matrix: <img src="https://latex.codecogs.com/png.latex?%5Csum%20=%20%5Cbegin%7Bbmatrix%7D%0A%5Csigma_1%5E2&amp;%5Csigma_%7B12%7D&amp;%5Ccdots&amp;%5Csigma_%7B1n%7D%5C%5C%0A%5Csigma_%7B21%7D&amp;%5Csigma_2%5E2&amp;%5Ccdots&amp;0%5C%5C%0A%5Cvdots%20&amp;%20%5Cvdots%20&amp;%20%5Cddots%20&amp;%20%5Cvdots%20%5C%5C%0A%5Csigma_%7Bn1%7D%20&amp;%200%20&amp;%200%20&amp;%20%5Csigma_n%5E2%0A%5Cend%7Bbmatrix%7D"> Where <img src="https://latex.codecogs.com/png.latex?%5Csigma_%7B12%7D%20=%20%5Csigma_%7B21%7D"> is the covariance between features 1&amp;2. Therefore, the covariance matrix is <em>symmetric positive (semi) definite</em>.
<ul>
<li>Computationally expensive</li>
<li>Use it when number of examples <img src="https://latex.codecogs.com/png.latex?%5Cgeq"> 10 times number of features, i.e.&nbsp;<img src="https://latex.codecogs.com/png.latex?m%20%5Cgeq%2010n"></li>
<li>If some features are linearly dependent or number of examples is less than number of features <img src="https://latex.codecogs.com/png.latex?%5Crightarrow"> covariance matrix won’t be invertible</li>
<li>No need to add more features to capture unusual values of combination of features because it captures that through covariances of all pairs of features</li>
<li>Univariate density function can be derived from Multivariate density function where covariance matrix would be a diagonal matrix. Therefore, <img src="https://latex.codecogs.com/png.latex?%5Csigma_%7Bij%7D%20=%200"> for all <img src="https://latex.codecogs.com/png.latex?i%20%5Cneq%20j"></li>
</ul></li>
</ol>
<p>There are some assumptions made implicitly here: - For each feature, <img src="https://latex.codecogs.com/png.latex?X_i">’s are IID (independently and identically distributed). - Using Central Theorem (CLT): the distribution of sum of iid random variable are approximately normal. Therefore, this would allow us to fit normal distribution that’s parameterized by <img src="https://latex.codecogs.com/png.latex?%5Cmu"> and <img src="https://latex.codecogs.com/png.latex?%5Csigma%5E2">. - <img src="https://latex.codecogs.com/png.latex?%5Cmu"> and <img src="https://latex.codecogs.com/png.latex?%5Csum"> will be estimated using maximum-likelihood estimation method.</p>
<p>When fitting multivariate probability distribution using the above assumptions, we’ll use that pdf to estimate the probability that each example from the validation/test set was generated by this pdf. If the probability is smaller that <img src="https://latex.codecogs.com/png.latex?%5Cepsilon">, then we believe that such example was generated by different mutlivariate PDF and, therefor, classified as <em>anomaly</em> (outlier).</p>
<p>In this exercise, we’ll implement an anomaly detection algorithm to detect anomalous behavior in server computers. The features measure the throughput (mb/s) and latency (ms) of response of each server. While servers were operating, <img src="https://latex.codecogs.com/png.latex?m%20=%20307"> examples of how they were behaving were captured. We suspect that the vast majority of them are normal (non-anomalous) examples of the servers operating normally.</p>
<p>Let’s first load and plot the data:</p>
<div id="cell-4" class="cell" data-code_folding="[0]" data-execution_count="1">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> numpy.linalg <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pinv, det</span>
<span id="cb1-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pandas <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> pd</span>
<span id="cb1-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb1-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> scipy.io <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> loadmat, whosmat</span>
<span id="cb1-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> scipy.optimize <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> opt</span>
<span id="cb1-7"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> seaborn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> sns</span>
<span id="cb1-8"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> warnings <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> filterwarnings</span>
<span id="cb1-9"></span>
<span id="cb1-10"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>matplotlib inline</span>
<span id="cb1-11">sns.set_context(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'notebook'</span>)</span>
<span id="cb1-12">plt.style.use(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'fivethirtyeight'</span>)</span>
<span id="cb1-13">filterwarnings(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'ignore'</span>)</span></code></pre></div>
</details>
</div>
</section>
<section id="functions" class="level2">
<h2 class="anchored" data-anchor-id="functions">Functions</h2>
<div id="cell-6" class="cell" data-code_folding="[1,46]" data-execution_count="2">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Compute guassian distribution fn</span></span>
<span id="cb2-2"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> gaussian_estimate(X_train, X_val, gaussian_type<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'univariate'</span>):</span>
<span id="cb2-3">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">'''</span></span>
<span id="cb2-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    parameters</span></span>
<span id="cb2-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    ----------</span></span>
<span id="cb2-6"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    X_train: array-like</span></span>
<span id="cb2-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        training features matrix m x n that has only normal examples.</span></span>
<span id="cb2-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    X_val: array-like</span></span>
<span id="cb2-9"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        cross validation features matrix that has anomalous and normal</span></span>
<span id="cb2-10"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        examples.</span></span>
<span id="cb2-11"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    gussian_type: str</span></span>
<span id="cb2-12"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        univariate or multivariate.</span></span>
<span id="cb2-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    </span></span>
<span id="cb2-14"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Returns</span></span>
<span id="cb2-15"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    -------</span></span>
<span id="cb2-16"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    pdf: array-like</span></span>
<span id="cb2-17"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        multivariate pdf vector of n x 1</span></span>
<span id="cb2-18"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    '''</span></span>
<span id="cb2-19">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># number of training examples and features</span></span>
<span id="cb2-20">    m, n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X_train.shape</span>
<span id="cb2-21">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># number of cv examples</span></span>
<span id="cb2-22">    mval <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X_val.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>]</span>
<span id="cb2-23"></span>
<span id="cb2-24">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># compute mean and covariance matrix</span></span>
<span id="cb2-25">    mu <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X_train.mean(axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb2-26">    cov <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (m)) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> (X_train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> mu).T.dot(X_train <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> mu)</span>
<span id="cb2-27"></span>
<span id="cb2-28">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># convert the covariance matrix to diagonal if it's a univariate</span></span>
<span id="cb2-29">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> gaussian_type <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'univariate'</span>:</span>
<span id="cb2-30">        z <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros_like(cov)</span>
<span id="cb2-31">        np.fill_diagonal(z, np.diagonal(cov))</span>
<span id="cb2-32">        cov <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> z</span>
<span id="cb2-33"></span>
<span id="cb2-34">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># compute determinant and inverse of covariance matrix</span></span>
<span id="cb2-35">    cov_det <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> det(cov)</span>
<span id="cb2-36">    cov_inv <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pinv(cov)</span>
<span id="cb2-37"></span>
<span id="cb2-38">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># compute pdf vector</span></span>
<span id="cb2-39">    pdf <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ((<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> np.pi) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> (<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span>n <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> (cov_det <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">**</span> (<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>)) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*\</span></span>
<span id="cb2-40">        np.exp(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> np.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>(np.multiply((X_val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> mu).dot(cov_inv),</span>
<span id="cb2-41">                                         (X_val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> mu)), axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb2-42"></span>
<span id="cb2-43">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> pdf</span>
<span id="cb2-44"></span>
<span id="cb2-45"></span>
<span id="cb2-46"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Hyperparameter tuning of epsilon using cv dataset</span></span>
<span id="cb2-47"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> select_threshold(y_val, p_val):</span>
<span id="cb2-48">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">'''</span></span>
<span id="cb2-49"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    parameters</span></span>
<span id="cb2-50"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    ----------</span></span>
<span id="cb2-51"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    y_val: array-like</span></span>
<span id="cb2-52"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        label whether a validation example is normal (0) or anomaly (1).</span></span>
<span id="cb2-53"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    p_val: array-like</span></span>
<span id="cb2-54"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        pdf for validated examples.</span></span>
<span id="cb2-55"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    </span></span>
<span id="cb2-56"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    Returns</span></span>
<span id="cb2-57"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    -------</span></span>
<span id="cb2-58"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    eplsion : float</span></span>
<span id="cb2-59"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        best epsilon value tuned on validation data.</span></span>
<span id="cb2-60"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    F1_score : float</span></span>
<span id="cb2-61"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">        F1 score using epsilon tuned on validation data.</span></span>
<span id="cb2-62"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">    '''</span></span>
<span id="cb2-63">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># initialize epsilon and F1 score values</span></span>
<span id="cb2-64">    best_epsilon <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb2-65">    best_F1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb2-66"></span>
<span id="cb2-67">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># compute stepsize for each iteration</span></span>
<span id="cb2-68">    epsilon_stepsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (p_val.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>() <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> p_val.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>()) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span></span>
<span id="cb2-69"></span>
<span id="cb2-70">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> epsilon <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> np.arange(p_val.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">min</span>(), p_val.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">max</span>(), epsilon_stepsize):</span>
<span id="cb2-71">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># get predictions vector</span></span>
<span id="cb2-72">        pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> ((p_val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> epsilon) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>).reshape(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb2-73"></span>
<span id="cb2-74">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># compute true positives, false positives, false negatives</span></span>
<span id="cb2-75">        tp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>((pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&amp;</span> (y_val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb2-76">        fp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>((pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&amp;</span> (y_val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>))</span>
<span id="cb2-77">        fn <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>((pred <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&amp;</span> (y_val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb2-78"></span>
<span id="cb2-79">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># compute precision and recall</span></span>
<span id="cb2-80">        precision_ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> tp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (tp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> fp)</span>
<span id="cb2-81">        recall_ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> tp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (tp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> fn)</span>
<span id="cb2-82"></span>
<span id="cb2-83">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># compute F1 score</span></span>
<span id="cb2-84">        F1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> ((precision_ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> recall_) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> (precision_ <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> recall_))</span>
<span id="cb2-85">        <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># if F1 score &gt; best_F1, set best_F1 = F1</span></span>
<span id="cb2-86">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> F1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&gt;</span> best_F1:</span>
<span id="cb2-87">            best_F1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> F1</span>
<span id="cb2-88">            best_epsilon <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> epsilon</span>
<span id="cb2-89"></span>
<span id="cb2-90">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> best_epsilon, best_F1</span></code></pre></div>
</details>
</div>
<div id="cell-7" class="cell" data-code_folding="[0]" data-tags="[]" data-execution_count="6">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Load data</span></span>
<span id="cb3-2">data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loadmat(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'../data/servers_anomaly_detection.mat'</span>)</span>
<span id="cb3-3"></span>
<span id="cb3-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Training data</span></span>
<span id="cb3-5">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> data[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'X'</span>]</span>
<span id="cb3-6"></span>
<span id="cb3-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Cross validation data</span></span>
<span id="cb3-8">X_val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> data[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Xval'</span>]</span>
<span id="cb3-9">y_val <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> data[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'yval'</span>]</span>
<span id="cb3-10"></span>
<span id="cb3-11"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot data</span></span>
<span id="cb3-12">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>))</span>
<span id="cb3-13">plt.scatter(X[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], s <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50</span>, c <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'blue'</span>)</span>
<span id="cb3-14">plt.axis([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">30</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">30</span>])</span>
<span id="cb3-15">plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Latency (ms)'</span>)</span>
<span id="cb3-16">plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Throughput (mb/s)'</span>)</span>
<span id="cb3-17">plt.gca().set_aspect(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'equal'</span>)</span>
<span id="cb3-18"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># plt.title('Scatter plot of the first dataset');</span></span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Anomaly-Detection_files/figure-html/cell-4-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-1"><img src="https://imaddabbura.github.io/posts/anomaly-detection/Anomaly-Detection_files/figure-html/cell-4-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<div id="cell-8" class="cell" data-execution_count="7">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># plt.subplots(1, 2, 1)</span></span>
<span id="cb4-2">sns.kdeplot(X[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb4-3">sns.kdeplot(X[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Anomaly-Detection_files/figure-html/cell-5-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-2"><img src="https://imaddabbura.github.io/posts/anomaly-detection/Anomaly-Detection_files/figure-html/cell-5-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>Now, we’ll first estimate the Gaussian distribution for both the training and cross validation sets. Note that we use training dataset that has ONLY normal examples when computing mean and covariance and then use cross validation that has both normal and anomalous examples to know the best epsilon.</p>
<div id="cell-10" class="cell" data-code_folding="[0]" data-execution_count="10">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Fit guassian distribution on both training and CV examples</span></span>
<span id="cb5-2">ptrain <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> gaussian_estimate(X, X)</span>
<span id="cb5-3">pval <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> gaussian_estimate(X, X_val, gaussian_type<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'multivariate'</span>)</span>
<span id="cb5-4"></span>
<span id="cb5-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Tune epsilon</span></span>
<span id="cb5-6">epsilon, F1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> select_threshold(y_val, pval)</span>
<span id="cb5-7"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'The best epsilon tuned using CV that yielded the best'</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span></span>
<span id="cb5-8">      <span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'F1-score </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>F1<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.3f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;"> is: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>epsilon<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">.'</span>)</span></code></pre></div>
</details>
<div class="cell-output cell-output-stdout">
<pre><code>The best epsilon tuned using CV that yielded the bestF1-score 0.875 is: 9.065769728392737e-05.</code></pre>
</div>
</div>
<p>We’ll use the value of epsilon that we tuned using CV to see what examples were anomalous based on our algorithm. Below is the scatter plot of the training data where red points are anomalous examples.</p>
<div id="cell-12" class="cell" data-code_folding="[0]" data-tags="[]" data-execution_count="11">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Get the index of the outlier</span></span>
<span id="cb7-2">outliers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.where(ptrain <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">&lt;</span> epsilon)</span>
<span id="cb7-3"></span>
<span id="cb7-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot data</span></span>
<span id="cb7-5">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>))</span>
<span id="cb7-6">plt.scatter(X[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50</span>, c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'blue'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Normal Examples'</span>)</span>
<span id="cb7-7">plt.scatter(X[outliers[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X[outliers[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">60</span>, c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'red'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Anomalous Examples'</span>)</span>
<span id="cb7-8">plt.axis([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">30</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">30</span>])</span>
<span id="cb7-9">plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Latency (ms)'</span>)</span>
<span id="cb7-10">plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Throughput (mb/s)'</span>)</span>
<span id="cb7-11">plt.legend(loc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'upper right'</span>)</span>
<span id="cb7-12">plt.title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Scatter plot of the training dataset'</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span></span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Anomaly-Detection_files/figure-html/cell-7-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-3"><img src="https://imaddabbura.github.io/posts/anomaly-detection/Anomaly-Detection_files/figure-html/cell-7-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>Finally, we’ll try to fit Gaussian distribution on training dataset that has 1000 examples and 11 features. Note that in both examples we used <em>Multivariate</em> not <em>Univariate</em> Gaussian distribution.</p>
<div id="cell-14" class="cell" data-code_folding="[]" data-execution_count="12">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Load data</span></span>
<span id="cb8-2">data <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> loadmat(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'../data/ex8data2.mat'</span>)</span>
<span id="cb8-3"></span>
<span id="cb8-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Training data</span></span>
<span id="cb8-5">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> data[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'X'</span>]</span>
<span id="cb8-6"></span>
<span id="cb8-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Cross validation data</span></span>
<span id="cb8-8">Xval <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> data[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Xval'</span>]</span>
<span id="cb8-9">yval <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> data[<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'yval'</span>]</span>
<span id="cb8-10"></span>
<span id="cb8-11"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Fit guassian distribution on both training and CV examples</span></span>
<span id="cb8-12">ptrain <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> gaussian_estimate(X, X, gaussian_type<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'multivariate'</span>)</span>
<span id="cb8-13">pval <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> gaussian_estimate(X, Xval, gaussian_type<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'multivariate'</span>)</span>
<span id="cb8-14"></span>
<span id="cb8-15"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Tune epsilon</span></span>
<span id="cb8-16">epsilon, F1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> select_threshold(yval, pval)</span>
<span id="cb8-17"><span class="bu" style="color: null;
background-color: null;
font-style: inherit;">print</span>(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'The best epsilon tuned using CV that yielded the best'</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">\</span></span>
<span id="cb8-18">      <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'F1-score </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{F1:.3f}</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;"> is: </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{epsilon}</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">.'</span>)</span></code></pre></div>
</details>
<div class="cell-output cell-output-error">
<pre><code>FileNotFoundError: [Errno 2] No such file or directory: '../data/ex8data2.mat'</code></pre>
</div>
</div>
<p>Using the best-epsilon value we got above, we can then classify any example as anomaly if <img src="https://latex.codecogs.com/png.latex?p(x)%20%3C%20%5Cepsilon">; otherwise, it’s normal.</p>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion">Conclusion</h2>
<ul>
<li>The implementation of the variance/covariance in the detection algorithms has <img src="https://latex.codecogs.com/png.latex?m"> in the denominator not <img src="https://latex.codecogs.com/png.latex?(m%20-%201)"> because with large datasets this doesn’t make a difference. However, the unbiased estimator of the variance should have <img src="https://latex.codecogs.com/png.latex?(m%20-%201)"> in the denominator not <img src="https://latex.codecogs.com/png.latex?m">.</li>
<li>Anomaly detection vs Supervised learning:
<ul>
<li>Use Anomaly Detection when you have large number of negative examples and very small number of positive examples. The reason is because the supervised learning algorithm wouldn’t be able to have enough examples to learn about the scene especially if the future anomalies are nothing like training anomalies</li>
<li>Use Supervised Learning algorithms such as logistic regression if you have enough positive examples that make the learning easy on the algorithm and probably it would outperform Anomaly Detection algorithms.<br>
</li>
</ul></li>
<li>Univariate PDF performs well most of the times compared to Multivariate PDF and scale really well.</li>
</ul>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Machine Learning</category>
  <guid>https://imaddabbura.github.io/posts/anomaly-detection/Anomaly-Detection.html</guid>
  <pubDate>Wed, 11 Sep 2019 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/anomaly-detection/feature.jpg" medium="image" type="image/jpeg"/>
</item>
<item>
  <title>Gradient Descent Algorithm and Its Variants</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/optimization/gradient-descent.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge evergreen">evergreen</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p><strong>Optimization</strong> refers to the task of minimizing/maximizing an objective function <img src="https://latex.codecogs.com/png.latex?f(x)"> parameterized by <img src="https://latex.codecogs.com/png.latex?x">. In machine/deep learning terminology, it’s the task of minimizing the cost/loss function <img src="https://latex.codecogs.com/png.latex?J(w)"> parameterized by the model’s parameters <img src="https://latex.codecogs.com/png.latex?w%20%5Cin%20%5Cmathbb%7BR%7D%5Ed">. Optimization algorithms (in case of minimization) have one of the following goals: - Find the global minimum of the objective function. This is feasible if the objective function is convex, i.e.&nbsp;any local minimum is a global minimum. - Find the lowest possible value of the objective function within its neighbor. That’s usually the case if the objective function is not convex as the case in most deep learning problems.</p>
<p>There are three kinds of optimization algorithms:</p>
<ul>
<li>Optimization algorithm that is not iterative and simply solves for one point.</li>
<li>Optimization algorithm that is iterative in nature and converges to acceptable solution regardless of the parameters initialization such as gradient descent applied to logistic regression.</li>
<li>Optimization algorithm that is iterative in nature and applied to a set of problems that have non-convex cost functions such as neural networks. Therefore, parameters’ initialization plays a critical role in speeding up convergence and achieving lower error rates.</li>
</ul>
<p><strong>Gradient Descent</strong> is the most common optimization algorithm in <em>machine learning</em> and <em>deep learning</em>. It is a first-order optimization algorithm. This means it only takes into account the first derivative when performing the updates on the parameters. On each iteration, we update the parameters in the opposite direction of the gradient of the objective function <img src="https://latex.codecogs.com/png.latex?J(w)"> w.r.t to the parameters where the gradient gives the direction of the steepest ascent. The size of the step we take on each iteration to reach the local minimum is determined by the learning rate <img src="https://latex.codecogs.com/png.latex?%5Calpha">. Therefore, we follow the direction of the slope downhill until we reach a local minimum.</p>
<p>In this notebook, we’ll cover gradient descent algorithm and its variants: <em>Batch Gradient Descent, Mini-batch Gradient Descent, and Stochastic Gradient Descent</em>.</p>
<p>Let’s first see how gradient descent and its associated steps works on logistic regression before going into the details of its variants. For the sake of simplicity, let’s assume that the logistic regression model has only two parameters: weight <img src="https://latex.codecogs.com/png.latex?w"> and bias <img src="https://latex.codecogs.com/png.latex?b">.</p>
<ol type="1">
<li>Initialize weight <img src="https://latex.codecogs.com/png.latex?w"> and bias <img src="https://latex.codecogs.com/png.latex?b"> to any random numbers.</li>
<li>Pick a value for the learning rate <img src="https://latex.codecogs.com/png.latex?%5Calpha">. The learning rate determines how big the step would be on each iteration.</li>
</ol>
<ul>
<li>If <img src="https://latex.codecogs.com/png.latex?%5Calpha"> is very small, it would take long time to converge and become computationally expensive.</li>
<li>IF <img src="https://latex.codecogs.com/png.latex?%5Calpha"> is large, it may fail to converge and overshoot the minimum.</li>
</ul>
<p>Therefore, plot the cost function against different values of <img src="https://latex.codecogs.com/png.latex?%5Calpha"> and pick the value of <img src="https://latex.codecogs.com/png.latex?%5Calpha"> that is right before the first value that didn’t converge so that we would have a very fast learning algorithm that converges (see figure 1).</p>
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<p><a href="images/learning_rate.PNG" class="lightbox" data-gallery="quarto-lightbox-gallery-1" data-glightbox="description: .lightbox-desc-1" title="Figure 1: Gradient descent with different learning rates Source"><img src="https://imaddabbura.github.io/posts/optimization/images/learning_rate.PNG" class="quarto-figure quarto-figure-left figure-img" width="600" height="400" alt="Figure 1: Gradient descent with different learning rates Source"></a></p>
</figure>
</div>
<figcaption><strong>Figure 1</strong>: Gradient descent with different learning rates <a href="http://cs231n.github.io/neural-networks-3/">Source</a></figcaption>
</figure>
</div>
<ul>
<li>The most commonly used rates are : <em>0.001, 0.003, 0.01, 0.03, 0.1, 0.3</em>.</li>
</ul>
<ol start="3" type="1">
<li>Make sure to scale the data if it’s on very different scales. If we don’t scale the data, the level curves (contours) would be narrower and taller which means it would take longer time to converge (see figure 2).</li>
</ol>
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<p><a href="images/normalized-vs-unnormalized.PNG" class="lightbox" data-gallery="quarto-lightbox-gallery-2" data-glightbox="description: .lightbox-desc-2" title="Figure 2: Gradient descent: normalized versus unnormalized level curves"><img src="https://imaddabbura.github.io/posts/optimization/images/normalized-vs-unnormalized.PNG" class="quarto-figure quarto-figure-left figure-img" width="800" height="300" alt="Figure 2: Gradient descent: normalized versus unnormalized level curves"></a></p>
</figure>
</div>
<figcaption><strong>Figure 2</strong>: Gradient descent: normalized versus unnormalized level curves</figcaption>
</figure>
</div>
<p>Scale the data to have <img src="https://latex.codecogs.com/png.latex?%5Cmu%20=%200"> and <img src="https://latex.codecogs.com/png.latex?%5Csigma%20=%201">. Below is the formula for scaling each example: <img src="https://latex.codecogs.com/png.latex?%5C%5C%7B%7D%5Cfrac%7Bx_i%20-%20%5Cmu%7D%7B%5Csigma%7D%5Ctag%7B1%7D%5C%5C%7B%7D%20"> 4. On each iteration, take the partial derivative of the cost function <img src="https://latex.codecogs.com/png.latex?J(w)"> w.r.t each parameter (gradient): <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Cpartial%7D%7B%5Cpartial%20w%7DJ(w)%20=%20%5Cnabla_w%20J%5Ctag%7B2%7D%5C%5C%7B%7D"> <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Cpartial%7D%7B%5Cpartial%20b%7DJ(w)%20=%20%5Cnabla_b%20J%5Ctag%7B3%7D%5C%5C%7B%7D"> The update equations are: <img src="https://latex.codecogs.com/png.latex?w%20=%20w%20-%20%5Calpha%20%5Cnabla_w%20J%5Ctag%7B4%7D%5C%5C%7B%7D"> <img src="https://latex.codecogs.com/png.latex?b%20=%20b%20-%20%5Calpha%20%5Cnabla_b%20J%5Ctag%7B5%7D%5C%5C%7B%7D"> * For the sake of illustration, assume we don’t have bias. If the slope of the current values of <img src="https://latex.codecogs.com/png.latex?w%20%3E%200">, this means that we are to the right of optimal <img src="https://latex.codecogs.com/png.latex?w%5E*">. Therefore, the update will be negative, and will start getting close to the optimal values of <img src="https://latex.codecogs.com/png.latex?w%5E*">. However, if it’s negative, the update will be positive and will increase the current values of <img src="https://latex.codecogs.com/png.latex?w"> to converge to the optimal values of <img src="https://latex.codecogs.com/png.latex?w%5E*"> (see figure 3):</p>
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<p><a href="images/gradients.PNG" class="lightbox" data-gallery="quarto-lightbox-gallery-3" data-glightbox="description: .lightbox-desc-3" title="Figure 3: Gradient descent. An illustration of how gradient descent algorithm uses the first derivative of the loss function to follow downhill it’s minimum."><img src="https://imaddabbura.github.io/posts/optimization/images/gradients.PNG" class="quarto-figure quarto-figure-left figure-img" width="600" height="400" alt="Figure 3: Gradient descent. An illustration of how gradient descent algorithm uses the first derivative of the loss function to follow downhill it’s minimum."></a></p>
</figure>
</div>
<figcaption><strong>Figure 3</strong>: Gradient descent. An illustration of how gradient descent algorithm uses the first derivative of the loss function to follow downhill it’s minimum.</figcaption>
</figure>
</div>
<ul>
<li>Continue the process until the cost function converges. That is, until the error curve becomes flat and doesn’t change.</li>
<li>In addition, on each iteration, the step would be in the direction that gives the maximum change since it’s perpendicular to level curves at each step.</li>
</ul>
<p>Now let’s discuss the three variants of gradient descent algorithm. The main difference between them is the amount of data we use when computing the gradients for each learning step. The trade-off between them is the accuracy of the gradient versus the time complexity to perform each parameter’s update (learning step).</p>
</section>
<section id="batch-gradient-descent" class="level2">
<h2 class="anchored" data-anchor-id="batch-gradient-descent">Batch Gradient Descent</h2>
<p>Batch Gradient Descent is when we sum up over all examples on each iteration when performing the updates to the parameters. Therefore, for each update, we have to sum over all examples: <img src="https://latex.codecogs.com/png.latex?w%20=%20w%20-%20%5Calpha%20%5Cnabla_w%20J(w)%5Ctag%7B6%7D"></p>
<div class="sourceCode" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(num_epochs):</span>
<span id="cb1-2">grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> compute_gradient(data, params)</span>
<span id="cb1-3">params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> learning_rate <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> grad</span></code></pre></div>
<p>The main advantages:</p>
<ul>
<li>We can use fixed learning rate during training without worrying about learning rate decay.</li>
<li>It has straight trajectory towards the minimum and it is guaranteed to converge in theory to the global minimum if the loss function is convex and to a local minimum if the loss function is not convex.</li>
<li>It has unbiased estimate of gradients. The more the examples, the lower the standard error.</li>
</ul>
<p>The main disadvantages:</p>
<ul>
<li>Even though we can use vectorized implementation, it may still be slow to go over all examples especially when we have large datasets.</li>
<li>Each step of learning happens after going over all examples where some examples may be redundant and don’t contribute much to the update.</li>
</ul>
</section>
<section id="mini-batch-gradient-descent" class="level2">
<h2 class="anchored" data-anchor-id="mini-batch-gradient-descent">Mini-Batch Gradient Descent</h2>
<p>Instead of going over all examples, Mini-batch Gradient Descent sums up over lower number of examples based on batch size. Therefore, learning happens on each mini-batch of <img src="https://latex.codecogs.com/png.latex?b"> examples:</p>
<p><img src="https://latex.codecogs.com/png.latex?w%20=%20w%20-%20%5Calpha%20%5Cnabla_w%20J(x%5E%7B%5C%7Bi:i%20+%20b%5C%7D%7D,%20y%5E%7B%5C%7Bi:%20i%20+%20b%5C%7D%7D;%20w)%5Ctag%7B7%7D%5C%5C%7B%7D"></p>
<ul>
<li>Shuffle the training dataset to avoid pre-existing order of examples.</li>
<li>Partition the training dataset into <img src="https://latex.codecogs.com/png.latex?b"> mini-batches based on the batch size. If the training set size is not divisible by batch size, the remaining will be its own batch.</li>
</ul>
<div class="sourceCode" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(num_epochs):</span>
<span id="cb2-2">np.random.shuffle(data)</span>
<span id="cb2-3"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> batch <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> radom_minibatches(data, batch_size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">32</span>):</span>
<span id="cb2-4">    grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> compute_gradient(batch, params)</span>
<span id="cb2-5">    params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> learning_rate <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> grad</span></code></pre></div>
<p>The batch size is something we can tune. It is usually chosen as power of 2 such as 32, 64, 128, 256, 512, etc. The reason behind it is because some hardware such as GPUs achieve better runtime with common batch sizes such as power of 2.</p>
<p>The main advantages:</p>
<ul>
<li>Faster than Batch version because it goes through a lot less examples than Batch (all examples).</li>
<li>Randomly selecting examples will help avoid redundant examples or examples that are very similar that don’t contribute much to the learning.</li>
<li>With batch size &lt; size of training set, it adds noise to the learning process that helps improving generalization error.</li>
<li>Even though with more examples the estimate would have lower standard error, the return is less than linear compared to the computational burden we incur.</li>
</ul>
<p>The main disadvantages:</p>
<ul>
<li>It won’t converge. On each iteration, the learning step may go back and forth due to the noise. Therefore, it wanders around the minimum region but never converges.</li>
<li>Due to the noise, the learning steps have more oscillations (see figure 4) and requires adding learning-decay to decrease the learning rate as we become closer to the minimum.</li>
</ul>
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<p><a href="images/batch-vs-minibatch.PNG" class="lightbox" data-gallery="quarto-lightbox-gallery-4" data-glightbox="description: .lightbox-desc-4" title="Figure 4: Gradient descent: batch versus mini-batch loss function"><img src="https://imaddabbura.github.io/posts/optimization/images/batch-vs-minibatch.PNG" class="quarto-figure quarto-figure-left figure-img" width="800" height="300" alt="Figure 4: Gradient descent: batch versus mini-batch loss function"></a></p>
</figure>
</div>
<figcaption><strong>Figure 4</strong>: Gradient descent: batch versus mini-batch loss function</figcaption>
</figure>
</div>
<p>With large training datasets, we don’t usually need more than 2-10 passes over all training examples (epochs). Note: with batch size <img src="https://latex.codecogs.com/png.latex?b%20=%20m">, we get the Batch Gradient Descent.</p>
</section>
<section id="stochastic-gradient-descent" class="level2">
<h2 class="anchored" data-anchor-id="stochastic-gradient-descent">Stochastic Gradient Descent</h2>
<p>Instead of going through all examples, Stochastic Gradient Descent (SGD) performs the parameters update on each example <img src="https://latex.codecogs.com/png.latex?(x%5Ei,%20y%5Ei)">. Therefore, learning happens on every example:</p>
<p><img src="https://latex.codecogs.com/png.latex?w%20=%20w%20-%20%5Calpha%20%5Cnabla_w%20J(x%5Ei,%20y%5Ei;%20w)%5Ctag%7B7%7D"></p>
<ul>
<li>Shuffle the training dataset to avoid pre-existing order of examples.</li>
<li>Partition the training dataset into <img src="https://latex.codecogs.com/png.latex?m"> examples.</li>
</ul>
<div class="sourceCode" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(num_epochs):</span>
<span id="cb3-2">    np.random.shuffle(data)</span>
<span id="cb3-3">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> example <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> data:</span>
<span id="cb3-4">        grad <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> compute_gradient(example, params)</span>
<span id="cb3-5">        params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> params <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> learning_rate <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> grad</span></code></pre></div>
<p>It shares most of the advantages and the disadvantages with mini-batch version. Below are the ones that are specific to SGD:</p>
<ul>
<li>It adds even more noise to the learning process than mini-batch that helps improving generalization error. However, this would increase the run time.</li>
<li>We can’t utilize vectorization over 1 example and becomes very slow. Also, the variance becomes large since we only use 1 example for each learning step.</li>
</ul>
<p>Below is a graph that shows the gradient descent’s variants and their direction towards the minimum:</p>
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<p><a href="images/batch-vs-minibatch-vs-stochastic.PNG" class="lightbox" data-gallery="quarto-lightbox-gallery-5" data-glightbox="description: .lightbox-desc-5" title="Figure 5: Gradient descent variants’ trajectory towards minimum"><img src="https://imaddabbura.github.io/posts/optimization/images/batch-vs-minibatch-vs-stochastic.PNG" class="quarto-figure quarto-figure-left figure-img" width="600" height="300" alt="Figure 5: Gradient descent variants’ trajectory towards minimum"></a></p>
</figure>
</div>
<figcaption><strong>Figure 5</strong>: Gradient descent variants’ trajectory towards minimum</figcaption>
</figure>
</div>
<p>As the figure above shows, SGD direction is very noisy compared to mini-batch.</p>
</section>
<section id="challenges" class="level2">
<h2 class="anchored" data-anchor-id="challenges">Challenges</h2>
<p>Below are some challenges regarding gradient descent algorithm in general as well as its variants - mainly batch and mini-batch:</p>
<ul>
<li>Gradient descent is a first-order optimization algorithm, which means it doesn’t take into account the second derivatives of the cost function. However, the curvature of the function affects the size of each learning step. The gradient measures the steepness of the curve but the second derivative measures the curvature of the curve. Therefore, if:</li>
<li>Second derivative = 0 <img src="https://latex.codecogs.com/png.latex?%5Crightarrow"> the curvature is linear. Therefore, the step size = the learning rate <img src="https://latex.codecogs.com/png.latex?%5Calpha">.</li>
<li>Second derivative &gt; 0 <img src="https://latex.codecogs.com/png.latex?%5Crightarrow"> the curvature is going upward. Therefore, the step size &lt; the learning rate <img src="https://latex.codecogs.com/png.latex?%5Calpha"> and may lead to divergence.</li>
<li>Second derivative &lt; 0 <img src="https://latex.codecogs.com/png.latex?%5Crightarrow"> the curvature is going downward. Therefore, the step size &gt; the learning rate <img src="https://latex.codecogs.com/png.latex?%5Calpha">.</li>
</ul>
<p>As a result, the direction that looks promising to the gradient may not be so and may lead to slow the learning process or even diverge. - If Hessian matrix has poor conditioning number, i.e.&nbsp;the direction of the most curvature has much more curvature than the direction of the lowest curvature. This will lead the cost function to be very sensitive in some directions and insensitive in other directions. As a result, it will make it harder on the gradient because the direction that looks promising for the gradient may not lead to big changes in the cost function (see figure 7).</p>
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<p><a href="images/curvature.PNG" class="lightbox" data-gallery="quarto-lightbox-gallery-6" data-glightbox="description: .lightbox-desc-6" title="Figure 6: Gradient descent fails to exploit the curvature information contained in the Hessian matrix. Source"><img src="https://imaddabbura.github.io/posts/optimization/images/curvature.PNG" class="quarto-figure quarto-figure-left figure-img" height="400" alt="Figure 6: Gradient descent fails to exploit the curvature information contained in the Hessian matrix. Source"></a></p>
</figure>
</div>
<figcaption><strong>Figure 6</strong>: Gradient descent fails to exploit the curvature information contained in the Hessian matrix. <a href="http://www.deeplearningbook.org/contents/numerical.html">Source</a></figcaption>
</figure>
</div>
<ul>
<li>The norm of the gradient <img src="https://latex.codecogs.com/png.latex?g%5ETg"> is supposed to decrease slowly with each learning step because the curve is getting flatter and steepness of the curve will decrease. However, we see that the norm of the gradient is increasing, because of the curvature of the curve. Nonetheless, even though the gradients’ norm is increasing, we’re able to achieve a very low error rates (see figure 8).</li>
</ul>
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<p><a href="images/gradient_norm.PNG" class="lightbox" data-gallery="quarto-lightbox-gallery-7" data-glightbox="description: .lightbox-desc-7" title="Figure 7: Gradient norm. Source"><img src="https://imaddabbura.github.io/posts/optimization/images/gradient_norm.PNG" class="quarto-figure quarto-figure-left figure-img" width="600" height="300" alt="Figure 7: Gradient norm. Source"></a></p>
</figure>
</div>
<figcaption><strong>Figure 7</strong>: Gradient norm. <a href="http://www.deeplearningbook.org/contents/optimization.html">Source</a></figcaption>
</figure>
</div>
<ul>
<li>In small dimensions, local minimum is common; however, in large dimensions, saddle points are more common. Saddle point is when the function curves up in some directions and curves down in other directions. In other words, saddle point looks a minimum from one direction and a maximum from other direction (see figure 9). This happens when at least one eigenvalue of the hessian matrix is negative and the rest of eigenvalues are positive.</li>
</ul>
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<div class="quarto-figure quarto-figure-left">
<figure class="figure">
<p><a href="images/saddle.PNG" class="lightbox" data-gallery="quarto-lightbox-gallery-8" data-glightbox="description: .lightbox-desc-8" title="Figure 8: Saddle point"><img src="https://imaddabbura.github.io/posts/optimization/images/saddle.PNG" class="quarto-figure quarto-figure-left figure-img" width="600" height="300" alt="Figure 8: Saddle point"></a></p>
</figure>
</div>
<figcaption><strong>Figure 8</strong>: Saddle point</figcaption>
</figure>
</div>
<ul>
<li>As discussed previously, choosing a proper learning rate is hard. Also, for mini-batch gradient descent, we have to adjust the learning rate during the training process to make sure it converges to the local minimum and not wander around it. Figuring out the decay rate of the learning rate is also hard and changes with different datasets.</li>
<li>All parameter updates have the same learning rate; however, we may want to perform larger updates to some parameters that have their directional derivatives more inline with the trajectory towards the minimum than other parameters.</li>
</ul>



</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Deep Learning</category>
  <guid>https://imaddabbura.github.io/posts/optimization/gradient-descent.html</guid>
  <pubDate>Mon, 18 Feb 2019 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/optimization/images/gradient_cover.PNG" medium="image"/>
</item>
<item>
  <title>Conda Essentials</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/swe/conda-essentials.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge growing">growing</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p><strong>Conda</strong> in an open source package management system that works on all platforms. It is a tool that helps manage packages and environments for different programming languages. Develop a high level understanding of how Conda works helped me at so many levels especially when it comes to managing environments and make my work more reproducable. Below are the notes that I wrote down during my journey of learning Conda and I always refere back to them:</p>
</section>
<section id="general" class="level2">
<h2 class="anchored" data-anchor-id="general">General</h2>
<ul>
<li>Conda packages are files and executables that can in principle contain images, data, noteboeeks, files, etc.</li>
<li>Conda mainly used in Python ecosystem; however, it can be used with other languages such R, Julia, Scala, etc.</li>
<li>When installing a package using Conda, it installs its dependencies with it. Also, Conda is able to figure out the platform you’re using without the need to specify the platform when installing packages.</li>
<li>When installing a package, Conda:
<ul>
<li>Checks the platform.</li>
<li>Checks the Python version.</li>
<li>Install the latest version of the package that is compatible with Python.</li>
<li>If it has dependencies, installs the latest versions of the dependencies that are also compatible with each other.</li>
</ul></li>
<li>Under semantic versioning, software is labeled with a three-part version identifier of the form <code>MAJOR.MINOR.PATCH</code>; the label components are non-negative integers separated by periods. Assuming all software starts at version 0.0.0, the <code>MAJOR</code> version number is increased when significant new functionality is introduced (often with corresponding API changes). Increases in the <code>MINOR</code> version number generally reflect improvements (e.g., new features) that avoid backward-incompatible API changes. For instance, adding an optional argument to a function API (in a way that allows old code to run unchanged) is a change worthy of increasing the <code>MINOR</code> version number. An increment to the <code>PATCH</code> version number is approriate mostly for bug fixes that preserve the same <code>MAJOR</code> and MINOR revision numbers. Software patches do not typically introduce new features or change APIs at all (except sometimes to address security issues).</li>
<li>We can specify <code>MAJOR</code>, <code>MAJOR.MINOR</code>, or <code>MAJOR.MINOR.PATCH</code> when installing any package.</li>
<li>We can use logical operators to install versions of a package. Examples:
<ul>
<li><code>conda install 'python=3.6|3.7'</code>.</li>
<li><code>conda install 'python=3.6|3.7*'</code> .</li>
<li><code>conda install 'python&gt;=3.6, &lt;=3.7'</code>.</li>
</ul></li>
</ul>
</section>
<section id="common-commands" class="level2">
<h2 class="anchored" data-anchor-id="common-commands">Common Commands</h2>
<ul>
<li>To update a package, <code>conda update pckg</code>.</li>
<li>To uninstall a package, <code>conda remove pckg</code>.</li>
<li>To search what available versions of a specific package is available, use <code>conda search pckg</code>.</li>
<li><code>conda list</code> will list all installed packages.</li>
<li><code>conda list -n env-name</code> will list all packages in the environment env-name.</li>
<li><code>conda list pckg</code> will give information about pckg.</li>
<li>When installing a pckg without including a channel, it defaults to the main channel that is maintained by Anaconda Inc.</li>
<li>There other channels where people can upload their packages to and we can reach to those channels when looking for installation such fastai. We use <code>conda install -c fastai fastai</code>. Here the channel is fastai and the pckg is also fastai.</li>
<li><code>conda search -c conda-forge -c fastai --override-channels --platform osx-64 fastai</code> means:
<ul>
<li>Search for fastai in two channels: conda-forge, fastai.</li>
<li>override-channels means do not go to default main channel.</li>
<li>platform specify which platform.</li>
</ul></li>
<li>Sometimes we don’t know the channel of the pckg, we can use <code>anaconda search pckg</code> that will return all the channels that the pckg is at and their versions.</li>
<li>conda-forge is almost as good as the main channels which is led by the community. It has a lot more packages than the main channel.</li>
<li>There is no system that rates channels, so be carefel when installing packages from any channel.</li>
<li>We can list all packages in a channel such as <code>conda search -c conda-forge --override-channels</code> that will list all packages for the conda-forge channel.</li>
</ul>
</section>
<section id="environments" class="level2">
<h2 class="anchored" data-anchor-id="environments">Environments</h2>
<ul>
<li>Environments are a good practice of documenting data science/software development work.</li>
<li>Environments are nothing more than a directory that contains all the packages so that when trying to import them, it imports them from this directory only. we can use <code>conda env list</code> to see all the available environments on our machine.</li>
<li>To get the packages from a specific environment by name, use <code>conda list -n env-name</code>. Otherwise, we get the packages from the current environment.</li>
<li>To activate an environment, use <code>conda activate env-name</code>. To deactivate, <code>conda deactivate</code>.</li>
<li>Environments usually don’t take a lot of space.</li>
<li>We can remove environments using <code>conda env remove -n env-name</code>.</li>
<li>To create an environment, use <code>conda create -n env-name</code>. We can also add additional package names to install after creation such as <code>conda create -n env-name python=3.6* numpy&gt;=1.1</code>.</li>
<li>To export an environment, use <code>conda env export -n env-name</code>. This will return the output to the terminal. We can also export to a file. For that use <code>conda env export -n env-name -f env-name.yml</code>. The ‘.yml’ extension is strongly enouraged. Doing this will assure that all the packages used can be installed by others exactly.</li>
<li>We can create also an environment from .yml file using <code>conda env create -f env-name.yml</code>. Note also that if we only use <code>conda env create</code>, it will look for a file that has .yml extension and has the same name as env-name in the current local directory. Moreover, we can create the .yml file with doing the export ourselves and only specify what is important in our environments.</li>
</ul>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Software Engineering</category>
  <guid>https://imaddabbura.github.io/posts/swe/conda-essentials.html</guid>
  <pubDate>Mon, 18 Feb 2019 06:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/swe/images/conda-essentials.png" medium="image" type="image/png" height="58" width="144"/>
</item>
<item>
  <title>K-means Clustering: Algorithm, Applications, Evaluation Methods, and Drawbacks</title>
  <dc:creator>Imad Dabbura</dc:creator>
  <link>https://imaddabbura.github.io/posts/ml/Kmeans-Clustering.html</link>
  <description><![CDATA[ 





<div class="status-badge-container" style="margin-bottom: 1rem;"><span class="status-badge growing">growing</span></div>
<section id="introduction" class="level2">
<h2 class="anchored" data-anchor-id="introduction">Introduction</h2>
<p><strong>Clustering</strong> is one of the most common exploratory data analysis technique used to get an intuition about the structure of the data. It can be defined as the task of identifying subgroups in the data such that data points in the same subgroup (cluster) are very similar while data points in different clusters are very different. In other words, we try to find homogeneous subgroups within the data such that data points in each cluster are as similar as possible according to a similarity measure such as euclidean-based distance or correlation-based distance. The decision of which similarity measure to use is application-specific.</p>
<p>Clustering analysis can be done on the basis of features where we try to find subgroups of samples based on features or on the basis of samples where we try to find subgroups of features based on samples. We’ll cover here clustering based on features. Clustering is used in market segmentation; where we try to fined customers that are similar to each other whether in terms of behaviors or attributes, image segmentation/compression; where we try to group similar regions together, document clustering based on topics, etc.</p>
<p>Unlike supervised learning, clustering is considered an unsupervised learning method since we don’t have the ground truth to compare the output of the clustering algorithm to the true labels to evaluate its performance. We only want to try to investigate the structure of the data by grouping the data points into distinct subgroups.</p>
<p>In this post, we will cover only <strong>Kmeans</strong> which is considered as one of the most used clustering algorithms due to its simplicity.</p>
</section>
<section id="kmeans-algorithm" class="level2">
<h2 class="anchored" data-anchor-id="kmeans-algorithm">Kmeans Algorithm</h2>
<p><strong>Kmeans</strong> algorithm is an iterative algorithm that tries to partition the dataset into <img src="https://latex.codecogs.com/png.latex?K"> pre-defined distinct non-overlapping subgroups (clusters) where each data point belongs to <strong>only one group</strong>. It tries to make the inter-cluster data points as similar as possible while also keeping the clusters as different (far) as possible. It assigns data points to a cluster such that the sum of the squared distance between the data points and the cluster’s centroid (arithmetic mean of all the data points that belong to that cluster) is at the minimum. The less variation we have within clusters, the more homogeneous (similar) the data points are within the same cluster.</p>
<p>The way kmeans algorithm works is as follows:</p>
<ol type="1">
<li>Specify number of clusters <img src="https://latex.codecogs.com/png.latex?K">.</li>
<li>Initialize centroids by first shuffling the dataset and then randomly selecting <img src="https://latex.codecogs.com/png.latex?K"> data points for the centroids without replacement.</li>
<li>Keep iterating until there is no change to the centroids. i.e assignment of data points to clusters isn’t changing.
<ul>
<li>Compute the sum of the squared distance between data points and all centroids.</li>
<li>Assign each data point to the closest cluster (centroid).</li>
<li>Compute the centroids for the clusters by taking the average of the all data points that belong to each cluster.</li>
</ul></li>
</ol>
<p>The approach kmeans follows to solve the problem is called <strong>Expectation-Maximization</strong>. The E-step is assigning the data points to the closest cluster. The M-step is computing the centroid of each cluster. Below is a break down of how we can solve it mathematically (feel free to skip it).</p>
<p>The objective function is: <img src="https://latex.codecogs.com/png.latex?J%20=%20%5Csum_%7Bi%20=%201%7D%5E%7Bm%7D%5Csum_%7Bk%20=%201%7D%5E%7BK%7Dw_%7Bik%7D%5C%7Cx%5Ei%20-%20%5Cmu_k%5C%7C%5E2%5C%5C%7B%7D"> where <img src="https://latex.codecogs.com/png.latex?w_%7Bik%7D%20=%201"> for data point <img src="https://latex.codecogs.com/png.latex?x%5Ei"> if it belongs to cluster <img src="https://latex.codecogs.com/png.latex?k">; otherwise, <img src="https://latex.codecogs.com/png.latex?w_%7Bik%7D%20=%200">. Also, <img src="https://latex.codecogs.com/png.latex?%5Cmu_k"> is the centroid of <img src="https://latex.codecogs.com/png.latex?x%5Ei">’s cluster.</p>
<p>It’s a minimization problem of two parts. We first minimize J w.r.t. <img src="https://latex.codecogs.com/png.latex?w_%7Bik%7D"> and treat <img src="https://latex.codecogs.com/png.latex?%5Cmu_k"> fixed. Then we minimize J w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cmu_k"> and treat <img src="https://latex.codecogs.com/png.latex?w_%7Bik%7D"> fixed. Technically speaking, we differentiate J w.r.t. <img src="https://latex.codecogs.com/png.latex?w_%7Bik%7D"> first and update cluster assignments (<em>E-step</em>). Then we differentiate J w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cmu_%7Bk%7D"> and recompute the centroids after the cluster assignments from previous step (<em>M-step</em>). Therefore, E-step is: <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Cpartial%20J%7D%7B%5Cpartial%20w_%7Bik%7D%7D%20=%20%5Csum_%7Bi%20=%201%7D%5E%7Bm%7D%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%5C%7Cx%5Ei%20-%20%5Cmu_k%5C%7C%5E2%5C%5C%7B%7D"> <img src="https://latex.codecogs.com/png.latex?%0A%5CRightarrow%0A%5Cbegin%7Bequation%7D%0A%20%20w_%7Bik%7D%20=%20%5Cbegin%7Bcases%7D%0A%20%20%20%201%20&amp;%20%5Ctext%7Bif%20$k%20=%20arg%20min_j%5C%20%5C%7Cx%5Ei%20-%20%5Cmu_j%5C%7C%5E2$%7D%5C%5C%0A%20%20%20%200%20&amp;%20%5Ctext%7Botherwise%7D.%0A%20%20%5Cend%7Bcases%7D%0A%5Cend%7Bequation%7D%5Ctag%7B1%7D%5C%5C%7B%7D%0A"> In other words, assign the data point <img src="https://latex.codecogs.com/png.latex?x%5Ei"> to the closest cluster judged by its sum of squared distance from cluster’s centroid.</p>
<p>And M-step is: <img src="https://latex.codecogs.com/png.latex?%5C%20%5Cfrac%7B%5Cpartial%20J%7D%7B%5Cpartial%20%5Cmu_k%7D%20=%202%5Csum_%7Bi%20=%201%7D%5E%7Bm%7Dw_%7Bik%7D(x%5Ei%20-%20%5Cmu_k)%20=%200%5C%5C%7B%7D"> <img src="https://latex.codecogs.com/png.latex?%5CRightarrow%20%5Cmu_k%20=%20%5Cfrac%7B%5Csum_%7Bi%20=%201%7D%5E%7Bm%7Dw_%7Bik%7Dx%5Ei%7D%7B%5Csum_%7Bi%20=%201%7D%5E%7Bm%7Dw_%7Bik%7D%7D%5Ctag%7B2%7D%5C%5C%7B%7D"> Which translates to recomputing the centroid of each cluster to reflect the new assignments.</p>
<p>Few things to note here:</p>
<ul>
<li>Since clustering algorithms including kmeans use distance-based measurements to determine the similarity between data points, it’s recommended to standardize the data to have a mean of zero and a standard deviation of one since almost always the features in any dataset would have different units of measurements such as age vs income.</li>
<li>Given kmeans iterative nature and the random initialization of centroids at the start of the algorithm, different initializations may lead to different clusters since kmeans algorithm may <em>stuck in a local optimum and may not converge to global optimum</em>. Therefore, it’s recommended to run the algorithm using different initializations of centroids and pick the results of the run that that yielded the lower sum of squared distance.</li>
<li>Assignment of examples isn’t changing is the same thing as no change in within-cluster variation: <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B1%7D%7Bm_k%7D%5Csum_%7Bi%20=%201%7D%5E%7Bm_k%7D%5C%7Cx%5Ei%20-%20%5Cmu_%7Bc%5Ek%7D%5C%7C%5E2"></li>
</ul>
</section>
<section id="implementation" class="level2">
<h2 class="anchored" data-anchor-id="implementation">Implementation</h2>
<p>We’ll use simple implementation of kmeans here to just illustrate some concepts. Then we will use <code>sklearn</code> implementation that is more efficient take care of many things for us.</p>
<div id="cell-7" class="cell" data-code_folding="[]" data-execution_count="1">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb1" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb1-1"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> numpy <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> np</span>
<span id="cb1-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> numpy.linalg <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> norm</span>
<span id="cb1-3"></span>
<span id="cb1-4"></span>
<span id="cb1-5"><span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">class</span> Kmeans:</span>
<span id="cb1-6">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">'''Implementing Kmeans algorithm.'''</span></span>
<span id="cb1-7"></span>
<span id="cb1-8">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> <span class="fu" style="color: #4758AB;
background-color: null;
font-style: inherit;">__init__</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, n_clusters, max_iter<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>, random_state<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">123</span>):</span>
<span id="cb1-9">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_clusters <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> n_clusters</span>
<span id="cb1-10">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.max_iter <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> max_iter</span>
<span id="cb1-11">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.random_state <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> random_state</span>
<span id="cb1-12"></span>
<span id="cb1-13">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> initializ_centroids(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, X):</span>
<span id="cb1-14">        np.random.RandomState(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.random_state)</span>
<span id="cb1-15">        random_idx <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.permutation(X.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb1-16">        centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X[random_idx[:<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_clusters]]</span>
<span id="cb1-17">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> centroids</span>
<span id="cb1-18"></span>
<span id="cb1-19">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> compute_centroids(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, X, labels):</span>
<span id="cb1-20">        centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros((<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_clusters, X.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]))</span>
<span id="cb1-21">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> k <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_clusters):</span>
<span id="cb1-22">            centroids[k, :] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.mean(X[labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> k, :], axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>)</span>
<span id="cb1-23">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> centroids</span>
<span id="cb1-24"></span>
<span id="cb1-25">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> compute_distance(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, X, centroids):</span>
<span id="cb1-26">        distance <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros((X.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_clusters))</span>
<span id="cb1-27">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> k <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_clusters):</span>
<span id="cb1-28">            row_norm <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> norm(X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> centroids[k, :], axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb1-29">            distance[:, k] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.square(row_norm)</span>
<span id="cb1-30">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> distance</span>
<span id="cb1-31"></span>
<span id="cb1-32">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> find_closest_cluster(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, distance):</span>
<span id="cb1-33">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> np.argmin(distance, axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb1-34"></span>
<span id="cb1-35">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> compute_sse(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, X, labels, centroids):</span>
<span id="cb1-36">        distance <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.zeros(X.shape[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb1-37">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> k <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.n_clusters):</span>
<span id="cb1-38">            distance[labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> k] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> norm(X[labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> k] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span> centroids[k], axis<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb1-39">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> np.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">sum</span>(np.square(distance))</span>
<span id="cb1-40">    </span>
<span id="cb1-41">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> fit(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, X):</span>
<span id="cb1-42">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.initializ_centroids(X)</span>
<span id="cb1-43">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.max_iter):</span>
<span id="cb1-44">            old_centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.centroids</span>
<span id="cb1-45">            distance <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.compute_distance(X, old_centroids)</span>
<span id="cb1-46">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.find_closest_cluster(distance)</span>
<span id="cb1-47">            <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.compute_centroids(X, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.labels)</span>
<span id="cb1-48">            <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">if</span> np.<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">all</span>(old_centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.centroids):</span>
<span id="cb1-49">                <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">break</span></span>
<span id="cb1-50">        <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.error <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.compute_sse(X, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.labels, <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.centroids)</span>
<span id="cb1-51">    </span>
<span id="cb1-52">    <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">def</span> predict(<span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>, X):</span>
<span id="cb1-53">        distance <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.compute_distance(X, old_centroids)</span>
<span id="cb1-54">        <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">return</span> <span class="va" style="color: #111111;
background-color: null;
font-style: inherit;">self</span>.find_closest_cluster(distance)</span></code></pre></div>
</details>
</div>
</section>
<section id="applications" class="level2">
<h2 class="anchored" data-anchor-id="applications">Applications</h2>
<p>kmeans algorithm is very popular and used in a variety of applications such as market segmentation, document clustering, image segmentation and image compression, etc. The goal usually when we undergo a cluster analysis is either:</p>
<ol type="1">
<li>Get a meaningful intuition of the structure of the data we’re dealing with.</li>
<li>Cluster-then-predict where different models will be built for different subgroups if we believe there is a wide variation in the behaviors of different subgroups. An example of that is clustering patients into different subgroups and build a model for each subgroup to predict the probability of the risk of having heart attack.</li>
</ol>
<p>In this post, we’ll apply clustering on two cases: - Geyser eruptions segmentation (2-D dataset). - Image compression.</p>
<section id="kmeans-on-geysers-eruptions-segmentation" class="level3">
<h3 class="anchored" data-anchor-id="kmeans-on-geysers-eruptions-segmentation">Kmeans on Geyser’s Eruptions Segmentation</h3>
<p>We’ll first implement the kmeans algorithm on 2D dataset and see how it works. The dataset has 272 observations and 2 features. The data covers the waiting time between eruptions and the duration of the eruption for the Old Faithful geyser in Yellowstone National Park, Wyoming, USA. We will try to find <img src="https://latex.codecogs.com/png.latex?K"> subgroups within the data points and group them accordingly. Below is the description of the features:</p>
<ul>
<li>eruptions (float): Eruption time in minutes.</li>
<li>waiting (int): Waiting time to next eruption.</li>
</ul>
<p>Let’s plot the data first:</p>
<div id="cell-12" class="cell" data-code_folding="[]" data-execution_count="2">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb2" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb2-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Modules</span></span>
<span id="cb2-2"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> matplotlib.pyplot <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> plt</span>
<span id="cb2-3"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> matplotlib.image <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> imread</span>
<span id="cb2-4"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> pandas <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> pd</span>
<span id="cb2-5"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> seaborn <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">as</span> sns</span>
<span id="cb2-6"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> sklearn.datasets.samples_generator <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> (make_blobs,</span>
<span id="cb2-7">                                                make_circles,</span>
<span id="cb2-8">                                                make_moons)</span>
<span id="cb2-9"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> sklearn.cluster <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> KMeans, SpectralClustering</span>
<span id="cb2-10"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> sklearn.preprocessing <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> StandardScaler</span>
<span id="cb2-11"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> sklearn.metrics <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> silhouette_samples, silhouette_score</span>
<span id="cb2-12"></span>
<span id="cb2-13"><span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span>matplotlib inline</span>
<span id="cb2-14">sns.set_context(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'notebook'</span>)</span>
<span id="cb2-15">plt.style.use(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'fivethirtyeight'</span>)</span>
<span id="cb2-16"><span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">from</span> warnings <span class="im" style="color: #00769E;
background-color: null;
font-style: inherit;">import</span> filterwarnings</span>
<span id="cb2-17">filterwarnings(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'ignore'</span>)</span></code></pre></div>
</details>
</div>
<div id="cell-13" class="cell" data-execution_count="3">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb3" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb3-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Import the data</span></span>
<span id="cb3-2">df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> pd.read_csv(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'../data/old_faithful.csv'</span>)</span>
<span id="cb3-3"></span>
<span id="cb3-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot the data</span></span>
<span id="cb3-5">plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>))</span>
<span id="cb3-6">plt.scatter(df.iloc[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], df.iloc[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb3-7">plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Eruption time in mins'</span>)</span>
<span id="cb3-8">plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Waiting time to next eruption'</span>)</span>
<span id="cb3-9">plt.title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Visualization of raw data'</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span></span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-4-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-1"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-4-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>We’ll use this data because it’s easy to plot and visually spot the clusters since its a 2-dimension dataset. It’s obvious that we have 2 clusters. Let’s standardize the data first and run the kmeans algorithm on the standardized data with <img src="https://latex.codecogs.com/png.latex?K%20=%202">.</p>
<div id="cell-15" class="cell" data-code_folding="[]" data-execution_count="4">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb4" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb4-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Standardize the data</span></span>
<span id="cb4-2">X_std <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> StandardScaler().fit_transform(df)</span>
<span id="cb4-3"></span>
<span id="cb4-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Run local implementation of kmeans</span></span>
<span id="cb4-5">km <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Kmeans(n_clusters<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, max_iter<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">100</span>)</span>
<span id="cb4-6">km.fit(X_std)</span>
<span id="cb4-7">centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.centroids</span>
<span id="cb4-8"></span>
<span id="cb4-9"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot the clustered data</span></span>
<span id="cb4-10">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>))</span>
<span id="cb4-11">plt.scatter(X_std[km.labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X_std[km.labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb4-12">            c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'green'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cluster 1'</span>)</span>
<span id="cb4-13">plt.scatter(X_std[km.labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X_std[km.labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb4-14">            c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'blue'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cluster 2'</span>)</span>
<span id="cb4-15">plt.scatter(centroids[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], centroids[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'*'</span>, s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">300</span>,</span>
<span id="cb4-16">            c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'centroid'</span>)</span>
<span id="cb4-17">plt.legend()</span>
<span id="cb4-18">plt.xlim([<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb4-19">plt.ylim([<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb4-20">plt.xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Eruption time in mins'</span>)</span>
<span id="cb4-21">plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Waiting time to next eruption'</span>)</span>
<span id="cb4-22">plt.title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Visualization of clustered data'</span>, fontweight<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'bold'</span>)</span>
<span id="cb4-23">ax.set_aspect(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'equal'</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span></span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-5-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-2"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-5-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>The above graph shows the scatter plot of the data colored by the cluster they belong to. In this example, we chose <img src="https://latex.codecogs.com/png.latex?K%20=%202">. The symbol **’*’** is the centroid of each cluster. We can think of those 2 clusters as geyser had different kinds of behaviors under different scenarios.</p>
<p>Next, we’ll show that different initializations of centroids may yield to different results. I’ll use 9 different <code>random_state</code> to change the initialization of the centroids and plot the results. The title of each plot will be the sum of squared distance of each initialization.</p>
<p>As a side note, this dataset is considered very easy and converges in less than 10 iterations. Therefore, to see the effect of random initialization on convergence, I am going to go with 3 iterations to illustrate the concept. However, in real world applications, datasets are not at all that clean and nice!</p>
<div id="cell-17" class="cell" data-execution_count="5">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb5" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb5-1">n_iter <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">9</span></span>
<span id="cb5-2">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>))</span>
<span id="cb5-3">ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.ravel(ax)</span>
<span id="cb5-4">centers <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb5-5"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(n_iter):</span>
<span id="cb5-6">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Run local implementation of kmeans</span></span>
<span id="cb5-7">    km <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> Kmeans(n_clusters<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>,</span>
<span id="cb5-8">                max_iter<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>,</span>
<span id="cb5-9">                random_state<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>np.random.randint(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1000</span>, size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb5-10">    km.fit(X_std)</span>
<span id="cb5-11">    centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.centroids</span>
<span id="cb5-12">    centers.append(centroids)</span>
<span id="cb5-13">    ax[i].scatter(X_std[km.labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X_std[km.labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb5-14">                  c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'green'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cluster 1'</span>)</span>
<span id="cb5-15">    ax[i].scatter(X_std[km.labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X_std[km.labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb5-16">                  c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'blue'</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'cluster 2'</span>)</span>
<span id="cb5-17">    ax[i].scatter(centroids[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], centroids[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>],</span>
<span id="cb5-18">                  c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r'</span>, marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'*'</span>, s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">300</span>, label<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'centroid'</span>)</span>
<span id="cb5-19">    ax[i].set_xlim([<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb5-20">    ax[i].set_ylim([<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb5-21">    ax[i].legend(loc<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'lower right'</span>)</span>
<span id="cb5-22">    ax[i].set_title(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>km<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">.</span>error<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">:.4f}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>)</span>
<span id="cb5-23">    ax[i].set_aspect(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'equal'</span>)</span>
<span id="cb5-24">plt.tight_layout()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span></span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-6-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-3"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-6-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>As the graph above shows that we only ended up with two different ways of clusterings based on different initializations. We would pick the one with the lowest sum of squared distance.</p>
</section>
<section id="image-compression" class="level3">
<h3 class="anchored" data-anchor-id="image-compression">Image Compression</h3>
<p>In this part, we’ll implement kmeans to compress an image. The image that we’ll be working on is 396 x 396 x 3. Therefore, for each pixel location we would have 3 8-bit integers that specify the red, green, and blue intensity values. Our goal is to reduce the number of colors to 30 and represent (compress) the photo using those 30 colors only. To pick which colors to use, we’ll use kmeans algorithm on the image and treat every pixel as a data point. That means reshape the image from height x width x channels to (height * width) x channel, i,e we would have 396 x 396 = 156,816 data points in 3-dimensional space which are the intensity of RGB. Doing so will allow us to represent the image using the 30 centroids for each pixel and would significantly reduce the size of the image by a factor of 6. The original image size was 396 x 396 x 24 = 3,763,584 bits; however, the new compressed image would be 30 x 24 + 396 x 396 x 4 = 627,984 bits. The huge difference comes from the fact that we’ll be using centroids as a lookup for pixels’ colors and that would reduce the size of each pixel location to 4-bit instead of 8-bit.</p>
<p>From now on we will be using <code>sklearn</code> implementation of kmeans. Few thing to note here:</p>
<ul>
<li><code>n_init</code> is the number of times of running the kmeans with different centroid’s initialization. The result of the best one will be reported.</li>
<li><code>tol</code> is the within-cluster variation metric used to declare convergence.</li>
<li>The default of <code>init</code> is <strong>k-means++</strong> which is supposed to yield a better results than just random initialization of centroids.</li>
</ul>
<div id="cell-21" class="cell" data-code_folding="[]" data-execution_count="6">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb6" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb6-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Read the image</span></span>
<span id="cb6-2">img <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> imread(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'images/my_image.jpg'</span>)</span>
<span id="cb6-3">img_size <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> img.shape</span>
<span id="cb6-4"></span>
<span id="cb6-5"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Reshape it to be 2-dimension</span></span>
<span id="cb6-6">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> img.reshape(img_size[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>] <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">*</span> img_size[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], img_size[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb6-7"></span>
<span id="cb6-8"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Run the Kmeans algorithm</span></span>
<span id="cb6-9">km <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> KMeans(n_clusters<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">30</span>)</span>
<span id="cb6-10">km.fit(X)</span>
<span id="cb6-11"></span>
<span id="cb6-12"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Use the centroids to compress the image</span></span>
<span id="cb6-13">X_compressed <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.cluster_centers_[km.labels_]</span>
<span id="cb6-14">X_compressed <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.clip(X_compressed.astype(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'uint8'</span>), <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">255</span>)</span>
<span id="cb6-15"></span>
<span id="cb6-16"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Reshape X_recovered to have the same dimension as the original image 128 * 128 * 3</span></span>
<span id="cb6-17">X_compressed <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> X_compressed.reshape(img_size[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], img_size[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], img_size[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb6-18"></span>
<span id="cb6-19"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot the original and the compressed image next to each other</span></span>
<span id="cb6-20">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> (<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">12</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">8</span>))</span>
<span id="cb6-21">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].imshow(img)</span>
<span id="cb6-22">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Original Image'</span>)</span>
<span id="cb6-23">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].imshow(X_compressed)</span>
<span id="cb6-24">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Compressed Image with 30 colors'</span>)</span>
<span id="cb6-25"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> ax <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> fig.axes:</span>
<span id="cb6-26">    ax.axis(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'off'</span>)</span>
<span id="cb6-27">plt.tight_layout()<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span></span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-7-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-4"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-7-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>We can see the comparison between the original image and the compressed one. The compressed image looks close to the original one which means we’re able to retain the majority of the characteristics of the original image. With smaller number of clusters we would have higher compression rate at the expense of image quality. As a side note, this image compression method is called <em>lossy data compression</em> because we can’t reconstruct the original image from the compressed image.</p>
</section>
</section>
<section id="evaluation-methods" class="level2">
<h2 class="anchored" data-anchor-id="evaluation-methods">Evaluation Methods</h2>
<p>Contrary to supervised learning where we have the ground truth to evaluate the model’s performance, clustering analysis doesn’t have a solid evaluation metric that we can use to evaluate the outcome of different clustering algorithms. Moreover, since kmeans requires <img src="https://latex.codecogs.com/png.latex?k"> as an input and doesn’t learn it from data, there is no right answer in terms of the number of clusters that we should have in any problem. Sometimes domain knowledge and intuition may help but usually that is not the case. In the cluster-predict methodology, we can evaluate how well the models are performing based on different <img src="https://latex.codecogs.com/png.latex?K"> clusters since clusters are used in the downstream modeling.</p>
<p>In this post we’ll cover two metrics that may give us some intuition about <img src="https://latex.codecogs.com/png.latex?k">:</p>
<ul>
<li>Elbow method</li>
<li>Silhouette analysis</li>
</ul>
<section id="elbow-method" class="level3">
<h3 class="anchored" data-anchor-id="elbow-method">Elbow Method</h3>
<p><strong>Elbow</strong> method gives us an idea on what a good <img src="https://latex.codecogs.com/png.latex?k"> number of clusters would be based on the sum of squared distance (SSE) between data points and their assigned clusters’ centroids. We pick <img src="https://latex.codecogs.com/png.latex?k"> at the spot where SSE starts to flatten out and forming an elbow. We’ll use the geyser dataset and evaluate SSE for different values of <img src="https://latex.codecogs.com/png.latex?k"> and see where the curve might form an elbow and flatten out.</p>
<div id="cell-27" class="cell" data-execution_count="7">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb7" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb7-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Run the Kmeans algorithm and get the index of data points clusters</span></span>
<span id="cb7-2">sse <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb7-3">list_k <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">list</span>(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>))</span>
<span id="cb7-4"></span>
<span id="cb7-5"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> k <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> list_k:</span>
<span id="cb7-6">    km <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> KMeans(n_clusters<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>k)</span>
<span id="cb7-7">    km.fit(X_std)</span>
<span id="cb7-8">    sse.append(km.inertia_)</span>
<span id="cb7-9"></span>
<span id="cb7-10"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot sse against k</span></span>
<span id="cb7-11">plt.figure(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>))</span>
<span id="cb7-12">plt.plot(list_k, sse, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'-o'</span>)</span>
<span id="cb7-13">plt.xlabel(<span class="vs" style="color: #20794D;
background-color: null;
font-style: inherit;">r'Number of clusters $k$'</span>)</span>
<span id="cb7-14">plt.ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Sum of squared distance'</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span></span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-8-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-5"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-8-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>The graph above shows that <img src="https://latex.codecogs.com/png.latex?k%20=%202"> is not a good choice. Sometimes it’s still hard to figure out a good number of clusters to use because the curve is monotonically decreasing and may not show any elbow or has an obvious point where the curve starts flattening out.</p>
</section>
<section id="silhouette-analysis" class="level3">
<h3 class="anchored" data-anchor-id="silhouette-analysis">Silhouette Analysis</h3>
<p><strong>Silhouette analysis</strong> can be used to determine the degree of separation between clusters. For each sample:</p>
<ul>
<li>Compute the average distance from all data points in the same cluster (<img src="https://latex.codecogs.com/png.latex?a%5Ei">).</li>
<li>Compute the average distance from all data points in the closest cluster (<img src="https://latex.codecogs.com/png.latex?b%5Ei">).</li>
<li>Compute the coefficient: <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7Bb%5Ei%20-%20a%5Ei%7D%7Bmax(a%5Ei,%20b%5Ei)%7D"> The coefficient can take values in the interval [-1, 1].
<ul>
<li>If it is 0 –&gt; the sample is very close to the neighboring clusters.</li>
<li>It it is 1 –&gt; the sample is far away from the neighboring clusters.</li>
<li>It it is -1 –&gt; the sample is assigned to the wrong clusters.</li>
</ul></li>
</ul>
<p>Therefore, we want the coefficients to be as big as possible and close to 1 to have a good clusters. We’ll use here geyser dataset again because its cheaper to run the silhouette analysis and it is actually obvious that there is most likely only two groups of data points.</p>
<div id="cell-31" class="cell" data-execution_count="8">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb8" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb8-1"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i, k <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>]):</span>
<span id="cb8-2">    fig, (ax1, ax2) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb8-3">    fig.set_size_inches(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">18</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">7</span>)</span>
<span id="cb8-4">    </span>
<span id="cb8-5">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Run the Kmeans algorithm</span></span>
<span id="cb8-6">    km <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> KMeans(n_clusters<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>k)</span>
<span id="cb8-7">    labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.fit_predict(X_std)</span>
<span id="cb8-8">    centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.cluster_centers_</span>
<span id="cb8-9"></span>
<span id="cb8-10">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Get silhouette samples</span></span>
<span id="cb8-11">    silhouette_vals <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> silhouette_samples(X_std, labels)</span>
<span id="cb8-12"></span>
<span id="cb8-13">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Silhouette plot</span></span>
<span id="cb8-14">    y_ticks <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> []</span>
<span id="cb8-15">    y_lower, y_upper <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span></span>
<span id="cb8-16">    <span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i, cluster <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(np.unique(labels)):</span>
<span id="cb8-17">        cluster_silhouette_vals <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> silhouette_vals[labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">==</span> cluster]</span>
<span id="cb8-18">        cluster_silhouette_vals.sort()</span>
<span id="cb8-19">        y_upper <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(cluster_silhouette_vals)</span>
<span id="cb8-20">        ax1.barh(<span class="bu" style="color: null;
background-color: null;
font-style: inherit;">range</span>(y_lower, y_upper), cluster_silhouette_vals, edgecolor<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'none'</span>, height<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>)</span>
<span id="cb8-21">        ax1.text(<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.03</span>, (y_lower <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> y_upper) <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">/</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">str</span>(i <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+</span> <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>))</span>
<span id="cb8-22">        y_lower <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">+=</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">len</span>(cluster_silhouette_vals)</span>
<span id="cb8-23"></span>
<span id="cb8-24">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Get the average silhouette score and plot it</span></span>
<span id="cb8-25">    avg_score <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.mean(silhouette_vals)</span>
<span id="cb8-26">    ax1.axvline(avg_score, linestyle<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'--'</span>, linewidth<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'green'</span>)</span>
<span id="cb8-27">    ax1.set_yticks([])</span>
<span id="cb8-28">    ax1.set_xlim([<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb8-29">    ax1.set_xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Silhouette coefficient values'</span>)</span>
<span id="cb8-30">    ax1.set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Cluster labels'</span>)</span>
<span id="cb8-31">    ax1.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Silhouette plot for the various clusters'</span>, y<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.02</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span></span>
<span id="cb8-32">    </span>
<span id="cb8-33">    <span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Scatter plot of data colored with labels</span></span>
<span id="cb8-34">    ax2.scatter(X_std[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X_std[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>labels)</span>
<span id="cb8-35">    ax2.scatter(centroids[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], centroids[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'*'</span>, c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r'</span>, s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">250</span>)</span>
<span id="cb8-36">    ax2.set_xlim([<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb8-37">    ax2.set_xlim([<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">-</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>])</span>
<span id="cb8-38">    ax2.set_xlabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Eruption time in mins'</span>)</span>
<span id="cb8-39">    ax2.set_ylabel(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Waiting time to next eruption'</span>)</span>
<span id="cb8-40">    ax2.set_title(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Visualization of clustered data'</span>, y<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.02</span>)</span>
<span id="cb8-41">    ax2.set_aspect(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'equal'</span>)</span>
<span id="cb8-42">    plt.tight_layout()</span>
<span id="cb8-43">    plt.suptitle(<span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">f'Silhouette analysis using k = </span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">{</span>k<span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">}</span><span class="ss" style="color: #20794D;
background-color: null;
font-style: inherit;">'</span>,</span>
<span id="cb8-44">                 fontsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">16</span>, fontweight<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'semibold'</span>, y<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.05</span>)<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">;</span></span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-9-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-6"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-9-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-9-output-2.png" class="lightbox" data-gallery="quarto-lightbox-gallery-7"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-9-output-2.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-9-output-3.png" class="lightbox" data-gallery="quarto-lightbox-gallery-8"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-9-output-3.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>As the above plots show, <code>n_clusters=2</code> has the best average silhouette score of around 0.75 and all clusters being above the average shows that it is actually a good choice. Also, the thickness of the silhouette plot gives an indication of how big each cluster is. The plot shows that cluster 1 has almost double the samples than cluster 2. However, as we increased <code>n_clusters</code> to 3 and 4, the average silhouette score decreased dramatically to around 0.48 and 0.39 respectively. Moreover, the thickness of silhouette plot started showing wide fluctuations. The bottom line is: Good <code>n_clusters</code> will have a well above 0.5 silhouette average score as well as all of the clusters have higher than the average score.</p>
</section>
</section>
<section id="drawbacks" class="level2">
<h2 class="anchored" data-anchor-id="drawbacks">Drawbacks</h2>
<p>Kmeans algorithm is good in capturing structure of the data if clusters have a spherical-like shape. It always try to construct a nice spherical shape around the centroid. That means, the minute the clusters have a complicated geometric shapes, kmeans does a poor job in clustering the data. We’ll illustrate three cases where kmeans will not perform well.</p>
<p>First, kmeans algorithm doesn’t let data points that are far-away from each other share the same cluster even though they obviously belong to the same cluster. Below is an example of data points on two different horizontal lines that illustrates how kmeans tries to group half of the data points of each horizontal lines together.</p>
<div id="cell-35" class="cell" data-code_folding="[]" data-execution_count="9">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb9" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb9-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Create horizantal data</span></span>
<span id="cb9-2">X <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.tile(np.linspace(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span>), <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb9-3">y <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.repeat(np.array([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>]), <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span>)</span>
<span id="cb9-4">df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.c_[X, y]</span>
<span id="cb9-5"></span>
<span id="cb9-6">km <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> KMeans(n_clusters<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb9-7">km.fit(df)</span>
<span id="cb9-8">labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.predict(df)</span>
<span id="cb9-9">centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.cluster_centers_</span>
<span id="cb9-10"></span>
<span id="cb9-11">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>))</span>
<span id="cb9-12">plt.scatter(X, y, c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>labels)</span>
<span id="cb9-13">plt.xlim([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>])</span>
<span id="cb9-14">plt.ylim([<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>])</span>
<span id="cb9-15">plt.text(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">5.1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'A'</span>, color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'red'</span>)</span>
<span id="cb9-16">plt.text(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">5.1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'B'</span>, color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'red'</span>)</span>
<span id="cb9-17">plt.text(<span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">2.8</span>, <span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">4.1</span>, <span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'C'</span>, color<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'red'</span>)</span>
<span id="cb9-18">ax.set_aspect(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'equal'</span>)</span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-10-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-9"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-10-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>Kmeans considers the point ‘B’ closer to point ‘A’ than point ‘C’ since they have non-spherical shape. Therefore, points ‘A’ and ‘B’ will be in the same cluster but point ‘C’ will be in a different cluster. Note the <strong>Single Linkage</strong> hierarchical clustering method gets this right because it doesn’t separate similar points).</p>
<p>Second, we’ll generate data from multivariate normal distributions with different means and standard deviations. So we would have 3 groups of data where each group was generated from different multivariate normal distribution (different mean/standard deviation). One group will have a lot more data points than the other two combined. Next, we’ll run kmeans on the data with <img src="https://latex.codecogs.com/png.latex?K%20=%203"> and see if it will be able to cluster the data correctly. To make the comparison easier, I am going to plot first the data colored based on the distribution it came from. Then I will plot the same data but now colored based on the clusters they have been assigned to.</p>
<div id="cell-38" class="cell" data-code_folding="[]" data-execution_count="10">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb10" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb10-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Create data from three different multivariate distributions</span></span>
<span id="cb10-2">X_1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.multivariate_normal(mean<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">4</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], cov<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>]], size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">75</span>)</span>
<span id="cb10-3">X_2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.multivariate_normal(mean<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">6</span>], cov<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>]], size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">250</span>)</span>
<span id="cb10-4">X_3 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.random.multivariate_normal(mean<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">5</span>], cov<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>[[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], [<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>]], size<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">20</span>)</span>
<span id="cb10-5">df <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> np.concatenate([X_1, X_2, X_3])</span>
<span id="cb10-6"></span>
<span id="cb10-7"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Run kmeans</span></span>
<span id="cb10-8">km <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> KMeans(n_clusters<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">3</span>)</span>
<span id="cb10-9">km.fit(df)</span>
<span id="cb10-10">labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.predict(df)</span>
<span id="cb10-11">centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.cluster_centers_</span>
<span id="cb10-12"></span>
<span id="cb10-13"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Plot the data</span></span>
<span id="cb10-14">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, figsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">10</span>))</span>
<span id="cb10-15">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].scatter(X_1[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X_1[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb10-16">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].scatter(X_2[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X_2[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb10-17">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].scatter(X_3[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X_3[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>])</span>
<span id="cb10-18">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>].set_aspect(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'equal'</span>)</span>
<span id="cb10-19">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].scatter(df[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], df[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>labels)</span>
<span id="cb10-20">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].scatter(centroids[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], centroids[:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'o'</span>,</span>
<span id="cb10-21">                c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">"white"</span>, alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">200</span>, edgecolor<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'k'</span>)</span>
<span id="cb10-22"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i, c <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>(centroids):</span>
<span id="cb10-23">    ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].scatter(c[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], c[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'$</span><span class="sc" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%d</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">$'</span> <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">%</span> i, s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">50</span>, alpha<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, edgecolor<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r'</span>)</span>
<span id="cb10-24">ax[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>].set_aspect(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'equal'</span>)</span>
<span id="cb10-25">plt.tight_layout()</span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-11-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-10"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-11-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>Looks like kmeans couldn’t figure out the clusters correctly. Since it tries to minimize the within-cluster variation, it gives more weight to bigger clusters than smaller ones. In other words, data points in smaller clusters may be left away from the centroid in order to focus more on the larger cluster.</p>
<p>Last, we’ll generate data that have complicated geometric shapes such as moons and circles within each other and test kmeans on both of the datasets.</p>
<div id="cell-40" class="cell" data-code_folding="[]" data-execution_count="11">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb11" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb11-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Cricles</span></span>
<span id="cb11-2">X1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_circles(factor<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, noise<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.05</span>, n_samples<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1500</span>)</span>
<span id="cb11-3"></span>
<span id="cb11-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Moons</span></span>
<span id="cb11-5">X2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_moons(n_samples<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1500</span>, noise<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.05</span>)</span>
<span id="cb11-6"></span>
<span id="cb11-7">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb11-8"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i, X <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>([X1, X2]):</span>
<span id="cb11-9">    fig.set_size_inches(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">18</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">7</span>)</span>
<span id="cb11-10">    km <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> KMeans(n_clusters<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb11-11">    km.fit(X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb11-12">    labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.predict(X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb11-13">    centroids <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> km.cluster_centers_</span>
<span id="cb11-14"></span>
<span id="cb11-15">    ax[i].scatter(X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>][:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>][:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>labels)</span>
<span id="cb11-16">    ax[i].scatter(centroids[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], centroids[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'*'</span>, s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">400</span>, c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'r'</span>)</span>
<span id="cb11-17">    ax[i].scatter(centroids[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], centroids[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], marker<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'+'</span>, s<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">300</span>, c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'green'</span>)</span>
<span id="cb11-18">plt.suptitle(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Simulated data'</span>, y<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.05</span>, fontsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">22</span>, fontweight<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'semibold'</span>)</span>
<span id="cb11-19">plt.tight_layout()</span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-12-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-11"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-12-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
<p>As expected, kmeans couldn’t figure out the correct clusters for both datasets. However, we can help kmeans perfectly cluster these kind of datasets if we use kernel methods. The idea is we transform to higher dimensional representation that make the data linearly separable (the same idea that we use in SVMs). Different kinds of algorithms work very well in such scenarios such as <code>SpectralClustering</code>, see below:</p>
<div id="cell-42" class="cell" data-execution_count="12">
<details class="code-fold">
<summary>Code</summary>
<div class="sourceCode cell-code" id="cb12" style="background: #f1f3f5;"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb12-1"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Cricles</span></span>
<span id="cb12-2">X1 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_circles(factor<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.5</span>, noise<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.05</span>, n_samples<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1500</span>)</span>
<span id="cb12-3"></span>
<span id="cb12-4"><span class="co" style="color: #5E5E5E;
background-color: null;
font-style: inherit;"># Moons</span></span>
<span id="cb12-5">X2 <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> make_moons(n_samples<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1500</span>, noise<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">0.05</span>)</span>
<span id="cb12-6"></span>
<span id="cb12-7">fig, ax <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> plt.subplots(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>)</span>
<span id="cb12-8"><span class="cf" style="color: #003B4F;
background-color: null;
font-style: inherit;">for</span> i, X <span class="kw" style="color: #003B4F;
background-color: null;
font-style: inherit;">in</span> <span class="bu" style="color: null;
background-color: null;
font-style: inherit;">enumerate</span>([X1, X2]):</span>
<span id="cb12-9">    fig.set_size_inches(<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">18</span>, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">7</span>)</span>
<span id="cb12-10">    sp <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> SpectralClustering(n_clusters<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">2</span>, affinity<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'nearest_neighbors'</span>)</span>
<span id="cb12-11">    sp.fit(X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>])</span>
<span id="cb12-12">    labels <span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span> sp.labels_</span>
<span id="cb12-13">    ax[i].scatter(X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>][:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>], X[<span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">0</span>][:, <span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">1</span>], c<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span>labels)</span>
<span id="cb12-14">plt.suptitle(<span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'Simulated data'</span>, y<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="fl" style="color: #AD0000;
background-color: null;
font-style: inherit;">1.05</span>, fontsize<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="dv" style="color: #AD0000;
background-color: null;
font-style: inherit;">22</span>, fontweight<span class="op" style="color: #5E5E5E;
background-color: null;
font-style: inherit;">=</span><span class="st" style="color: #20794D;
background-color: null;
font-style: inherit;">'semibold'</span>)</span>
<span id="cb12-15">plt.tight_layout()</span></code></pre></div>
</details>
<div class="cell-output cell-output-display">
<div>
<figure class="figure">
<p><a href="Kmeans-Clustering_files/figure-html/cell-13-output-1.png" class="lightbox" data-gallery="quarto-lightbox-gallery-12"><img src="https://imaddabbura.github.io/posts/ml/Kmeans-Clustering_files/figure-html/cell-13-output-1.png" class="img-fluid figure-img"></a></p>
</figure>
</div>
</div>
</div>
</section>
<section id="conclusion" class="level2">
<h2 class="anchored" data-anchor-id="conclusion">Conclusion</h2>
<p>Kmeans clustering is one of the most popular clustering algorithms and usually the first thing practitioners apply when solving clustering tasks to get an idea of the structure of the dataset. The goal of kmeans is to group data points into distinct non-overlapping subgroups. It does a very good job when the clusters have a kind of spherical shapes. However, it suffers as the geometric shapes of clusters deviates from spherical shapes. Moreover, it also doesn’t learn the number of clusters from the data and requires it to be pre-defined. To be a good practitioner, it’s good to know the assumptions behind algorithms/methods so that you would have a pretty good idea about the strength and weakness of each method. This will help you decide when to use each method and under what circumstances. In this post, we covered both strength, weaknesses, and some evaluation methods related to kmeans.</p>
<p>Below are the main takeaways:</p>
<ul>
<li>Scale/standardize the data when applying kmeans algorithm.</li>
<li>Elbow method in selecting number of clusters doesn’t usually work because the error function is monotonically decreasing for all <img src="https://latex.codecogs.com/png.latex?k">s.</li>
<li>Kmeans gives more weight to the bigger clusters.</li>
<li>Kmeans assumes spherical shapes of clusters (with radius equal to the distance between the centroid and the furthest data point) and doesn’t work well when clusters are in different shapes such as elliptical clusters.</li>
<li>If there is overlapping between clusters, kmeans doesn’t have an intrinsic measure for uncertainty for the examples belong to the overlapping region in order to determine for which cluster to assign each data point.</li>
<li>Kmeans may still cluster the data even if it can’t be clustered such as data that comes from <em>uniform distributions</em>.</li>
</ul>


</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a> ]]></description>
  <category>Machine Learning</category>
  <guid>https://imaddabbura.github.io/posts/ml/Kmeans-Clustering.html</guid>
  <pubDate>Tue, 11 Sep 2018 05:00:00 GMT</pubDate>
  <media:content url="https://imaddabbura.github.io/posts/ml/images/kmeans-clustering.png" medium="image" type="image/png" height="80" width="144"/>
</item>
</channel>
</rss>
