back to index

Let's build GPT: from scratch, in code, spelled out.


Chapters

0:0 intro: ChatGPT, Transformers, nanoGPT, Shakespeare
7:52 reading and exploring the data
9:28 tokenization, train/val split
14:27 data loader: batches of chunks of data
22:11 simplest baseline: bigram language model, loss, generation
34:53 training the bigram model
38:0 port our code to a script
42:13 version 1: averaging past context with for loops, the weakest form of aggregation
47:11 the trick in self-attention: matrix multiply as weighted aggregation
51:54 version 2: using matrix multiply
54:42 version 3: adding softmax
58:26 minor code cleanup
60:18 positional encoding
62:0 THE CRUX OF THE VIDEO: version 4: self-attention
71:38 note 1: attention as communication
72:46 note 2: attention has no notion of space, operates over sets
73:40 note 3: there is no communication across batch dimension
74:14 note 4: encoder blocks vs. decoder blocks
75:39 note 5: attention vs. self-attention vs. cross-attention
76:56 note 6: "scaled" self-attention. why divide by sqrt(head_size)
79:11 inserting a single self-attention block to our network
81:59 multi-headed self-attention
84:25 feedforward layers of transformer block
86:48 residual connections
92:51 layernorm (and its relationship to our previous batchnorm)
97:49 scaling up the model! creating a few variables. adding dropout
102:39 encoder vs. decoder vs. both (?) Transformers
106:22 super quick walkthrough of nanoGPT, batched multi-headed self-attention
108:53 back to ChatGPT, GPT-3, pretraining vs. finetuning, RLHF
114:32 conclusions

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi everyone. So by now you have probably heard of ChatGPT. It has taken the world and the AI
00:00:05.280 | community by storm and it is a system that allows you to interact with an AI and give it text-based
00:00:11.840 | tasks. So for example, we can ask ChatGPT to write us a small haiku about how important it is that
00:00:16.960 | people understand AI and then they can use it to improve the world and make it more prosperous.
00:00:20.800 | So when we run this, AI knowledge brings prosperity for all to see, embrace its power.
00:00:28.080 | Okay, not bad. And so you could see that ChatGPT went from left to right and generated all these
00:00:33.200 | words sequentially. Now I asked it already the exact same prompt a little bit earlier and it
00:00:39.760 | generated a slightly different outcome. AI's power to grow, ignorance holds us back, learn, prosperity
00:00:46.080 | waits. So pretty good in both cases and slightly different. So you can see that ChatGPT is a
00:00:51.440 | probabilistic system and for any one prompt it can give us multiple answers sort of replying to it.
00:00:58.160 | Now this is just one example of a prompt. People have come up with many, many examples and there
00:01:02.640 | are entire websites that index interactions with ChatGPT and so many of them are quite humorous.
00:01:09.360 | Explain HTML to me like I'm a dog, write release notes for chess too, write a note about Elon Musk
00:01:15.760 | buying a Twitter and so on. So as an example, please write a breaking news article about a
00:01:22.080 | leaf falling from a tree. In a shocking turn of events, a leaf has fallen from a tree in the local
00:01:27.920 | park. Witnesses report that the leaf, which was previously attached to a branch of a tree, detached
00:01:32.800 | itself and fell to the ground. Very dramatic. So you can see that this is a pretty remarkable system
00:01:38.560 | and it is what we call a language model because it models the sequence of words or characters or
00:01:46.320 | tokens more generally and it knows how certain words follow each other in English language.
00:01:52.000 | And so from its perspective, what it is doing is it is completing the sequence. So I give it the
00:01:57.840 | start of a sequence and it completes the sequence with the outcome. And so it's a language model
00:02:03.280 | in that sense. Now I would like to focus on the under the hood components of what makes
00:02:10.080 | ChatGPT work. So what is the neural network under the hood that models the sequence of these words?
00:02:16.400 | And that comes from this paper called "Attention is All You Need" in 2017, a landmark paper,
00:02:22.880 | a landmark paper in AI that produced and proposed the transformer architecture.
00:02:27.600 | So GPT is short for generatively pre-trained transformer. So transformer is the neural
00:02:35.440 | net that actually does all the heavy lifting under the hood. It comes from this paper in 2017.
00:02:41.040 | Now if you read this paper, this reads like a pretty random machine translation paper and that's
00:02:46.000 | because I think the authors didn't fully anticipate the impact that the transformer would have on the
00:02:49.920 | field. And this architecture that they produced in the context of machine translation in their case
00:02:55.600 | actually ended up taking over the rest of AI in the next five years after. And so this architecture
00:03:02.320 | with minor changes was copy pasted into a huge amount of applications in AI in more recent years.
00:03:08.960 | And that includes at the core of ChatGPT. Now we are not going to, what I'd like to do now is I'd
00:03:15.280 | like to build out something like ChatGPT, but we're not going to be able to of course reproduce
00:03:20.560 | ChatGPT. This is a very serious production grade system. It is trained on a good chunk of internet
00:03:27.760 | and then there's a lot of pre-training and fine tuning stages to it. And so it's very complicated.
00:03:32.960 | What I'd like to focus on is just to train a transformer based language model. And in our case
00:03:39.200 | it's going to be a character level language model. I still think that is a very educational with
00:03:44.480 | respect to how these systems work. So I don't want to train on the chunk of internet. We need a smaller
00:03:49.520 | data set. In this case, I propose that we work with my favorite toy data set. It's called Tiny
00:03:54.800 | Shakespeare. And what it is is basically it's a concatenation of all of the works of Shakespeare
00:04:00.240 | in my understanding. And so this is all of Shakespeare in a single file. This file is about
00:04:05.760 | one megabyte and it's just all of Shakespeare. And what we are going to do now is we're going
00:04:11.440 | to basically model how these characters follow each other. So for example, given a chunk of
00:04:16.560 | these characters like this, given some context of characters in the past, the transformer neural
00:04:23.440 | network will look at the characters that I've highlighted and is going to predict that G
00:04:27.680 | is likely to come next in the sequence. And it's going to do that because we're going to train that
00:04:32.000 | transformer on Shakespeare. And it's just going to try to produce character sequences that look
00:04:37.840 | like this. And in that process, it's going to model all the patterns inside this data.
00:04:42.160 | So once we've trained the system, I'd just like to give you a preview.
00:04:46.160 | We can generate infinite Shakespeare. And of course, it's a fake thing that looks kind of like
00:04:52.640 | Shakespeare. Apologies for there's some jank that I'm not able to resolve in here, but
00:05:03.760 | you can see how this is going character by character. And it's kind of like predicting
00:05:07.600 | Shakespeare-like language. So "Verily, my lord, the sites have left thee again, the king,
00:05:14.960 | coming with my curses with precious pale." And then "Tranio says something else," et cetera.
00:05:20.880 | And this is just coming out of the transformer in a very similar manner as it would come out
00:05:25.200 | in chat GPT. In our case, character by character, in chat GPT, it's coming out on the token by token
00:05:32.240 | level. And tokens are these sort of like little sub-word pieces. So they're not word level. They're
00:05:36.880 | kind of like word chunk level. And now I've already written this entire code to train these
00:05:45.040 | transformers. And it is in a GitHub repository that you can find, and it's called nanoGPT.
00:05:51.600 | So nanoGPT is a repository that you can find on my GitHub. And it's a repository for training
00:05:57.680 | transformers on any given text. And what I think is interesting about it, because there's many ways
00:06:03.440 | to train transformers, but this is a very simple implementation. So it's just two files of 300
00:06:08.560 | lines of code each. One file defines the GPT model, the transformer, and one file trains it
00:06:14.400 | on some given text dataset. And here I'm showing that if you train it on a open web text dataset,
00:06:19.600 | which is a fairly large dataset of web pages, then I reproduce the performance of GPT2.
00:06:26.320 | So GPT2 is an early version of OpenAI's GPT from 2017, if I recall correctly. And I've only so far
00:06:34.400 | reproduced the smallest 124 million parameter model. But basically, this is just proving that
00:06:39.280 | the code base is correctly arranged. And I'm able to load the neural network weights that OpenAI has
00:06:45.440 | released later. So you can take a look at the finished code here in nanoGPT. But what I would
00:06:51.200 | like to do in this lecture is I would like to basically write this repository from scratch.
00:06:57.040 | So we're going to begin with an empty file, and we're going to define a transformer piece by piece.
00:07:02.160 | We're going to train it on the tiny Shakespeare dataset, and we'll see how we can then generate
00:07:08.640 | infinite Shakespeare. And of course, this can copy paste to any arbitrary text dataset that you like.
00:07:14.000 | But my goal really here is to just make you understand and appreciate how under the hood
00:07:19.600 | chat-gpt works. And really, all that's required is a proficiency in Python and some basic
00:07:27.040 | understanding of calculus and statistics. And it would help if you also see my previous videos
00:07:32.960 | on the same YouTube channel, in particular, my Make More series, where I define smaller and
00:07:40.400 | simpler neural network language models. So multilayered perceptrons and so on. It really
00:07:45.520 | introduces the language modeling framework. And then here in this video, we're going to focus on
00:07:50.000 | the transformer neural network itself. Okay, so I created a new Google Colab Jupyter notebook here.
00:07:56.240 | And this will allow me to later easily share this code that we're going to develop together
00:08:00.960 | with you so you can follow along. So this will be in the video description later. Now, here I've
00:08:06.800 | just done some preliminaries. I downloaded the dataset, the tiny Shakespeare dataset at this URL,
00:08:11.680 | and you can see that it's about a one megabyte file. Then here I opened the input.txt file and
00:08:17.120 | just read in all the text as a string. And we see that we are working with one million characters
00:08:21.920 | roughly. And the first 1000 characters, if we just print them out, are basically what you would
00:08:26.960 | expect. This is the first 1000 characters of the tiny Shakespeare dataset, roughly up to here.
00:08:32.400 | So, so far, so good. Next, we're going to take this text. And the text is a sequence of characters
00:08:39.200 | in Python. So when I call the set constructor on it, I'm just going to get the set of all the
00:08:45.360 | characters that occur in this text. And then I call list on that to create a list of those
00:08:51.600 | characters instead of just a set so that I have an ordering, an arbitrary ordering. And then I sort
00:08:57.040 | that. So basically, we get just all the characters that occur in the entire dataset, and they're
00:09:01.600 | sorted. Now, the number of them is going to be our vocabulary size. These are the possible elements
00:09:07.280 | of our sequences. And we see that when I print here the characters, there's 65 of them in total.
00:09:14.160 | There's a space character, and then all kinds of special characters, and then capitals and
00:09:19.760 | lowercase letters. So that's our vocabulary. And that's the sort of like possible characters that
00:09:25.440 | the model can see or emit. Okay, so next, we would like to develop some strategy to tokenize
00:09:32.080 | the input text. Now, when people say tokenize, they mean convert the raw text as a string
00:09:38.640 | to some sequence of integers according to some vocabulary of possible elements.
00:09:44.640 | So as an example, here, we are going to be building a character-level language model.
00:09:49.360 | So we're simply going to be translating individual characters into integers.
00:09:52.640 | So let me show you a chunk of code that sort of does that for us.
00:09:56.400 | So we're building both the encoder and the decoder. And let me just talk through what's
00:10:01.040 | happening here. When we encode an arbitrary text, like "Hi there," we're going to receive
00:10:07.840 | a list of integers that represents that string. So for example, 46, 47, etc. And then we also
00:10:15.200 | have the reverse mapping. So we can take this list and decode it to get back the exact same string.
00:10:21.600 | So it's really just like a translation to integers and back for arbitrary string. And for us,
00:10:27.440 | it is done on a character level. Now, the way this was achieved is we just iterate over all
00:10:32.640 | the characters here and create a lookup table from the character to the integer and vice versa.
00:10:37.920 | And then to encode some string, we simply translate all the characters individually.
00:10:42.320 | And to decode it back, we use the reverse mapping and concatenate all of it.
00:10:46.720 | Now, this is only one of many possible encodings or many possible sort of tokenizers. And it's
00:10:52.720 | a very simple one. But there's many other schemas that people have come up with in practice.
00:10:57.360 | So for example, Google uses Sentence Piece. So Sentence Piece will also encode text into
00:11:03.280 | integers, but in a different schema and using a different vocabulary. And Sentence Piece is a
00:11:10.720 | sub-word sort of tokenizer. And what that means is that you're not encoding entire words, but
00:11:16.880 | you're not also encoding individual characters. It's a sub-word unit level. And that's usually
00:11:22.800 | what's adopted in practice. For example, also OpenAI has this library called TicToken that
00:11:27.600 | uses a byte pair encoding tokenizer. And that's what GPT uses. And you can also just encode words
00:11:34.960 | into like Hello World into a list of integers. So as an example, I'm using the TicToken library
00:11:41.120 | here. I'm getting the encoding for GPT-2 or that was used for GPT-2. Instead of just having 65
00:11:47.600 | possible characters or tokens, they have 50,000 tokens. And so when they encode the exact same
00:11:54.720 | string, hi there, we only get a list of three integers. But those integers are not between 0
00:12:00.480 | and 64. They are between 0 and 50,256. So basically, you can trade off the codebook size
00:12:10.320 | and the sequence lengths. So you can have very long sequences of integers with very small
00:12:15.040 | vocabularies, or you can have short sequences of integers with very large vocabularies. And so
00:12:23.520 | typically people use in practice these sub-word encodings, but I'd like to keep our tokenizer
00:12:29.600 | very simple. So we're using character level tokenizer. And that means that we have very
00:12:33.760 | small codebooks. We have very simple encode and decode functions, but we do get very long
00:12:40.160 | sequences as a result. But that's the level at which we're going to stick with this lecture,
00:12:44.160 | because it's the simplest thing. Okay, so now that we have an encoder and a decoder,
00:12:48.400 | effectively a tokenizer, we can tokenize the entire training set of Shakespeare. So here's a
00:12:53.680 | chunk of code that does that. And I'm going to start to use the PyTorch library and specifically
00:12:58.000 | the torch.tensor from the PyTorch library. So we're going to take all of the text in Tiny
00:13:03.280 | Shakespeare, encode it, and then wrap it into a torch.tensor to get the data tensor. So here's
00:13:09.360 | what the data tensor looks like when I look at just the first 1000 characters or the 1000 elements
00:13:14.400 | of it. So we see that we have a massive sequence of integers. And this sequence of integers here
00:13:20.080 | is basically an identical translation of the first 1000 characters here. So I believe, for example,
00:13:26.400 | that 0 is a newline character, and maybe 1 is a space. I'm not 100% sure. But from now on,
00:13:32.560 | the entire data set of text is re-represented as just, it's just stretched out as a single,
00:13:37.040 | very large sequence of integers. Let me do one more thing before we move on here. I'd like to
00:13:42.880 | separate out our data set into a train and a validation split. So in particular, we're going
00:13:48.320 | to take the first 90% of the data set and consider that to be the training data for the transformer.
00:13:53.920 | And we're going to withhold the last 10% at the end of it to be the validation data. And this
00:13:59.360 | will help us understand to what extent our model is overfitting. So we're going to basically hide
00:14:04.000 | and keep the validation data on the side, because we don't want just a perfect memorization of this
00:14:08.880 | exact Shakespeare. We want a neural network that sort of creates Shakespeare-like text. And so it
00:14:15.200 | should be fairly likely for it to produce the actual, stowed away, true Shakespeare text.
00:14:22.640 | And so we're going to use this to get a sense of the overfitting.
00:14:27.200 | Okay, so now we would like to start plugging these text sequences or integer sequences into
00:14:31.920 | the transformer so that it can train and learn those patterns. Now, the important thing to realize
00:14:37.600 | is we're never going to actually feed entire text into a transformer all at once. That would be
00:14:42.000 | computationally very expensive and prohibitive. So when we actually train a transformer on a lot
00:14:46.960 | of these data sets, we only work with chunks of the data set. And when we train the transformer,
00:14:51.680 | we basically sample random little chunks out of the training set and train on just chunks at a
00:14:56.480 | time. And these chunks have basically some kind of a length and some maximum length. Now, the maximum
00:15:04.240 | length typically, at least in the code I usually write, is called block size. You can find it under
00:15:10.560 | different names like context length or something like that. Let's start with the block size of just
00:15:14.560 | eight. And let me look at the first train data characters, the first block size plus one
00:15:19.920 | characters. I'll explain why plus one in a second. So this is the first nine characters in the
00:15:26.000 | sequence, in the training set. Now, what I'd like to point out is that when you sample a chunk of
00:15:31.520 | data like this, so say these nine characters out of the training set, this actually has multiple
00:15:37.440 | examples packed into it. And that's because all of these characters follow each other. And so what
00:15:44.560 | this thing is going to say when we plug it into a transformer is we're going to actually
00:15:49.600 | simultaneously train it to make prediction at every one of these positions. Now, in a chunk
00:15:56.240 | of nine characters, there's actually eight individual examples packed in there. So there's
00:16:01.440 | the example that when 18, in the context of 18, 47 likely comes next. In a context of 18 and 47,
00:16:10.080 | 56 comes next. In the context of 18, 47, 56, 57 can come next, and so on. So that's the eight
00:16:18.080 | individual examples. Let me actually spell it out with code. So here's a chunk of code to illustrate.
00:16:24.160 | X are the inputs to the transformer. It will just be the first block size characters. Y will be the
00:16:31.760 | next block size characters. So it's offset by one. And that's because Y are the targets for each
00:16:38.960 | position in the input. And then here I'm iterating over all the block size of eight. And the context
00:16:45.680 | is always all the characters in X up to T and including T. And the target is always the T-th
00:16:52.720 | character, but in the targets array Y. So let me just run this. And basically it spells out what
00:16:59.600 | I said in words. These are the eight examples hidden in a chunk of nine characters that we
00:17:05.680 | sampled from the training set. I want to mention one more thing. We train on all the
00:17:13.120 | eight examples here with context between one all the way up to context of block size. And we train
00:17:19.440 | on that not just for computational reasons because we happen to have the sequence already or something
00:17:23.280 | like that. It's not just done for efficiency. It's also done to make the transformer network
00:17:29.440 | be used to seeing contexts all the way from as little as one all the way to block size.
00:17:35.120 | And we'd like the transformer to be used to seeing everything in between. And that's going to be
00:17:40.000 | useful later during inference because while we're sampling, we can start sampling generation with as
00:17:45.440 | little as one character of context. And the transformer knows how to predict the next
00:17:49.200 | character with all the way up to just context of one. And so then it can predict everything up to
00:17:54.320 | block size. And after block size, we have to start truncating because the transformer will never
00:17:59.040 | receive more than block size inputs when it's predicting the next character.
00:18:03.520 | Okay, so we've looked at the time dimension of the tensors that are going to be feeding into
00:18:08.560 | the transformer. There's one more dimension to care about, and that is the batch dimension.
00:18:12.000 | And so as we're sampling these chunks of text, we're going to be actually every time we're going
00:18:17.760 | to feed them into a transformer, we're going to have many batches of multiple chunks of text that
00:18:22.080 | are all stacked up in a single tensor. And that's just done for efficiency just so that we can keep
00:18:26.720 | the GPUs busy because they are very good at parallel processing of data. And so we just
00:18:33.920 | want to process multiple chunks all at the same time. But those chunks are processed completely
00:18:38.320 | independently, they don't talk to each other, and so on. So let me basically just generalize this
00:18:43.200 | and introduce a batch dimension. Here's a chunk of code. Let me just run it, and then I'm going
00:18:48.560 | to explain what it does. So here, because we're going to start sampling random locations in the
00:18:54.640 | data sets to pull chunks from, I am setting the seed in the random number generator so that the
00:19:01.360 | numbers I see here are going to be the same numbers you see later if you try to reproduce this.
00:19:06.320 | Now the batch size here is how many independent sequences we are processing every forward-backward
00:19:11.120 | pass of the transformer. The block size, as I explained, is the maximum context length
00:19:16.640 | to make those predictions. So let's say batch size 4, block size 8, and then here's how we get batch
00:19:22.240 | for any arbitrary split. If the split is a training split, then we're going to look at train data,
00:19:27.600 | otherwise at val data. That gives us the data array. And then when I generate random positions
00:19:35.280 | to grab a chunk out of, I actually generate batch size number of random offsets. So because this is
00:19:44.000 | 4, ix is going to be 4 numbers that are randomly generated between 0 and len of data minus block
00:19:51.360 | size. So it's just random offsets into the training set. And then x's, as I explained,
00:19:57.600 | are the first block size characters starting at i. The y's are the offset by 1 of that, so just add
00:20:06.640 | plus 1. And then we're going to get those chunks for every one of integers i in ix and use a torch
00:20:14.320 | dot stack to take all those one-dimensional tensors as we saw here, and we're going to
00:20:21.760 | stack them up as rows. And so they all become a row in a 4 by 8 tensor.
00:20:28.640 | So here's where I'm printing them. When I sample a batch xb and yb,
00:20:33.600 | the inputs to the transformer now are the input x is the 4 by 8 tensor, four rows of eight columns,
00:20:44.960 | and each one of these is a chunk of the training set. And then the targets here are in the
00:20:52.160 | associated array y, and they will come in to the transformer all the way at the end
00:20:55.920 | to create the loss function. So they will give us the correct answer for every single position
00:21:02.960 | inside x. And then these are the four independent rows. So spelled out as we did before,
00:21:12.320 | this 4 by 8 array contains a total of 32 examples, and they're completely independent
00:21:18.640 | as far as the transformer is concerned. So when the input is 24, the target is 43,
00:21:26.560 | or rather 43 here in the y array. When the input is 24, 43, the target is 58.
00:21:31.840 | When the input is 24, 43, 58, the target is 5, etc. Or like when it is a 52, 58, 1, the target is 58.
00:21:41.760 | Right, so you can sort of see this spelled out. These are the 32 independent examples
00:21:46.560 | packed in to a single batch of the input x, and then the desired targets are in y.
00:21:52.720 | And so now this integer tensor of x is going to feed into the transformer,
00:22:00.560 | and that transformer is going to simultaneously process all these examples, and then look up the
00:22:05.840 | correct integers to predict in every one of these positions in the tensor y. Okay, so now that we
00:22:12.560 | have our batch of input that we'd like to feed into a transformer, let's start basically feeding
00:22:16.960 | this into neural networks. Now we're going to start off with the simplest possible neural network,
00:22:21.760 | which in the case of language modeling, in my opinion, is the bigram language model.
00:22:25.280 | And we've covered the bigram language model in my Make More series in a lot of depth. And so here
00:22:30.320 | I'm going to sort of go faster, and let's just implement the PyTorch module directly that
00:22:34.960 | implements the bigram language model. So I'm importing the PyTorch NN module for reproducibility,
00:22:43.120 | and then here I'm constructing a bigram language model, which is a subclass of NN module.
00:22:47.680 | And then I'm calling it, and I'm passing in the inputs and the targets, and I'm just printing.
00:22:54.480 | Now when the inputs and targets come here, you see that I'm just taking the index,
00:22:59.520 | the inputs x here, which I renamed to idx, and I'm just passing them into this token embedding table.
00:23:05.120 | So what's going on here is that here in the constructor, we are creating a token embedding
00:23:10.880 | table, and it is of size vocab size by vocab size. And we're using an n-dot embedding, which is a
00:23:18.480 | very thin wrapper around basically a tensor of shape vocab size by vocab size. And what's happening
00:23:24.640 | here is that when we pass idx here, every single integer in our input is going to refer to this
00:23:30.640 | embedding table, and is going to pluck out a row of that embedding table corresponding to its index.
00:23:36.000 | So 24 here will go to the embedding table, and will pluck out the 24th row. And then 43 will go
00:23:43.040 | here and pluck out the 43rd row, etc. And then PyTorch is going to arrange all of this into a
00:23:48.880 | batch by time by channel tensor. In this case, batch is 4, time is 8, and c, which is the channels,
00:23:58.480 | is vocab size or 65. And so we're just going to pluck out all those rows, arrange them in a b by
00:24:04.320 | t by c, and now we're going to interpret this as the logits, which are basically the scores
00:24:09.520 | for the next character in the sequence. And so what's happening here is we are predicting what
00:24:15.120 | comes next based on just the individual identity of a single token. And you can do that because,
00:24:21.200 | I mean, currently the tokens are not talking to each other, and they're not seeing any context,
00:24:26.080 | except for they're just seeing themselves. So I'm a token number 5, and then I can actually
00:24:32.320 | make pretty decent predictions about what comes next just by knowing that I'm token 5,
00:24:36.560 | because some characters follow other characters in typical scenarios. So we saw a lot of this
00:24:43.680 | in a lot more depth in the MakeMore series. And here, if I just run this, then we currently get
00:24:49.120 | the predictions, the scores, the logits for every one of the 4 by 8 positions. Now that we've made
00:24:56.000 | predictions about what comes next, we'd like to evaluate the loss function. And so in MakeMore
00:25:00.320 | series, we saw that a good way to measure a loss or a quality of the predictions is to use the
00:25:06.000 | negative log likelihood loss, which is also implemented in PyTorch under the name cross
00:25:10.320 | entropy. So what we'd like to do here is loss is the cross entropy on the predictions and the
00:25:17.680 | targets. And so this measures the quality of the logits with respect to the targets. In other words,
00:25:23.600 | we have the identity of the next character, so how well are we predicting the next character
00:25:28.560 | based on the logits? And intuitively, the correct dimension of logits, depending on whatever the
00:25:36.960 | target is, should have a very high number, and all the other dimensions should be a very low number.
00:25:41.040 | Now, the issue is that this won't actually-- this is what we want. We want to basically output
00:25:47.280 | the logits and the loss. This is what we want, but unfortunately, this won't actually run.
00:25:55.040 | We get an error message. But intuitively, we want to measure this. Now, when we go to the PyTorch
00:26:03.040 | cross entropy documentation here, we're trying to call the cross entropy in its functional form.
00:26:10.800 | So that means we don't have to create a module for it. But here, when we go to the documentation,
00:26:16.000 | you have to look into the details of how PyTorch expects these inputs. And basically,
00:26:20.720 | the issue here is PyTorch expects, if you have multidimensional input, which we do because we
00:26:26.160 | have a b by t by c tensor, then it actually really wants the channels to be the second dimension
00:26:33.440 | here. So basically, it wants a b by c by t instead of a b by t by c. And so it's just the details of
00:26:43.680 | how PyTorch treats these kinds of inputs. And so we don't actually want to deal with that. So what
00:26:51.200 | we're going to do instead is we need to basically reshape our logits. So here's what I like to do.
00:26:55.280 | I like to basically give names to the dimensions. So logits.shape is b by t by c and unpack those
00:27:01.440 | numbers. And then let's say that logits equals logits.view. And we want it to be a b times t
00:27:09.920 | by c, so just a two-dimensional array. So we're going to take all of these positions here, and
00:27:19.760 | we're going to stretch them out in a one-dimensional sequence and preserve the channel dimension as
00:27:25.520 | the second dimension. So we're just kind of like stretching out the array so it's two-dimensional.
00:27:30.640 | And in that case, it's going to better conform to what PyTorch sort of expects in its dimensions.
00:27:35.440 | Now, we have to do the same to targets because currently targets are of shape b by t,
00:27:44.560 | and we want it to be just b times t, so one-dimensional. Now, alternatively, you could
00:27:49.760 | always still just do minus one because PyTorch will guess what this should be if you want to
00:27:54.320 | lay it out. But let me just be explicit and say b times t. Once we reshape this, it will match
00:28:00.400 | the cross-entropy case, and then we should be able to evaluate our loss.
00:28:04.560 | Okay, so with that right now, and we can do loss. And so currently we see that the loss is 4.87.
00:28:14.480 | Now, because we have 65 possible vocabulary elements, we can actually guess at what the loss
00:28:20.480 | should be. And in particular, we covered negative log-likelihood in a lot of detail. We are expecting
00:28:26.640 | log or ln of 1/65 and negative of that. So we're expecting the loss to be about 4.17,
00:28:37.280 | but we're getting 4.87. And so that's telling us that the initial predictions are not super diffuse.
00:28:42.880 | They've got a little bit of entropy, and so we're guessing wrong.
00:28:45.680 | So yes, but actually we are able to evaluate the loss. Okay, so now that we can evaluate the
00:28:54.000 | quality of the model on some data, we'd like to also be able to generate from the model.
00:28:59.120 | So let's do the generation. Now, I'm going to go again a little bit faster here because I covered
00:29:03.680 | all this already in previous videos. So here's a generate function for the model.
00:29:12.240 | So we take the same kind of input, idx here, and basically this is the current
00:29:20.400 | context of some characters in some batch. So it's also b by t, and the job of generate
00:29:29.040 | is to basically take this b by t and extend it to be b by t plus 1, plus 2, plus 3. And so it's
00:29:33.920 | just basically it continues the generation in all the batch dimensions in the time dimension.
00:29:39.120 | So that's its job, and it will do that for max new tokens. So you can see here on the bottom,
00:29:44.240 | there's going to be some stuff here, but on the bottom, whatever is predicted is concatenated on
00:29:49.760 | top of the previous idx along the first dimension, which is the time dimension, to create a b by t
00:29:55.120 | plus 1. So that becomes a new idx. So the job of generate is to take a b by t and make it a b by t
00:30:01.280 | plus 1, plus 2, plus 3, as many as we want max new tokens. So this is the generation from the model.
00:30:08.160 | Now inside the generation, what are we doing? We're taking the current indices,
00:30:12.000 | we're getting the predictions. So we get those are in the logits, and then the loss here is
00:30:19.200 | going to be ignored because we're not using that, and we have no targets that are sort of ground
00:30:24.400 | truth targets that we're going to be comparing with. Then once we get the logits, we are only
00:30:30.400 | focusing on the last step. So instead of a b by t by c, we're going to pluck out the negative one,
00:30:37.520 | the last element in the time dimension, because those are the predictions for what comes next.
00:30:41.840 | So that gives us the logits, which we then convert to probabilities via softmax. And then we use
00:30:47.840 | torch.multinomial to sample from those probabilities, and we ask PyTorch to give us
00:30:52.080 | one sample. And so idx next will become a b by 1, because in each one of the batch dimensions,
00:31:00.080 | we're going to have a single prediction for what comes next. So this num_samples equals 1
00:31:04.560 | will make this b a 1. And then we're going to take those integers that come from the sampling
00:31:10.400 | process according to the probability distribution given here, and those integers get just concatenated
00:31:15.360 | on top of the current sort of like running stream of integers. And this gives us a b by t plus 1.
00:31:20.720 | And then we can return that. Now one thing here is you see how I'm calling self of idx, which will
00:31:28.720 | end up going to the forward function. I'm not providing any targets, so currently this would
00:31:33.520 | give an error because targets is sort of like not given. So targets has to be optional. So targets
00:31:40.480 | is none by default. And then if targets is none, then there's no loss to create. So it's just loss
00:31:47.680 | is none. But else all of this happens and we can create a loss. So this will make it so if we have
00:31:55.840 | the targets, we provide them and get a loss. If we have no targets, we'll just get the logits.
00:32:01.360 | So this here will generate from the model. And let's take that for a ride now.
00:32:07.200 | Oops. So I have another code chunk here, which will generate for the model
00:32:13.280 | from the model. And OK, this is kind of crazy. So maybe let me
00:32:16.720 | let me break this down. So these are the idx, right?
00:32:20.960 | I'm creating a batch will be just one time will be just one.
00:32:29.600 | So I'm creating a little one by one tensor and it's holding a zero and the D type, the data type
00:32:35.520 | is integer. So zero is going to be how we kick off the generation. And remember that zero is
00:32:41.920 | the element standing for a new line character. So it's kind of like a reasonable thing to
00:32:47.520 | feed in as the very first character in a sequence to be the new line. So it's going to be idx,
00:32:54.080 | which we're going to feed in here. Then we're going to ask for 100 tokens
00:32:58.640 | and then end that generate will continue that. Now, because generate works on the level of batches,
00:33:06.160 | we then have to index into the zero throw to basically unplug the single batch dimension
00:33:13.760 | that exists. And then that gives us a time steps, just a one dimensional array of all the indices,
00:33:21.600 | which we will convert to simple Python list from PyTorch tensor so that that can feed into
00:33:28.560 | our decode function and convert those integers into text. So let me bring this back and we're
00:33:35.680 | generating a hundred tokens. Let's run. And here's the generation that we achieved.
00:33:41.200 | So obviously it's garbage. And the reason it's garbage is because this is a totally random model.
00:33:46.400 | So next up, we're going to want to train this model. Now, one more thing I wanted to point
00:33:50.320 | out here is this function is written to be general, but it's kind of like ridiculous right now because
00:33:58.000 | we're feeding in all this, we're building out this context and we're concatenating it all.
00:34:02.800 | And we're always feeding it all into the model. But that's kind of ridiculous because this is
00:34:08.560 | just a simple bigram model. So to make, for example, this prediction about K,
00:34:11.920 | we only needed this W, but actually what we fed into the model is we fed the entire sequence.
00:34:17.520 | And then we only looked at the very last piece and predicted K. So the only reason I'm writing
00:34:23.440 | it in this way is because right now this is a bigram model, but I'd like to keep this function
00:34:28.320 | fixed. And I'd like it to work later when our characters actually basically look further in
00:34:36.000 | the history. And so right now the history is not used. So this looks silly, but eventually the
00:34:41.120 | history will be used. And so that's why we want to do it this way. So just a quick comment on that.
00:34:46.720 | So now we see that this is random. So let's train the model. So it becomes a bit less random.
00:34:53.280 | Okay, let's now train the model. So first what I'm going to do is I'm going to create a PyTorch
00:34:57.760 | optimization object. So here we are using the optimizer AdamW. Now in the Makemore series,
00:35:05.360 | we've only ever used stochastic gradient descent, the simplest possible optimizer, which you can get
00:35:09.520 | using the SGD instead. But I want to use Adam, which is a much more advanced and popular optimizer
00:35:14.880 | and it works extremely well. For a typical good setting for the learning rate is roughly 3e-4.
00:35:21.760 | But for very, very small networks, like is the case here, you can get away with much,
00:35:25.520 | much higher learning rates, 1e-3 or even higher probably. But let me create the optimizer object,
00:35:31.280 | which will basically take the gradients and update the parameters using the gradients.
00:35:36.400 | And then here, our batch size up above was only 4. So let me actually use something bigger,
00:35:42.720 | let's say 32. And then for some number of steps, we are sampling a new batch of data,
00:35:48.480 | we're evaluating the loss, we're zeroing out all the gradients from the previous step,
00:35:53.200 | getting the gradients for all the parameters, and then using those gradients to update our
00:35:57.920 | parameters. So typical training loop, as we saw in the Makemore series. So let me now run this
00:36:04.400 | for say 100 iterations and let's see what kind of losses we're going to get.
00:36:08.400 | So we started around 4.7 and now we're getting down to like 4.6, 4.5, etc. So the optimization
00:36:18.320 | is definitely happening, but let's sort of try to increase the number of iterations and only
00:36:24.800 | print at the end, because we probably will not train for longer. Okay, so we're down to 3.6,
00:36:31.840 | roughly. Roughly down to 3. This is the most janky optimization.
00:36:47.200 | Okay, it's working. Let's just do 10,000.
00:36:49.040 | And then from here, we want to copy this. And hopefully, we're going to get something
00:36:57.280 | reasonable. And of course, it's not going to be Shakespeare from a bigram model, but at least
00:37:01.280 | we see that the loss is improving. And hopefully, we're expecting something a bit more reasonable.
00:37:06.480 | Okay, so we're down at about 2.5-ish. Let's see what we get. Okay, dramatic improvements,
00:37:13.600 | certainly on what we had here. So let me just increase the number of tokens.
00:37:18.240 | Okay, so we see that we're starting to get something at least like reasonable-ish.
00:37:23.680 | Certainly not Shakespeare, but the model is making progress.
00:37:30.640 | So that is the simplest possible model. So now what I'd like to do is...
00:37:35.440 | Obviously, this is a very simple model because the tokens are not talking to each other.
00:37:41.600 | So given the previous context of whatever was generated, we're only looking at the very last
00:37:46.160 | character to make the predictions about what comes next. So now these tokens have to start
00:37:51.280 | talking to each other and figuring out what is in the context so that they can make better
00:37:56.000 | predictions for what comes next. And this is how we're going to kick off the transformer.
00:38:00.400 | Okay, so next, I took the code that we developed in this Jupyter notebook,
00:38:03.360 | and I converted it to be a script. And I'm doing this because I just want to simplify
00:38:08.640 | our intermediate work into just the final product that we have at this point.
00:38:11.760 | So in the top here, I put all the hyperparameters that we've defined.
00:38:16.640 | I introduced a few, and I'm going to speak to that in a little bit.
00:38:19.360 | Otherwise, a lot of this should be recognizable. Reproducibility, read data,
00:38:25.200 | get the encoder and the decoder, create the train and test splits, use the data loader
00:38:31.440 | that gets a batch of the inputs and targets. This is new, and I'll talk about it in a second.
00:38:38.800 | Now, this is the background language model that we developed, and it can forward and give us a
00:38:43.120 | logits and loss, and it can generate. And then here, we are creating the optimizer,
00:38:48.720 | and this is the training loop. So everything here should look pretty familiar. Now,
00:38:54.480 | some of the small things that I added, number one, I added the ability to run on a GPU if you have it.
00:39:00.560 | So if you have a GPU, then this will use CUDA instead of just CPU, and everything will be a
00:39:05.520 | lot more faster. Now, when device becomes CUDA, then we need to make sure that when we load the
00:39:11.040 | data, we move it to device. When we create the model, we want to move the model parameters to
00:39:17.840 | device. So as an example, here we have the NN embedding table, and it's got a dot weight inside
00:39:24.080 | it, which stores the lookup table. So that would be moved to the GPU so that all the calculations
00:39:30.160 | here happen on the GPU, and they can be a lot faster. And then finally here, when I'm creating
00:39:35.120 | the context that feeds it to generate, I have to make sure that I create on the device.
00:39:39.200 | Number two, what I introduced is the fact that here in the training loop,
00:39:45.680 | here I was just printing the loss.item inside the training loop, but this is a very noisy
00:39:53.840 | measurement of the current loss because every batch will be more or less lucky.
00:39:58.400 | And so what I want to do usually is I have an estimate loss function, and the estimate loss
00:40:05.360 | basically then goes up here, and it averages up the loss over multiple batches. So in particular,
00:40:14.560 | we're going to iterate eval_iter_times, and we're going to basically get our loss, and then we're
00:40:19.520 | going to get the average loss for both splits. And so this will be a lot less noisy. So here,
00:40:25.280 | when we call the estimate loss, we're going to report the pretty accurate train and validation
00:40:30.240 | loss. Now, when we come back up, you'll notice a few things here. I'm setting the model to
00:40:35.680 | evaluation phase, and down here I'm resetting it back to training phase. Now, right now for
00:40:41.520 | our model as is, this doesn't actually do anything because the only thing inside this model is this
00:40:46.720 | nn.embedding, and this network would behave the same in both evaluation mode and training mode.
00:40:56.320 | We have no dropout layers, we have no batch norm layers, etc. But it is a good practice to think
00:41:01.040 | through what mode your neural network is in because some layers will have different behavior
00:41:06.400 | at inference time or training time. And there's also this context manager, torch.nograd,
00:41:13.680 | and this is just telling PyTorch that everything that happens inside this function,
00:41:17.280 | we will not call .backward on. And so PyTorch can be a lot more efficient with its memory use
00:41:23.200 | because it doesn't have to store all the intermediate variables because we're never
00:41:27.120 | going to call backward. And so it can be a lot more memory efficient in that way.
00:41:31.680 | So also a good practice to tell PyTorch when we don't intend to do backpropagation.
00:41:37.600 | So right now, this script is about 120 lines of code, and that's kind of our starter code.
00:41:44.320 | I'm calling it bigram.py, and I'm going to release it later. Now running this script
00:41:49.760 | gives us output in the terminal, and it looks something like this.
00:41:53.600 | It basically, as I ran this code, it was giving me the train loss and the val loss,
00:41:59.520 | and we see that we convert to somewhere around 2.5 with the bigram model. And then here's the
00:42:04.960 | sample that we produced at the end. And so we have everything packaged up in the script,
00:42:10.640 | and we're in a good position now to iterate on this. Okay, so we are almost ready to start
00:42:15.280 | writing our very first self-attention block for processing these tokens. Now,
00:42:21.280 | before we actually get there, I want to get you used to a mathematical trick that is used in the
00:42:26.800 | self-attention inside a transformer, and is really just at the heart of an efficient implementation
00:42:32.880 | of self-attention. And so I want to work with this toy example to just get you used to this
00:42:37.360 | operation, and then it's going to make it much more clear once we actually get to it in the
00:42:43.360 | script again. So let's create a b_t_t_c, where b, t, and c are just 4, 8, and 2 in this toy example.
00:42:50.240 | And these are basically channels, and we have batches, and we have the time component,
00:42:56.640 | and we have some information at each point in the sequence, so c. Now, what we would like to do is
00:43:03.280 | we would like these tokens, so we have up to 8 tokens here in a batch, and these 8 tokens are
00:43:09.600 | currently not talking to each other, and we would like them to talk to each other. We'd like to
00:43:13.040 | couple them. And in particular, we want to couple them in a very specific way. So the token, for
00:43:20.960 | example, at the fifth location, it should not communicate with tokens in the sixth, seventh,
00:43:26.000 | and eighth location, because those are future tokens in the sequence. The token on the fifth
00:43:31.760 | location should only talk to the one in the fourth, third, second, and first. So information
00:43:37.600 | only flows from previous context to the current time step, and we cannot get any information from
00:43:42.640 | the future, because we are about to try to predict the future. So what is the easiest way for tokens
00:43:49.600 | to communicate? The easiest way, I would say, is if we're a fifth token and I'd like to communicate
00:43:56.640 | with my past, the simplest way we can do that is to just do an average of all the preceding elements.
00:44:06.000 | So for example, if I'm the fifth token, I would like to take the channels that make up, that are
00:44:12.080 | information at my step, but then also the channels from the fourth step, third step, second step,
00:44:16.880 | and the first step, I'd like to average those up, and then that would become sort of like a
00:44:21.200 | feature vector that summarizes me in the context of my history. Now, of course, just doing a sum,
00:44:27.360 | or like an average, is an extremely weak form of interaction. Like this communication is extremely
00:44:32.160 | lossy. We've lost a ton of information about the spatial arrangements of all those tokens,
00:44:36.160 | but that's okay for now. We'll see how we can bring that information back later.
00:44:40.080 | For now, what we would like to do is, for every single batch element independently,
00:44:45.920 | for every t-th token in that sequence, we'd like to now calculate the average of all the
00:44:52.880 | vectors in all the previous tokens, and also at this token. So let's write that out.
00:44:58.320 | I have a small snippet here, and instead of just fumbling around,
00:45:02.720 | let me just copy paste it and talk to it. So in other words, we're going to create x,
00:45:09.760 | and B-O-W is short for bag of words, because bag of words is kind of like a term that people use
00:45:17.360 | when you are just averaging up things. So this is just a bag of words. Basically, there's a word
00:45:21.760 | stored on every one of these eight locations, and we're doing a bag of words, we're just averaging.
00:45:26.080 | So in the beginning, we're going to say that it's just initialized at zero,
00:45:30.480 | and then I'm doing a for loop here, so we're not being efficient yet. That's coming. But for now,
00:45:34.800 | we're just iterating over all the batch dimensions independently, iterating over time, and then
00:45:40.560 | the previous tokens are at this batch dimension, and then everything up to and including the t-th
00:45:48.720 | token. So when we slice out x in this way, xprev becomes of shape how many t elements there were
00:45:58.640 | in the past, and then of course, c, so all the two-dimensional information from these little
00:46:03.600 | tokens. So that's the previous sort of chunk of tokens from my current sequence. And then I'm
00:46:12.160 | just doing the average, or the mean, over the zero of dimensions. So I'm averaging out the time here,
00:46:17.520 | and I'm just going to get a little c one-dimensional vector, which I'm going to
00:46:22.080 | store in x bag of words. So I can run this, and this is not going to be very informative, because
00:46:31.040 | let's see, so this is x of zero, so this is the zeroth batch element, and then xbow at zero.
00:46:36.400 | Now you see how at the first location here, you see that the two are equal, and that's because
00:46:44.320 | we're just doing an average of this one token. But here, this one is now an average of these two,
00:46:50.560 | and now this one is an average of these three, and so on. And this last one is the average
00:47:01.120 | of all of these elements, so vertical average, just averaging up all the tokens, now gives this
00:47:06.240 | outcome here. So this is all well and good, but this is very inefficient. Now the trick is that
00:47:13.040 | we can be very, very efficient about doing this using matrix multiplication. So that's the
00:47:18.000 | mathematical trick, and let me show you what I mean. Let's work with the toy example here.
00:47:22.000 | Let me run it, and I'll explain. I have a simple matrix here that is a three by three of all ones.
00:47:29.920 | A matrix B of just random numbers, and it's a three by two, and a matrix C, which will be three
00:47:34.800 | by three multiply three by two, which will give out a three by two. So here we're just using
00:47:40.240 | matrix multiplication. So A multiply B gives us C. Okay, so how are these numbers in C
00:47:50.080 | achieved, right? So this number in the top left is the first row of A dot product with the first
00:47:58.320 | column of B. And since all the row of A right now is all just ones, then the dot product here with
00:48:05.680 | this column of B is just going to do a sum of this column. So two plus six plus six is 14.
00:48:12.400 | The element here in the output of C is also the first column here, the first row of A,
00:48:18.320 | multiplied now with the second column of B. So seven plus four plus five is 16. Now you see that
00:48:25.280 | there's repeating elements here. So this 14 again is because this row is again all ones,
00:48:29.520 | and it's multiplying the first column of B. So we get 14. And this one is, and so on. So this last
00:48:35.360 | number here is the last row dot product last column. Now the trick here is the following.
00:48:42.400 | This is just a boring number of, it's just a boring array of all ones. But Torch has this
00:48:49.840 | function called trill, which is short for a triangular, something like that. And you can
00:48:56.880 | wrap it in Torch dot ones, and it will just return the lower triangular portion of this.
00:49:01.520 | So now it will basically zero out these guys here. So we just get the lower triangular part.
00:49:09.520 | Well, what happens if we do that?
00:49:11.680 | So now we'll have A like this and B like this. And now what are we getting here in C?
00:49:20.240 | Well, what is this number? Well, this is the first row times the first column. And because
00:49:25.680 | this is zeros, these elements here are now ignored. So we just get a two. And then this
00:49:32.720 | number here is the first row times the second column. And because these are zeros, they get
00:49:37.760 | ignored, and it's just seven. The seven multiplies this one. But look what happened here. Because
00:49:43.520 | this is one and then zeros, what ended up happening is we're just plucking out the row,
00:49:48.800 | this row of B, and that's what we got. Now here we have one, one, zero. So here,
00:49:55.840 | one, one, zero dot product with these two columns will now give us two plus six, which is eight,
00:50:00.880 | and seven plus four, which is 11. And because this is one, one, one, we ended up with the addition
00:50:07.040 | of all of them. And so basically, depending on how many ones and zeros we have here, we are
00:50:12.720 | basically doing a sum currently of the variable number of these rows, and that gets deposited
00:50:19.520 | into C. So currently, we're doing sums because these are ones, but we can also do average,
00:50:25.680 | right? And you can start to see how we could do average of the rows of B sort of in an incremental
00:50:32.160 | fashion. Because we don't have to, we can basically normalize these rows so that they sum to one,
00:50:38.400 | and then we're going to get an average. So if we took A, and then we did A equals A divide,
00:50:43.520 | torch dot sum of A in the one-th dimension, and then let's keep dim as true. So therefore,
00:50:56.800 | the broadcasting will work out. So if I rerun this, you see now that these rows now sum to one.
00:51:03.600 | So this row is one, this row is 0.5, 0.5 is zero, and here we get one-thirds. And now when we do A
00:51:09.760 | multiply B, what are we getting? Here we are just getting the first row, first row. Here now we are
00:51:16.320 | getting the average of the first two rows. Okay, so two and six average is four, and four and seven
00:51:23.920 | average is 5.5. And on the bottom here, we are now getting the average of these three rows.
00:51:31.360 | So the average of all of elements of B are now deposited here. And so you can see that by
00:51:37.600 | manipulating these elements of this multiplying matrix, and then multiplying it with any given
00:51:44.400 | matrix, we can do these averages in this incremental fashion. Because we just get,
00:51:49.680 | and we can manipulate that based on the elements of A. Okay, so that's very convenient. So let's
00:51:56.640 | swing back up here and see how we can vectorize this and make it much more efficient using what
00:52:00.880 | we've learned. So in particular, we are going to produce an array A, but here I'm going to call it
00:52:07.840 | "weigh," short for "weights." But this is our A, and this is how much of every row we want to average
00:52:15.920 | up. And it's going to be an average because you can see that these rows sum to one. So this is our
00:52:21.600 | A, and then our B in this example, of course, is X. So what's going to happen here now is that we
00:52:29.520 | are going to have an expo two. And this expo two is going to be weigh multiplying our X.
00:52:38.000 | So let's think this through. Weigh is T by T, and this is matrix multiplying in PyTorch,
00:52:45.200 | a B by T by C. And it's giving us what shape. So PyTorch will come here and it will see that
00:52:53.360 | these shapes are not the same. So it will create a batch dimension here. And this is a batch matrix
00:52:59.280 | multiply. And so it will apply this matrix multiplication in all the batch elements in
00:53:05.040 | parallel and individually. And then for each batch element, there will be a T by T multiplying T by C
00:53:11.760 | exactly as we had below. So this will now create B by T by C, and expo two will now become identical
00:53:24.480 | to expo. So we can see that torch.allclose of expo and expo two should be true. Now,
00:53:35.360 | so this kind of like convinces us that these are in fact the same. So expo and expo two,
00:53:44.880 | if I just print them, okay, we're not going to be able to just stare it down. But
00:53:52.960 | well, let me try expo basically just at the zeroth element and expo two at the zeroth element. So
00:53:57.360 | just the first batch, and we should see that this and that should be identical, which they are.
00:54:02.400 | Right. So what happened here? The trick is we were able to use batch matrix multiply
00:54:08.720 | to do this aggregation, really. And it's a weighted aggregation. And the weights are specified in this
00:54:18.160 | T by T array. And we're basically doing weighted sums. And these weighted sums are according to
00:54:25.920 | the weights inside here, they take on sort of this triangular form. And so that means that a token at
00:54:32.640 | the T dimension will only get sort of information from the tokens preceding it. So that's exactly
00:54:40.240 | what we want. And finally, I would like to rewrite it in one more way. And we're going to see why
00:54:45.520 | that's useful. So this is the third version. And it's also identical to the first and second.
00:54:51.280 | But let me talk through it. It uses softmax. So trill here is this matrix, lower triangular ones,
00:55:00.560 | way begins as all zero. Okay, so if I just print way in the beginning, it's all zero,
00:55:09.520 | then I use masked fill. So what this is doing is way dot masked fill, it's all zeros. And I'm saying
00:55:17.840 | for all the elements where trill is equal to equal zero, make them be negative infinity.
00:55:23.760 | So all the elements where trill is zero will become negative infinity now.
00:55:28.080 | So this is what we get. And then the final one here is softmax.
00:55:36.240 | So if I take a softmax along every single, so dim is negative one, so along every single row,
00:55:40.800 | if I do a softmax, what is that going to do? Well, softmax is also like a normalization operation,
00:55:52.160 | right? And so spoiler alert, you get the exact same matrix. Let me bring back the softmax.
00:55:59.120 | And recall that in softmax, we're going to exponentiate every single one of these.
00:56:04.720 | And then we're going to divide by the sum. And so if we exponentiate every single element here,
00:56:09.760 | we're going to get a one. And here we're going to get basically zero, zero, zero, zero, everywhere
00:56:15.040 | else. And then when we normalize, we just get one. Here, we're going to get one, one, and then
00:56:20.960 | zeros. And the softmax will again divide, and this will give us 0.5, 0.5, and so on. And so this is
00:56:27.680 | also the same way to produce this mask. Now, the reason that this is a bit more interesting,
00:56:34.560 | and the reason we're going to end up using it in self-attention, is that these weights here begin
00:56:41.120 | with zero. And you can think of this as like an interaction strength, or like an affinity. So
00:56:47.600 | basically, it's telling us how much of each token from the past do we want to aggregate and average
00:56:54.720 | up. And then this line is saying, tokens from the past cannot communicate. By setting them to
00:57:02.160 | negative infinity, we're saying that we will not aggregate anything from those tokens. And so
00:57:08.000 | basically, this then goes through softmax, and through the weighted, and this is the aggregation
00:57:11.840 | through matrix multiplication. And so what this is now is, you can think of these as, these zeros
00:57:19.920 | are currently just set by us to be zero. But a quick preview is that these affinities between
00:57:26.240 | the tokens are not going to be just constant at zero, they're going to be data dependent. These
00:57:31.280 | tokens are going to start looking at each other, and some tokens will find other tokens more or
00:57:35.920 | less interesting. And depending on what their values are, they're going to find each other
00:57:40.880 | interesting to different amounts, and I'm going to call those affinities, I think. And then here,
00:57:45.680 | we are saying, the future cannot communicate with the past. We're going to clamp them. And then when
00:57:51.440 | we normalize and sum, we're going to aggregate their values, depending on how interesting they
00:57:57.120 | find each other. And so that's the preview for self-attention. And basically, long story short
00:58:03.280 | from this entire section is that you can do weighted aggregations of your past elements
00:58:09.440 | by using matrix multiplication of a lower triangular fashion. And then the elements
00:58:16.640 | here in the lower triangular part are telling you how much of each element fuses into this position.
00:58:22.480 | So we're going to use this trick now to develop the self-attention block.
00:58:26.240 | So first, let's get some quick preliminaries out of the way. First, the thing I'm kind of
00:58:31.200 | bothered by is that you see how we're passing in vocab size into the constructor? There's no need
00:58:35.280 | to do that because vocab size is already defined up top as a global variable. So there's no need
00:58:39.840 | to pass this stuff around. Next, what I want to do is I don't want to actually create, I want to
00:58:46.160 | create like a level of indirection here where we don't directly go to the embedding for the logits,
00:58:51.920 | but instead we go through this intermediate phase because we're going to start making that bigger.
00:58:56.400 | So let me introduce a new variable, nembed. It's short for number of embedding dimensions.
00:59:03.520 | So nembed here will be, say, 32. That was a suggestion from GitHub Copilot, by the way.
00:59:11.120 | It also suggested 32, which is a good number. So this is an embedding table and only 32-dimensional
00:59:18.160 | embeddings. So then here, this is not going to give us logits directly. Instead, this is going
00:59:23.760 | to give us token embeddings. That's what I'm going to call it. And then to go from the token embeddings
00:59:28.720 | to the logits, we're going to need a linear layer. So self.lmhead, let's call it, short for language
00:59:34.560 | modeling head, is nnlinear from nembed up to vocab size. And then when we swing over here,
00:59:41.040 | we're actually going to get the logits by exactly what the Copilot says. Now, we have to be careful
00:59:46.720 | here because this c and this c are not equal. This is nembed c and this is vocab size. So let's just
00:59:55.440 | say that nembed is equal to c. And then this just creates one spurious layer of interaction through
01:00:02.160 | a linear layer, but this should basically run. So we see that this runs and this currently looks
01:00:15.600 | kind of spurious, but we're going to build on top of this. Now, next up. So far, we've taken these
01:00:21.280 | indices and we've encoded them based on the identity of the tokens inside IDX. The next
01:00:28.320 | thing that people very often do is that we're not just encoding the identity of these tokens,
01:00:32.800 | but also their position. So we're going to have a second position embedding table here.
01:00:38.160 | So self.position_embedding_table is an embedding of block size by nembed. And so each position
01:00:44.960 | from zero to block size minus one will also get its own embedding vector. And then here,
01:00:50.400 | first, let me decode b by t from IDX.shape. And then here, we're also going to have a
01:00:56.560 | pause embedding, which is the positional embedding. And this is tor-arrange. So this will be
01:01:02.560 | basically just integers from zero to t minus one. And all of those integers from zero to t minus one
01:01:08.160 | get embedded through the table to create a t by c. And then here, this gets renamed to just say x.
01:01:16.160 | And x will be the addition of the token embeddings with the positional embeddings.
01:01:20.560 | And here, the broadcasting node will work out. So b by t by c plus t by c, this gets right-aligned,
01:01:27.760 | a new dimension of one gets added, and it gets broadcasted across batch.
01:01:31.360 | So at this point, x holds not just the token identities, but the positions at which these
01:01:37.920 | tokens occur. And this is currently not that useful because, of course, we just have a simple
01:01:42.480 | bigram model. So it doesn't matter if you're in the fifth position, the second position, or wherever.
01:01:46.880 | It's all translation invariant at this stage. So this information currently wouldn't help.
01:01:51.200 | But as we work on the self-attention block, we'll see that this starts to matter.
01:01:55.520 | Okay, so now we get the crux of self-attention. So this is probably the most important part of
01:02:03.840 | this video to understand. We're going to implement a small self-attention for a single individual
01:02:09.280 | head, as they're called. So we start off with where we were. So all of this code is familiar.
01:02:14.400 | So right now, I'm working with an example where I changed the number of channels from 2 to 32.
01:02:19.920 | So we have a 4 by 8 arrangement of tokens. And the information at each token is currently 32
01:02:27.360 | dimensional. But we just are working with random numbers. Now, we saw here that the code as we had
01:02:34.480 | it before does a simple weight, simple average of all the past tokens and the current token. So it's
01:02:42.800 | just the previous information and current information is just being mixed together in an
01:02:46.080 | average. And that's what this code currently achieves. And it does so by creating this lower
01:02:50.800 | triangular structure, which allows us to mask out this weight matrix that we create. So we mask it
01:02:59.040 | out and then we normalize it. And currently, when we initialize the affinities between all the
01:03:05.040 | different sort of tokens or nodes, I'm going to use those terms interchangeably. So when we
01:03:11.200 | initialize the affinities between all the different tokens to be 0, then we see that
01:03:15.680 | weight gives us this structure where every single row has these uniform numbers. And so that's what
01:03:23.760 | then in this matrix multiply makes it so that we're doing a simple average. Now, we don't actually
01:03:31.360 | want this to be all uniform because different tokens will find different other tokens more or
01:03:38.880 | less interesting. And we want that to be data dependent. So, for example, if I'm a vowel,
01:03:42.880 | then maybe I'm looking for consonants in my past and maybe I want to know what those consonants
01:03:48.000 | are and I want that information to flow to me. And so I want to now gather information from the
01:03:53.520 | past, but I want to do it in a data dependent way. And this is the problem that self-attention
01:03:57.840 | solves. Now, the way self-attention solves this is the following. Every single node or every single
01:04:04.400 | token at each position will emit two vectors. It will emit a query and it will emit a key.
01:04:13.520 | Now, the query vector, roughly speaking, is what am I looking for? And the key vector,
01:04:19.200 | roughly speaking, is what do I contain? And then the way we get affinities between these tokens now
01:04:26.960 | in a sequence is we basically just do a dot product between the keys and the queries. So,
01:04:32.880 | my query dot products with all the keys of all the other tokens and that dot product now becomes
01:04:40.320 | way. And so if the key and the query are sort of aligned, they will interact to a very high amount
01:04:48.880 | and then I will get to learn more about that specific token as opposed to any other token
01:04:55.040 | in the sequence. So, let's implement this now. We're going to implement a single
01:05:04.560 | what's called head of self-attention. So, this is just one head. There's a hyperparameter involved
01:05:10.800 | with these heads, which is the head size. And then here I'm initializing linear modules and
01:05:16.720 | I'm using bias equals false. So, these are just going to apply a matrix multiply with some fixed
01:05:21.040 | weights. And now let me produce a key and queue, k and q, by forwarding these modules on x. So,
01:05:30.960 | the size of this will now become b by t by 16 because that is the head size and the same here,
01:05:38.640 | b by t by 16. So, this being the head size. So, you see here that when I forward this linear
01:05:49.760 | on top of my x, all the tokens in all the positions in the b by t arrangement, all of them
01:05:55.760 | in parallel and independently produce a key and a query. So, no communication has happened yet.
01:06:01.040 | But the communication comes now. All the queries will dot product with all the keys.
01:06:06.960 | So, basically what we want is we want way now or the affinities between these to be query
01:06:14.480 | multiplying key. But we have to be careful with, we can't matrix multiply this. We actually need
01:06:19.440 | to transpose k, but we have to be also careful because these are, when you have the batch
01:06:25.760 | dimension. So, in particular, we want to transpose the last two dimensions, dimension negative one
01:06:32.160 | and dimension negative two. So, negative two, negative one. And so, this matrix multiply now
01:06:39.840 | will basically do the following b by t by 16. Matrix multiplies b by 16 by t to give us b by t
01:06:51.600 | by t. Right? So, for every row of b, we're now going to have a t square matrix giving us the
01:07:00.960 | affinities. And these are now the way. So, they're not zeros. They are now coming from this dot
01:07:07.120 | product between the keys and the queries. So, this can now run. I can run this. And the weighted
01:07:13.760 | aggregation now is a function in a data-dependent manner between the keys and queries of these
01:07:18.960 | nodes. So, just inspecting what happened here, the way takes on this form. And you see that
01:07:27.440 | before way was just a constant. So, it was applied in the same way to all the batch elements. But
01:07:33.040 | now every single batch element will have different sort of way because every single batch element
01:07:38.480 | contains different tokens at different positions. And so, this is now data-dependent. So, when we
01:07:44.480 | look at just the zeroth row, for example, in the input, these are the weights that came out. And
01:07:50.960 | so, you can see now that they're not just exactly uniform. And in particular, as an example here
01:07:56.320 | for the last row, this was the eighth token. And the eighth token knows what content it has, and
01:08:01.680 | it knows at what position it's in. And now the eighth token, based on that, creates a query.
01:08:08.400 | Hey, I'm looking for this kind of stuff. I'm a vowel. I'm on the eighth position. I'm looking
01:08:12.880 | for any consonants at positions up to four. And then all the nodes get to emit keys. And maybe
01:08:19.760 | one of the channels could be I am a consonant, and I am in a position up to four. And that key
01:08:26.080 | would have a high number in that specific channel. And that's how the query and the key when they
01:08:30.880 | dark product, they can find each other and create a high affinity. And when they have a high affinity,
01:08:35.760 | like say this token was pretty interesting to this eighth token, when they have a high affinity,
01:08:43.680 | then through the softmax, I will end up aggregating a lot of its information into my position. And so,
01:08:49.360 | I'll get to learn a lot about it. Now, we're looking at way after this has already happened.
01:08:59.120 | Let me erase this operation as well. So, let me erase the masking and the softmax,
01:09:03.120 | just to show you the under the hood internals and how that works. So, without the masking
01:09:07.840 | and the softmax, way comes out like this, right? This is the outputs of the dark products.
01:09:12.320 | And these are the raw outputs, and they take on values from negative two to positive two,
01:09:17.920 | et cetera. So, that's the raw interactions and raw affinities between all the nodes.
01:09:24.160 | But now, if I'm a fifth node, I will not want to aggregate anything from the sixth node,
01:09:29.360 | seventh node, and the eighth node. So, actually, we use the upper triangular masking. So, those
01:09:35.200 | are not allowed to communicate. And now, we actually want to have a nice distribution.
01:09:41.680 | So, we don't want to aggregate negative 0.11 of this node. That's crazy. So, instead,
01:09:46.880 | we exponentiate and normalize. And now, we get a nice distribution that sums to one.
01:09:51.440 | And this is telling us now in a data-dependent manner how much of information to aggregate
01:09:55.600 | from any of these tokens in the past. So, that's way, and it's not zeros anymore,
01:10:02.160 | but it's calculated in this way. Now, there's one more part to a single self-attention head.
01:10:10.080 | And that is that when we do the aggregation, we don't actually aggregate the tokens exactly.
01:10:14.240 | We aggregate, we produce one more value here, and we call that the value.
01:10:21.040 | So, in the same way that we produced key and query, we're also going to create a value.
01:10:24.640 | And then, here, we don't aggregate x. We calculate a v, which is just achieved by propagating this
01:10:36.400 | linear on top of x again. And then, we output way multiplied by v. So, v is the elements that
01:10:44.720 | we aggregate, or the vector that we aggregate, instead of the raw x. And now, of course, this
01:10:51.040 | will make it so that the output here of the single head will be 16-dimensional, because that is the
01:10:55.680 | head size. So, you can think of x as kind of like private information to this token,
01:11:01.600 | if you think about it that way. So, x is kind of private to this token. So, I'm a fifth token,
01:11:06.720 | and I have some identity, and my information is kept in vector x. And now, for the purposes of
01:11:14.400 | the single head, here's what I'm interested in, here's what I have, and if you find me interesting,
01:11:21.200 | here's what I will communicate to you. And that's stored in v. And so, v is the thing that gets
01:11:26.320 | aggregated for the purposes of this single head between the different nodes. And that's basically
01:11:33.840 | the self-attention mechanism. This is what it does. There are a few notes that I would like to make
01:11:40.080 | about attention. Number one, attention is a communication mechanism. You can really think
01:11:45.680 | about it as a communication mechanism where you have a number of nodes in a directed graph,
01:11:50.240 | where basically you have edges pointed between nodes like this. And what happens is every node
01:11:56.480 | has some vector of information, and it gets to aggregate information via a weighted sum from all
01:12:02.160 | of the nodes that point to it. And this is done in a data-dependent manner, so depending on whatever
01:12:07.760 | data is actually stored at each node at any point in time. Now, our graph doesn't look like this.
01:12:13.840 | Our graph has a different structure. We have eight nodes because the block size is eight,
01:12:18.240 | and there's always eight tokens. And the first node is only pointed to by itself. The second
01:12:25.040 | node is pointed to by the first node and itself, all the way up to the eighth node, which is pointed
01:12:30.160 | to by all the previous nodes and itself. And so, that's the structure that our directed graph has,
01:12:36.240 | or happens to have, in an autoregressive sort of scenario like language modeling.
01:12:40.400 | But in principle, attention can be applied to any arbitrary directed graph, and it's just a
01:12:44.560 | communication mechanism between the nodes. The second node is that, notice that there's no notion
01:12:49.680 | of space. So, attention simply acts over a set of vectors in this graph. And so, by default,
01:12:56.400 | these nodes have no idea where they are positioned in the space. And that's why we need to encode
01:13:00.480 | them positionally and sort of give them some information that is anchored to a specific
01:13:04.800 | position so that they sort of know where they are. And this is different than, for example,
01:13:09.840 | from convolution, because if you run, for example, a convolution operation over some input,
01:13:13.840 | there is a very specific sort of layout of the information in space, and the convolutional
01:13:19.280 | filters sort of act in space. And so, it's not like an attention. An attention is just a set
01:13:25.920 | of vectors out there in space. They communicate. And if you want them to have a notion of space,
01:13:30.720 | you need to specifically add it, which is what we've done when we calculated the
01:13:34.480 | positional encodings and added that information to the vectors.
01:13:40.160 | The next thing that I hope is very clear is that the elements across the batch dimension,
01:13:44.400 | which are independent examples, never talk to each other. They're always processed independently.
01:13:48.480 | And this is a batched matrix multiply that applies basically a matrix multiplication
01:13:52.400 | kind of in parallel across the batch dimension. So, maybe it would be more accurate to say that
01:13:57.280 | in this analogy of a directed graph, we really have, because the batch size is four,
01:14:02.000 | we really have four separate pools of eight nodes, and those eight nodes only talk to each other.
01:14:07.200 | But in total, there's like 32 nodes that are being processed, but there's sort of four separate
01:14:12.480 | pools of eight. You can look at it that way. The next note is that here in the case of language
01:14:18.240 | modeling, we have this specific structure of directed graph where the future tokens will
01:14:24.320 | not communicate to the past tokens. But this doesn't necessarily have to be the constraint
01:14:28.960 | in the general case. And in fact, in many cases, you may want to have all of the nodes talk to each
01:14:34.800 | other fully. So, as an example, if you're doing sentiment analysis or something like that with
01:14:39.200 | a transformer, you might have a number of tokens and you may want to have them all talk to each
01:14:43.920 | other fully because later you are predicting, for example, the sentiment of the sentence.
01:14:48.320 | And so, it's okay for these nodes to talk to each other. And so, in those cases,
01:14:53.440 | you will use an encoder block of self-attention. And all it means that it's an encoder block is
01:14:59.600 | that you will delete this line of code, allowing all the nodes to completely talk to each other.
01:15:04.560 | What we're implementing here is sometimes called a decoder block. And it's called a decoder
01:15:09.360 | because it is sort of like decoding language. And it's got this autoregressive format where you have
01:15:16.880 | to mask with the triangular matrix so that nodes from the future never talk to the past.
01:15:23.200 | Because they would give away the answer. And so, basically, in encoder blocks, you would delete
01:15:27.920 | this, allow all the nodes to talk. In decoder blocks, this will always be present so that you
01:15:33.040 | have this triangular structure. But both are allowed and attention doesn't care. Attention
01:15:37.280 | supports arbitrary connectivity between nodes. The next thing I wanted to comment on is you keep
01:15:41.680 | hearing me say attention, self-attention, etc. There's actually also something called cross
01:15:46.400 | attention. What is the difference? Basically, the reason this attention is self-attention
01:15:53.760 | is because the keys, queries, and the values are all coming from the same source, from x.
01:16:00.240 | So the same source, x, produces keys, queries, and values. So these nodes are self-attending.
01:16:06.240 | But in principle, attention is much more general than that. For example, in encoder-decoder
01:16:11.360 | transformers, you can have a case where the queries are produced from x, but the keys and
01:16:17.120 | the values come from a whole separate external source. And sometimes from encoder blocks that
01:16:22.640 | encode some context that we'd like to condition on. And so the keys and the values will actually
01:16:26.960 | come from a whole separate source. Those are nodes on the side. And here we're just producing
01:16:31.760 | queries and we're reading off information from the side. So cross attention is used when there's a
01:16:38.080 | separate source of nodes we'd like to pull information from into our nodes. And it's
01:16:44.080 | self-attention if we just have nodes that would like to look at each other and talk to each other.
01:16:47.440 | So this attention here happens to be self-attention. But in principle,
01:16:53.440 | attention is a lot more general. Okay, and the last note at this stage is if we come to the
01:16:59.520 | attention is all you need paper here. We've already implemented attention. So given query,
01:17:03.840 | key and value, we've multiplied the query on the key. We've soft-maxed it. And then we are
01:17:09.680 | aggregating the values. There's one more thing that we're missing here, which is the dividing
01:17:13.680 | by one over square root of the head size. The decay here is the head size. Why are they doing
01:17:18.960 | this? Why is this important? So they call it a scaled attention. And it's kind of like an
01:17:24.880 | important normalization to basically have. The problem is if you have unit Gaussian inputs,
01:17:30.160 | so zero mean unit variance, k and q are unit Gaussian. And if you just do weigh naively,
01:17:35.520 | then you see that your weigh actually will be, the variance will be on the order of head size,
01:17:40.000 | which in our case is 16. But if you multiply by one over head size square root, so this is square
01:17:46.000 | root and this is one over, then the variance of weigh will be one. So it will be preserved.
01:17:51.440 | Now, why is this important? You'll notice that weigh here will feed into soft-max.
01:17:59.520 | And so it's really important, especially at initialization, that weigh be fairly diffuse.
01:18:03.920 | So in our case here, we sort of locked out here and weigh had a fairly diffuse numbers here. So
01:18:11.760 | like this. Now, the problem is that because of soft-max, if weigh takes on very positive and
01:18:18.160 | very negative numbers inside it, soft-max will actually converge towards one-hot vectors.
01:18:24.160 | And so I can illustrate that here. Say we are applying soft-max to a tensor of values that
01:18:30.880 | are very close to zero, then we're going to get a diffuse thing out of soft-max. But the moment I
01:18:36.000 | take the exact same thing and I start sharpening it and making it bigger by multiplying these
01:18:40.000 | numbers by eight, for example, you'll see that the soft-max will start to sharpen. And in fact,
01:18:44.560 | it will sharpen towards the max. So it will sharpen towards whatever number here is the highest.
01:18:50.080 | And so basically, we don't want these values to be too extreme, especially at initialization.
01:18:54.480 | Otherwise, soft-max will be way too peaky. And you're basically aggregating information from
01:19:00.800 | like a single node. Every node just aggregates information from a single other node. That's
01:19:04.720 | not what we want, especially at initialization. And so the scaling is used just to control the
01:19:09.840 | variance at initialization. Okay, so having said all that, let's now take our self-attention
01:19:15.040 | knowledge and let's take it for a spin. So here in the code, I've created this head module and
01:19:21.120 | implements a single head of self-attention. So you give it a head size, and then here it creates the
01:19:26.720 | key query and the value linear layers. Typically, people don't use biases in these. So those are
01:19:32.640 | the linear projections that we're going to apply to all of our nodes. Now here, I'm creating this
01:19:37.760 | trill variable. Trill is not a parameter of the module. So in sort of PyTorch naming conventions,
01:19:43.360 | this is called a buffer. It's not a parameter. And you have to assign it to the module using
01:19:48.240 | a register buffer. So that creates the trill, the lower triangular matrix. And when we're given the
01:19:54.960 | input x, this should look very familiar now. We calculate the keys, the queries. We calculate the
01:20:00.240 | attention scores in Sideway. We normalize it. So we're using scaled attention here.
01:20:05.280 | Then we make sure that future doesn't communicate with the past. So this makes it a decoder block.
01:20:11.920 | And then softmax, and then aggregate the value and output. Then here in the language model,
01:20:17.520 | I'm creating a head in the constructor, and I'm calling it self-attention head.
01:20:21.760 | And the head size, I'm going to keep as the same and embed, just for now. And then here,
01:20:29.040 | once we've encoded the information with the token embeddings and the position embeddings,
01:20:34.160 | we're simply going to feed it into the self-attention head. And then the output of
01:20:38.080 | that is going to go into the decoder language modeling head and create the logits. So this is
01:20:44.720 | the simplest way to plug in a self-attention component into our network right now.
01:20:50.160 | I had to make one more change, which is that here in the generate, we have to make sure that our
01:20:58.240 | IDX that we feed into the model, because now we're using positional embeddings,
01:21:02.800 | we can never have more than block size coming in. Because if IDX is more than block size,
01:21:08.720 | then our position embedding table is going to run out of scope, because it only has embeddings for
01:21:12.640 | up to block size. And so therefore, I added some code here to crop the context that we're going to
01:21:18.560 | feed into self, so that we never pass in more than block size elements. So those are the changes,
01:21:26.400 | and let's now train the network. So I also came up to the script here, and I decreased the learning
01:21:31.120 | rate, because the self-attention can't tolerate very, very high learning rates. And then I also
01:21:36.560 | increased the number of iterations, because the learning rate is lower. And then I trained it,
01:21:40.080 | and previously we were only able to get to up to 2.5, and now we are down to 2.4. So we definitely
01:21:45.840 | see a little bit of an improvement from 2.5 to 2.4, roughly, but the text is still not amazing.
01:21:51.280 | So clearly, the self-attention head is doing some useful communication,
01:21:56.320 | but we still have a long way to go. Okay, so now we've implemented the scale.productAttention.
01:22:01.920 | Now next up, in the attention is all you need paper, there's something called multi-head
01:22:06.000 | attention. And what is multi-head attention? It's just applying multiple attentions in parallel,
01:22:11.760 | and concatenating their results. So they have a little bit of diagram here. I don't know if this
01:22:16.720 | is super clear. It's really just multiple attentions in parallel. So let's implement that.
01:22:23.120 | Fairly straightforward. If we want a multi-head attention, then we want multiple heads of
01:22:28.160 | self-attention running in parallel. So in PyTorch, we can do this by simply creating multiple heads.
01:22:34.720 | So however many heads you want, and then what is the head size of each. And then we run all of them
01:22:43.280 | in parallel into a list, and simply concatenate all of the outputs. And we're concatenating over
01:22:49.440 | the channel dimension. So the way this looks now is, we don't have just a single attention
01:22:54.720 | that has a head size of 32, because remember, an embed is 32. Instead of having one communication
01:23:03.680 | channel, we now have four communication channels in parallel. And each one of these communication
01:23:09.360 | channels typically will be smaller correspondingly. So because we have four communication channels,
01:23:16.640 | we want eight-dimensional self-attention. And so from each communication channel, we're getting
01:23:21.120 | together eight-dimensional vectors. And then we have four of them, and that concatenates to give
01:23:25.840 | us 32, which is the original and embed. And so this is kind of similar to, if you're familiar
01:23:31.440 | with convolutions, this is kind of like a group convolution. Because basically, instead of having
01:23:36.080 | one large convolution, we do convolution in groups, and that's multi-headed self-attention.
01:23:42.400 | And so then here, we just use SA heads, self-attention heads, instead. Now, I actually
01:23:48.320 | ran it, and scrolling down, I ran the same thing, and then we now get down to 2.28, roughly. And
01:23:57.040 | the output is still, the generation is still not amazing, but clearly the validation loss is
01:24:01.200 | improving, because we were at 2.4 just now. And so it helps to have multiple communication channels,
01:24:07.120 | because obviously, these tokens have a lot to talk about. They want to find the consonants,
01:24:11.920 | the vowels, they want to find the vowels just from certain positions, they want to find any kinds of
01:24:17.200 | different things. And so it helps to create multiple independent channels of communication,
01:24:21.440 | gather lots of different types of data, and then decode the output.
01:24:25.920 | Now, going back to the paper for a second, of course, I didn't explain this figure in full
01:24:29.440 | detail, but we are starting to see some components of what we've already implemented. We have the
01:24:33.520 | positional encodings, the token encodings that add, we have the masked multi-headed attention
01:24:38.560 | implemented. Now, here's another multi-headed attention, which is a cross-attention to an
01:24:43.680 | encoder, which we haven't, we're not going to implement in this case. I'm going to come back
01:24:47.680 | to that later. But I want you to notice that there's a feedforward part here, and then this
01:24:52.560 | is grouped into a block that gets repeated again and again. Now, the feedforward part here is just
01:24:57.280 | a simple multi-layer perceptron. So the multi-headed, so here position-wise feedforward
01:25:04.960 | networks is just a simple little MLP. So I want to start basically in a similar fashion, also
01:25:10.560 | adding computation into the network. And this computation is on a per node level. So I've
01:25:17.440 | already implemented it, and you can see the diff highlighted on the left here when I've added or
01:25:21.440 | changed things. Now, before we had the multi-headed self-attention that did the communication,
01:25:26.640 | but we went way too fast to calculate the logits. So the tokens looked at each other, but didn't
01:25:32.560 | really have a lot of time to think on what they found from the other tokens. And so what I've
01:25:38.880 | implemented here is a little feedforward single layer. And this little layer is just a linear
01:25:44.240 | followed by a ReLU non-linearity, and that's it. So it's just a little layer, and then I call it
01:25:51.120 | feedforward and embed. And then this feedforward is just called sequentially right after the
01:25:57.680 | self-attention. So we self-attend, then we feedforward. And you'll notice that the feedforward
01:26:03.040 | here, when it's applying linear, this is on a per token level. All the tokens do this independently.
01:26:08.560 | So the self-attention is the communication, and then once they've gathered all the data,
01:26:12.960 | now they need to think on that data individually. And so that's what feedforward is doing,
01:26:17.920 | and that's why I've added it here. Now, when I train this, the validation loss actually continues
01:26:22.720 | to go down, now to 2.24, which is down from 2.28. The output still looks kind of terrible,
01:26:29.440 | but at least we've improved the situation. And so as a preview, we're going to now start to
01:26:35.120 | intersperse the communication with the computation. And that's also what the transformer does when it
01:26:42.320 | has blocks that communicate and then compute, and it groups them and replicates them.
01:26:47.360 | Okay, so let me show you what we'd like to do. We'd like to do something like this. We have
01:26:52.560 | a block, and this block is basically this part here, except for the cross-attention.
01:26:57.280 | Now, the block basically intersperses communication and then computation.
01:27:01.920 | The communication is done using multi-headed self-attention, and then the computation is
01:27:07.600 | done using a feedforward network on all the tokens independently. Now, what I've added
01:27:13.840 | here also is, you'll notice, this takes the number of embeddings in the embedding dimension and the
01:27:19.600 | number of heads that we would like, which is kind of like group size in group convolution.
01:27:23.280 | And I'm saying that the number of heads we'd like is four. And so because this is 32,
01:27:28.720 | we calculate that because this is 32, the number of heads should be four.
01:27:31.920 | The head size should be eight, so that everything sort of works out channel-wise.
01:27:37.520 | So this is how the transformer structures the sizes, typically. So the head size will become
01:27:45.120 | eight, and then this is how we want to intersperse them. And then here, I'm trying to create blocks,
01:27:49.920 | which is just a sequential application of block, block, block. So then we're interspersing
01:27:54.800 | communication feedforward many, many times, and then finally we decode. Now, I actually tried to
01:28:00.720 | run this, and the problem is, this doesn't actually give a very good result. And the reason for that
01:28:07.520 | is, we're starting to actually get a pretty deep neural net. And deep neural nets suffer
01:28:12.480 | from optimization issues, and I think that's what we're kind of like slightly starting to run into.
01:28:16.560 | So we need one more idea that we can borrow from the transformer paper to resolve those
01:28:21.760 | difficulties. Now, there are two optimizations that dramatically help with the depth of these
01:28:26.400 | networks and make sure that the networks remain optimizable. Let's talk about the first one.
01:28:31.120 | The first one in this diagram is, you see this arrow here, and then this arrow and this arrow,
01:28:37.520 | those are skip connections, or sometimes called residual connections. They come from this paper,
01:28:42.400 | the Presidual Learning for Image Recognition, from about 2015, that introduced the concept.
01:28:48.960 | Now, these are basically, what it means is, you transform the data, but then you have a skip
01:28:55.120 | connection with addition from the previous features. Now, the way I like to visualize it,
01:29:00.720 | that I prefer, is the following. Here, the computation happens from the top to bottom,
01:29:07.440 | and basically, you have this residual pathway, and you are free to fork off from the residual
01:29:12.880 | pathway, perform some computation, and then project back to the residual pathway via addition.
01:29:17.600 | And so you go from the inputs to the targets only via plus, and plus, and plus. And the reason this
01:29:25.920 | is useful is because during backpropagation, remember from our micrograd video earlier,
01:29:30.480 | addition distributes gradients equally to both of its branches that fed us the input.
01:29:36.960 | And so the supervision, or the gradients from the loss, basically hop through every addition node
01:29:44.640 | all the way to the input, and then also fork off into the residual blocks. But basically,
01:29:51.680 | you have this gradient superhighway that goes directly from the supervision all the way to
01:29:56.000 | the input, unimpeded. And then these residual blocks are usually initialized in the beginning,
01:30:01.040 | so they contribute very, very little, if anything, to the residual pathway. They are initialized
01:30:05.680 | that way. So in the beginning, they are almost kind of like not there. But then during the
01:30:10.480 | optimization, they come online over time, and they start to contribute. But at least at the
01:30:17.040 | initialization, you can go from directly supervision to the input, gradient is unimpeded and just
01:30:22.080 | flows, and then the blocks over time kick in. And so that dramatically helps with the optimization.
01:30:28.480 | So let's implement this. So coming back to our block here, basically what we want to do is
01:30:33.360 | we want to do x equals x plus self-attention, and x equals x plus self-upfeedforward.
01:30:39.440 | So this is x, and then we fork off and do some communication and come back. And we fork off,
01:30:46.400 | and we do some computation and come back. So those are residual connections. And then
01:30:51.440 | swinging back up here, we also have to introduce this projection. So nn.linear.
01:30:58.480 | And this is going to be from after we concatenate this. This is the size and embed. So this is the
01:31:05.360 | output of the self-attention itself. But then we actually want to apply the projection,
01:31:11.040 | and that's the result. So the projection is just a linear transformation of the outcome of this
01:31:16.800 | layer. So that's the projection back into the residual pathway. And then here in the feed
01:31:22.560 | forward, it's going to be the same thing. I could have a self-dot projection here as well. But let
01:31:27.920 | me just simplify it, and let me couple it inside the same sequential container. And so this is the
01:31:35.360 | projection layer going back into the residual pathway. And so that's it. So now we can train
01:31:43.280 | this. So I implemented one more small change. When you look into the paper again, you see that
01:31:49.280 | the dimensionality of input and output is 512 for them. And they're saying that the inner layer
01:31:54.080 | here in the feed forward has dimensionality of 2048. So there's a multiplier of 4. And so the
01:32:00.080 | inner layer of the feed forward network should be multiplied by 4 in terms of channel sizes.
01:32:05.040 | So I came here, and I multiplied 4 times embed here for the feed forward, and then from 4 times
01:32:10.560 | nembed coming back down to nembed when we go back to the projection. So adding a bit of computation
01:32:16.560 | here and growing that layer that is in the residual block on the side of the residual pathway.
01:32:22.960 | And then I train this, and we actually get down all the way to 2.08 validation loss. And we also
01:32:28.480 | see that the network is starting to get big enough that our train loss is getting ahead
01:32:32.000 | of validation loss. So we start to see a little bit of overfitting. And our generations here are
01:32:40.720 | still not amazing, but at least you see that we can see like is here, this now, grief, sink.
01:32:45.440 | Like this starts to almost look like English. So yeah, we're starting to really get there.
01:32:51.520 | Okay, and the second innovation that is very helpful for optimizing very deep neural networks
01:32:55.600 | is right here. So we have this addition now, that's the residual part. But this norm is referring to
01:33:00.480 | something called layer norm. So layer norm is implemented in PyTorch. It's a paper that came out
01:33:05.360 | a while back here. And layer norm is very, very similar to batch norm. So remember back to
01:33:13.600 | our Make More Series part three, we implemented batch normalization.
01:33:18.560 | And batch normalization basically just made sure that across the batch dimension,
01:33:24.080 | any individual neuron had unit Gaussian distribution. So it was zero mean and unit
01:33:31.840 | standard deviation, one standard deviation output. So what I did here is I'm copy pasting the batch
01:33:38.080 | norm 1D that we developed in our Make More Series. And see here, we can initialize, for example,
01:33:43.680 | this module, and we can have a batch of 32 100 dimensional vectors feeding through the batch
01:33:48.800 | norm layer. So what this does is it guarantees that when we look at just the zeroth column,
01:33:55.520 | it's a zero mean, one standard deviation. So it's normalizing every single column of this input.
01:34:02.880 | Now the rows are not going to be normalized by default, because we're just normalizing columns.
01:34:09.600 | So let's not implement layer norm. It's very complicated. Look, we come here,
01:34:14.960 | we change this from zero to one. So we don't normalize the columns, we normalize the rows.
01:34:20.880 | And now we've implemented layer norm. So now the columns are not going to be normalized.
01:34:28.080 | But the rows are going to be normalized for every individual example, it's 100 dimensional vector
01:34:34.960 | is normalized in this way. And because our computation now does not span across examples,
01:34:40.640 | we can delete all of this buffers stuff, because we can always apply this operation,
01:34:48.080 | and don't need to maintain any running buffers. So we don't need the buffers.
01:34:52.080 | We don't, there's no distinction between training and test time.
01:34:57.280 | And we don't need these running buffers. We do keep gamma and beta, we don't need the momentum,
01:35:04.800 | we don't care if it's training or not. And this is now a layer norm. And it normalizes
01:35:12.080 | the rows instead of the columns. And this here is identical to basically this here.
01:35:18.240 | So let's now implement layer norm in our transformer. Before I incorporate the layer norm,
01:35:23.760 | I just wanted to note that, as I said, very few details about the transformer have changed in the
01:35:28.160 | last five years. But this is actually something that slightly departs from the original paper.
01:35:32.480 | You see that the add and norm is applied after the transformation. But now it is a bit more,
01:35:40.400 | basically, common to apply the layer norm before the transformation. So there's a reshuffling of
01:35:45.360 | the layer norms. So this is called the pre-norm formulation, and that's the one that we're going
01:35:49.680 | to implement as well. So slight deviation from the original paper. Basically, we need to layer
01:35:54.320 | norms. Layer norm one is nn.layernorm, and we tell it how many, what is the embedding dimension.
01:36:02.640 | And we need the second layer norm. And then here, the layer norms are applied immediately on x.
01:36:08.400 | So self.layernorm1 applied on x, and self.layernorm2 applied on x, before it goes into
01:36:16.240 | self-attention and feedforward. And the size of the layer norm here is an embed, so 32.
01:36:22.320 | So when the layer norm is normalizing our features, it is the normalization here
01:36:30.160 | happens, the mean and the variance are taken over 32 numbers. So the batch and the time act as batch
01:36:36.320 | dimensions, both of them. So this is kind of like a per-token transformation that just normalizes
01:36:42.720 | the features and makes them unit mean, unit Gaussian at initialization. But of course,
01:36:49.680 | because these layer norms inside it have these gamma and beta trainable parameters,
01:36:55.200 | the layer norm will eventually create outputs that might not be unit Gaussian,
01:37:00.640 | but the optimization will determine that. So for now, this is incorporating the layer norms,
01:37:06.720 | and let's train them up. Okay, so I let it run, and we see that we get down to 2.06,
01:37:11.840 | which is better than the previous 2.08. So a slight improvement by adding the layer norms.
01:37:16.560 | And I'd expect that they help even more if we had a bigger and deeper network.
01:37:20.880 | One more thing I forgot to add is that there should be a layer norm here also typically,
01:37:25.280 | as at the end of the transformer and right before the final linear layer that decodes into vocabulary.
01:37:32.560 | So I added that as well. So at this stage, we actually have a pretty complete transformer
01:37:37.360 | according to the original paper, and it's a decoder-only transformer. I'll talk about that
01:37:42.000 | in a second. But at this stage, the major pieces are in place, so we can try to scale this up and
01:37:47.040 | see how well we can push this number. Now, in order to scale up the model, I had to perform
01:37:51.600 | some cosmetic changes here to make it nicer. So I introduced this variable called n_layer,
01:37:56.640 | which just specifies how many layers of the blocks we're going to have. I create a bunch
01:38:01.600 | of blocks, and we have a new variable, number of heads as well. I pulled out the layer norm here,
01:38:07.040 | and so this is identical. Now, one thing that I did briefly change is I added dropout. So dropout
01:38:14.000 | is something that you can add right before the residual connection back into the residual pathway.
01:38:20.880 | So we can drop out that as the last layer here. We can drop out here at the end of the
01:38:27.200 | multi-headed restriction as well. And we can also drop out here when we calculate the
01:38:32.880 | basically affinities, and after the softmax, we can drop out some of those. So we can randomly
01:38:39.200 | prevent some of the nodes from communicating. And so dropout comes from this paper from 2014 or so,
01:38:46.720 | and basically it takes your neural net, and it randomly, every forward-backward pass,
01:38:54.080 | shuts off some subset of neurons. So randomly drops them to zero and trains without them.
01:39:01.520 | And what this does effectively is because the mask of what's being dropped out has changed
01:39:06.880 | every single forward-backward pass, it ends up kind of training an ensemble of subnetworks.
01:39:12.400 | And then at test time, everything is fully enabled and kind of all of those subnetworks
01:39:16.720 | are merged into a single ensemble, if you want to think about it that way.
01:39:20.160 | So I would read the paper to get the full detail. For now, we're just going to stay on the level of
01:39:25.200 | this is a regularization technique, and I added it because I'm about to scale up the model quite a
01:39:30.320 | bit, and I was concerned about overfitting. So now when we scroll up to the top, we'll see that
01:39:36.400 | I changed a number of hyperparameters here about our neural net. So I made the batch size be much
01:39:41.040 | larger, now it's 64. I changed the block size to be 256, so previously it was just 8 characters
01:39:47.440 | of context. Now it is 256 characters of context to predict the 257th. I brought down the learning
01:39:55.440 | rate a little bit because the neural net is now much bigger, so I brought down the learning rate.
01:40:00.080 | The embedding dimension is now 384, and there are six heads. So 384 divide 6 means that every head
01:40:07.200 | is 64-dimensional as a standard. And then there are going to be six layers of that,
01:40:13.120 | and the dropout will be at 0.2. So every forward-backward pass, 20% of all these
01:40:18.240 | intermediate calculations are disabled and dropped to zero.
01:40:22.880 | And then I already trained this and I ran it, so drumroll, how well does it perform?
01:40:29.440 | So let me just scroll up here. We get a validation loss of 1.48, which is actually quite a bit of an
01:40:37.040 | improvement on what we had before, which I think was 2.07. So we went from 2.07 all the way down
01:40:42.000 | to 1.48 just by scaling up this neural net with the code that we have. And this, of course, ran
01:40:47.120 | for a lot longer. This may be trained for, I want to say, about 15 minutes on my A100 GPU, so that's
01:40:52.960 | a pretty good GPU. And if you don't have a GPU, you're not going to be able to reproduce this.
01:40:57.600 | On a CPU, this would be, I would not run this on a CPU or a MacBook or something like that.
01:41:02.640 | You'll have to break down the number of layers and the embedding dimension and so on.
01:41:07.040 | But in about 15 minutes, we can get this kind of a result. And I'm printing some of the Shakespeare
01:41:14.640 | here, but what I did also is I printed 10,000 characters, so a lot more, and I wrote them to
01:41:19.040 | a file. And so here we see some of the outputs. So it's a lot more recognizable as the input text
01:41:27.360 | file. So the input text file, just for reference, looked like this. So there's always someone
01:41:32.960 | speaking in this manner, and our predictions now take on that form. Except, of course, they're
01:41:40.720 | nonsensical when you actually read them. So it is, "Every crimpty be a house. Oh, those probation."
01:41:49.440 | "We give heed."
01:41:50.240 | "Oho, sent me you mighty lord."
01:41:57.840 | Anyway, so you can read through this. It's nonsensical, of course, but this is just a
01:42:05.120 | transformer trained on the character level for 1 million characters that come from Shakespeare.
01:42:10.080 | So there's sort of like blabbers on in Shakespeare-like manner, but it doesn't,
01:42:14.160 | of course, make sense at this scale. But I think it's still a pretty good demonstration
01:42:19.520 | of what's possible. So now I think that kind of concludes the programming section of this video.
01:42:28.320 | We basically kind of did a pretty good job of implementing this transformer, but the picture
01:42:34.880 | doesn't exactly match up to what we've done. So what's going on with all these additional parts
01:42:39.040 | here? So let me finish explaining this architecture and why it looks so funky. Basically, what's
01:42:44.400 | happening here is what we implemented here is a decoder-only transformer. So there's no component
01:42:50.640 | here. This part is called the encoder, and there's no cross-attention block here. Our block only has
01:42:56.720 | a self-attention and the feedforward, so it is missing this third in-between piece here. This
01:43:02.960 | piece does cross-attention. So we don't have it, and we don't have the encoder. We just have the
01:43:07.200 | decoder. And the reason we have a decoder only is because we are just generating text, and it's
01:43:13.360 | unconditioned on anything. We're just blabbering on according to a given dataset. What makes it a
01:43:19.040 | decoder is that we are using the triangular mask in our transformer. So it has this autoregressive
01:43:25.040 | property where we can just go and sample from it. So the fact that it's using the triangular mask
01:43:31.440 | to mask out the attention makes it a decoder, and it can be used for language modeling.
01:43:36.560 | Now, the reason that the original paper had an encoder-decoder architecture is because it is a
01:43:41.520 | machine translation paper. So it is concerned with a different setting in particular. It expects some
01:43:48.240 | tokens that encode, say for example, French, and then it is expected to decode the translation
01:43:54.560 | in English. So typically, these here are special tokens. So you are expected to read in this and
01:44:01.760 | condition on it. And then you start off the generation with a special token called start.
01:44:06.560 | So this is a special new token that you introduce and always place in the beginning. And then the
01:44:12.800 | network is expected to output neural networks are awesome, and then a special end token to finish
01:44:18.800 | the generation. So this part here will be decoded exactly as we've done it. Neural networks are
01:44:26.080 | awesome will be identical to what we did. But unlike what we did, they want to condition the
01:44:32.240 | generation on some additional information. And in that case, this additional information is the
01:44:37.680 | French sentence that they should be translating. So what they do now is they bring the encoder.
01:44:43.680 | Now the encoder reads this part here. So we're only going to take the part of French, and we're
01:44:50.400 | going to create tokens from it exactly as we've seen in our video. And we're going to put a
01:44:55.920 | transformer on it. But there's going to be no triangular mask. And so all the tokens are allowed
01:45:00.640 | to talk to each other as much as they want. And they're just encoding whatever's the content of
01:45:05.440 | this French sentence. Once they've encoded it, they've they basically come out in the top here.
01:45:12.480 | And then what happens here is in our decoder, which does the language modeling, there's an
01:45:18.960 | additional connection here to the outputs of the encoder. And that is brought in through
01:45:24.640 | cross attention. So the queries are still generated from x. But now the keys and the values
01:45:30.800 | are coming from the side, the keys and the values are coming from the top generated by the nodes
01:45:36.320 | that came outside of the decode the encoder. And those tops, the keys and the values, they're
01:45:41.840 | the top of it, feed in on a side into every single block of the decoder. And so that's why there's an
01:45:48.400 | additional cross attention. And really what it's doing is it's conditioning the decoding,
01:45:53.520 | not just on the past of this current decoding, but also on having seen the full, fully encoded
01:46:01.440 | French prompt sort of. And so it's an encoder decoder model, which is why we have those two
01:46:07.360 | transformers and additional block and so on. So we did not do this because we have no we have
01:46:12.480 | nothing to encode, there's no conditioning, we just have a text file, and we just want to imitate it.
01:46:16.640 | And that's why we are using a decoder only transformer, exactly as done in GPT.
01:46:21.440 | Okay, so now I wanted to do a very brief walkthrough of nano GPT, which you can find
01:46:26.560 | on my GitHub. And now GPT is basically two files of interest. There's trained up by and model by
01:46:33.040 | trained up by as all the boilerplate code for training the network. It is basically all the
01:46:38.320 | stuff that we had here is the training loop. It's just that it's a lot more complicated,
01:46:43.280 | because we're saving and loading checkpoints and pre trained weights. And we are
01:46:47.120 | decaying the learning rate and compiling the model and using distributed training across
01:46:51.520 | multiple nodes or GPUs. So the training that pie gets a little bit more hairy, complicated,
01:46:56.720 | there's more options, etc. But the model that I should look very, very similar to what we've done
01:47:03.520 | here. In fact, the model is almost identical. So first, here we have the causal self attention
01:47:10.080 | block. And all of this should look very, very recognizable to you. We're producing queries,
01:47:14.720 | keys values, we're doing dot products, we're masking, applying softmax, optionally dropping
01:47:20.560 | out. And here we are pulling the way the values. What is different here is that in our code,
01:47:26.480 | I have separated out the multi headed attention into just a single individual head.
01:47:32.720 | And then here, I have multiple heads, and I explicitly concatenate them.
01:47:37.680 | Whereas here, all of it is implemented in a batched manner inside a single causal self attention.
01:47:43.200 | And so we don't just have a B and a T and a C dimension, we also end up with a fourth dimension,
01:47:48.080 | which is the heads. And so it just gets a lot more sort of hairy, because we have four dimensional
01:47:53.360 | array tensors now, but it is equivalent mathematically. So the exact same thing is
01:47:59.280 | happening as what we have, it's just it's a bit more efficient, because all the heads are now
01:48:03.600 | treated as a batch dimension as well. Then we have the multilayer perceptron, it's using the
01:48:08.720 | Gelu nonlinearity, which is defined here, instead of Relu. And this is done just because OpenAI
01:48:14.640 | used it, and I want to be able to load their checkpoints. The blocks of the transformer are
01:48:19.520 | identical, the communicate and the compute phase as we saw, and then the GPT will be identical,
01:48:24.720 | we have the position encodings, token encodings, the blocks, the layer norm at the end,
01:48:30.080 | the final linear layer. And this should look all very recognizable. And there's a bit more here,
01:48:36.160 | because I'm loading checkpoints and stuff like that. I'm separating out the parameters into those
01:48:40.480 | that should be weight decayed and those that shouldn't. But the generate function should also
01:48:45.280 | be very, very similar. So a few details are different, but you should definitely be able to
01:48:49.600 | look at this file and be able to understand a lot of the pieces now. So let's now bring things back
01:48:55.040 | to chat-gpt. What would it look like if we wanted to train chat-gpt ourselves? And how does it relate
01:49:00.320 | to what we learned today? Well, to train chat-gpt, there are roughly two stages. First is the pre
01:49:05.920 | training stage, and then the fine tuning stage. In the pre training stage, we are training on a
01:49:11.840 | large chunk of internet, and just trying to get a first decoder only transformer to babble text.
01:49:18.560 | So it's very, very similar to what we've done ourselves. Except we've done like a tiny little
01:49:24.080 | baby pre training step. And so in our case, this is how you print a number of parameters. I printed
01:49:32.080 | it and it's about 10 million. So this transformer that I created here to create a little Shakespeare
01:49:37.120 | transformer was about 10 million parameters. Our data set is roughly 1 million characters,
01:49:44.880 | so roughly 1 million tokens. But you have to remember that OpenAI uses different vocabulary.
01:49:49.440 | They're not on the character level. They use these sub word chunks of words. And so they have a
01:49:54.960 | vocabulary of 50,000 roughly elements. And so their sequences are a bit more condensed. So our data set,
01:50:02.320 | the Shakespeare data set would be probably around 300,000 tokens in the OpenAI vocabulary, roughly.
01:50:08.160 | So we trained about 10 million parameter model on roughly 300,000 tokens. Now, when you go to the
01:50:14.800 | GPT-3 paper and you look at the transformers that they trained, they trained a number of
01:50:22.320 | transformers of different sizes. But the biggest transformer here has 175 billion parameters.
01:50:27.680 | So ours is again 10 million. They used this number of layers in a transformer. This is the N embed.
01:50:34.160 | This is the number of heads. And this is the head size. And then this is the batch size.
01:50:41.040 | So ours was 65. And the learning rate is similar. Now, when they trained this transformer,
01:50:47.600 | they trained on 300 billion tokens. So again, remember, ours is about 300,000. So this is
01:50:53.600 | about a million fold increase. And this number would not be even that large by today's standards.
01:50:59.200 | It'd be going up 1 trillion and above. So they are training a significantly larger model
01:51:07.200 | on a good chunk of the internet. And that is the pre-training stage. But otherwise,
01:51:12.400 | these hyperparameters should be fairly recognizable to you. And the architecture
01:51:15.920 | is actually nearly identical to what we implemented ourselves. But of course,
01:51:19.760 | it's a massive infrastructure challenge to train this. You're talking about typically thousands of
01:51:24.560 | GPUs having to talk to each other to train models of this size. So that's just the pre-training
01:51:30.720 | stage. Now, after you complete the pre-training stage, you don't get something that responds to
01:51:36.480 | your questions with answers and is not helpful and et cetera. You get a document completer.
01:51:41.520 | So it babbles, but it doesn't babble Shakespeare. It babbles internet. It will create arbitrary
01:51:48.000 | news articles and documents, and it will try to complete documents because that's what it's
01:51:51.760 | trained for. It's trying to complete the sequence. So when you give it a question,
01:51:55.360 | it would just potentially just give you more questions. It would follow with more questions.
01:52:00.240 | It will do whatever it looks like some closed document would do in the training data on the
01:52:05.840 | internet. And so who knows, you're getting kind of like undefined behavior. It might basically
01:52:10.240 | answer with two questions with other questions. It might ignore your question. It might just try
01:52:15.120 | to complete some news article. It's totally on the mind, as we say. So the second fine tuning stage
01:52:21.360 | is to actually align it to be an assistant. And this is the second stage. And so this chat GPT
01:52:28.240 | blog post from OpenAI talks a little bit about how this stage is achieved. We basically,
01:52:35.120 | there's roughly three steps to this stage. So what they do here is they start to collect
01:52:39.920 | training data that looks specifically like what an assistant would do. So there are documents
01:52:44.800 | that have the format where the question is on top and then an answer is below. And they have
01:52:49.200 | a large number of these, but probably not on the order of the internet. This is probably on the
01:52:53.520 | order of maybe thousands of examples. And so they then fine tune the model to basically only focus
01:53:02.480 | on documents that look like that. And so you're starting to slowly align it. So it's going to
01:53:06.880 | expect a question at the top and it's going to expect to complete the answer. And these very,
01:53:11.840 | very large models are very sample efficient during their fine tuning. So this actually somehow works.
01:53:16.960 | But that's just step one. That's just fine tuning. So then they actually have more steps where,
01:53:21.760 | okay, the second step is you let the model respond and then different raters look at the
01:53:26.800 | different responses and rank them for their preference as to which one is better than the
01:53:31.040 | other. They use that to train a reward model. So they can predict basically using a different
01:53:36.160 | network, how much of any candidate response would be desirable. And then once they have a reward
01:53:43.920 | model, they run PPO, which is a form of policy gradient reinforcement learning optimizer to fine
01:53:51.360 | tune this sampling policy so that the answers that the chat GPT now generates are expected to score a
01:54:00.000 | high reward according to the reward model. And so basically there's a whole aligning stage here,
01:54:06.240 | or fine tuning stage. It's got multiple steps in between there as well. And it takes the model from
01:54:11.840 | being a document completer to a question answerer. And that's like a whole separate stage. A lot of
01:54:18.560 | this data is not available publicly. It is internal to OpenAI and it's much harder to replicate this
01:54:24.640 | stage. And so that's roughly what would give you a chat GPT. And nano-GPT focuses on the pre-training
01:54:31.680 | stage. Okay. And that's everything that I wanted to cover today. So we trained to summarize a
01:54:38.000 | decoder only transformer following this famous paper, attention is all you need from 2017.
01:54:43.840 | And so that's basically a GPT. We trained it on a tiny Shakespeare and got sensible results.
01:54:52.640 | All of the training code is roughly 200 lines of code. I will be releasing this code base.
01:55:00.320 | So also it comes with all the Git log commits along the way as we built it up.
01:55:05.200 | In addition to this code, I'm going to release the notebook, of course, the Google collab.
01:55:11.520 | And I hope that gave you a sense for how you can train these models, like say GPT-3, that will be
01:55:19.120 | architecturally basically identical to what we have, but they are somewhere between 10,000 and
01:55:23.280 | 1 million times bigger, depending on how you count. And so that's all I have for now. We did
01:55:30.400 | not talk about any of the fine tuning stages that would typically go on top of this. So if you're
01:55:34.880 | interested in something that's not just language modeling, but you actually want to, you know,
01:55:38.320 | say perform tasks or you want them to be aligned in a specific way, or you want to detect sentiment
01:55:45.280 | or anything like that, basically anytime you don't want something that's just a document completer,
01:55:49.440 | you have to complete further stages of fine tuning, which we did not cover. And that could
01:55:54.560 | be simple supervised fine tuning, or it can be something more fancy, like we see in Chatship-BT,
01:55:58.880 | where we actually train a reward model and then do rounds of PPO to align it with respect to
01:56:03.840 | the reward model. So there's a lot more that can be done on top of it. I think for now we're
01:56:08.240 | starting to get to about two hours mark. So I'm going to kind of finish here. I hope you enjoyed
01:56:14.880 | the lecture and yeah go forth and transform. See you later.