{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget https://www.perseus.tufts.edu/hopper/dltext?doc=Perseus%3Atext%3A1999.02.0008 -O atticus.xml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Encoder-only Transformers: Bidirectional Encoder Representations from Transformers (BERT)\n", "\n", "As we saw last week with the decoder-only architecture, attention transformers are very good at learning text features and predicting the next token from a sequence of tokens. Today, we will explore the other half of the transformer: the encoder and encoder-only architectures. In this lesson, we will implement a specific encoder-only transformer called BERT, from the title of the paper that introduced it: *Bidirectional Encoder Representations from Transformers*. BERT was the cutting edge of NLP for many years before being unseated by decoder-only transformers, but BERT is still used for many different applications. As with word2vec, BERT gives us embeddings for individual words, feature extraction, allowing us to build further models for tasks like NER and token classification, as we did in week 10.\n", "\n", "\n", "## Parts of the Encoder-only transformer\n", "The encoder-only transformer is made up of several parts (see schematic below):\n", "* Embeddings: Just like word2vec, the RNN and the decoder-only architecture, the encoder-only architecture takes advantage of an embedding layer. As in the decoder-only transformer, there are two different types of embeddings: token embeddings and positional encodings.\n", "* Positional Encodings: These are added to the input embeddings to give the model information about the position of each token in the sequence. Like the token embeddings, this is just an embedding layer that learns what areas of the `block_size` are more important based on the tokens.\n", "* Masked Multi-Head Attention: Unlike what we saw with the decoder-only model, we train encoder-only models by masking a certain percentage of tokens per each sequence and having the model guess which tokens we masked. Attention will work the exactly same way however!\n", "* Feed forward: This layer allows the model to process the information from the attention layer through non-linear transformations, increasing the model's capacity to learn complex patterns\n", "* Last linear layer: This last linear layer allows the model to make its predictions for the next token in the sequence.\n", "* Softmax: As we have seen since word2vec, this function transforms the logits of a linear layer into a probability distribution from which we can sample from and get the index of the predicted next token.\n", "\n", "It is worth noting that a \"Block\" is made up of the masked mulit-head attention, the normalization layers and the feed forward layer. This Block can be repeated many times before a prediction is actually made. In fact, the only difference between smaller and larger models often comes down to how many repetitions of these blocks there are.\n", "\n", "![image](https://github.com/pnadelofficial/nlp-and-the-human-record2024/blob/07b020394c111d3f4e5b90661200e2214186302f/encoder.png?raw=true)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\n", "\n", "We are going to follow the decoder-only notebook as closely as possible to show you how similar these two architectures really are. (For truly, they are just two sidesof the same coin.) So, just like in that notebook, I will be using Cicer's Letters to Atticus from Perseus.\n", "\n", "Unlike the decoder-only notebook, I will be using a more standard tokenization scheme. Rather than each of our tokens being characters, we will use `nltk`'s `word_tokenize` function to tokenize our sentences. In a later lesson, we will see how to create our own very robust tokenizer, but for now, this will do." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# extracting text from XML\n", "from bs4 import BeautifulSoup\n", "import re\n", "\n", "soup = BeautifulSoup(open(\"atticus.xml\", \"r\").read(), features=\"xml\")\n", "\n", "letters = []\n", "for d in soup.find_all(\"div2\"):\n", " dateline = d.dateline.extract().get_text().strip()\n", " salute = d.salute.extract().get_text().strip()\n", " text = re.sub(r\"\\s+\", \" \", d.get_text().strip().replace(\"\\n\", \"\"))\n", " letters.append(dateline + \"\\n\" + salute + \"\\n\" + text)\n", "\n", "text = \"\\n\\n\".join(letters)\n", "print(len(text))\n", "print(text[:1000])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import nltk\n", "\n", "nltk.download(\"punkt_tab\")\n", "\n", "tokenized_text = nltk.word_tokenize(text) # tokenizing the text\n", "print(len(tokenized_text))\n", "print(tokenized_text[:100])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As mentioned above, BERT is a Masked Language Model (MLM) meaning that we mask a certain percentage of tokens and ask the model to fill in the gaps. To that end, we need to add a `MASK` token which will stand in for the masked tokens.\n", "\n", "Also unlike the decoder-only model, we will need a `PAD` token, so that all of our sequences are the same length. In the decoder, we relied on next token prediction to create batches of training data. In this model, we can rely on that, so some sequences will be shorter than other, specifically if a sequence is shorter than `block_size`. In these cases, we can use a this `PAD` token." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokens = list(set([w.lower() for w in tokenized_text])) + [\n", " \"MASK\",\n", " \"PAD\",\n", "] # added tokens\n", "print(len(tokens))\n", "print(tokens[:100])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# same as what we saw with the decoder\n", "stoi = {ch: i for i, ch in enumerate(tokens)}\n", "itos = {i: ch for i, ch in enumerate(tokens)}\n", "encode = lambda s: [\n", " stoi[c.lower()] if c != \"MASK\" else stoi[\"MASK\"] for c in nltk.word_tokenize(s)\n", "]\n", "decode = lambda l: \" \".join([itos[i] for i in l])\n", "\n", "print(encode(\"salve mundus\"))\n", "print(decode(encode(\"salve mundus\")))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "data = torch.tensor(encode(text), dtype=torch.long) # tokenizing our data\n", "print(data.shape, data.dtype)\n", "print(data[:1000])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# as before, reservering 10% for validation\n", "n = int(0.9 * len(data))\n", "train_data = data[:n]\n", "val_data = data[n:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now they our data is tokenized we can work on developing a single method `get_batch` which will *collate* the data tensor above into multiple training examples.\n", "\n", "In the last notebook, this was somewhat straightforward as we knew we were trying to predict the next token based on a given sequence of tokens. Recall, though, we want today's language model to predict randomly masked tokens. This task will train the token embeddings to match the semantic relationships between words.\n", "\n", "Here is how we'll set up our training examples:\n", "* Select a random sequence of training data (just the token numbers)\n", "* From this sequence, select a subset of tokens as \"masked\" tokens, that are covered up and unknown to the model\n", "* Return the newly masked sequence (x), the target sequence (y) and token mask itself, along with any other data structures we need.\n", "\n", "\n", "To this end, our new `get_batch` method will need to:\n", "1. Select a random sequence of training data (this is the same code as we saw in the last lesson).\n", "2. We will also randomly cut out and 'pad' certain tokens to give the model a different context lengths.\n", "3. We will then create an 'attention mask', which starts off as just 1s but all of the padded tokens will be set to 0s. This attention mask is the encoder-only equivalent of the `tril` mask in the decoder-only model. The encoder will learn what the correct values will be through the forward and backward passes, and these 1s will become weights that the model is applying different tokens in the sequence.\n", "4. We can then randomly mask some of the tokens, as mentioned above, keeping track of the masked tokens in a specific data structure." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# necessary hyperparameters\n", "batch_size = 4\n", "block_size = 8\n", "vocab_size = len(tokens)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ix = torch.randint(\n", " len(train_data) - block_size, (batch_size,)\n", ") # random sequence of data\n", "print(ix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.stack([data[i : i + block_size] for i in ix]) # will be masked\n", "y = x.clone() # will become targets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pad_token_id = stoi[\"PAD\"]\n", "pad_token_id" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mask_token_id = stoi[\"MASK\"]\n", "mask_token_id" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 50/50 chance that the sequence will be cut off and padded with the pad token\n", "# helps the model learn to embed words from a variety of sequence length\n", "import random\n", "\n", "for i in range(batch_size):\n", " if random.random() < 0.5:\n", " pad_length = random.randint(1, block_size // 2) # random amount to pad\n", " x[i, -pad_length:] = pad_token_id\n", " y[i, -pad_length:] = pad_token_id" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# learnable attention mask set to 1s and 0s\n", "attention_mask = (x != pad_token_id).float()\n", "attention_mask" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# masking 15% of the tokens in the sequence\n", "mask = torch.rand(x.shape) < 0.15\n", "mask = mask & (x != pad_token_id)\n", "mask" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The original BERT paper did the following of all of the *masked* tokens (not all of the tokens):\n", "\n", "* 80% are replaced with MASK token\n", "* 10% are replaced with a random token\n", "* 10% are left unchanged" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 80% are replaced with the MASK token\n", "mask_replace = mask & (torch.rand(x.shape) < 0.8)\n", "mask_replace" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 10% (50% of left over mask tokens) are replaced with a random token\n", "# 10% (other 50% of left over mask tokens) are left unchanged\n", "mask_random = mask & (torch.rand(x.shape) < 0.5) & ~mask_replace\n", "mask_random" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# applying the mask token to selected ids\n", "x[mask_replace] = mask_token_id\n", "x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# applying the random token to the selected ids\n", "random_tokens = torch.randint(vocab_size - 1, x[mask_random].shape)\n", "random_tokens = torch.where(random_tokens == pad_token_id, mask_token_id, random_tokens)\n", "random_tokens" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# pulling it all together into a single tensor\n", "x[mask_random] = random_tokens\n", "x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# return the:\n", "## training example (x), masked tensor\n", "## targets for this example (y)\n", "## attention mask - will change depending on pads and masks\n", "## mask - \"answer key\" for the targets\n", "x, y, attention_mask, mask" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# single function that does all of this\n", "def get_batch(split, mask_ratio=0.15):\n", " data = train_data if split == \"train\" else val_data\n", " ix = torch.randint(len(data) - block_size, (batch_size,))\n", " x = torch.stack([data[i : i + block_size] for i in ix])\n", " y = x.clone()\n", "\n", " for i in range(batch_size):\n", " if random.random() < 0.5:\n", " pad_length = random.randint(1, block_size // 2)\n", " x[i, -pad_length:] = pad_token_id\n", " y[i, -pad_length:] = pad_token_id\n", "\n", " attention_mask = (x != pad_token_id).float()\n", " mask = torch.rand(x.shape) < mask_ratio\n", " mask = mask & (x != pad_token_id)\n", "\n", " mask_replace = mask & (torch.rand(x.shape) < 0.8)\n", " mask_random = mask & (torch.rand(x.shape) < 0.5) & ~mask_replace\n", "\n", " x[mask_replace] = mask_token_id\n", " random_tokens = torch.randint(vocab_size - 1, x[mask_random].shape)\n", " random_tokens = torch.where(\n", " random_tokens == pad_token_id, mask_token_id, random_tokens\n", " )\n", " x[mask_random] = random_tokens\n", "\n", " return x, y, attention_mask, mask" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# what the tokens look like\n", "xb, yb, attention_mask, pred_mask = get_batch(\"train\")\n", "xb, yb, attention_mask, pred_mask" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# what the actual words look like\n", "for b in range(batch_size):\n", " print(decode(xb[b].tolist()))\n", " print(decode(yb[b].tolist()))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Attention\n", "\n", "Our data, though a different configuration than in the last notebook, is ready for a single forward pass through an attention head. As we saw in the last notebook, a single head of attention is made up of:\n", "* Attention mask: in this example this came from our `get_batch` method. In the last notebook, we used `tril` to create this.\n", "* Three linear projection layers: \n", " * Key\n", " * Query\n", " * Value\n", "* A projection layer that projects our weights from `head_size` to `n_embd`\n", "\n", "In addition to this, to complete a full forward pass we'll also need:\n", "* A token embedding table: these are learnable parameters that will become the word embeddings/vectors.\n", "* A positional embedding table: these learnable parameters help the model manage the length of the sequence, given the attention mask.\n", "* A feed forward layer: Containing a non-linearity, this layer allows the model to model complex data beyond linear transformations.\n", "* A final projection layer: This layer takes our weights from `n_embd` to `vocab_size`, so that the tokens with the highest probability of being a masked token has the highest weight.\n", "* Cross entropy loss function (negative log likelihood): This is the loss function for an either/or decision, as we have seen in the past." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "from torch.nn import functional as F\n", "\n", "n_embd = 64\n", "vocab_size = len(tokens)\n", "token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", "position_embedding_table = nn.Embedding(block_size, n_embd)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tok_emb = token_embedding_table(xb) # token embeddings\n", "pos_emb = position_embedding_table(torch.arange(block_size)) # position embeddings\n", "tok_emb.shape, pos_emb.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = tok_emb + pos_emb # elementwise addition to create x\n", "x.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "head_size = 16\n", "key = nn.Linear(n_embd, head_size, bias=False)\n", "query = nn.Linear(n_embd, head_size, bias=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "k = key(x)\n", "k.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "q = query(x)\n", "q.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights = (\n", " q @ k.transpose(-2, -1) * head_size**-0.5\n", ") # need to reshape to make matmul work\n", "weights.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "attention_mask = attention_mask.unsqueeze(1).expand(-1, block_size, -1)\n", "attention_mask.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights = weights.masked_fill(attention_mask == 0, float(\"-inf\"))\n", "weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights = F.softmax(weights, dim=-1)\n", "weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "value = nn.Linear(n_embd, head_size, bias=False)\n", "v = value(x)\n", "v.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "out = weights @ v\n", "out.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "proj = nn.Linear(head_size, n_embd)\n", "out = proj(out)\n", "out.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class FeedFoward(nn.Module):\n", " \"\"\"a simple linear layer followed by a non-linearity\"\"\"\n", "\n", " def __init__(self, n_embd, dropout=0.0):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "\n", "ffwd = FeedFoward(n_embd)\n", "out = ffwd(out)\n", "out.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lm_head = nn.Linear(n_embd, vocab_size)\n", "logits = lm_head(out)\n", "logits.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "logits = logits.view(-1, vocab_size)\n", "logits.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "yb.view(-1).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "targets = yb.view(-1)\n", "targets.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "targets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pred_mask = pred_mask.view(-1)\n", "pred_mask.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pred_mask" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "masked_logits = logits[pred_mask]\n", "masked_logits.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "masked_logits" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "masked_targets = targets[pred_mask]\n", "masked_targets.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "masked_targets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss = F.cross_entropy(masked_logits, masked_targets)\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Full model\n", "Below are all of the modules needed to fully construct the BERT model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Single head of attention" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Head(nn.Module):\n", " def __init__(self, head_size, n_embd=64, dropout=0.0):\n", " super().__init__()\n", " self.key = nn.Linear(n_embd, head_size, bias=False)\n", " self.query = nn.Linear(n_embd, head_size, bias=False)\n", " self.value = nn.Linear(n_embd, head_size, bias=False)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x, attention_mask):\n", " B, T, C = x.shape\n", " k = self.key(x)\n", " q = self.query(x)\n", "\n", " weights = q @ k.transpose(-2, -1) * C**-0.5\n", " attention_mask = attention_mask.unsqueeze(1).expand(-1, block_size, -1)\n", " weights = weights.masked_fill(attention_mask == 0, float(\"-inf\"))\n", " weights = F.softmax(weights, dim=-1)\n", " weights = self.dropout(weights)\n", "\n", " v = self.value(x)\n", " out = weights @ v\n", " return out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multihead attention, feedfoward layer and a single Block" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MultiHeadAttention(nn.Module):\n", " def __init__(self, num_heads, head_size, n_embd=64, dropout=0.0):\n", " super().__init__()\n", " self.heads = nn.ModuleList(\n", " [Head(head_size, n_embd, dropout) for _ in range(num_heads)]\n", " )\n", " self.proj = nn.Linear(head_size * num_heads, n_embd)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x, attention_mask):\n", " out = torch.cat([h(x, attention_mask) for h in self.heads], dim=-1)\n", " out = self.dropout(self.proj(out))\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class FeedFoward(nn.Module):\n", " \"\"\"a simple linear layer followed by a non-linearity\"\"\"\n", "\n", " def __init__(self, n_embd, dropout=0.0):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Block(nn.Module):\n", " \"\"\"Transformer block: communication followed by computation\"\"\"\n", "\n", " def __init__(self, n_embd, n_head, dropout=0.0):\n", " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", " super().__init__()\n", " head_size = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout)\n", " self.ffwd = FeedFoward(n_embd)\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", "\n", " def forward(self, x, attention_mask=None):\n", " x = x + self.sa(self.ln1(x), attention_mask)\n", " x = x + self.ffwd(self.ln2(x))\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Final transformer all together" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Transformer(nn.Module):\n", " def __init__(\n", " self, vocab_size, n_embd, n_head, n_layer, block_size, device, dropout=0.0\n", " ):\n", " super().__init__()\n", " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", " self.blocks = nn.Sequential(\n", " *[Block(n_embd, n_head, dropout) for _ in range(n_layer)]\n", " )\n", " self.ln_f = nn.LayerNorm(n_embd)\n", " self.lm_head = nn.Linear(n_embd, vocab_size)\n", " self.device = device\n", "\n", " def forward(self, idx, targets=None, attention_mask=None, pred_mask=None):\n", " B, T = idx.shape\n", "\n", " tok_emb = self.token_embedding_table(idx)\n", " pos_emb = self.position_embedding_table(torch.arange(T, device=self.device))\n", " x = tok_emb + pos_emb\n", "\n", " for block in self.blocks:\n", " x = block(x, attention_mask)\n", "\n", " x = self.ln_f(x)\n", " logits = self.lm_head(x)\n", "\n", " if targets is None:\n", " loss = None\n", " else:\n", " logits = logits.view(-1, vocab_size)\n", " targets = targets.view(-1)\n", " pred_mask = pred_mask.view(-1)\n", " masked_logits = logits[pred_mask]\n", " masked_targets = targets[pred_mask]\n", " loss = F.cross_entropy(masked_logits, masked_targets)\n", "\n", " return logits, loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training and evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "\n", "# hyperparameters\n", "batch_size = 32 # how many independent sequences will we process in parallel\n", "block_size = 64 # what is the maximum context length for predictions\n", "max_iters = 5000 # amount of epochs\n", "eval_interval = 100 # every this many epochs we look at the validation set\n", "learning_rate = 1e-5 # learning rate for the optimizer\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # what device to use\n", "eval_iters = 200 # how many iterations in the evaluation\n", "n_embd = 256 # embedding size\n", "n_head = 8 # attention heads\n", "n_layer = 4 # how many blocks\n", "dropout = 0.1 # amount of dropout\n", "# ------------\n", "\n", "model = Transformer(\n", " n_embd=n_embd,\n", " n_head=n_head,\n", " n_layer=n_layer,\n", " vocab_size=vocab_size,\n", " block_size=block_size,\n", " device=device,\n", ")\n", "m = model.to(device)\n", "print(sum(p.numel() for p in m.parameters()) / 1e6, \"M parameters\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def estimate_loss():\n", " out = {}\n", " model.eval()\n", " for split in [\"train\", \"val\"]:\n", " losses = torch.zeros(eval_iters)\n", " for k in range(eval_iters):\n", " X, Y, attention_mask, pred_mask = get_batch(split, mask_ratio=0.15)\n", " X, Y, attention_mask, pred_mask = (\n", " X.to(device),\n", " Y.to(device),\n", " attention_mask.to(device),\n", " pred_mask.to(device),\n", " )\n", " logits, loss = model(X, Y, attention_mask, pred_mask)\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_losses = []\n", "valid_losses = []\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", "for epoch in range(max_iters):\n", " if epoch % eval_interval == 0 or epoch == max_iters - 1:\n", " losses = estimate_loss()\n", " train_losses.append(losses[\"train\"])\n", " valid_losses.append(losses[\"val\"])\n", " print(\n", " f\"step {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\"\n", " )\n", "\n", " xb, yb, attention_mask, pred_mask = get_batch(\"train\")\n", " xb, yb, attention_mask, pred_mask = (\n", " xb.to(device),\n", " yb.to(device),\n", " attention_mask.to(device),\n", " pred_mask.to(device),\n", " )\n", "\n", " logits, loss = model(xb, yb, attention_mask, pred_mask)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(train_losses, label=\"train\")\n", "plt.plot(valid_losses, label=\"valid\")\n", "plt.legend()\n", "plt.title(\"Loss\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fero_idx = stoi[\"fero\"]\n", "with torch.no_grad():\n", " fero_embedding = model.token_embedding_table.to(\"cpu\")(\n", " torch.Tensor([fero_idx]).long()\n", " )\n", "fero_embedding.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fero_embedding" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_embedding(word):\n", " idx = stoi[word]\n", " with torch.no_grad():\n", " embedding = model.token_embedding_table.to(\"cpu\")(torch.Tensor([idx]).long())\n", " return embedding" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "visualize_words = [\n", " # example words - names\n", " \"antonius\",\n", " \"caesar\",\n", " \"pompei\",\n", " \"galba\",\n", " \"catilina\",\n", " \"cornificius\",\n", " \"scipio\",\n", " \"lucullus\",\n", " \"pontius\",\n", "]\n", "embeddings = [get_embedding(word) for word in visualize_words]\n", "visualize_vecs = torch.stack(embeddings)\n", "visualize_vecs = visualize_vecs.squeeze(1).to(\"cpu\").numpy()\n", "visualize_idx = [stoi[word] for word in visualize_words]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "temp = visualize_vecs - np.mean(visualize_vecs, axis=0)\n", "covariance = 1.0 / len(visualize_idx) * temp.T.dot(temp)\n", "U, S, V = np.linalg.svd(covariance)\n", "coord = temp.dot(U[:, 0:2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for i in range(len(visualize_words)):\n", " plt.text(\n", " coord[i, 0],\n", " coord[i, 1],\n", " visualize_words[i],\n", " bbox=dict(facecolor=\"green\", alpha=0.1),\n", " )\n", "\n", "plt.xlim((np.min(coord[:, 0] - 0.5), np.max(coord[:, 0] + 0.5)))\n", "plt.ylim((np.min(coord[:, 1] - 0.5), np.max(coord[:, 1] + 0.5)))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "mystnb": { "execution_mode": "off" } }, "nbformat": 4, "nbformat_minor": 0 }