Byte-level Byte-Pair Encoding (BPE)

NLP
Deep Learning
Author

Imad Dabbura

Published

April 10, 2024

Introduction

Byte-level Byte-Pair Encoding (BPE) uses subword tokenization strategy that includes 256 byte to represent text plus count frequency to merge bytes until we reach a desirable vocabulary size. In each iteration:

  • We calculate the count of bigrams in the dataset
  • Pick the bigram with the highest frequency and add it to the vocabulary
  • Merge the tokens that matches the bigram from above
  • Continue until we get to the predefined vocabulary size

Implementation

Code
from typing import Iterable
import requests

Detailed Walk-through

text = "A Programmer’s Introduction to Unicode March 3, 2017"
tokens = text.encode("utf-8")  # raw bytes
tokens = list(
    tokens
)  # convert to a list of integers in range 0..255 for convenience
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts
[(k,v) for k,v in get_stats(tokens).items()][:10]
[((65, 32), 1),
 ((32, 80), 1),
 ((80, 114), 1),
 ((114, 111), 2),
 ((111, 103), 1),
 ((103, 114), 1),
 ((114, 97), 1),
 ((97, 109), 1),
 ((109, 109), 1),
 ((109, 101), 1)]
def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and tuple(ids[i : i + 2]) == pair:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids
vocab_sz = 276  # i.e. we want to have only 20 merges
n_merges = vocab_sz - 256
ids = list(tokens)
merges = {}
for i in range(n_merges):
    stats = get_stats(ids)
    top_pair = max(stats, key=stats.get)
    idx = 256 + i
    ids = merge(ids, top_pair, idx)
    merges[top_pair] = idx
    break
def encode(text, merges):
    tokens = list(text.encode("utf-8"))
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break
        tokens = merge(tokens, pair, merges[pair])
    return tokens
def decode(ids, vocab):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]
text == decode(encode(text, merges), vocab)
True

Clean Implementation

class BPETokenizer:
    """Byte-pair encoder."""

    def __init__(self, vocab_sz: int):
        """
        Args:
            vocab_sz (int): Vocabulary size.
        """
        self.vocab_sz = vocab_sz
        self.vocab = {}
        self.merges = {}

    def train(self, text: Iterable[str]):
        """Train Byte-pair encoder."""
        ids = list(text.encode("utf-8"))
        for i in range(256, self.vocab_sz):
            stats = self._get_stats(ids)
            pair = max(stats, key=stats.get)
            idx = i
            self.merges[pair] = idx
            ids = self._merge(ids, pair, idx)
        self.vocab = self._build_vocab(ids)

    def encode(self, text):
        """Encode string to bytes using vocabulary built during training."""
        ids = list(text.encode("utf-8"))

        # If text is empty or has one character -> it is already encoded from previous step
        while len(ids) >= 2:
            # stats is used only for getting pairs next to each other
            stats = self._get_stats(ids)
            # Because we built vocab (and merges) bottom-up, we need to encode
            # idx from smallest index because some later pairs depend on pairs
            # occured before
            # If a pair doesn't exist, it wouldn't participate in the list
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break  # No more pairs to merge
            idx = self.merges[pair]
            ids = self._merge(ids, pair, idx)
        return ids

    def decode(self, tokens: Iterable[int]):
        """Decode tokens into string using the vocabulary built during training."""
        tokens = b"".join(self.vocab[idx] for idx in tokens)
        # It is important to replace tokens that were not seen during training
        # with `?`; otherwise, it would fail
        return tokens.decode("utf-8", errors="replace")

    def _get_stats(self, ids: Iterable[int]):
        """Get pair counts."""
        counts = {}
        for pair in zip(ids, ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts

    def _merge(self, ids: Iterable[int], pair: Iterable[int], idx: int):
        """Merge pairs that match `pair` with new index `idx`."""
        newids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and tuple(pair) == tuple(ids[i : i + 2]):
                newids.append(idx)
                i += 2
            else:
                newids.append(ids[i])
                i += 1
        return newids

    def _build_vocab(self, ids: Iterable[int]):
        """Build vocabulary from 0-255 bytes and merges."""
        vocab = {idx: bytes([idx]) for idx in range(256)}
        # Here we assume the items returned would be in the same order they were inserted. This is Okay starting in Python 3.10
        for (p0, p1), idx in self.merges.items():
            # This would be a concatenation of the bytes
            vocab[idx] = vocab[p0] + vocab[p1]
        return vocab
text = requests.get("https://docs.python.org/3/library/stdtypes.html#bytes.decode").text
tokenizer = BPETokenizer(300)
tokenizer.train(text)
tokenizer.decode(tokenizer.encode(text)) == text
True

Resources