!wget https://www.perseus.tufts.edu/hopper/dltext?doc=Perseus%3Atext%3A1999.02.0008 -O atticus.xml

Byte-Pair Encoding Tokenization#

In the past two weeks, we have looked at the transformer architecture in depth. Both versions, GPT and BERT, are useful for different tasks, but we saw that tokenization played a key role regardless of what model we were training. Now that you have a sense of how these models are trained, we’ll take a closer look at tokenization and how important it can be for training models, especially for non-English languages.

Learning objectives:

  • Grasp the concepts behind the Byte-Pair Encoding algorithm and how we can use it for tokenization

  • Understand why traditional tokenization schemes are insufficient for modern models

  • See how to use HuggingFace Tokenizers to train and save your own tokenizer

  • Apply your own language to this framework and improve your own tools using custom tokenizers

The Byte-Pair Encoding Algorithm#

We haven’t used the word algorithm very much in this class, so to remind you an algorithm is just a list of steps that a computer (or human) can follow to accomplish a task. In this case, we want to take some text and figure out the best way to split them up. Below is a list of steps we will follow and we will see each in detail in the code.

  1. Start with a base vocabulary of characters or bytes (we’ll start with bytes and then move to characters/traditional words)

  2. Count the frequency of all adjacent pairs

  3. Find the most frequent pairs

  4. Add this pair to a vocabulary list as a new token

  5. Replace all occurrences of the pair with the new token

  6. Repeat these steps until you reach a desired size of the vocabulary.

Importantly, this algorithm can be run until we merge the whole text into a single token, but this would be unhelpful, so we introduce a new hyperparameter: vocabulary size.

# extracting text from xml
from bs4 import BeautifulSoup
import re

soup = BeautifulSoup(open("atticus.xml", "r").read(), features="xml")

letters = []
for d in soup.find_all("div2"):
    dateline = d.dateline.extract().get_text().strip()
    salute = d.salute.extract().get_text().strip()
    text = re.sub(r"\s+", " ", d.get_text().strip().replace("\n", ""))
    letters.append(dateline + "\n" + salute + "\n" + text)

text = "\n\n".join(letters)
print(len(text))
print(text[:1000])
text_sample = text[:25]
text_sample

To show you the most basic form of the algorithm, we are going to start by using bytes to represent the parts of our vocabulary.

Bytes are the internal representation of the character for the computer. There are several ways of representing these characters to the computer, but the most common is called UTF-8. UTF-8 covers most characters in several different alphabets, so it has become a standard across all of computing.

[
    ord(x) for x in text_sample
]  # ord returns unicode code point for each unicode character
list(text_sample.encode("utf-8"))  # utf-8 encoded string, raw bytes
# step 1 in our algorithm: creating a base set of bytes for our vocabulary

tokens = text.encode("utf-8")  # raw bytes
tokens = list(
    map(int, tokens)
)  # convert to a list of integers in range 0..255 for convenience
print("---")
print(text[:100], "...")
print("full length:", len(text))
print("---")
print(tokens)
print("full length:", len(tokens))
# step 2: count the frequency of adjacent pairs of tokens
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):  # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts


stats = get_stats(tokens)
stats
# step 3: find the most common pair
top_pair = max(stats, key=stats.get)
top_pair
# step 3/4/5: find the most common pair. then add it to the vocabulary as a single token. last replace the pair with a single token
def merge(ids, pair, idx):
    # in the list of ints (ids), replace all consecutive occurrences of pair with the new token idx
    newids = []
    i = 0
    while i < len(ids):
        # if we are not at the very last position AND the pair matches, replace it
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids


print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))
tokens2 = merge(tokens, top_pair, 256)
print(tokens2[:50])
print("length:", len(tokens2))
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts


def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids


# step 6: repeat all steps until we get a desired vocab length
vocab_size = 276  # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens)  # copy so we don't destroy the original list

merges = {}  # (int, int) -> int
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx
print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(
    f"compression ratio: {len(tokens) / len(ids):.2f}X"
)  # bpe was invented as compression tool
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]


def decode(ids):
    # given ids (list of integers), return Python string
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors="replace")  # need replace, show no replace
    return text


print(decode([128]))
merges
def encode(text):
    # given a string, return list of integers (the tokens)
    tokens = list(text.encode("utf-8"))
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break  # nothing else can be merged
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens


print(encode("qualis"))
print(encode("artifex"))
print(encode(""))

Pre-tokenization#

We saw how the BPE algorithm works iwht raw bytes, but it will also work with characters, characters-strings and full words. It is very common that tokenizers will first conduct a simpler form of tokenization, confusingly called pre-tokenization, before we begin the BPE algorithm.

import regex as re

gpt2pat = re.compile(
    r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)

