{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget https://www.perseus.tufts.edu/hopper/dltext?doc=Perseus%3Atext%3A1999.02.0008 -O atticus.xml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Byte-Pair Encoding Tokenization\n", "\n", "In the past two weeks, we have looked at the transformer architecture in depth. Both versions, GPT and BERT, are useful for different tasks, but we saw that **tokenization** played a key role regardless of what model we were training. Now that you have a sense of how these models are trained, we'll take a closer look at tokenization and how important it can be for training models, especially for non-English languages.\n", "\n", "**Learning objectives**:\n", "* Grasp the concepts behind the Byte-Pair Encoding algorithm and how we can use it for tokenization\n", "* Understand why traditional tokenization schemes are insufficient for modern models\n", "* See how to use `HuggingFace Tokenizers` to train and save your own tokenizer \n", "* Apply your own language to this framework and improve your own tools using custom tokenizers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Byte-Pair Encoding Algorithm\n", "\n", "We haven't used the word *algorithm* very much in this class, so to remind you an algorithm is just a list of steps that a computer (or human) can follow to accomplish a task. In this case, we want to take some text and figure out the best way to split them up. Below is a list of steps we will follow and we will see each in detail in the code.\n", "\n", "1. Start with a base vocabulary of characters or bytes (we'll start with bytes and then move to characters/traditional words)\n", "2. Count the frequency of all adjacent pairs\n", "3. Find the most frequent pairs\n", "4. Add this pair to a vocabulary list as a new token\n", "5. Replace all occurrences of the pair with the new token\n", "6. Repeat these steps until you reach a desired size of the vocabulary.\n", "\n", "Importantly, this algorithm can be run until we merge the whole text into a single token, but this would be unhelpful, so we introduce a new hyperparameter: vocabulary size." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# extracting text from xml\n", "from bs4 import BeautifulSoup\n", "import re\n", "\n", "soup = BeautifulSoup(open(\"atticus.xml\", \"r\").read(), features=\"xml\")\n", "\n", "letters = []\n", "for d in soup.find_all(\"div2\"):\n", " dateline = d.dateline.extract().get_text().strip()\n", " salute = d.salute.extract().get_text().strip()\n", " text = re.sub(r\"\\s+\", \" \", d.get_text().strip().replace(\"\\n\", \"\"))\n", " letters.append(dateline + \"\\n\" + salute + \"\\n\" + text)\n", "\n", "text = \"\\n\\n\".join(letters)\n", "print(len(text))\n", "print(text[:1000])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_sample = text[:25]\n", "text_sample" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To show you the most basic form of the algorithm, we are going to start by using bytes to represent the parts of our vocabulary.\n", "\n", "Bytes are the internal representation of the character for the computer. There are several ways of representing these characters to the computer, but the most common is called *UTF-8*. UTF-8 covers most characters in several different alphabets, so it has become a standard across all of computing." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "[\n", " ord(x) for x in text_sample\n", "] # ord returns unicode code point for each unicode character" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "list(text_sample.encode(\"utf-8\")) # utf-8 encoded string, raw bytes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# step 1 in our algorithm: creating a base set of bytes for our vocabulary\n", "\n", "tokens = text.encode(\"utf-8\") # raw bytes\n", "tokens = list(\n", " map(int, tokens)\n", ") # convert to a list of integers in range 0..255 for convenience\n", "print(\"---\")\n", "print(text[:100], \"...\")\n", "print(\"full length:\", len(text))\n", "print(\"---\")\n", "print(tokens)\n", "print(\"full length:\", len(tokens))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# step 2: count the frequency of adjacent pairs of tokens\n", "def get_stats(ids):\n", " counts = {}\n", " for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements\n", " counts[pair] = counts.get(pair, 0) + 1\n", " return counts\n", "\n", "\n", "stats = get_stats(tokens)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "stats" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# step 3: find the most common pair\n", "top_pair = max(stats, key=stats.get)\n", "top_pair" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# step 3/4/5: find the most common pair. then add it to the vocabulary as a single token. last replace the pair with a single token\n", "def merge(ids, pair, idx):\n", " # in the list of ints (ids), replace all consecutive occurrences of pair with the new token idx\n", " newids = []\n", " i = 0\n", " while i < len(ids):\n", " # if we are not at the very last position AND the pair matches, replace it\n", " if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:\n", " newids.append(idx)\n", " i += 2\n", " else:\n", " newids.append(ids[i])\n", " i += 1\n", " return newids\n", "\n", "\n", "print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokens2 = merge(tokens, top_pair, 256)\n", "print(tokens2[:50])\n", "print(\"length:\", len(tokens2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_stats(ids):\n", " counts = {}\n", " for pair in zip(ids, ids[1:]):\n", " counts[pair] = counts.get(pair, 0) + 1\n", " return counts\n", "\n", "\n", "def merge(ids, pair, idx):\n", " newids = []\n", " i = 0\n", " while i < len(ids):\n", " if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:\n", " newids.append(idx)\n", " i += 2\n", " else:\n", " newids.append(ids[i])\n", " i += 1\n", " return newids\n", "\n", "\n", "# step 6: repeat all steps until we get a desired vocab length\n", "vocab_size = 276 # the desired final vocabulary size\n", "num_merges = vocab_size - 256\n", "ids = list(tokens) # copy so we don't destroy the original list\n", "\n", "merges = {} # (int, int) -> int\n", "for i in range(num_merges):\n", " stats = get_stats(ids)\n", " pair = max(stats, key=stats.get)\n", " idx = 256 + i\n", " print(f\"merging {pair} into a new token {idx}\")\n", " ids = merge(ids, pair, idx)\n", " merges[pair] = idx" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"tokens length:\", len(tokens))\n", "print(\"ids length:\", len(ids))\n", "print(\n", " f\"compression ratio: {len(tokens) / len(ids):.2f}X\"\n", ") # bpe was invented as compression tool" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vocab = {idx: bytes([idx]) for idx in range(256)}\n", "for (p0, p1), idx in merges.items():\n", " vocab[idx] = vocab[p0] + vocab[p1]\n", "\n", "\n", "def decode(ids):\n", " # given ids (list of integers), return Python string\n", " tokens = b\"\".join(vocab[idx] for idx in ids)\n", " text = tokens.decode(\"utf-8\", errors=\"replace\") # need replace, show no replace\n", " return text\n", "\n", "\n", "print(decode([128]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "merges" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def encode(text):\n", " # given a string, return list of integers (the tokens)\n", " tokens = list(text.encode(\"utf-8\"))\n", " while len(tokens) >= 2:\n", " stats = get_stats(tokens)\n", " pair = min(stats, key=lambda p: merges.get(p, float(\"inf\")))\n", " if pair not in merges:\n", " break # nothing else can be merged\n", " idx = merges[pair]\n", " tokens = merge(tokens, pair, idx)\n", " return tokens\n", "\n", "\n", "print(encode(\"qualis\"))\n", "print(encode(\"artifex\"))\n", "print(encode(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pre-tokenization\n", "We saw how the BPE algorithm works iwht raw bytes, but it will also work with characters, characters-strings and full words. It is very common that tokenizers will first conduct a simpler form of tokenization, confusingly called pre-tokenization, before we begin the BPE algorithm." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import regex as re\n", "\n", "gpt2pat = re.compile(\n", " r\"\"\"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+\"\"\"\n", ")\n", "\n", "print(re.findall(gpt2pat, \"Hello've world123 how's are you!!!?\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def split_latin_tokens(text, enclitics=None):\n", " if not enclitics:\n", " enclitics = [\"que\", \"ve\", \"ne\", \"met\", \"ce\", \"ci\"]\n", " enclitics_or = \"|\".join(enclitics)\n", "\n", " text = re.sub(rf\"(\\w)({enclitics_or})\\b\", r\"\\1 \\2\", text) # spaces before enclitics\n", "\n", " # pattern = r'\\s*\\p{L}+|\\s*\\p{N}+|\\s*[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+'\n", " pattern = r\"\\s*\\w+|\\s*\\d+|\\s*[^\\s\\w\\d]+|\\s+(?!\\S)|\\s+\" # modified pattern because I was getting weird errors\n", " tokens = re.findall(pattern, text)\n", " tokens = [t.strip() if t.strip() in enclitics else t for t in tokens]\n", " return tokens\n", "\n", "\n", "# Example usage\n", "text = \"arma virumque cano\"\n", "result = split_latin_tokens(text)\n", "print(result)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from bs4 import BeautifulSoup\n", "import re\n", "\n", "soup = BeautifulSoup(open(\"atticus.xml\", \"r\").read(), features=\"xml\")\n", "\n", "letters = []\n", "for d in soup.find_all(\"div2\"):\n", " dateline = d.dateline.extract().get_text().strip()\n", " salute = d.salute.extract().get_text().strip()\n", " text = re.sub(r\"\\s+\", \" \", d.get_text().strip().replace(\"\\n\", \"\"))\n", " letters.append(dateline + \"\\n\" + salute + \"\\n\" + text)\n", "\n", "text = \"\\n\\n\".join(letters)\n", "print(len(text))\n", "print(text[:1000])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preproc = split_latin_tokens(text)\n", "len(preproc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_vocab(tokens):\n", " # Start with a simple mapping of unique tokens to IDs\n", " vocab = {}\n", " for token in tokens:\n", " if token not in vocab:\n", " vocab[token] = len(vocab)\n", " return vocab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vocab = create_vocab(preproc)\n", "len(vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ids = [vocab[token] for token in preproc]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "stats = get_stats(ids)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sorted(stats.items(), key=lambda x: x[1], reverse=True)[:10]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# takes about 10 minutes\n", "vocab_size = len(vocab) + 1000 # the desired final vocabulary size\n", "num_merges = vocab_size - len(vocab)\n", "print(f\"num_merges: {num_merges}\")\n", "ids = list(tokens) # copy so we don't destroy the original list\n", "\n", "merges = {} # (int, int) -> int\n", "for i in range(num_merges):\n", " stats = get_stats(ids)\n", " pair = max(stats, key=stats.get)\n", " idx = len(vocab) + i\n", " print(f\"merging {pair} into a new token {idx}\")\n", " ids = merge(ids, pair, idx)\n", " merges[pair] = idx" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# to encode\n", "new_latin = \" arma virumque\"\n", "tokens = split_latin_tokens(new_latin)\n", "print(\"Preprocessed tokens: \", tokens)\n", "ids = [vocab[token] for token in tokens]\n", "print(\"Token IDs: \", ids)\n", "while len(ids) >= 2:\n", " stats = get_stats(ids)\n", " pair = min(stats, key=lambda p: merges.get(p, float(\"inf\")))\n", " if pair not in merges:\n", " break # nothing else can be merged\n", " idx = merges[pair]\n", " ids = merge(ids, pair, idx)\n", "print(\"Encoded IDs: \", ids)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# to decode\n", "itos = {idx: token for token, idx in vocab.items()}\n", "tokens = [itos[idx] for idx in ids]\n", "print(\"Decoded tokens: \", tokens)\n", "text = \"\".join(tokens)\n", "print(\"Decoded text: \", text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using HuggingFace `tokenizers`\n", "The code and concepts above aren't too challenging to wrap your head around, but when you are juggling many different alphabets and a lot of training data, this can get very complicated.\n", "\n", "Thankfully, `HuggingFace` have provided the world with their tokenizer library `Tokenizers`. It's very simple to use so let's take a look at it.\n", "\n", "In this example, I've scraped all of the Ancient Greek texts from Perseus and trained a tokenizer on a subset of this material. We can then use this custom tokenizer to train a new GPT on Ancient Greek." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Getting data from Perseus" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!git clone https://github.com/PerseusDL/canonical-greekLit.git" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "from bs4 import BeautifulSoup\n", "\n", "greek_texts = {}\n", "\n", "for path, directories, files in os.walk(\"/content/canonical-greekLit\"):\n", " for file in files:\n", " if file.endswith(\".xml\") and (\"grc\" in file):\n", " with open(os.path.join(path, file), \"r\") as f:\n", " soup = BeautifulSoup(f.read(), features=\"xml\")\n", " greek_texts[file] = soup.body.get_text().strip()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.makedirs(\"greek_texts\", exist_ok=True)\n", "for k, v in greek_texts.items():\n", " with open(f\"greek_texts/{k}\", \"w\") as f:\n", " f.write(v)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tok_split = 0.01\n", "tok_set = list(greek_texts.items())[: int(len(greek_texts) * tok_split)]\n", "len(tok_set)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_and_valid = list(greek_texts.items())[int(len(greek_texts) * tok_split) :]\n", "len(training_and_valid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "greek_texts_for_tok = \"\"\n", "for k, v in tok_set:\n", " greek_texts_for_tok += v + \"\\n[EOS]\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using `Tokenizers`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install tokenizers -q" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tokenizers import Tokenizer\n", "from tokenizers.models import BPE\n", "\n", "tokenizer = Tokenizer(BPE())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tokenizers.trainers import BpeTrainer\n", "\n", "trainer = BpeTrainer(special_tokens=[\"[EOS]\", \"[UNK]\"], unk_token=\"[UNK]\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tokenizers.pre_tokenizers import Whitespace\n", "\n", "tokenizer.pre_tokenizer = Whitespace() # optimize for greek enclitics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "files = [f\"greek_texts/{k}\" for k, v in tok_set]\n", "tokenizer.train(files, trainer)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer.save(\"ancient_greek_tokenizer.json\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "example = training_and_valid[0][1]\n", "tok_example = tokenizer.encode(example)\n", "\n", "for i, t in zip(tok_example.ids[:20], tok_example.tokens[:20]):\n", " print(f\"{i:6}: {t}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loaded_tokenizer = Tokenizer.from_file(\"ancient_greek_tokenizer.json\")\n", "tok_example = loaded_tokenizer.encode(example)\n", "\n", "for i, t in zip(tok_example.ids[:20], tok_example.tokens[:20]):\n", " print(f\"{i:6}: {t}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using our new tokenizer\n", "We can then use this new custom tokenizer to train a new GPT using the same code as before." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Our GPT code from two weeks ago" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# same code from before\n", "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "\n", "\n", "class Head(nn.Module):\n", " \"\"\"one head of self-attention\"\"\"\n", "\n", " def __init__(self, head_size, n_embd, dropout):\n", " super().__init__()\n", " self.key = nn.Linear(n_embd, head_size, bias=False)\n", " self.query = nn.Linear(n_embd, head_size, bias=False)\n", " self.value = nn.Linear(n_embd, head_size, bias=False)\n", " self.register_buffer(\"tril\", torch.tril(torch.ones(block_size, block_size)))\n", "\n", " self.dropout = nn.Dropout(dropout) # standard dropout\n", "\n", " def forward(self, x):\n", " B, T, C = x.shape\n", " k = self.key(x) # (B,T,C)\n", " q = self.query(x) # (B,T,C)\n", " # compute attention scores (\"affinities\")\n", " wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n", " wei = wei.masked_fill(self.tril[:T, :T] == 0, float(\"-inf\")) # (B, T, T)\n", " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", " wei = self.dropout(wei)\n", " # perform the weighted aggregation of the values\n", " v = self.value(x) # (B,T,C)\n", " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n", " return out\n", "\n", "\n", "class MultiHeadAttention(nn.Module):\n", " \"\"\"multiple heads of self-attention in parallel\"\"\"\n", "\n", " def __init__(self, num_heads, head_size, n_embd, dropout):\n", " super().__init__()\n", " self.heads = nn.ModuleList(\n", " [Head(head_size, n_embd, dropout) for _ in range(num_heads)]\n", " )\n", " self.proj = nn.Linear(n_embd, n_embd)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", " out = self.dropout(self.proj(out))\n", " return out\n", "\n", "\n", "class FeedFoward(nn.Module):\n", " \"\"\"a simple linear layer followed by a non-linearity\"\"\"\n", "\n", " def __init__(self, n_embd, dropout=0.0):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "\n", "class Block(nn.Module):\n", " \"\"\"Transformer block: communication followed by computation\"\"\"\n", "\n", " def __init__(self, n_embd, n_head, dropout=0.0):\n", " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", " super().__init__()\n", " head_size = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout)\n", " self.ffwd = FeedFoward(n_embd)\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", "\n", " def forward(self, x):\n", " x = x + self.sa(self.ln1(x))\n", " x = x + self.ffwd(self.ln2(x))\n", " return x\n", "\n", "\n", "class Transformer(nn.Module):\n", "\n", " def __init__(self, n_embd, n_head, n_layer, device):\n", " super().__init__()\n", " # each token directly reads off the logits for the next token from a lookup table\n", " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", " self.blocks = nn.Sequential(\n", " *[Block(n_embd, n_head=n_head) for _ in range(n_layer)]\n", " )\n", " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", " self.lm_head = nn.Linear(n_embd, vocab_size)\n", " self.device = device\n", "\n", " def forward(self, idx, targets=None):\n", " B, T = idx.shape\n", "\n", " # idx and targets are both (B,T) tensor of integers\n", " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", " pos_emb = self.position_embedding_table(\n", " torch.arange(T, device=self.device)\n", " ) # (T,C)\n", " x = tok_emb + pos_emb # (B,T,C)\n", " x = self.blocks(x) # (B,T,C)\n", " x = self.ln_f(x) # (B,T,C)\n", " logits = self.lm_head(x) # (B,T,vocab_size)\n", "\n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B * T, C)\n", " targets = targets.view(B * T)\n", " loss = F.cross_entropy(logits, targets)\n", "\n", " return logits, loss\n", "\n", " def generate(self, idx, max_new_tokens):\n", " # idx is (B, T) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " # crop idx to the last block_size tokens\n", " idx_cond = idx[:, -block_size:]\n", " # get the predictions\n", " logits, loss = self(idx_cond)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " probs = F.softmax(logits, dim=-1) # (B, C)\n", " # sample from the distribution\n", " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", " return idx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Same training loop too" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tqdm import tqdm\n", "\n", "# hyperparameters\n", "batch_size = 64 # how many independent sequences will we process in parallel\n", "block_size = 64 # what is the maximum context length for predictions\n", "max_iters = 5000 # amount of epochs\n", "eval_interval = 100 # every this many epochs we look at the validation set\n", "learning_rate = 5e-5 # learning rate for the optimizer\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # what device to use\n", "eval_iters = 200 # how many iterations in the evaluation\n", "n_embd = 128 # embedding size\n", "n_head = 16 # attention heads\n", "n_layer = 16 # how many blocks\n", "dropout = 0.2 # amount of dropout\n", "# ------------\n", "\n", "# train and test splits\n", "tokenized_ids = []\n", "for k, v in tqdm(training_and_valid): # takes a few minutes\n", " tokenized_ids.extend(tokenizer.encode(v).ids)\n", "\n", "# train and test splits\n", "data = torch.tensor(tokenized_ids, dtype=torch.long)\n", "n = int(0.9 * len(data)) # first 90% will be train, rest val\n", "train_data = data[:n]\n", "val_data = data[n:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vocab_size = tokenizer.get_vocab_size()\n", "model = Transformer(n_embd=n_embd, n_head=n_head, n_layer=n_layer, device=device)\n", "m = model.to(device)\n", "print(sum(p.numel() for p in m.parameters()) / 1e6, \"M parameters\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# function for estimating the loss during evaluation\n", "def get_batch(split):\n", " # generate a small batch of data of inputs x and targets y\n", " data = train_data if split == \"train\" else val_data # choosing the right data split\n", " ix = torch.randint(\n", " len(data) - block_size, (batch_size,)\n", " ) # get a random batch of ids\n", " x = torch.stack(\n", " [data[i : i + block_size] for i in ix]\n", " ) # create contexts for each id\n", " y = torch.stack(\n", " [data[i + 1 : i + block_size + 1] for i in ix]\n", " ) # create the targets for each context\n", " return x, y\n", "\n", "\n", "@torch.no_grad()\n", "def estimate_loss():\n", " out = {}\n", " model.eval()\n", " for split in [\"train\", \"val\"]:\n", " losses = torch.zeros(eval_iters)\n", " for k in range(eval_iters):\n", " X, Y = get_batch(split)\n", " X, Y = X.to(device), Y.to(device)\n", " logits, loss = model(X, Y)\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_losses = []\n", "valid_losses = []\n", "\n", "data = data.to(\"cuda\")\n", "model = model.to(\"cuda\")\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", "for epoch in range(max_iters):\n", " if epoch % eval_interval == 0 or epoch == max_iters - 1:\n", " losses = estimate_loss()\n", " train_losses.append(losses[\"train\"])\n", " valid_losses.append(losses[\"val\"])\n", " print(\n", " f\"step {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\"\n", " )\n", "\n", " xb, yb = get_batch(\"train\")\n", " xb, yb = xb.to(device), yb.to(device)\n", "\n", " logits, loss = model(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(train_losses, label=\"train\")\n", "plt.plot(valid_losses, label=\"valid\")\n", "plt.legend()\n", "plt.title(\"Loss\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n", "print(tokenizer.decode(m.generate(context, max_new_tokens=2000)[0].tolist()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "iliad_context = torch.tensor(\n", " [tokenizer.encode(\"μῆνιν\").ids], dtype=torch.long, device=device\n", ") # opening of the iliad\n", "print(tokenizer.decode(m.generate(iliad_context, max_new_tokens=2000)[0].tolist()))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "TVpSzDU7k_So" ], "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "mystnb": { "execution_mode": "off" } }, "nbformat": 4, "nbformat_minor": 0 }