{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget https://www.perseus.tufts.edu/hopper/dltext?doc=Perseus%3Atext%3A1999.02.0008 -O atticus.xml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Decoder-only Transformers: Generative Pre-trained Transformers (GPTs)\n", "\n", "With the release of ChatGPT by OpenAI in the autumn of 2022, many began to flock to \"AI\" treating it like magic. Today, we will investigate the modeling technqiues at the core of this technology, the decoder-only transformer.\n", "\n", "The original transformer, as preposed by Vaswami et al., contained two parts: the encoder and the decoder. This architecture is still used for transformer-based machine translation, but researchers have also split up these two different parts and found they have useful features by themselves.\n", "\n", "In a future lesson, we will take a close look at the encoder and how it is suited for representing the semantic meaning with word vectors. Today, though, we will explore the decoder and what it is capable of.\n", "\n", "**Learning objectives:**\n", "* Understand how to inference GPTs and how they are trained in order that that inferencing is possible.\n", "* Examine the internal states of models including attention heads and MLPs. Learn more about activations and how the model works in practice.\n", "* Connect sentence transformer (encoder-only) work to how GPTs work (decoder-only).\n", "* Access and utilize the free GPU resources on Colab, and learn why we need GPUs in order to do this work.\n", "\n", "*Note on terminology:* Unfortunately, there are many overlapping terms in this field. For example, a \"GPT\" is just a \"Decoder-only transformer\" that has been trained on a huge amount of data. (We'll talk about \"pretraining\" vs. training soon.) This confusion becomes all the more manifest when talking about \"Artificial Intelligence\" and \"Large Language Models\". This terminology is problematic and a serious deteriment to the field. I will attempt to be as consistent as possible with my use of terms.\n", "\n", "## Parts of the Decoder-only transformer\n", "The decoder-only transformer is made up of several parts (see the schematic below):\n", "* Embeddings: as with all of the language modeling techniques we have seen in these notebooks, the transformer relies on embeddings to internally represent token meaning. In this case we have two different types:\n", " - Input embedding: This is the embedding for the new token entering the model. The embedding of the next predicted token becomes the next input to the model, as in RNNs.\n", " - Output embeddings: This is the embedding for the next token entering the model.\n", "* Positional Encodings: These are added to the input embeddings to give the model information about the position of each token in the sequence. Like the token embeddings, this is just an embedding layer that learns what areas of the `block_size` are more important based on the tokens.\n", "* Masked Multi-Head Attention: We got acquainted with attention in our exploration of machine translation, where we used it to move between our encoder and decoder. For transformer attention, we will model attention as learnable parameters by our model and do away with modeling weights on our embeddings directly. This is were the title for Vaswami et al.'s paper \"Attention is all you need\" comes from. In addition, we will also have a \"causal\" mask, where the model will learn to predict the next word in a sequence.\n", "* Normalization: These layers make sure that all of the data passing through the network is regularized and well behaved, not causing any gradients that would disrupt the model.\n", "* Feed forward: This layer allows the model to process the information from the attention layer through non-linear transformations, increasing the model's capacity to learn complex patterns\n", "* Last linear layer: This last linear layer allows the model to make its predictions for the next token in the sequence.\n", "* Softmax: As we have seen since word2vec, this function transforms the logits of a linear layer into a probability distribution from which we can sample from and get the index of the predicted next token.\n", "\n", "It is worth noting that a \"Block\" is made up of the masked mulit-head attention, the normalization layers and the feed forward layer. This Block can be repeated many times before a prediction is actually made. In fact, the only difference between smaller and larger models often comes down to how many repetitions of these blocks there are.\n", "\n", "\n", "\n", "*Last note*: This notebook is *heavily* inspired by Andrej Karpathy's fabulous [Let's build GPT: from scratch, in code, spelled out](https://www.youtube.com/watch?v=kCc8FmEb1nY). In fact, it's mostly the same, besides these textual additions for explanation and what data we use. I would highly recommend that you also watch this video. Karpathy does a wondeful job explaining these concepts with code and is treasure to the deep learning world. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\n", "\n", "Let's start off by preparing our data. As we have seen, this process does not depend on a certain language, so I will be using Cicero's Letters to Atticus from Perseus.\n", "\n", "Unlike more common GPT implementations, I will not be tokenizing the text in a standard method. Instead, each of our tokens will be character or letter. This makes tokenization easier as we are not talking about tokenization, the details of which could take up an entire course, in this notebook. This, however, will serious hold back our performance. At the end of the notebook, I incorporate a standard tokenizer into model training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# extracting text from XML\n", "from bs4 import BeautifulSoup\n", "import re\n", "\n", "soup = BeautifulSoup(open(\"atticus.xml\", \"r\").read(), features=\"xml\")\n", "\n", "letters = []\n", "for d in soup.find_all(\"div2\"):\n", " dateline = d.dateline.extract().get_text().strip()\n", " salute = d.salute.extract().get_text().strip()\n", " text = re.sub(r\"\\s+\", \" \", d.get_text().strip().replace(\"\\n\", \"\"))\n", " letters.append(dateline + \"\\n\" + salute + \"\\n\" + text)\n", "\n", "text = \"\\n\\n\".join(letters)\n", "print(len(text))\n", "print(text[:1000])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# \"tokenization\": getting each character for simplicity\n", "chars = sorted(list(set(text)))\n", "vocab_size = len(chars)\n", "print(\"\".join(chars))\n", "print(vocab_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# these are data structures that we can use to easily move between the integer representation of the text and the character representation\n", "stoi = {ch: i for i, ch in enumerate(chars)}\n", "itos = {i: ch for i, ch in enumerate(chars)}\n", "encode = lambda s: [stoi[c] for c in s]\n", "decode = lambda l: \"\".join([itos[i] for i in l])\n", "\n", "print(encode(\"salve mundus\"))\n", "print(decode(encode(\"salve mundus\")))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "data = torch.tensor(\n", " encode(text), dtype=torch.long\n", ") # turning our encoded data into a tensor\n", "print(data.shape, data.dtype)\n", "print(data[:1000])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# reserving 10% of the data for validation\n", "n = int(0.9 * len(data))\n", "train_data = data[:n]\n", "val_data = data[n:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "block_size = 8 # small block size to get started\n", "train_data[: block_size + 1] # first block_size chunk" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Masked Language Modeling task:** Our goal for this model is to have it predict the next token given all of the tokens in our sequence thus far, as we have seen in other models. Below is how we would set this up for training, also called *collation*." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = train_data[:block_size]\n", "y = train_data[1 : block_size + 1]\n", "for t in range(block_size):\n", " context = x[: t + 1]\n", " target = y[t]\n", " print(f\"when input token(s) is/are {context} the target: {target}\")\n", " print(\n", " f\"when input character(s) is/are *{decode([c.item() for c in context])}* the target: *{decode([target.item()])}*\"\n", " )\n", " print()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(itos[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# putting it all together\n", "torch.manual_seed(1337) # seed for reproducibility\n", "batch_size = 4 # how many starting ids we get initially\n", "block_size = 8 # size of context as before\n", "\n", "\n", "def get_batch(split):\n", " # generate a small batch of data of inputs x and targets y\n", " data = train_data if split == \"train\" else val_data # choosing the right data split\n", " ix = torch.randint(\n", " len(data) - block_size, (batch_size,)\n", " ) # get a random batch of ids\n", " x = torch.stack(\n", " [data[i : i + block_size] for i in ix]\n", " ) # create contexts for each id\n", " y = torch.stack(\n", " [data[i + 1 : i + block_size + 1] for i in ix]\n", " ) # create the targets for each context\n", " return x, y\n", "\n", "\n", "xb, yb = get_batch(\"train\")\n", "print(\"inputs:\")\n", "print(xb.shape)\n", "print(xb)\n", "print(\"targets:\")\n", "print(yb.shape)\n", "print(yb)\n", "\n", "print(\"-\" * 20)\n", "\n", "for b in range(batch_size):\n", " for t in range(block_size):\n", " context = xb[b, : t + 1]\n", " target = yb[b, t]\n", " print(f\"when input token(s) is/are {context} the target: {target}\")\n", " print(\n", " f\"when input character(s) is/are *{decode([c.item() for c in context])}* the target: *{decode([target.item()])}*\"\n", " )\n", " print()\n", " if b < batch_size - 1:\n", " print(\"-\" * 20)\n", " print(\"Next set of contexts/targets\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Super simple Language Model\n", "\n", "Before we start looking at the decoder-only transformer, let's just see if our data is working by training a super simple model. Again this idea/code is taken from Karpathy's video.\n", "\n", "All this model does is use the embedding table to model token meaning, particularly bad here as our \"tokens\" are just single characters. It just gets the embeddings from the context by passing through the embedding table (called `logits`) and then uses cross entropy loss (softmax) to get a loss and logits for the next token." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "\n", "torch.manual_seed(1337)\n", "\n", "\n", "class BigramLanguageModel(nn.Module):\n", "\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " # each token directly reads off the logits for the next token from a lookup table\n", " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n", "\n", " def forward(self, idx, targets=None):\n", "\n", " # idx and targets are both (B,T) tensor of integers\n", " logits = self.token_embedding_table(idx) # (B,T,C)\n", "\n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B * T, C)\n", " targets = targets.view(B * T)\n", " loss = F.cross_entropy(logits, targets)\n", "\n", " return logits, loss\n", "\n", " def generate(self, idx, max_new_tokens):\n", " # idx is (B, T) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " # get the predictions\n", " logits, loss = self(idx)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " probs = F.softmax(logits, dim=-1) # (B, C)\n", " # sample from the distribution\n", " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", " return idx\n", "\n", "\n", "m = BigramLanguageModel(vocab_size)\n", "logits, loss = m(xb, yb)\n", "print(logits.shape)\n", "print(loss)\n", "\n", "print(\n", " decode(\n", " m.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[\n", " 0\n", " ].tolist()\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# optimizer for this very simple network\n", "optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch_size = 32\n", "for steps in range(1000): # increase number of steps for \"good\" results...\n", "\n", " # sample a batch of data\n", " xb, yb = get_batch(\"train\")\n", "\n", " # evaluate the loss\n", " logits, loss = m(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()\n", "\n", "print(loss.item()) # much smaller loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\n", " decode( # from our \"tokenizer\"\n", " m.generate( # from our model\n", " idx=torch.zeros((1, 1), dtype=torch.long), # empty starting context\n", " max_new_tokens=500,\n", " )[0].tolist()\n", " )\n", ") # but these results are not very good" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Attention\n", "\n", "Now that we know our data is working and we can use it to reduce the loss in very simple network, we can increase the complexity by examining the attention mechanism at the core of the transformer architecture.\n", "\n", "As Karpathy tell us, attention is really just a mathematical trick for aggregating weights in parallelizable and easy to compute way. It consists of having two matrices, called $a$ and $b$, and taking their cross product such that the resulting tensor is a weighted average of the two. In fact this is always what we do when we take a cross product, but this $a$ matrix is special.\n", "\n", "We create $a$ as a matrix with a top right triangle of zeros. This matrix will tell the result matrix which parts of the $b$ matrix to pay **attention** to and so this triangular shape with tell the model to only look at certain tokens, specifically so that it replicated the order of the sequence of tokens through the time dimension of our training example.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = torch.tril(\n", " torch.ones(3, 3)\n", ") # tril creates a matrix with the top triangle made of zeros (masked)\n", "a" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# normalize a\n", "a = a / torch.sum(a, 1, keepdim=True)\n", "a # now the \"weight\" of each row is split up between the non-zero terms" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(22091997) # for reproducibility\n", "b = torch.randint(0, 10, (3, 2)).float() # random matrix\n", "b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we take the cross product of $a$ and $b$ the sizes must line up:\n", "\n", "* $a$ - 3 x 3\n", "* $b$ - 3 x 8\n", "* $c$ (result) - 3 x 8" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c = a @ b\n", "c # the weights from a have been distributed across b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# the first row is exactly the same because a tells c to only pay ATTENTION to the first element of b\n", "print(a[0]) # just 1, 0, 0, refers to rows of b\n", "print()\n", "print(b[0])\n", "print(c[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# second row is the average of the first two rows of b\n", "print(\n", " a[1]\n", ") # tells c to pay ATTENTION to the first two rows of b but weight your attention by .5 (the normal average)\n", "print()\n", "print(b[1])\n", "print(c[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# second row spelled out without `a` matrix\n", "print(f\"First row of b: {b[0]}\")\n", "print(f\"Second row of b: {b[1]}\")\n", "print(f\"Normal average of first two rows of b: {(b[0] + b[1])/2}\")\n", "print(f\"Second row of c: {c[1]}\") # same!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# third row is the average of all three rows of b\n", "print(\n", " a[2]\n", ") # tells c to pay ATTENTION to all three rows of b but weight your attention by .33 (the normal average)\n", "print()\n", "print(b[2])\n", "print(c[2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# third row spelled out without `a` matrix\n", "print(f\"First row of b: {b[1]}\")\n", "print(f\"Second row of b: {b[2]}\")\n", "print(f\"Normal average of first two rows of b: {(b[1] + b[2])/2}\")\n", "print(f\"Second row of c: {c[2]}\") # same!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# consider the following toy example:\n", "torch.manual_seed(1337)\n", "B, T, C = 4, 8, 2 # batch, time, channels\n", "x = torch.randn(B, T, C)\n", "x.shape, x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# We want x[b,t] = mean_{i<=t} x[b,i]\n", "# doing this without a matrix - slow with large matrices\n", "xbow = torch.zeros((B, T, C))\n", "for b in range(B):\n", " for t in range(T):\n", " xprev = x[b, : t + 1] # (t,C)\n", " xbow[b, t] = torch.mean(xprev, 0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# version 2: using matrix multiply for a weighted aggregation\n", "wei = torch.tril(torch.ones(T, T))\n", "wei = wei / wei.sum(1, keepdim=True)\n", "xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xbow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xbow2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Self-attention\n", "\n", "As we can see, attention allows us to scale importance of tokens over the sequence of the training example. In our network, we want to learn how best to scale this $a$ matrix so that it is all we need to predict the next token of the model. This is where the famous title of the paper that introduced the transformer comes from: \"Attention Is All You Need\".\n", "\n", "To do this scaling, we introduce three new matrices: a *key* matrix (K), a *value* matrix (V) and a *query* matrix (Q). These linear projections learn the affinities between different tokens, so that when we apply the $a$ matrix, we do so in a data-driven, non-abritrary weight aggregation rather than a simple average.\n", "\n", "We can conceptually understand what these linear projections are doing in the schematic and description below:\n", "* Token embedding for a given token: \"What I am\" (`x`, below)\n", "* Key vector for a given token: \"What do I contain\" (`k`, below)\n", "* Query vector for a given token: \"What am I looking for\" (`q`, below)\n", "* Value vector for a given token: \"What I will communicate to you\" (`v`, below)\n", "\n", "![image](https://cdn-images-1.medium.com/max/2000/1*sQP6cxjpXZ_lxDFYYe9Vdw.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Random example\n", "\n", "This example uses a random matrix. Next we'll look at this with a real training example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "B, T, C = 4, 8, 32 # batch, time, channels\n", "x = torch.randn(B, T, C) # this would be from our training examples in a real model\n", "x.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# creating our keys and values for each token in the training example\n", "head_size = 16\n", "key = nn.Linear(C, head_size, bias=False)\n", "query = nn.Linear(C, head_size, bias=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# key projection\n", "k = key(x)\n", "k.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# query projection\n", "q = query(x)\n", "q.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# determine affinities between what each token wants and what token has\n", "# each element is a score between what each token wants and what token has\n", "weights = q @ k.transpose(-2, -1) # need to reshape to make matmul work\n", "weights.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tril = torch.tril(torch.ones(T, T)) # apply the triangular matrix\n", "weights = weights.masked_fill(\n", " tril == 0, float(\"-inf\")\n", ") # these weights are now scaled by the affinities from above\n", "weights.shape, weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights = F.softmax(weights, dim=-1) # pass through softmax to get a prob distribution\n", "weights.shape, weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# value projection\n", "value = nn.Linear(C, head_size, bias=False)\n", "v = value(x)\n", "v.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# last matmul to apply the values to the affinities\n", "out = weights @ v\n", "out.shape, out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "out[0], out[0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Real training example\n", "Now we can see what this looks like with a real training example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch_size = 4\n", "block_size = 8\n", "x, y = get_batch(\"train\")\n", "x.shape, y.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_embd = 64\n", "token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", "position_embedding_table = nn.Embedding(block_size, n_embd)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tok_emb = token_embedding_table(x) # token embeddings\n", "pos_emb = position_embedding_table(torch.arange(block_size)) # position embeddings\n", "tok_emb.shape, pos_emb.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = tok_emb + pos_emb # elementwise addition to create x\n", "x.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# initialize our projections\n", "head_size = 16\n", "key = nn.Linear(n_embd, head_size, bias=False)\n", "query = nn.Linear(n_embd, head_size, bias=False)\n", "value = nn.Linear(n_embd, head_size, bias=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "k = key(x)\n", "q = query(x)\n", "k.shape, q.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights = q @ k.transpose(-2, -1) * block_size**-0.5 # scaling\n", "weights.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tril = torch.tril(torch.ones(T, T))\n", "weights = weights.masked_fill(tril == 0, float(\"-inf\"))\n", "weights.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights = F.softmax(weights, dim=-1)\n", "weights.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = value(x)\n", "v.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "out = weights @ v\n", "out.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notes from Karpathy:\n", "- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.\n", "- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.\n", "- Each example across batch dimension is of course processed completely independently and never \"talk\" to each other\n", "- In an \"encoder\" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a \"decoder\" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.\n", "- \"self-attention\" just means that the keys and values are produced from the same source as queries. In \"cross-attention\", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)\n", "- \"Scaled\" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much.\n", "\n", "To finish this illustration, I will take this example to a calculation of loss for this training example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class FeedFoward(nn.Module):\n", " \"\"\"a simple linear layer followed by a non-linearity\"\"\"\n", "\n", " def __init__(self, n_embd, dropout=0.0):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "\n", "ffwd = FeedFoward(n_embd)\n", "lm_head = nn.Linear(n_embd, vocab_size)\n", "\n", "x = ffwd(x) # out\n", "logits = lm_head(x)\n", "logits.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "B, T, C = logits.shape\n", "logits = logits.view(B * T, C)\n", "targets = y.view(B * T) # targets from above\n", "loss = F.cross_entropy(logits, targets)\n", "loss.item() # loss for this training example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we can calculate a loss value for this model, we could call `loss.backward()` and get the gradients needed to take a step with our optimizer (`optimizer.step()`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modeling\n", "\n", "Seeing self-attention gave us the tools needed to fully implement the decoder-only transformer. Below are all of the modules that we wrote out above." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Head\n", "\n", "A single \"head\" of attention is just what we saw above with some slight alterations." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Head(nn.Module):\n", " \"\"\"one head of self-attention\"\"\"\n", "\n", " def __init__(self, head_size, n_embd=64, dropout=0.0):\n", " super().__init__()\n", " self.key = nn.Linear(n_embd, head_size, bias=False)\n", " self.query = nn.Linear(n_embd, head_size, bias=False)\n", " self.value = nn.Linear(n_embd, head_size, bias=False)\n", " self.register_buffer(\"tril\", torch.tril(torch.ones(block_size, block_size)))\n", "\n", " self.dropout = nn.Dropout(dropout) # standard dropout\n", "\n", " def forward(self, x):\n", " B, T, C = x.shape\n", " k = self.key(x) # (B,T,C)\n", " q = self.query(x) # (B,T,C)\n", " # compute attention scores (\"affinities\")\n", " wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n", " wei = wei.masked_fill(self.tril[:T, :T] == 0, float(\"-inf\")) # (B, T, T)\n", " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", " wei = self.dropout(wei)\n", " # perform the weighted aggregation of the values\n", " v = self.value(x) # (B,T,C)\n", " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ex_head = Head(head_size=16)\n", "ex_head" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "B, T, C = 4, 8, 64 # pos + tok embeddings\n", "x = torch.randn(B, T, C)\n", "out = ex_head(x)\n", "out.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multihead\n", "\n", "Now we can group together these head into multi-headed attention. This module is able to do the attention calculation in parallel. That's all that's different about it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MultiHeadAttention(nn.Module):\n", " \"\"\"multiple heads of self-attention in parallel\"\"\"\n", "\n", " def __init__(self, num_heads, head_size, n_embd=64, dropout=0.0):\n", " super().__init__()\n", " self.heads = nn.ModuleList(\n", " [Head(head_size, n_embd=n_embd) for _ in range(num_heads)]\n", " )\n", " self.proj = nn.Linear(n_embd, n_embd)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", " out = self.dropout(self.proj(out))\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ex_multihead = MultiHeadAttention(num_heads=4, head_size=16)\n", "ex_multihead" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "out = ex_multihead(x)\n", "out.shape # 16 * 4 = 64" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Feed forward and Block\n", "\n", "As we saw above, the feed forward layer added a nonlinearity which allows the model to learn more complex features. In a normal RNN, this layer takes as input the actual token embeddings, but in the transformer, it takes in the output of the attention heads.\n", "\n", "The Block just wraps all of what we've seen so far in a single module." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class FeedFoward(nn.Module):\n", " \"\"\"a simple linear layer followed by a non-linearity\"\"\"\n", "\n", " def __init__(self, n_embd, dropout=0.0):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Block(nn.Module):\n", " \"\"\"Transformer block: communication followed by computation\"\"\"\n", "\n", " def __init__(self, n_embd, n_head, dropout=0.0):\n", " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", " super().__init__()\n", " head_size = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout)\n", " self.ffwd = FeedFoward(n_embd)\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", "\n", " def forward(self, x):\n", " x = x + self.sa(self.ln1(x))\n", " x = x + self.ffwd(self.ln2(x))\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Full model\n", "\n", "We have seen all of the foundations of the transformer, so now we can put it all together in a single model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Transformer(nn.Module):\n", "\n", " def __init__(self, n_embd, n_head, n_layer, device):\n", " super().__init__()\n", " # each token directly reads off the logits for the next token from a lookup table\n", " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", " self.blocks = nn.Sequential(\n", " *[Block(n_embd, n_head=n_head) for _ in range(n_layer)]\n", " )\n", " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", " self.lm_head = nn.Linear(n_embd, vocab_size)\n", " self.device = device\n", "\n", " def forward(self, idx, targets=None):\n", " B, T = idx.shape\n", "\n", " # idx and targets are both (B,T) tensor of integers\n", " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", " pos_emb = self.position_embedding_table(\n", " torch.arange(T, device=self.device)\n", " ) # (T,C)\n", " x = tok_emb + pos_emb # (B,T,C)\n", " x = self.blocks(x) # (B,T,C)\n", " x = self.ln_f(x) # (B,T,C)\n", " logits = self.lm_head(x) # (B,T,vocab_size)\n", "\n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B * T, C)\n", " targets = targets.view(B * T)\n", " loss = F.cross_entropy(logits, targets)\n", "\n", " return logits, loss\n", "\n", " def generate(self, idx, max_new_tokens):\n", " # idx is (B, T) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " # crop idx to the last block_size tokens\n", " idx_cond = idx[:, -block_size:]\n", " # get the predictions\n", " logits, loss = self(idx_cond)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " probs = F.softmax(logits, dim=-1) # (B, C)\n", " # sample from the distribution\n", " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", " return idx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initial Training\n", "\n", "Finally, we can start training with our Latin dataset.\n", "\n", "Two things are not typical about this trainining:\n", "1. Our tokenizer is still primitive\n", "2. We are running this on the CPU" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "\n", "# hyperparameters\n", "batch_size = 16 # how many independent sequences will we process in parallel\n", "block_size = 32 # what is the maximum context length for predictions\n", "max_iters = 5000 # amount of epochs\n", "eval_interval = 100 # every this many epochs we look at the validation set\n", "learning_rate = 1e-3 # learning rate for the optimizer\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # what device to use\n", "eval_iters = 200 # how many iterations in the evaluation\n", "n_embd = 64 # embedding size\n", "n_head = 4 # attention heads\n", "n_layer = 4 # how many blocks\n", "dropout = 0.0 # amount of dropout\n", "# ------------\n", "\n", "model = Transformer(n_embd=n_embd, n_head=n_head, n_layer=n_layer, device=device)\n", "m = model.to(device)\n", "print(sum(p.numel() for p in m.parameters()) / 1e6, \"M parameters\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# function for estimating the loss during evaluation\n", "@torch.no_grad()\n", "def estimate_loss():\n", " out = {}\n", " model.eval()\n", " for split in [\"train\", \"val\"]:\n", " losses = torch.zeros(eval_iters)\n", " for k in range(eval_iters):\n", " X, Y = get_batch(split)\n", " logits, loss = model(X.to(\"cuda\"), Y.to(\"cuda\"))\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_losses = []\n", "valid_losses = []\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", "for epoch in range(max_iters):\n", " if epoch % eval_interval == 0 or epoch == max_iters - 1:\n", " losses = estimate_loss()\n", " train_losses.append(losses[\"train\"])\n", " valid_losses.append(losses[\"val\"])\n", " print(\n", " f\"step {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\"\n", " )\n", "\n", " xb, yb = get_batch(\"train\")\n", " xb, yb = xb.to(device), yb.to(device)\n", "\n", " logits, loss = model(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(train_losses, label=\"train\")\n", "plt.plot(valid_losses, label=\"valid\")\n", "plt.legend()\n", "plt.title(\"Loss\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n", "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## More typical training\n", "\n", "As opposed to above, this code trains the transformer with a more typical tokenizer and on the GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# tiktoken: same tokenizer as gpt3/3.5/4\n", "!pip install tiktoken -Uq" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# extracting text from XML\n", "from bs4 import BeautifulSoup\n", "import re\n", "\n", "soup = BeautifulSoup(open(\"atticus.xml\", \"r\").read(), features=\"xml\")\n", "\n", "letters = []\n", "for d in soup.find_all(\"div2\"):\n", " dateline = d.dateline.extract().get_text().strip()\n", " salute = d.salute.extract().get_text().strip()\n", " text = re.sub(r\"\\s+\", \" \", d.get_text().strip().replace(\"\\n\", \"\"))\n", " letters.append(dateline + \"\\n\" + salute + \"\\n\" + text)\n", "\n", "text = \"\\n\\n\".join(letters)\n", "print(len(text))\n", "text[:100]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tiktoken\n", "\n", "tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n", "\n", "# example\n", "tokenizer.encode(text[:100])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "data = torch.tensor(\n", " tokenizer.encode(text), dtype=torch.long\n", ") # turning our encoded data into a tensor, as above\n", "print(data.shape, data.dtype)\n", "print(data[:100])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vocab_size = tokenizer.n_vocab # number of all unique tokens in the tokenizer\n", "vocab_size" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# reserving 10% of the data for validation, as above\n", "n = int(0.9 * len(data))\n", "train_data = data[:n]\n", "val_data = data[n:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "block_size = 8 # small block size to get started, as above\n", "train_data[: block_size + 1] # first block_size chunk, as above" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = train_data[:block_size]\n", "y = train_data[1 : block_size + 1]\n", "for t in range(block_size):\n", " context = x[: t + 1]\n", " target = y[t]\n", " print(f\"when input token(s) is/are {context} the target: {target}\")\n", " print(\n", " f\"when input character(s) is/are *{tokenizer.decode([c.item() for c in context])}* the target: *{tokenizer.decode([target.item()])}*\"\n", " )\n", " print()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# putting it all together\n", "torch.manual_seed(1337) # seed for reproducibility\n", "batch_size = 4 # how many starting ids we get initially\n", "block_size = 8 # size of context as before\n", "\n", "\n", "def get_batch(split):\n", " # generate a small batch of data of inputs x and targets y\n", " data = train_data if split == \"train\" else val_data # choosing the right data split\n", " ix = torch.randint(\n", " len(data) - block_size, (batch_size,)\n", " ) # get a random batch of ids\n", " x = torch.stack(\n", " [data[i : i + block_size] for i in ix]\n", " ) # create contexts for each id\n", " y = torch.stack(\n", " [data[i + 1 : i + block_size + 1] for i in ix]\n", " ) # create the targets for each context\n", " return x, y\n", "\n", "\n", "xb, yb = get_batch(\"train\")\n", "print(\"inputs:\")\n", "print(xb.shape)\n", "print(xb)\n", "print(\"targets:\")\n", "print(yb.shape)\n", "print(yb)\n", "\n", "print(\"-\" * 20)\n", "\n", "for b in range(batch_size):\n", " for t in range(block_size):\n", " context = xb[b, : t + 1]\n", " target = yb[b, t]\n", " print(f\"when input token(s) is/are {context} the target: {target}\")\n", " print(\n", " f\"when input character(s) is/are *{tokenizer.decode([c.item() for c in context])}* the target: *{tokenizer.decode([target.item()])}*\"\n", " )\n", " print()\n", " if b < batch_size - 1:\n", " print(\"-\" * 20)\n", " print(\"Next set of contexts/targets\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model\n", "\n", "The same code as above, repeated here so you don't need to run it above." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "\n", "class Head(nn.Module):\n", " \"\"\"one head of self-attention\"\"\"\n", "\n", " def __init__(self, head_size, n_embd=64, dropout=0.0):\n", " super().__init__()\n", " self.key = nn.Linear(n_embd, head_size, bias=False)\n", " self.query = nn.Linear(n_embd, head_size, bias=False)\n", " self.value = nn.Linear(n_embd, head_size, bias=False)\n", " self.register_buffer(\"tril\", torch.tril(torch.ones(block_size, block_size)))\n", "\n", " self.dropout = nn.Dropout(dropout) # standard dropout\n", "\n", " def forward(self, x):\n", " B, T, C = x.shape\n", " k = self.key(x) # (B,T,C)\n", " q = self.query(x) # (B,T,C)\n", " # compute attention scores (\"affinities\")\n", " wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n", " wei = wei.masked_fill(self.tril[:T, :T] == 0, float(\"-inf\")) # (B, T, T)\n", " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", " wei = self.dropout(wei)\n", " # perform the weighted aggregation of the values\n", " v = self.value(x) # (B,T,C)\n", " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n", " return out\n", "\n", "\n", "class MultiHeadAttention(nn.Module):\n", " \"\"\"multiple heads of self-attention in parallel\"\"\"\n", "\n", " def __init__(self, num_heads, head_size, n_embd=64, dropout=0.0):\n", " super().__init__()\n", " self.heads = nn.ModuleList([Head(head_size, n_embd) for _ in range(num_heads)])\n", " self.proj = nn.Linear(n_embd, n_embd)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", " out = self.dropout(self.proj(out))\n", " return out\n", "\n", "\n", "class FeedFoward(nn.Module):\n", " \"\"\"a simple linear layer followed by a non-linearity\"\"\"\n", "\n", " def __init__(self, n_embd, dropout=0.0):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "\n", "class Block(nn.Module):\n", " \"\"\"Transformer block: communication followed by computation\"\"\"\n", "\n", " def __init__(self, n_embd, n_head, dropout=0.0):\n", " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", " super().__init__()\n", " head_size = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout)\n", " self.ffwd = FeedFoward(n_embd)\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", "\n", " def forward(self, x):\n", " x = x + self.sa(self.ln1(x))\n", " x = x + self.ffwd(self.ln2(x))\n", " return x\n", "\n", "\n", "class Transformer(nn.Module):\n", "\n", " def __init__(self, n_embd, n_head, n_layer, device):\n", " super().__init__()\n", " # each token directly reads off the logits for the next token from a lookup table\n", " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", " self.blocks = nn.Sequential(\n", " *[Block(n_embd, n_head=n_head) for _ in range(n_layer)]\n", " )\n", " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", " self.lm_head = nn.Linear(n_embd, vocab_size)\n", " self.device = device\n", "\n", " def forward(self, idx, targets=None):\n", " B, T = idx.shape\n", "\n", " # idx and targets are both (B,T) tensor of integers\n", " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", " pos_emb = self.position_embedding_table(\n", " torch.arange(T, device=self.device)\n", " ) # (T,C)\n", " x = tok_emb + pos_emb # (B,T,C)\n", " x = self.blocks(x) # (B,T,C)\n", " x = self.ln_f(x) # (B,T,C)\n", " logits = self.lm_head(x) # (B,T,vocab_size)\n", "\n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B * T, C)\n", " targets = targets.view(B * T)\n", " loss = F.cross_entropy(logits, targets)\n", "\n", " return logits, loss\n", "\n", " def generate(self, idx, max_new_tokens):\n", " # idx is (B, T) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " # crop idx to the last block_size tokens\n", " idx_cond = idx[:, -block_size:]\n", " # get the predictions\n", " logits, loss = self(idx_cond)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " probs = F.softmax(logits, dim=-1) # (B, C)\n", " # sample from the distribution\n", " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", " return idx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training\n", "\n", "Training this model with a full tokenizer. This is deeply inefficient, but a useful example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "\n", "# hyperparameters\n", "batch_size = 16 # how many independent sequences will we process in parallel\n", "block_size = 64 # what is the maximum context length for predictions\n", "max_iters = 10000 # amount of epochs\n", "eval_interval = 100 # every this many epochs we look at the validation set\n", "learning_rate = 5e-5 # learning rate for the optimizer\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # what device to use\n", "eval_iters = 200 # how many iterations in the evaluation\n", "n_embd = 128 # embedding size\n", "n_head = 16 # attention heads\n", "n_layer = 8 # how many blocks\n", "dropout = 0.0 # amount of dropout\n", "# ------------\n", "\n", "model = Transformer(n_embd=n_embd, n_head=n_head, n_layer=n_layer, device=device)\n", "m = model.to(device)\n", "print(sum(p.numel() for p in m.parameters()) / 1e6, \"M parameters\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# function for estimating the loss during evaluation\n", "@torch.no_grad()\n", "def estimate_loss():\n", " out = {}\n", " model.eval()\n", " for split in [\"train\", \"val\"]:\n", " losses = torch.zeros(eval_iters)\n", " for k in range(eval_iters):\n", " X, Y = get_batch(split)\n", " X, Y = X.to(device), Y.to(device)\n", " logits, loss = model(X, Y)\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_losses = []\n", "valid_losses = []\n", "\n", "## training\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", "for epoch in range(max_iters):\n", " if epoch % eval_interval == 0 or epoch == max_iters - 1:\n", " losses = estimate_loss()\n", " train_losses.append(losses[\"train\"])\n", " valid_losses.append(losses[\"val\"])\n", " print(\n", " f\"step {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\"\n", " )\n", "\n", " xb, yb = get_batch(\"train\")\n", " xb, yb = xb.to(device), yb.to(device)\n", "\n", " logits, loss = model(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(train_losses, label=\"train\")\n", "plt.plot(valid_losses, label=\"valid\")\n", "plt.legend()\n", "plt.title(\"Loss\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n", "print(tokenizer.decode(m.generate(context, max_new_tokens=2000)[0].tolist()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "context_openning = torch.tensor(\n", " tokenizer.encode(\"Scr. Romae\"), dtype=torch.long, device=device\n", ").unsqueeze(dim=0)\n", "print(tokenizer.decode(m.generate(context_openning, max_new_tokens=2000)[0].tolist()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "TPLzDz-6BkmT", "PJG1LVY7Bm5h", "KParnKvQBokb" ], "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "mystnb": { "execution_mode": "off" } }, "nbformat": 4, "nbformat_minor": 0 }