print(re.findall(gpt2pat, "Hello've world123 how's are you!!!?"))
def split_latin_tokens(text, enclitics=None):
    if not enclitics:
        enclitics = ["que", "ve", "ne", "met", "ce", "ci"]
    enclitics_or = "|".join(enclitics)

    text = re.sub(rf"(\w)({enclitics_or})\b", r"\1 \2", text)  # spaces before enclitics

    # pattern = r'\s*\p{L}+|\s*\p{N}+|\s*[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+'
    pattern = r"\s*\w+|\s*\d+|\s*[^\s\w\d]+|\s+(?!\S)|\s+"  # modified pattern because I was getting weird errors
    tokens = re.findall(pattern, text)
    tokens = [t.strip() if t.strip() in enclitics else t for t in tokens]
    return tokens


# Example usage
text = "arma virumque cano"
result = split_latin_tokens(text)
print(result)
from bs4 import BeautifulSoup
import re

soup = BeautifulSoup(open("atticus.xml", "r").read(), features="xml")

letters = []
for d in soup.find_all("div2"):
    dateline = d.dateline.extract().get_text().strip()
    salute = d.salute.extract().get_text().strip()
    text = re.sub(r"\s+", " ", d.get_text().strip().replace("\n", ""))
    letters.append(dateline + "\n" + salute + "\n" + text)

text = "\n\n".join(letters)
print(len(text))
print(text[:1000])
preproc = split_latin_tokens(text)
len(preproc)
def create_vocab(tokens):
    # Start with a simple mapping of unique tokens to IDs
    vocab = {}
    for token in tokens:
        if token not in vocab:
            vocab[token] = len(vocab)
    return vocab
vocab = create_vocab(preproc)
len(vocab)
ids = [vocab[token] for token in preproc]
stats = get_stats(ids)
sorted(stats.items(), key=lambda x: x[1], reverse=True)[:10]
# takes about 10 minutes
vocab_size = len(vocab) + 1000  # the desired final vocabulary size
num_merges = vocab_size - len(vocab)
print(f"num_merges: {num_merges}")
ids = list(tokens)  # copy so we don't destroy the original list

merges = {}  # (int, int) -> int
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = len(vocab) + i
    print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx
# to encode
new_latin = " arma virumque"
tokens = split_latin_tokens(new_latin)
print("Preprocessed tokens: ", tokens)
ids = [vocab[token] for token in tokens]
print("Token IDs: ", ids)
while len(ids) >= 2:
    stats = get_stats(ids)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
        break  # nothing else can be merged
    idx = merges[pair]
    ids = merge(ids, pair, idx)
print("Encoded IDs: ", ids)
# to decode
itos = {idx: token for token, idx in vocab.items()}
tokens = [itos[idx] for idx in ids]
print("Decoded tokens: ", tokens)
text = "".join(tokens)
print("Decoded text: ", text)

Using HuggingFace tokenizers#

The code and concepts above aren’t too challenging to wrap your head around, but when you are juggling many different alphabets and a lot of training data, this can get very complicated.

Thankfully, HuggingFace have provided the world with their tokenizer library Tokenizers. It’s very simple to use so let’s take a look at it.

In this example, I’ve scraped all of the Ancient Greek texts from Perseus and trained a tokenizer on a subset of this material. We can then use this custom tokenizer to train a new GPT on Ancient Greek.

Getting data from Perseus#

!git clone https://github.com/PerseusDL/canonical-greekLit.git
import os
from bs4 import BeautifulSoup

greek_texts = {}

for path, directories, files in os.walk("/content/canonical-greekLit"):
    for file in files:
        if file.endswith(".xml") and ("grc" in file):
            with open(os.path.join(path, file), "r") as f:
                soup = BeautifulSoup(f.read(), features="xml")
                greek_texts[file] = soup.body.get_text().strip()
os.makedirs("greek_texts", exist_ok=True)
for k, v in greek_texts.items():
    with open(f"greek_texts/{k}", "w") as f:
        f.write(v)
tok_split = 0.01
tok_set = list(greek_texts.items())[: int(len(greek_texts) * tok_split)]
len(tok_set)
training_and_valid = list(greek_texts.items())[int(len(greek_texts) * tok_split) :]
len(training_and_valid)
greek_texts_for_tok = ""
for k, v in tok_set:
    greek_texts_for_tok += v + "\n[EOS]"

Using Tokenizers#

!pip install tokenizers -q
from tokenizers import Tokenizer
from tokenizers.models import BPE

tokenizer = Tokenizer(BPE())
from tokenizers.trainers import BpeTrainer

trainer = BpeTrainer(special_tokens=["[EOS]", "[UNK]"], unk_token="[UNK]")
from tokenizers.pre_tokenizers import Whitespace

tokenizer.pre_tokenizer = Whitespace()  # optimize for greek enclitics
files = [f"greek_texts/{k}" for k, v in tok_set]
tokenizer.train(files, trainer)
tokenizer.save("ancient_greek_tokenizer.json")
example = training_and_valid[0][1]
tok_example = tokenizer.encode(example)

for i, t in zip(tok_example.ids[:20], tok_example.tokens[:20]):
    print(f"{i:6}: {t}")
loaded_tokenizer = Tokenizer.from_file("ancient_greek_tokenizer.json")
tok_example = loaded_tokenizer.encode(example)

for i, t in zip(tok_example.ids[:20], tok_example.tokens[:20]):
    print(f"{i:6}: {t}")

Using our new tokenizer#

We can then use this new custom tokenizer to train a new GPT using the same code as before.

Our GPT code from two weeks ago#

# same code from before
import torch
import torch.nn as nn
from torch.nn import functional as F


class Head(nn.Module):
    """one head of self-attention"""

    def __init__(self, head_size, n_embd, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)  # standard dropout

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  # (B,T,C)
        q = self.query(x)  # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C**-0.5  # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))  # (B, T, T)
        wei = F.softmax(wei, dim=-1)  # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B,T,C)
        out = wei @ v  # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out


class MultiHeadAttention(nn.Module):
    """multiple heads of self-attention in parallel"""

    def __init__(self, num_heads, head_size, n_embd, dropout):
        super().__init__()
        self.heads = nn.ModuleList(
            [Head(head_size, n_embd, dropout) for _ in range(num_heads)]
        )
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FeedFoward(nn.Module):
    """a simple linear layer followed by a non-linearity"""

    def __init__(self, n_embd, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    """Transformer block: communication followed by computation"""

    def __init__(self, n_embd, n_head, dropout=0.0):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class Transformer(nn.Module):

    def __init__(self, n_embd, n_head, n_layer, device):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[Block(n_embd, n_head=n_head) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embd)  # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.device = device

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)  # (B,T,C)
        pos_emb = self.position_embedding_table(
            torch.arange(T, device=self.device)
        )  # (T,C)
        x = tok_emb + pos_emb  # (B,T,C)
        x = self.blocks(x)  # (B,T,C)
        x = self.ln_f(x)  # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

Same training loop too#

from tqdm import tqdm

# hyperparameters
batch_size = 64  # how many independent sequences will we process in parallel
block_size = 64  # what is the maximum context length for predictions
max_iters = 5000  # amount of epochs
eval_interval = 100  # every this many epochs we look at the validation set
learning_rate = 5e-5  # learning rate for the optimizer
device = "cuda" if torch.cuda.is_available() else "cpu"  # what device to use
eval_iters = 200  # how many iterations in the evaluation
n_embd = 128  # embedding size
n_head = 16  # attention heads
n_layer = 16  # how many blocks
dropout = 0.2  # amount of dropout
# ------------

# train and test splits
tokenized_ids = []
for k, v in tqdm(training_and_valid):  # takes a few minutes
    tokenized_ids.extend(tokenizer.encode(v).ids)

# train and test splits
data = torch.tensor(tokenized_ids, dtype=torch.long)
n = int(0.9 * len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
vocab_size = tokenizer.get_vocab_size()
model = Transformer(n_embd=n_embd, n_head=n_head, n_layer=n_layer, device=device)
m = model.to(device)
print(sum(p.numel() for p in m.parameters()) / 1e6, "M parameters")
# function for estimating the loss during evaluation
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data  # choosing the right data split
    ix = torch.randint(
        len(data) - block_size, (batch_size,)
    )  # get a random batch of ids
    x = torch.stack(
        [data[i : i + block_size] for i in ix]
    )  # create contexts for each id
    y = torch.stack(
        [data[i + 1 : i + block_size + 1] for i in ix]
    )  # create the targets for each context
    return x, y


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X, Y = X.to(device), Y.to(device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out
train_losses = []
valid_losses = []

data = data.to("cuda")
model = model.to("cuda")

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for epoch in range(max_iters):
    if epoch % eval_interval == 0 or epoch == max_iters - 1:
        losses = estimate_loss()
        train_losses.append(losses["train"])
        valid_losses.append(losses["val"])
        print(
            f"step {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
        )

    xb, yb = get_batch("train")
    xb, yb = xb.to(device), yb.to(device)

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
import matplotlib.pyplot as plt

plt.plot(train_losses, label="train")
plt.plot(valid_losses, label="valid")
plt.legend()
plt.title("Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(tokenizer.decode(m.generate(context, max_new_tokens=2000)[0].tolist()))
iliad_context = torch.tensor(
    [tokenizer.encode("μῆνιν").ids], dtype=torch.long, device=device
)  # opening of the iliad
print(tokenizer.decode(m.generate(iliad_context, max_new_tokens=2000)[0].tolist()))