back to indexLet's build GPT: from scratch, in code, spelled out.
Chapters
0:0 intro: ChatGPT, Transformers, nanoGPT, Shakespeare
7:52 reading and exploring the data
9:28 tokenization, train/val split
14:27 data loader: batches of chunks of data
22:11 simplest baseline: bigram language model, loss, generation
34:53 training the bigram model
38:0 port our code to a script
42:13 version 1: averaging past context with for loops, the weakest form of aggregation
47:11 the trick in self-attention: matrix multiply as weighted aggregation
51:54 version 2: using matrix multiply
54:42 version 3: adding softmax
58:26 minor code cleanup
60:18 positional encoding
62:0 THE CRUX OF THE VIDEO: version 4: self-attention
71:38 note 1: attention as communication
72:46 note 2: attention has no notion of space, operates over sets
73:40 note 3: there is no communication across batch dimension
74:14 note 4: encoder blocks vs. decoder blocks
75:39 note 5: attention vs. self-attention vs. cross-attention
76:56 note 6: "scaled" self-attention. why divide by sqrt(head_size)
79:11 inserting a single self-attention block to our network
81:59 multi-headed self-attention
84:25 feedforward layers of transformer block
86:48 residual connections
92:51 layernorm (and its relationship to our previous batchnorm)
97:49 scaling up the model! creating a few variables. adding dropout
102:39 encoder vs. decoder vs. both (?) Transformers
106:22 super quick walkthrough of nanoGPT, batched multi-headed self-attention
108:53 back to ChatGPT, GPT-3, pretraining vs. finetuning, RLHF
114:32 conclusions
00:00:00.000 |
Hi everyone. So by now you have probably heard of ChatGPT. It has taken the world and the AI 00:00:05.280 |
community by storm and it is a system that allows you to interact with an AI and give it text-based 00:00:11.840 |
tasks. So for example, we can ask ChatGPT to write us a small haiku about how important it is that 00:00:16.960 |
people understand AI and then they can use it to improve the world and make it more prosperous. 00:00:20.800 |
So when we run this, AI knowledge brings prosperity for all to see, embrace its power. 00:00:28.080 |
Okay, not bad. And so you could see that ChatGPT went from left to right and generated all these 00:00:33.200 |
words sequentially. Now I asked it already the exact same prompt a little bit earlier and it 00:00:39.760 |
generated a slightly different outcome. AI's power to grow, ignorance holds us back, learn, prosperity 00:00:46.080 |
waits. So pretty good in both cases and slightly different. So you can see that ChatGPT is a 00:00:51.440 |
probabilistic system and for any one prompt it can give us multiple answers sort of replying to it. 00:00:58.160 |
Now this is just one example of a prompt. People have come up with many, many examples and there 00:01:02.640 |
are entire websites that index interactions with ChatGPT and so many of them are quite humorous. 00:01:09.360 |
Explain HTML to me like I'm a dog, write release notes for chess too, write a note about Elon Musk 00:01:15.760 |
buying a Twitter and so on. So as an example, please write a breaking news article about a 00:01:22.080 |
leaf falling from a tree. In a shocking turn of events, a leaf has fallen from a tree in the local 00:01:27.920 |
park. Witnesses report that the leaf, which was previously attached to a branch of a tree, detached 00:01:32.800 |
itself and fell to the ground. Very dramatic. So you can see that this is a pretty remarkable system 00:01:38.560 |
and it is what we call a language model because it models the sequence of words or characters or 00:01:46.320 |
tokens more generally and it knows how certain words follow each other in English language. 00:01:52.000 |
And so from its perspective, what it is doing is it is completing the sequence. So I give it the 00:01:57.840 |
start of a sequence and it completes the sequence with the outcome. And so it's a language model 00:02:03.280 |
in that sense. Now I would like to focus on the under the hood components of what makes 00:02:10.080 |
ChatGPT work. So what is the neural network under the hood that models the sequence of these words? 00:02:16.400 |
And that comes from this paper called "Attention is All You Need" in 2017, a landmark paper, 00:02:22.880 |
a landmark paper in AI that produced and proposed the transformer architecture. 00:02:27.600 |
So GPT is short for generatively pre-trained transformer. So transformer is the neural 00:02:35.440 |
net that actually does all the heavy lifting under the hood. It comes from this paper in 2017. 00:02:41.040 |
Now if you read this paper, this reads like a pretty random machine translation paper and that's 00:02:46.000 |
because I think the authors didn't fully anticipate the impact that the transformer would have on the 00:02:49.920 |
field. And this architecture that they produced in the context of machine translation in their case 00:02:55.600 |
actually ended up taking over the rest of AI in the next five years after. And so this architecture 00:03:02.320 |
with minor changes was copy pasted into a huge amount of applications in AI in more recent years. 00:03:08.960 |
And that includes at the core of ChatGPT. Now we are not going to, what I'd like to do now is I'd 00:03:15.280 |
like to build out something like ChatGPT, but we're not going to be able to of course reproduce 00:03:20.560 |
ChatGPT. This is a very serious production grade system. It is trained on a good chunk of internet 00:03:27.760 |
and then there's a lot of pre-training and fine tuning stages to it. And so it's very complicated. 00:03:32.960 |
What I'd like to focus on is just to train a transformer based language model. And in our case 00:03:39.200 |
it's going to be a character level language model. I still think that is a very educational with 00:03:44.480 |
respect to how these systems work. So I don't want to train on the chunk of internet. We need a smaller 00:03:49.520 |
data set. In this case, I propose that we work with my favorite toy data set. It's called Tiny 00:03:54.800 |
Shakespeare. And what it is is basically it's a concatenation of all of the works of Shakespeare 00:04:00.240 |
in my understanding. And so this is all of Shakespeare in a single file. This file is about 00:04:05.760 |
one megabyte and it's just all of Shakespeare. And what we are going to do now is we're going 00:04:11.440 |
to basically model how these characters follow each other. So for example, given a chunk of 00:04:16.560 |
these characters like this, given some context of characters in the past, the transformer neural 00:04:23.440 |
network will look at the characters that I've highlighted and is going to predict that G 00:04:27.680 |
is likely to come next in the sequence. And it's going to do that because we're going to train that 00:04:32.000 |
transformer on Shakespeare. And it's just going to try to produce character sequences that look 00:04:37.840 |
like this. And in that process, it's going to model all the patterns inside this data. 00:04:42.160 |
So once we've trained the system, I'd just like to give you a preview. 00:04:46.160 |
We can generate infinite Shakespeare. And of course, it's a fake thing that looks kind of like 00:04:52.640 |
Shakespeare. Apologies for there's some jank that I'm not able to resolve in here, but 00:05:03.760 |
you can see how this is going character by character. And it's kind of like predicting 00:05:07.600 |
Shakespeare-like language. So "Verily, my lord, the sites have left thee again, the king, 00:05:14.960 |
coming with my curses with precious pale." And then "Tranio says something else," et cetera. 00:05:20.880 |
And this is just coming out of the transformer in a very similar manner as it would come out 00:05:25.200 |
in chat GPT. In our case, character by character, in chat GPT, it's coming out on the token by token 00:05:32.240 |
level. And tokens are these sort of like little sub-word pieces. So they're not word level. They're 00:05:36.880 |
kind of like word chunk level. And now I've already written this entire code to train these 00:05:45.040 |
transformers. And it is in a GitHub repository that you can find, and it's called nanoGPT. 00:05:51.600 |
So nanoGPT is a repository that you can find on my GitHub. And it's a repository for training 00:05:57.680 |
transformers on any given text. And what I think is interesting about it, because there's many ways 00:06:03.440 |
to train transformers, but this is a very simple implementation. So it's just two files of 300 00:06:08.560 |
lines of code each. One file defines the GPT model, the transformer, and one file trains it 00:06:14.400 |
on some given text dataset. And here I'm showing that if you train it on a open web text dataset, 00:06:19.600 |
which is a fairly large dataset of web pages, then I reproduce the performance of GPT2. 00:06:26.320 |
So GPT2 is an early version of OpenAI's GPT from 2017, if I recall correctly. And I've only so far 00:06:34.400 |
reproduced the smallest 124 million parameter model. But basically, this is just proving that 00:06:39.280 |
the code base is correctly arranged. And I'm able to load the neural network weights that OpenAI has 00:06:45.440 |
released later. So you can take a look at the finished code here in nanoGPT. But what I would 00:06:51.200 |
like to do in this lecture is I would like to basically write this repository from scratch. 00:06:57.040 |
So we're going to begin with an empty file, and we're going to define a transformer piece by piece. 00:07:02.160 |
We're going to train it on the tiny Shakespeare dataset, and we'll see how we can then generate 00:07:08.640 |
infinite Shakespeare. And of course, this can copy paste to any arbitrary text dataset that you like. 00:07:14.000 |
But my goal really here is to just make you understand and appreciate how under the hood 00:07:19.600 |
chat-gpt works. And really, all that's required is a proficiency in Python and some basic 00:07:27.040 |
understanding of calculus and statistics. And it would help if you also see my previous videos 00:07:32.960 |
on the same YouTube channel, in particular, my Make More series, where I define smaller and 00:07:40.400 |
simpler neural network language models. So multilayered perceptrons and so on. It really 00:07:45.520 |
introduces the language modeling framework. And then here in this video, we're going to focus on 00:07:50.000 |
the transformer neural network itself. Okay, so I created a new Google Colab Jupyter notebook here. 00:07:56.240 |
And this will allow me to later easily share this code that we're going to develop together 00:08:00.960 |
with you so you can follow along. So this will be in the video description later. Now, here I've 00:08:06.800 |
just done some preliminaries. I downloaded the dataset, the tiny Shakespeare dataset at this URL, 00:08:11.680 |
and you can see that it's about a one megabyte file. Then here I opened the input.txt file and 00:08:17.120 |
just read in all the text as a string. And we see that we are working with one million characters 00:08:21.920 |
roughly. And the first 1000 characters, if we just print them out, are basically what you would 00:08:26.960 |
expect. This is the first 1000 characters of the tiny Shakespeare dataset, roughly up to here. 00:08:32.400 |
So, so far, so good. Next, we're going to take this text. And the text is a sequence of characters 00:08:39.200 |
in Python. So when I call the set constructor on it, I'm just going to get the set of all the 00:08:45.360 |
characters that occur in this text. And then I call list on that to create a list of those 00:08:51.600 |
characters instead of just a set so that I have an ordering, an arbitrary ordering. And then I sort 00:08:57.040 |
that. So basically, we get just all the characters that occur in the entire dataset, and they're 00:09:01.600 |
sorted. Now, the number of them is going to be our vocabulary size. These are the possible elements 00:09:07.280 |
of our sequences. And we see that when I print here the characters, there's 65 of them in total. 00:09:14.160 |
There's a space character, and then all kinds of special characters, and then capitals and 00:09:19.760 |
lowercase letters. So that's our vocabulary. And that's the sort of like possible characters that 00:09:25.440 |
the model can see or emit. Okay, so next, we would like to develop some strategy to tokenize 00:09:32.080 |
the input text. Now, when people say tokenize, they mean convert the raw text as a string 00:09:38.640 |
to some sequence of integers according to some vocabulary of possible elements. 00:09:44.640 |
So as an example, here, we are going to be building a character-level language model. 00:09:49.360 |
So we're simply going to be translating individual characters into integers. 00:09:52.640 |
So let me show you a chunk of code that sort of does that for us. 00:09:56.400 |
So we're building both the encoder and the decoder. And let me just talk through what's 00:10:01.040 |
happening here. When we encode an arbitrary text, like "Hi there," we're going to receive 00:10:07.840 |
a list of integers that represents that string. So for example, 46, 47, etc. And then we also 00:10:15.200 |
have the reverse mapping. So we can take this list and decode it to get back the exact same string. 00:10:21.600 |
So it's really just like a translation to integers and back for arbitrary string. And for us, 00:10:27.440 |
it is done on a character level. Now, the way this was achieved is we just iterate over all 00:10:32.640 |
the characters here and create a lookup table from the character to the integer and vice versa. 00:10:37.920 |
And then to encode some string, we simply translate all the characters individually. 00:10:42.320 |
And to decode it back, we use the reverse mapping and concatenate all of it. 00:10:46.720 |
Now, this is only one of many possible encodings or many possible sort of tokenizers. And it's 00:10:52.720 |
a very simple one. But there's many other schemas that people have come up with in practice. 00:10:57.360 |
So for example, Google uses Sentence Piece. So Sentence Piece will also encode text into 00:11:03.280 |
integers, but in a different schema and using a different vocabulary. And Sentence Piece is a 00:11:10.720 |
sub-word sort of tokenizer. And what that means is that you're not encoding entire words, but 00:11:16.880 |
you're not also encoding individual characters. It's a sub-word unit level. And that's usually 00:11:22.800 |
what's adopted in practice. For example, also OpenAI has this library called TicToken that 00:11:27.600 |
uses a byte pair encoding tokenizer. And that's what GPT uses. And you can also just encode words 00:11:34.960 |
into like Hello World into a list of integers. So as an example, I'm using the TicToken library 00:11:41.120 |
here. I'm getting the encoding for GPT-2 or that was used for GPT-2. Instead of just having 65 00:11:47.600 |
possible characters or tokens, they have 50,000 tokens. And so when they encode the exact same 00:11:54.720 |
string, hi there, we only get a list of three integers. But those integers are not between 0 00:12:00.480 |
and 64. They are between 0 and 50,256. So basically, you can trade off the codebook size 00:12:10.320 |
and the sequence lengths. So you can have very long sequences of integers with very small 00:12:15.040 |
vocabularies, or you can have short sequences of integers with very large vocabularies. And so 00:12:23.520 |
typically people use in practice these sub-word encodings, but I'd like to keep our tokenizer 00:12:29.600 |
very simple. So we're using character level tokenizer. And that means that we have very 00:12:33.760 |
small codebooks. We have very simple encode and decode functions, but we do get very long 00:12:40.160 |
sequences as a result. But that's the level at which we're going to stick with this lecture, 00:12:44.160 |
because it's the simplest thing. Okay, so now that we have an encoder and a decoder, 00:12:48.400 |
effectively a tokenizer, we can tokenize the entire training set of Shakespeare. So here's a 00:12:53.680 |
chunk of code that does that. And I'm going to start to use the PyTorch library and specifically 00:12:58.000 |
the torch.tensor from the PyTorch library. So we're going to take all of the text in Tiny 00:13:03.280 |
Shakespeare, encode it, and then wrap it into a torch.tensor to get the data tensor. So here's 00:13:09.360 |
what the data tensor looks like when I look at just the first 1000 characters or the 1000 elements 00:13:14.400 |
of it. So we see that we have a massive sequence of integers. And this sequence of integers here 00:13:20.080 |
is basically an identical translation of the first 1000 characters here. So I believe, for example, 00:13:26.400 |
that 0 is a newline character, and maybe 1 is a space. I'm not 100% sure. But from now on, 00:13:32.560 |
the entire data set of text is re-represented as just, it's just stretched out as a single, 00:13:37.040 |
very large sequence of integers. Let me do one more thing before we move on here. I'd like to 00:13:42.880 |
separate out our data set into a train and a validation split. So in particular, we're going 00:13:48.320 |
to take the first 90% of the data set and consider that to be the training data for the transformer. 00:13:53.920 |
And we're going to withhold the last 10% at the end of it to be the validation data. And this 00:13:59.360 |
will help us understand to what extent our model is overfitting. So we're going to basically hide 00:14:04.000 |
and keep the validation data on the side, because we don't want just a perfect memorization of this 00:14:08.880 |
exact Shakespeare. We want a neural network that sort of creates Shakespeare-like text. And so it 00:14:15.200 |
should be fairly likely for it to produce the actual, stowed away, true Shakespeare text. 00:14:22.640 |
And so we're going to use this to get a sense of the overfitting. 00:14:27.200 |
Okay, so now we would like to start plugging these text sequences or integer sequences into 00:14:31.920 |
the transformer so that it can train and learn those patterns. Now, the important thing to realize 00:14:37.600 |
is we're never going to actually feed entire text into a transformer all at once. That would be 00:14:42.000 |
computationally very expensive and prohibitive. So when we actually train a transformer on a lot 00:14:46.960 |
of these data sets, we only work with chunks of the data set. And when we train the transformer, 00:14:51.680 |
we basically sample random little chunks out of the training set and train on just chunks at a 00:14:56.480 |
time. And these chunks have basically some kind of a length and some maximum length. Now, the maximum 00:15:04.240 |
length typically, at least in the code I usually write, is called block size. You can find it under 00:15:10.560 |
different names like context length or something like that. Let's start with the block size of just 00:15:14.560 |
eight. And let me look at the first train data characters, the first block size plus one 00:15:19.920 |
characters. I'll explain why plus one in a second. So this is the first nine characters in the 00:15:26.000 |
sequence, in the training set. Now, what I'd like to point out is that when you sample a chunk of 00:15:31.520 |
data like this, so say these nine characters out of the training set, this actually has multiple 00:15:37.440 |
examples packed into it. And that's because all of these characters follow each other. And so what 00:15:44.560 |
this thing is going to say when we plug it into a transformer is we're going to actually 00:15:49.600 |
simultaneously train it to make prediction at every one of these positions. Now, in a chunk 00:15:56.240 |
of nine characters, there's actually eight individual examples packed in there. So there's 00:16:01.440 |
the example that when 18, in the context of 18, 47 likely comes next. In a context of 18 and 47, 00:16:10.080 |
56 comes next. In the context of 18, 47, 56, 57 can come next, and so on. So that's the eight 00:16:18.080 |
individual examples. Let me actually spell it out with code. So here's a chunk of code to illustrate. 00:16:24.160 |
X are the inputs to the transformer. It will just be the first block size characters. Y will be the 00:16:31.760 |
next block size characters. So it's offset by one. And that's because Y are the targets for each 00:16:38.960 |
position in the input. And then here I'm iterating over all the block size of eight. And the context 00:16:45.680 |
is always all the characters in X up to T and including T. And the target is always the T-th 00:16:52.720 |
character, but in the targets array Y. So let me just run this. And basically it spells out what 00:16:59.600 |
I said in words. These are the eight examples hidden in a chunk of nine characters that we 00:17:05.680 |
sampled from the training set. I want to mention one more thing. We train on all the 00:17:13.120 |
eight examples here with context between one all the way up to context of block size. And we train 00:17:19.440 |
on that not just for computational reasons because we happen to have the sequence already or something 00:17:23.280 |
like that. It's not just done for efficiency. It's also done to make the transformer network 00:17:29.440 |
be used to seeing contexts all the way from as little as one all the way to block size. 00:17:35.120 |
And we'd like the transformer to be used to seeing everything in between. And that's going to be 00:17:40.000 |
useful later during inference because while we're sampling, we can start sampling generation with as 00:17:45.440 |
little as one character of context. And the transformer knows how to predict the next 00:17:49.200 |
character with all the way up to just context of one. And so then it can predict everything up to 00:17:54.320 |
block size. And after block size, we have to start truncating because the transformer will never 00:17:59.040 |
receive more than block size inputs when it's predicting the next character. 00:18:03.520 |
Okay, so we've looked at the time dimension of the tensors that are going to be feeding into 00:18:08.560 |
the transformer. There's one more dimension to care about, and that is the batch dimension. 00:18:12.000 |
And so as we're sampling these chunks of text, we're going to be actually every time we're going 00:18:17.760 |
to feed them into a transformer, we're going to have many batches of multiple chunks of text that 00:18:22.080 |
are all stacked up in a single tensor. And that's just done for efficiency just so that we can keep 00:18:26.720 |
the GPUs busy because they are very good at parallel processing of data. And so we just 00:18:33.920 |
want to process multiple chunks all at the same time. But those chunks are processed completely 00:18:38.320 |
independently, they don't talk to each other, and so on. So let me basically just generalize this 00:18:43.200 |
and introduce a batch dimension. Here's a chunk of code. Let me just run it, and then I'm going 00:18:48.560 |
to explain what it does. So here, because we're going to start sampling random locations in the 00:18:54.640 |
data sets to pull chunks from, I am setting the seed in the random number generator so that the 00:19:01.360 |
numbers I see here are going to be the same numbers you see later if you try to reproduce this. 00:19:06.320 |
Now the batch size here is how many independent sequences we are processing every forward-backward 00:19:11.120 |
pass of the transformer. The block size, as I explained, is the maximum context length 00:19:16.640 |
to make those predictions. So let's say batch size 4, block size 8, and then here's how we get batch 00:19:22.240 |
for any arbitrary split. If the split is a training split, then we're going to look at train data, 00:19:27.600 |
otherwise at val data. That gives us the data array. And then when I generate random positions 00:19:35.280 |
to grab a chunk out of, I actually generate batch size number of random offsets. So because this is 00:19:44.000 |
4, ix is going to be 4 numbers that are randomly generated between 0 and len of data minus block 00:19:51.360 |
size. So it's just random offsets into the training set. And then x's, as I explained, 00:19:57.600 |
are the first block size characters starting at i. The y's are the offset by 1 of that, so just add 00:20:06.640 |
plus 1. And then we're going to get those chunks for every one of integers i in ix and use a torch 00:20:14.320 |
dot stack to take all those one-dimensional tensors as we saw here, and we're going to 00:20:21.760 |
stack them up as rows. And so they all become a row in a 4 by 8 tensor. 00:20:28.640 |
So here's where I'm printing them. When I sample a batch xb and yb, 00:20:33.600 |
the inputs to the transformer now are the input x is the 4 by 8 tensor, four rows of eight columns, 00:20:44.960 |
and each one of these is a chunk of the training set. And then the targets here are in the 00:20:52.160 |
associated array y, and they will come in to the transformer all the way at the end 00:20:55.920 |
to create the loss function. So they will give us the correct answer for every single position 00:21:02.960 |
inside x. And then these are the four independent rows. So spelled out as we did before, 00:21:12.320 |
this 4 by 8 array contains a total of 32 examples, and they're completely independent 00:21:18.640 |
as far as the transformer is concerned. So when the input is 24, the target is 43, 00:21:26.560 |
or rather 43 here in the y array. When the input is 24, 43, the target is 58. 00:21:31.840 |
When the input is 24, 43, 58, the target is 5, etc. Or like when it is a 52, 58, 1, the target is 58. 00:21:41.760 |
Right, so you can sort of see this spelled out. These are the 32 independent examples 00:21:46.560 |
packed in to a single batch of the input x, and then the desired targets are in y. 00:21:52.720 |
And so now this integer tensor of x is going to feed into the transformer, 00:22:00.560 |
and that transformer is going to simultaneously process all these examples, and then look up the 00:22:05.840 |
correct integers to predict in every one of these positions in the tensor y. Okay, so now that we 00:22:12.560 |
have our batch of input that we'd like to feed into a transformer, let's start basically feeding 00:22:16.960 |
this into neural networks. Now we're going to start off with the simplest possible neural network, 00:22:21.760 |
which in the case of language modeling, in my opinion, is the bigram language model. 00:22:25.280 |
And we've covered the bigram language model in my Make More series in a lot of depth. And so here 00:22:30.320 |
I'm going to sort of go faster, and let's just implement the PyTorch module directly that 00:22:34.960 |
implements the bigram language model. So I'm importing the PyTorch NN module for reproducibility, 00:22:43.120 |
and then here I'm constructing a bigram language model, which is a subclass of NN module. 00:22:47.680 |
And then I'm calling it, and I'm passing in the inputs and the targets, and I'm just printing. 00:22:54.480 |
Now when the inputs and targets come here, you see that I'm just taking the index, 00:22:59.520 |
the inputs x here, which I renamed to idx, and I'm just passing them into this token embedding table. 00:23:05.120 |
So what's going on here is that here in the constructor, we are creating a token embedding 00:23:10.880 |
table, and it is of size vocab size by vocab size. And we're using an n-dot embedding, which is a 00:23:18.480 |
very thin wrapper around basically a tensor of shape vocab size by vocab size. And what's happening 00:23:24.640 |
here is that when we pass idx here, every single integer in our input is going to refer to this 00:23:30.640 |
embedding table, and is going to pluck out a row of that embedding table corresponding to its index. 00:23:36.000 |
So 24 here will go to the embedding table, and will pluck out the 24th row. And then 43 will go 00:23:43.040 |
here and pluck out the 43rd row, etc. And then PyTorch is going to arrange all of this into a 00:23:48.880 |
batch by time by channel tensor. In this case, batch is 4, time is 8, and c, which is the channels, 00:23:58.480 |
is vocab size or 65. And so we're just going to pluck out all those rows, arrange them in a b by 00:24:04.320 |
t by c, and now we're going to interpret this as the logits, which are basically the scores 00:24:09.520 |
for the next character in the sequence. And so what's happening here is we are predicting what 00:24:15.120 |
comes next based on just the individual identity of a single token. And you can do that because, 00:24:21.200 |
I mean, currently the tokens are not talking to each other, and they're not seeing any context, 00:24:26.080 |
except for they're just seeing themselves. So I'm a token number 5, and then I can actually 00:24:32.320 |
make pretty decent predictions about what comes next just by knowing that I'm token 5, 00:24:36.560 |
because some characters follow other characters in typical scenarios. So we saw a lot of this 00:24:43.680 |
in a lot more depth in the MakeMore series. And here, if I just run this, then we currently get 00:24:49.120 |
the predictions, the scores, the logits for every one of the 4 by 8 positions. Now that we've made 00:24:56.000 |
predictions about what comes next, we'd like to evaluate the loss function. And so in MakeMore 00:25:00.320 |
series, we saw that a good way to measure a loss or a quality of the predictions is to use the 00:25:06.000 |
negative log likelihood loss, which is also implemented in PyTorch under the name cross 00:25:10.320 |
entropy. So what we'd like to do here is loss is the cross entropy on the predictions and the 00:25:17.680 |
targets. And so this measures the quality of the logits with respect to the targets. In other words, 00:25:23.600 |
we have the identity of the next character, so how well are we predicting the next character 00:25:28.560 |
based on the logits? And intuitively, the correct dimension of logits, depending on whatever the 00:25:36.960 |
target is, should have a very high number, and all the other dimensions should be a very low number. 00:25:41.040 |
Now, the issue is that this won't actually-- this is what we want. We want to basically output 00:25:47.280 |
the logits and the loss. This is what we want, but unfortunately, this won't actually run. 00:25:55.040 |
We get an error message. But intuitively, we want to measure this. Now, when we go to the PyTorch 00:26:03.040 |
cross entropy documentation here, we're trying to call the cross entropy in its functional form. 00:26:10.800 |
So that means we don't have to create a module for it. But here, when we go to the documentation, 00:26:16.000 |
you have to look into the details of how PyTorch expects these inputs. And basically, 00:26:20.720 |
the issue here is PyTorch expects, if you have multidimensional input, which we do because we 00:26:26.160 |
have a b by t by c tensor, then it actually really wants the channels to be the second dimension 00:26:33.440 |
here. So basically, it wants a b by c by t instead of a b by t by c. And so it's just the details of 00:26:43.680 |
how PyTorch treats these kinds of inputs. And so we don't actually want to deal with that. So what 00:26:51.200 |
we're going to do instead is we need to basically reshape our logits. So here's what I like to do. 00:26:55.280 |
I like to basically give names to the dimensions. So logits.shape is b by t by c and unpack those 00:27:01.440 |
numbers. And then let's say that logits equals logits.view. And we want it to be a b times t 00:27:09.920 |
by c, so just a two-dimensional array. So we're going to take all of these positions here, and 00:27:19.760 |
we're going to stretch them out in a one-dimensional sequence and preserve the channel dimension as 00:27:25.520 |
the second dimension. So we're just kind of like stretching out the array so it's two-dimensional. 00:27:30.640 |
And in that case, it's going to better conform to what PyTorch sort of expects in its dimensions. 00:27:35.440 |
Now, we have to do the same to targets because currently targets are of shape b by t, 00:27:44.560 |
and we want it to be just b times t, so one-dimensional. Now, alternatively, you could 00:27:49.760 |
always still just do minus one because PyTorch will guess what this should be if you want to 00:27:54.320 |
lay it out. But let me just be explicit and say b times t. Once we reshape this, it will match 00:28:00.400 |
the cross-entropy case, and then we should be able to evaluate our loss. 00:28:04.560 |
Okay, so with that right now, and we can do loss. And so currently we see that the loss is 4.87. 00:28:14.480 |
Now, because we have 65 possible vocabulary elements, we can actually guess at what the loss 00:28:20.480 |
should be. And in particular, we covered negative log-likelihood in a lot of detail. We are expecting 00:28:26.640 |
log or ln of 1/65 and negative of that. So we're expecting the loss to be about 4.17, 00:28:37.280 |
but we're getting 4.87. And so that's telling us that the initial predictions are not super diffuse. 00:28:42.880 |
They've got a little bit of entropy, and so we're guessing wrong. 00:28:45.680 |
So yes, but actually we are able to evaluate the loss. Okay, so now that we can evaluate the 00:28:54.000 |
quality of the model on some data, we'd like to also be able to generate from the model. 00:28:59.120 |
So let's do the generation. Now, I'm going to go again a little bit faster here because I covered 00:29:03.680 |
all this already in previous videos. So here's a generate function for the model. 00:29:12.240 |
So we take the same kind of input, idx here, and basically this is the current 00:29:20.400 |
context of some characters in some batch. So it's also b by t, and the job of generate 00:29:29.040 |
is to basically take this b by t and extend it to be b by t plus 1, plus 2, plus 3. And so it's 00:29:33.920 |
just basically it continues the generation in all the batch dimensions in the time dimension. 00:29:39.120 |
So that's its job, and it will do that for max new tokens. So you can see here on the bottom, 00:29:44.240 |
there's going to be some stuff here, but on the bottom, whatever is predicted is concatenated on 00:29:49.760 |
top of the previous idx along the first dimension, which is the time dimension, to create a b by t 00:29:55.120 |
plus 1. So that becomes a new idx. So the job of generate is to take a b by t and make it a b by t 00:30:01.280 |
plus 1, plus 2, plus 3, as many as we want max new tokens. So this is the generation from the model. 00:30:08.160 |
Now inside the generation, what are we doing? We're taking the current indices, 00:30:12.000 |
we're getting the predictions. So we get those are in the logits, and then the loss here is 00:30:19.200 |
going to be ignored because we're not using that, and we have no targets that are sort of ground 00:30:24.400 |
truth targets that we're going to be comparing with. Then once we get the logits, we are only 00:30:30.400 |
focusing on the last step. So instead of a b by t by c, we're going to pluck out the negative one, 00:30:37.520 |
the last element in the time dimension, because those are the predictions for what comes next. 00:30:41.840 |
So that gives us the logits, which we then convert to probabilities via softmax. And then we use 00:30:47.840 |
torch.multinomial to sample from those probabilities, and we ask PyTorch to give us 00:30:52.080 |
one sample. And so idx next will become a b by 1, because in each one of the batch dimensions, 00:31:00.080 |
we're going to have a single prediction for what comes next. So this num_samples equals 1 00:31:04.560 |
will make this b a 1. And then we're going to take those integers that come from the sampling 00:31:10.400 |
process according to the probability distribution given here, and those integers get just concatenated 00:31:15.360 |
on top of the current sort of like running stream of integers. And this gives us a b by t plus 1. 00:31:20.720 |
And then we can return that. Now one thing here is you see how I'm calling self of idx, which will 00:31:28.720 |
end up going to the forward function. I'm not providing any targets, so currently this would 00:31:33.520 |
give an error because targets is sort of like not given. So targets has to be optional. So targets 00:31:40.480 |
is none by default. And then if targets is none, then there's no loss to create. So it's just loss 00:31:47.680 |
is none. But else all of this happens and we can create a loss. So this will make it so if we have 00:31:55.840 |
the targets, we provide them and get a loss. If we have no targets, we'll just get the logits. 00:32:01.360 |
So this here will generate from the model. And let's take that for a ride now. 00:32:07.200 |
Oops. So I have another code chunk here, which will generate for the model 00:32:13.280 |
from the model. And OK, this is kind of crazy. So maybe let me 00:32:16.720 |
let me break this down. So these are the idx, right? 00:32:20.960 |
I'm creating a batch will be just one time will be just one. 00:32:29.600 |
So I'm creating a little one by one tensor and it's holding a zero and the D type, the data type 00:32:35.520 |
is integer. So zero is going to be how we kick off the generation. And remember that zero is 00:32:41.920 |
the element standing for a new line character. So it's kind of like a reasonable thing to 00:32:47.520 |
feed in as the very first character in a sequence to be the new line. So it's going to be idx, 00:32:54.080 |
which we're going to feed in here. Then we're going to ask for 100 tokens 00:32:58.640 |
and then end that generate will continue that. Now, because generate works on the level of batches, 00:33:06.160 |
we then have to index into the zero throw to basically unplug the single batch dimension 00:33:13.760 |
that exists. And then that gives us a time steps, just a one dimensional array of all the indices, 00:33:21.600 |
which we will convert to simple Python list from PyTorch tensor so that that can feed into 00:33:28.560 |
our decode function and convert those integers into text. So let me bring this back and we're 00:33:35.680 |
generating a hundred tokens. Let's run. And here's the generation that we achieved. 00:33:41.200 |
So obviously it's garbage. And the reason it's garbage is because this is a totally random model. 00:33:46.400 |
So next up, we're going to want to train this model. Now, one more thing I wanted to point 00:33:50.320 |
out here is this function is written to be general, but it's kind of like ridiculous right now because 00:33:58.000 |
we're feeding in all this, we're building out this context and we're concatenating it all. 00:34:02.800 |
And we're always feeding it all into the model. But that's kind of ridiculous because this is 00:34:08.560 |
just a simple bigram model. So to make, for example, this prediction about K, 00:34:11.920 |
we only needed this W, but actually what we fed into the model is we fed the entire sequence. 00:34:17.520 |
And then we only looked at the very last piece and predicted K. So the only reason I'm writing 00:34:23.440 |
it in this way is because right now this is a bigram model, but I'd like to keep this function 00:34:28.320 |
fixed. And I'd like it to work later when our characters actually basically look further in 00:34:36.000 |
the history. And so right now the history is not used. So this looks silly, but eventually the 00:34:41.120 |
history will be used. And so that's why we want to do it this way. So just a quick comment on that. 00:34:46.720 |
So now we see that this is random. So let's train the model. So it becomes a bit less random. 00:34:53.280 |
Okay, let's now train the model. So first what I'm going to do is I'm going to create a PyTorch 00:34:57.760 |
optimization object. So here we are using the optimizer AdamW. Now in the Makemore series, 00:35:05.360 |
we've only ever used stochastic gradient descent, the simplest possible optimizer, which you can get 00:35:09.520 |
using the SGD instead. But I want to use Adam, which is a much more advanced and popular optimizer 00:35:14.880 |
and it works extremely well. For a typical good setting for the learning rate is roughly 3e-4. 00:35:21.760 |
But for very, very small networks, like is the case here, you can get away with much, 00:35:25.520 |
much higher learning rates, 1e-3 or even higher probably. But let me create the optimizer object, 00:35:31.280 |
which will basically take the gradients and update the parameters using the gradients. 00:35:36.400 |
And then here, our batch size up above was only 4. So let me actually use something bigger, 00:35:42.720 |
let's say 32. And then for some number of steps, we are sampling a new batch of data, 00:35:48.480 |
we're evaluating the loss, we're zeroing out all the gradients from the previous step, 00:35:53.200 |
getting the gradients for all the parameters, and then using those gradients to update our 00:35:57.920 |
parameters. So typical training loop, as we saw in the Makemore series. So let me now run this 00:36:04.400 |
for say 100 iterations and let's see what kind of losses we're going to get. 00:36:08.400 |
So we started around 4.7 and now we're getting down to like 4.6, 4.5, etc. So the optimization 00:36:18.320 |
is definitely happening, but let's sort of try to increase the number of iterations and only 00:36:24.800 |
print at the end, because we probably will not train for longer. Okay, so we're down to 3.6, 00:36:31.840 |
roughly. Roughly down to 3. This is the most janky optimization. 00:36:49.040 |
And then from here, we want to copy this. And hopefully, we're going to get something 00:36:57.280 |
reasonable. And of course, it's not going to be Shakespeare from a bigram model, but at least 00:37:01.280 |
we see that the loss is improving. And hopefully, we're expecting something a bit more reasonable. 00:37:06.480 |
Okay, so we're down at about 2.5-ish. Let's see what we get. Okay, dramatic improvements, 00:37:13.600 |
certainly on what we had here. So let me just increase the number of tokens. 00:37:18.240 |
Okay, so we see that we're starting to get something at least like reasonable-ish. 00:37:23.680 |
Certainly not Shakespeare, but the model is making progress. 00:37:30.640 |
So that is the simplest possible model. So now what I'd like to do is... 00:37:35.440 |
Obviously, this is a very simple model because the tokens are not talking to each other. 00:37:41.600 |
So given the previous context of whatever was generated, we're only looking at the very last 00:37:46.160 |
character to make the predictions about what comes next. So now these tokens have to start 00:37:51.280 |
talking to each other and figuring out what is in the context so that they can make better 00:37:56.000 |
predictions for what comes next. And this is how we're going to kick off the transformer. 00:38:00.400 |
Okay, so next, I took the code that we developed in this Jupyter notebook, 00:38:03.360 |
and I converted it to be a script. And I'm doing this because I just want to simplify 00:38:08.640 |
our intermediate work into just the final product that we have at this point. 00:38:11.760 |
So in the top here, I put all the hyperparameters that we've defined. 00:38:16.640 |
I introduced a few, and I'm going to speak to that in a little bit. 00:38:19.360 |
Otherwise, a lot of this should be recognizable. Reproducibility, read data, 00:38:25.200 |
get the encoder and the decoder, create the train and test splits, use the data loader 00:38:31.440 |
that gets a batch of the inputs and targets. This is new, and I'll talk about it in a second. 00:38:38.800 |
Now, this is the background language model that we developed, and it can forward and give us a 00:38:43.120 |
logits and loss, and it can generate. And then here, we are creating the optimizer, 00:38:48.720 |
and this is the training loop. So everything here should look pretty familiar. Now, 00:38:54.480 |
some of the small things that I added, number one, I added the ability to run on a GPU if you have it. 00:39:00.560 |
So if you have a GPU, then this will use CUDA instead of just CPU, and everything will be a 00:39:05.520 |
lot more faster. Now, when device becomes CUDA, then we need to make sure that when we load the 00:39:11.040 |
data, we move it to device. When we create the model, we want to move the model parameters to 00:39:17.840 |
device. So as an example, here we have the NN embedding table, and it's got a dot weight inside 00:39:24.080 |
it, which stores the lookup table. So that would be moved to the GPU so that all the calculations 00:39:30.160 |
here happen on the GPU, and they can be a lot faster. And then finally here, when I'm creating 00:39:35.120 |
the context that feeds it to generate, I have to make sure that I create on the device. 00:39:39.200 |
Number two, what I introduced is the fact that here in the training loop, 00:39:45.680 |
here I was just printing the loss.item inside the training loop, but this is a very noisy 00:39:53.840 |
measurement of the current loss because every batch will be more or less lucky. 00:39:58.400 |
And so what I want to do usually is I have an estimate loss function, and the estimate loss 00:40:05.360 |
basically then goes up here, and it averages up the loss over multiple batches. So in particular, 00:40:14.560 |
we're going to iterate eval_iter_times, and we're going to basically get our loss, and then we're 00:40:19.520 |
going to get the average loss for both splits. And so this will be a lot less noisy. So here, 00:40:25.280 |
when we call the estimate loss, we're going to report the pretty accurate train and validation 00:40:30.240 |
loss. Now, when we come back up, you'll notice a few things here. I'm setting the model to 00:40:35.680 |
evaluation phase, and down here I'm resetting it back to training phase. Now, right now for 00:40:41.520 |
our model as is, this doesn't actually do anything because the only thing inside this model is this 00:40:46.720 |
nn.embedding, and this network would behave the same in both evaluation mode and training mode. 00:40:56.320 |
We have no dropout layers, we have no batch norm layers, etc. But it is a good practice to think 00:41:01.040 |
through what mode your neural network is in because some layers will have different behavior 00:41:06.400 |
at inference time or training time. And there's also this context manager, torch.nograd, 00:41:13.680 |
and this is just telling PyTorch that everything that happens inside this function, 00:41:17.280 |
we will not call .backward on. And so PyTorch can be a lot more efficient with its memory use 00:41:23.200 |
because it doesn't have to store all the intermediate variables because we're never 00:41:27.120 |
going to call backward. And so it can be a lot more memory efficient in that way. 00:41:31.680 |
So also a good practice to tell PyTorch when we don't intend to do backpropagation. 00:41:37.600 |
So right now, this script is about 120 lines of code, and that's kind of our starter code. 00:41:44.320 |
I'm calling it bigram.py, and I'm going to release it later. Now running this script 00:41:49.760 |
gives us output in the terminal, and it looks something like this. 00:41:53.600 |
It basically, as I ran this code, it was giving me the train loss and the val loss, 00:41:59.520 |
and we see that we convert to somewhere around 2.5 with the bigram model. And then here's the 00:42:04.960 |
sample that we produced at the end. And so we have everything packaged up in the script, 00:42:10.640 |
and we're in a good position now to iterate on this. Okay, so we are almost ready to start 00:42:15.280 |
writing our very first self-attention block for processing these tokens. Now, 00:42:21.280 |
before we actually get there, I want to get you used to a mathematical trick that is used in the 00:42:26.800 |
self-attention inside a transformer, and is really just at the heart of an efficient implementation 00:42:32.880 |
of self-attention. And so I want to work with this toy example to just get you used to this 00:42:37.360 |
operation, and then it's going to make it much more clear once we actually get to it in the 00:42:43.360 |
script again. So let's create a b_t_t_c, where b, t, and c are just 4, 8, and 2 in this toy example. 00:42:50.240 |
And these are basically channels, and we have batches, and we have the time component, 00:42:56.640 |
and we have some information at each point in the sequence, so c. Now, what we would like to do is 00:43:03.280 |
we would like these tokens, so we have up to 8 tokens here in a batch, and these 8 tokens are 00:43:09.600 |
currently not talking to each other, and we would like them to talk to each other. We'd like to 00:43:13.040 |
couple them. And in particular, we want to couple them in a very specific way. So the token, for 00:43:20.960 |
example, at the fifth location, it should not communicate with tokens in the sixth, seventh, 00:43:26.000 |
and eighth location, because those are future tokens in the sequence. The token on the fifth 00:43:31.760 |
location should only talk to the one in the fourth, third, second, and first. So information 00:43:37.600 |
only flows from previous context to the current time step, and we cannot get any information from 00:43:42.640 |
the future, because we are about to try to predict the future. So what is the easiest way for tokens 00:43:49.600 |
to communicate? The easiest way, I would say, is if we're a fifth token and I'd like to communicate 00:43:56.640 |
with my past, the simplest way we can do that is to just do an average of all the preceding elements. 00:44:06.000 |
So for example, if I'm the fifth token, I would like to take the channels that make up, that are 00:44:12.080 |
information at my step, but then also the channels from the fourth step, third step, second step, 00:44:16.880 |
and the first step, I'd like to average those up, and then that would become sort of like a 00:44:21.200 |
feature vector that summarizes me in the context of my history. Now, of course, just doing a sum, 00:44:27.360 |
or like an average, is an extremely weak form of interaction. Like this communication is extremely 00:44:32.160 |
lossy. We've lost a ton of information about the spatial arrangements of all those tokens, 00:44:36.160 |
but that's okay for now. We'll see how we can bring that information back later. 00:44:40.080 |
For now, what we would like to do is, for every single batch element independently, 00:44:45.920 |
for every t-th token in that sequence, we'd like to now calculate the average of all the 00:44:52.880 |
vectors in all the previous tokens, and also at this token. So let's write that out. 00:44:58.320 |
I have a small snippet here, and instead of just fumbling around, 00:45:02.720 |
let me just copy paste it and talk to it. So in other words, we're going to create x, 00:45:09.760 |
and B-O-W is short for bag of words, because bag of words is kind of like a term that people use 00:45:17.360 |
when you are just averaging up things. So this is just a bag of words. Basically, there's a word 00:45:21.760 |
stored on every one of these eight locations, and we're doing a bag of words, we're just averaging. 00:45:26.080 |
So in the beginning, we're going to say that it's just initialized at zero, 00:45:30.480 |
and then I'm doing a for loop here, so we're not being efficient yet. That's coming. But for now, 00:45:34.800 |
we're just iterating over all the batch dimensions independently, iterating over time, and then 00:45:40.560 |
the previous tokens are at this batch dimension, and then everything up to and including the t-th 00:45:48.720 |
token. So when we slice out x in this way, xprev becomes of shape how many t elements there were 00:45:58.640 |
in the past, and then of course, c, so all the two-dimensional information from these little 00:46:03.600 |
tokens. So that's the previous sort of chunk of tokens from my current sequence. And then I'm 00:46:12.160 |
just doing the average, or the mean, over the zero of dimensions. So I'm averaging out the time here, 00:46:17.520 |
and I'm just going to get a little c one-dimensional vector, which I'm going to 00:46:22.080 |
store in x bag of words. So I can run this, and this is not going to be very informative, because 00:46:31.040 |
let's see, so this is x of zero, so this is the zeroth batch element, and then xbow at zero. 00:46:36.400 |
Now you see how at the first location here, you see that the two are equal, and that's because 00:46:44.320 |
we're just doing an average of this one token. But here, this one is now an average of these two, 00:46:50.560 |
and now this one is an average of these three, and so on. And this last one is the average 00:47:01.120 |
of all of these elements, so vertical average, just averaging up all the tokens, now gives this 00:47:06.240 |
outcome here. So this is all well and good, but this is very inefficient. Now the trick is that 00:47:13.040 |
we can be very, very efficient about doing this using matrix multiplication. So that's the 00:47:18.000 |
mathematical trick, and let me show you what I mean. Let's work with the toy example here. 00:47:22.000 |
Let me run it, and I'll explain. I have a simple matrix here that is a three by three of all ones. 00:47:29.920 |
A matrix B of just random numbers, and it's a three by two, and a matrix C, which will be three 00:47:34.800 |
by three multiply three by two, which will give out a three by two. So here we're just using 00:47:40.240 |
matrix multiplication. So A multiply B gives us C. Okay, so how are these numbers in C 00:47:50.080 |
achieved, right? So this number in the top left is the first row of A dot product with the first 00:47:58.320 |
column of B. And since all the row of A right now is all just ones, then the dot product here with 00:48:05.680 |
this column of B is just going to do a sum of this column. So two plus six plus six is 14. 00:48:12.400 |
The element here in the output of C is also the first column here, the first row of A, 00:48:18.320 |
multiplied now with the second column of B. So seven plus four plus five is 16. Now you see that 00:48:25.280 |
there's repeating elements here. So this 14 again is because this row is again all ones, 00:48:29.520 |
and it's multiplying the first column of B. So we get 14. And this one is, and so on. So this last 00:48:35.360 |
number here is the last row dot product last column. Now the trick here is the following. 00:48:42.400 |
This is just a boring number of, it's just a boring array of all ones. But Torch has this 00:48:49.840 |
function called trill, which is short for a triangular, something like that. And you can 00:48:56.880 |
wrap it in Torch dot ones, and it will just return the lower triangular portion of this. 00:49:01.520 |
So now it will basically zero out these guys here. So we just get the lower triangular part. 00:49:11.680 |
So now we'll have A like this and B like this. And now what are we getting here in C? 00:49:20.240 |
Well, what is this number? Well, this is the first row times the first column. And because 00:49:25.680 |
this is zeros, these elements here are now ignored. So we just get a two. And then this 00:49:32.720 |
number here is the first row times the second column. And because these are zeros, they get 00:49:37.760 |
ignored, and it's just seven. The seven multiplies this one. But look what happened here. Because 00:49:43.520 |
this is one and then zeros, what ended up happening is we're just plucking out the row, 00:49:48.800 |
this row of B, and that's what we got. Now here we have one, one, zero. So here, 00:49:55.840 |
one, one, zero dot product with these two columns will now give us two plus six, which is eight, 00:50:00.880 |
and seven plus four, which is 11. And because this is one, one, one, we ended up with the addition 00:50:07.040 |
of all of them. And so basically, depending on how many ones and zeros we have here, we are 00:50:12.720 |
basically doing a sum currently of the variable number of these rows, and that gets deposited 00:50:19.520 |
into C. So currently, we're doing sums because these are ones, but we can also do average, 00:50:25.680 |
right? And you can start to see how we could do average of the rows of B sort of in an incremental 00:50:32.160 |
fashion. Because we don't have to, we can basically normalize these rows so that they sum to one, 00:50:38.400 |
and then we're going to get an average. So if we took A, and then we did A equals A divide, 00:50:43.520 |
torch dot sum of A in the one-th dimension, and then let's keep dim as true. So therefore, 00:50:56.800 |
the broadcasting will work out. So if I rerun this, you see now that these rows now sum to one. 00:51:03.600 |
So this row is one, this row is 0.5, 0.5 is zero, and here we get one-thirds. And now when we do A 00:51:09.760 |
multiply B, what are we getting? Here we are just getting the first row, first row. Here now we are 00:51:16.320 |
getting the average of the first two rows. Okay, so two and six average is four, and four and seven 00:51:23.920 |
average is 5.5. And on the bottom here, we are now getting the average of these three rows. 00:51:31.360 |
So the average of all of elements of B are now deposited here. And so you can see that by 00:51:37.600 |
manipulating these elements of this multiplying matrix, and then multiplying it with any given 00:51:44.400 |
matrix, we can do these averages in this incremental fashion. Because we just get, 00:51:49.680 |
and we can manipulate that based on the elements of A. Okay, so that's very convenient. So let's 00:51:56.640 |
swing back up here and see how we can vectorize this and make it much more efficient using what 00:52:00.880 |
we've learned. So in particular, we are going to produce an array A, but here I'm going to call it 00:52:07.840 |
"weigh," short for "weights." But this is our A, and this is how much of every row we want to average 00:52:15.920 |
up. And it's going to be an average because you can see that these rows sum to one. So this is our 00:52:21.600 |
A, and then our B in this example, of course, is X. So what's going to happen here now is that we 00:52:29.520 |
are going to have an expo two. And this expo two is going to be weigh multiplying our X. 00:52:38.000 |
So let's think this through. Weigh is T by T, and this is matrix multiplying in PyTorch, 00:52:45.200 |
a B by T by C. And it's giving us what shape. So PyTorch will come here and it will see that 00:52:53.360 |
these shapes are not the same. So it will create a batch dimension here. And this is a batch matrix 00:52:59.280 |
multiply. And so it will apply this matrix multiplication in all the batch elements in 00:53:05.040 |
parallel and individually. And then for each batch element, there will be a T by T multiplying T by C 00:53:11.760 |
exactly as we had below. So this will now create B by T by C, and expo two will now become identical 00:53:24.480 |
to expo. So we can see that torch.allclose of expo and expo two should be true. Now, 00:53:35.360 |
so this kind of like convinces us that these are in fact the same. So expo and expo two, 00:53:44.880 |
if I just print them, okay, we're not going to be able to just stare it down. But 00:53:52.960 |
well, let me try expo basically just at the zeroth element and expo two at the zeroth element. So 00:53:57.360 |
just the first batch, and we should see that this and that should be identical, which they are. 00:54:02.400 |
Right. So what happened here? The trick is we were able to use batch matrix multiply 00:54:08.720 |
to do this aggregation, really. And it's a weighted aggregation. And the weights are specified in this 00:54:18.160 |
T by T array. And we're basically doing weighted sums. And these weighted sums are according to 00:54:25.920 |
the weights inside here, they take on sort of this triangular form. And so that means that a token at 00:54:32.640 |
the T dimension will only get sort of information from the tokens preceding it. So that's exactly 00:54:40.240 |
what we want. And finally, I would like to rewrite it in one more way. And we're going to see why 00:54:45.520 |
that's useful. So this is the third version. And it's also identical to the first and second. 00:54:51.280 |
But let me talk through it. It uses softmax. So trill here is this matrix, lower triangular ones, 00:55:00.560 |
way begins as all zero. Okay, so if I just print way in the beginning, it's all zero, 00:55:09.520 |
then I use masked fill. So what this is doing is way dot masked fill, it's all zeros. And I'm saying 00:55:17.840 |
for all the elements where trill is equal to equal zero, make them be negative infinity. 00:55:23.760 |
So all the elements where trill is zero will become negative infinity now. 00:55:28.080 |
So this is what we get. And then the final one here is softmax. 00:55:36.240 |
So if I take a softmax along every single, so dim is negative one, so along every single row, 00:55:40.800 |
if I do a softmax, what is that going to do? Well, softmax is also like a normalization operation, 00:55:52.160 |
right? And so spoiler alert, you get the exact same matrix. Let me bring back the softmax. 00:55:59.120 |
And recall that in softmax, we're going to exponentiate every single one of these. 00:56:04.720 |
And then we're going to divide by the sum. And so if we exponentiate every single element here, 00:56:09.760 |
we're going to get a one. And here we're going to get basically zero, zero, zero, zero, everywhere 00:56:15.040 |
else. And then when we normalize, we just get one. Here, we're going to get one, one, and then 00:56:20.960 |
zeros. And the softmax will again divide, and this will give us 0.5, 0.5, and so on. And so this is 00:56:27.680 |
also the same way to produce this mask. Now, the reason that this is a bit more interesting, 00:56:34.560 |
and the reason we're going to end up using it in self-attention, is that these weights here begin 00:56:41.120 |
with zero. And you can think of this as like an interaction strength, or like an affinity. So 00:56:47.600 |
basically, it's telling us how much of each token from the past do we want to aggregate and average 00:56:54.720 |
up. And then this line is saying, tokens from the past cannot communicate. By setting them to 00:57:02.160 |
negative infinity, we're saying that we will not aggregate anything from those tokens. And so 00:57:08.000 |
basically, this then goes through softmax, and through the weighted, and this is the aggregation 00:57:11.840 |
through matrix multiplication. And so what this is now is, you can think of these as, these zeros 00:57:19.920 |
are currently just set by us to be zero. But a quick preview is that these affinities between 00:57:26.240 |
the tokens are not going to be just constant at zero, they're going to be data dependent. These 00:57:31.280 |
tokens are going to start looking at each other, and some tokens will find other tokens more or 00:57:35.920 |
less interesting. And depending on what their values are, they're going to find each other 00:57:40.880 |
interesting to different amounts, and I'm going to call those affinities, I think. And then here, 00:57:45.680 |
we are saying, the future cannot communicate with the past. We're going to clamp them. And then when 00:57:51.440 |
we normalize and sum, we're going to aggregate their values, depending on how interesting they 00:57:57.120 |
find each other. And so that's the preview for self-attention. And basically, long story short 00:58:03.280 |
from this entire section is that you can do weighted aggregations of your past elements 00:58:09.440 |
by using matrix multiplication of a lower triangular fashion. And then the elements 00:58:16.640 |
here in the lower triangular part are telling you how much of each element fuses into this position. 00:58:22.480 |
So we're going to use this trick now to develop the self-attention block. 00:58:26.240 |
So first, let's get some quick preliminaries out of the way. First, the thing I'm kind of 00:58:31.200 |
bothered by is that you see how we're passing in vocab size into the constructor? There's no need 00:58:35.280 |
to do that because vocab size is already defined up top as a global variable. So there's no need 00:58:39.840 |
to pass this stuff around. Next, what I want to do is I don't want to actually create, I want to 00:58:46.160 |
create like a level of indirection here where we don't directly go to the embedding for the logits, 00:58:51.920 |
but instead we go through this intermediate phase because we're going to start making that bigger. 00:58:56.400 |
So let me introduce a new variable, nembed. It's short for number of embedding dimensions. 00:59:03.520 |
So nembed here will be, say, 32. That was a suggestion from GitHub Copilot, by the way. 00:59:11.120 |
It also suggested 32, which is a good number. So this is an embedding table and only 32-dimensional 00:59:18.160 |
embeddings. So then here, this is not going to give us logits directly. Instead, this is going 00:59:23.760 |
to give us token embeddings. That's what I'm going to call it. And then to go from the token embeddings 00:59:28.720 |
to the logits, we're going to need a linear layer. So self.lmhead, let's call it, short for language 00:59:34.560 |
modeling head, is nnlinear from nembed up to vocab size. And then when we swing over here, 00:59:41.040 |
we're actually going to get the logits by exactly what the Copilot says. Now, we have to be careful 00:59:46.720 |
here because this c and this c are not equal. This is nembed c and this is vocab size. So let's just 00:59:55.440 |
say that nembed is equal to c. And then this just creates one spurious layer of interaction through 01:00:02.160 |
a linear layer, but this should basically run. So we see that this runs and this currently looks 01:00:15.600 |
kind of spurious, but we're going to build on top of this. Now, next up. So far, we've taken these 01:00:21.280 |
indices and we've encoded them based on the identity of the tokens inside IDX. The next 01:00:28.320 |
thing that people very often do is that we're not just encoding the identity of these tokens, 01:00:32.800 |
but also their position. So we're going to have a second position embedding table here. 01:00:38.160 |
So self.position_embedding_table is an embedding of block size by nembed. And so each position 01:00:44.960 |
from zero to block size minus one will also get its own embedding vector. And then here, 01:00:50.400 |
first, let me decode b by t from IDX.shape. And then here, we're also going to have a 01:00:56.560 |
pause embedding, which is the positional embedding. And this is tor-arrange. So this will be 01:01:02.560 |
basically just integers from zero to t minus one. And all of those integers from zero to t minus one 01:01:08.160 |
get embedded through the table to create a t by c. And then here, this gets renamed to just say x. 01:01:16.160 |
And x will be the addition of the token embeddings with the positional embeddings. 01:01:20.560 |
And here, the broadcasting node will work out. So b by t by c plus t by c, this gets right-aligned, 01:01:27.760 |
a new dimension of one gets added, and it gets broadcasted across batch. 01:01:31.360 |
So at this point, x holds not just the token identities, but the positions at which these 01:01:37.920 |
tokens occur. And this is currently not that useful because, of course, we just have a simple 01:01:42.480 |
bigram model. So it doesn't matter if you're in the fifth position, the second position, or wherever. 01:01:46.880 |
It's all translation invariant at this stage. So this information currently wouldn't help. 01:01:51.200 |
But as we work on the self-attention block, we'll see that this starts to matter. 01:01:55.520 |
Okay, so now we get the crux of self-attention. So this is probably the most important part of 01:02:03.840 |
this video to understand. We're going to implement a small self-attention for a single individual 01:02:09.280 |
head, as they're called. So we start off with where we were. So all of this code is familiar. 01:02:14.400 |
So right now, I'm working with an example where I changed the number of channels from 2 to 32. 01:02:19.920 |
So we have a 4 by 8 arrangement of tokens. And the information at each token is currently 32 01:02:27.360 |
dimensional. But we just are working with random numbers. Now, we saw here that the code as we had 01:02:34.480 |
it before does a simple weight, simple average of all the past tokens and the current token. So it's 01:02:42.800 |
just the previous information and current information is just being mixed together in an 01:02:46.080 |
average. And that's what this code currently achieves. And it does so by creating this lower 01:02:50.800 |
triangular structure, which allows us to mask out this weight matrix that we create. So we mask it 01:02:59.040 |
out and then we normalize it. And currently, when we initialize the affinities between all the 01:03:05.040 |
different sort of tokens or nodes, I'm going to use those terms interchangeably. So when we 01:03:11.200 |
initialize the affinities between all the different tokens to be 0, then we see that 01:03:15.680 |
weight gives us this structure where every single row has these uniform numbers. And so that's what 01:03:23.760 |
then in this matrix multiply makes it so that we're doing a simple average. Now, we don't actually 01:03:31.360 |
want this to be all uniform because different tokens will find different other tokens more or 01:03:38.880 |
less interesting. And we want that to be data dependent. So, for example, if I'm a vowel, 01:03:42.880 |
then maybe I'm looking for consonants in my past and maybe I want to know what those consonants 01:03:48.000 |
are and I want that information to flow to me. And so I want to now gather information from the 01:03:53.520 |
past, but I want to do it in a data dependent way. And this is the problem that self-attention 01:03:57.840 |
solves. Now, the way self-attention solves this is the following. Every single node or every single 01:04:04.400 |
token at each position will emit two vectors. It will emit a query and it will emit a key. 01:04:13.520 |
Now, the query vector, roughly speaking, is what am I looking for? And the key vector, 01:04:19.200 |
roughly speaking, is what do I contain? And then the way we get affinities between these tokens now 01:04:26.960 |
in a sequence is we basically just do a dot product between the keys and the queries. So, 01:04:32.880 |
my query dot products with all the keys of all the other tokens and that dot product now becomes 01:04:40.320 |
way. And so if the key and the query are sort of aligned, they will interact to a very high amount 01:04:48.880 |
and then I will get to learn more about that specific token as opposed to any other token 01:04:55.040 |
in the sequence. So, let's implement this now. We're going to implement a single 01:05:04.560 |
what's called head of self-attention. So, this is just one head. There's a hyperparameter involved 01:05:10.800 |
with these heads, which is the head size. And then here I'm initializing linear modules and 01:05:16.720 |
I'm using bias equals false. So, these are just going to apply a matrix multiply with some fixed 01:05:21.040 |
weights. And now let me produce a key and queue, k and q, by forwarding these modules on x. So, 01:05:30.960 |
the size of this will now become b by t by 16 because that is the head size and the same here, 01:05:38.640 |
b by t by 16. So, this being the head size. So, you see here that when I forward this linear 01:05:49.760 |
on top of my x, all the tokens in all the positions in the b by t arrangement, all of them 01:05:55.760 |
in parallel and independently produce a key and a query. So, no communication has happened yet. 01:06:01.040 |
But the communication comes now. All the queries will dot product with all the keys. 01:06:06.960 |
So, basically what we want is we want way now or the affinities between these to be query 01:06:14.480 |
multiplying key. But we have to be careful with, we can't matrix multiply this. We actually need 01:06:19.440 |
to transpose k, but we have to be also careful because these are, when you have the batch 01:06:25.760 |
dimension. So, in particular, we want to transpose the last two dimensions, dimension negative one 01:06:32.160 |
and dimension negative two. So, negative two, negative one. And so, this matrix multiply now 01:06:39.840 |
will basically do the following b by t by 16. Matrix multiplies b by 16 by t to give us b by t 01:06:51.600 |
by t. Right? So, for every row of b, we're now going to have a t square matrix giving us the 01:07:00.960 |
affinities. And these are now the way. So, they're not zeros. They are now coming from this dot 01:07:07.120 |
product between the keys and the queries. So, this can now run. I can run this. And the weighted 01:07:13.760 |
aggregation now is a function in a data-dependent manner between the keys and queries of these 01:07:18.960 |
nodes. So, just inspecting what happened here, the way takes on this form. And you see that 01:07:27.440 |
before way was just a constant. So, it was applied in the same way to all the batch elements. But 01:07:33.040 |
now every single batch element will have different sort of way because every single batch element 01:07:38.480 |
contains different tokens at different positions. And so, this is now data-dependent. So, when we 01:07:44.480 |
look at just the zeroth row, for example, in the input, these are the weights that came out. And 01:07:50.960 |
so, you can see now that they're not just exactly uniform. And in particular, as an example here 01:07:56.320 |
for the last row, this was the eighth token. And the eighth token knows what content it has, and 01:08:01.680 |
it knows at what position it's in. And now the eighth token, based on that, creates a query. 01:08:08.400 |
Hey, I'm looking for this kind of stuff. I'm a vowel. I'm on the eighth position. I'm looking 01:08:12.880 |
for any consonants at positions up to four. And then all the nodes get to emit keys. And maybe 01:08:19.760 |
one of the channels could be I am a consonant, and I am in a position up to four. And that key 01:08:26.080 |
would have a high number in that specific channel. And that's how the query and the key when they 01:08:30.880 |
dark product, they can find each other and create a high affinity. And when they have a high affinity, 01:08:35.760 |
like say this token was pretty interesting to this eighth token, when they have a high affinity, 01:08:43.680 |
then through the softmax, I will end up aggregating a lot of its information into my position. And so, 01:08:49.360 |
I'll get to learn a lot about it. Now, we're looking at way after this has already happened. 01:08:59.120 |
Let me erase this operation as well. So, let me erase the masking and the softmax, 01:09:03.120 |
just to show you the under the hood internals and how that works. So, without the masking 01:09:07.840 |
and the softmax, way comes out like this, right? This is the outputs of the dark products. 01:09:12.320 |
And these are the raw outputs, and they take on values from negative two to positive two, 01:09:17.920 |
et cetera. So, that's the raw interactions and raw affinities between all the nodes. 01:09:24.160 |
But now, if I'm a fifth node, I will not want to aggregate anything from the sixth node, 01:09:29.360 |
seventh node, and the eighth node. So, actually, we use the upper triangular masking. So, those 01:09:35.200 |
are not allowed to communicate. And now, we actually want to have a nice distribution. 01:09:41.680 |
So, we don't want to aggregate negative 0.11 of this node. That's crazy. So, instead, 01:09:46.880 |
we exponentiate and normalize. And now, we get a nice distribution that sums to one. 01:09:51.440 |
And this is telling us now in a data-dependent manner how much of information to aggregate 01:09:55.600 |
from any of these tokens in the past. So, that's way, and it's not zeros anymore, 01:10:02.160 |
but it's calculated in this way. Now, there's one more part to a single self-attention head. 01:10:10.080 |
And that is that when we do the aggregation, we don't actually aggregate the tokens exactly. 01:10:14.240 |
We aggregate, we produce one more value here, and we call that the value. 01:10:21.040 |
So, in the same way that we produced key and query, we're also going to create a value. 01:10:24.640 |
And then, here, we don't aggregate x. We calculate a v, which is just achieved by propagating this 01:10:36.400 |
linear on top of x again. And then, we output way multiplied by v. So, v is the elements that 01:10:44.720 |
we aggregate, or the vector that we aggregate, instead of the raw x. And now, of course, this 01:10:51.040 |
will make it so that the output here of the single head will be 16-dimensional, because that is the 01:10:55.680 |
head size. So, you can think of x as kind of like private information to this token, 01:11:01.600 |
if you think about it that way. So, x is kind of private to this token. So, I'm a fifth token, 01:11:06.720 |
and I have some identity, and my information is kept in vector x. And now, for the purposes of 01:11:14.400 |
the single head, here's what I'm interested in, here's what I have, and if you find me interesting, 01:11:21.200 |
here's what I will communicate to you. And that's stored in v. And so, v is the thing that gets 01:11:26.320 |
aggregated for the purposes of this single head between the different nodes. And that's basically 01:11:33.840 |
the self-attention mechanism. This is what it does. There are a few notes that I would like to make 01:11:40.080 |
about attention. Number one, attention is a communication mechanism. You can really think 01:11:45.680 |
about it as a communication mechanism where you have a number of nodes in a directed graph, 01:11:50.240 |
where basically you have edges pointed between nodes like this. And what happens is every node 01:11:56.480 |
has some vector of information, and it gets to aggregate information via a weighted sum from all 01:12:02.160 |
of the nodes that point to it. And this is done in a data-dependent manner, so depending on whatever 01:12:07.760 |
data is actually stored at each node at any point in time. Now, our graph doesn't look like this. 01:12:13.840 |
Our graph has a different structure. We have eight nodes because the block size is eight, 01:12:18.240 |
and there's always eight tokens. And the first node is only pointed to by itself. The second 01:12:25.040 |
node is pointed to by the first node and itself, all the way up to the eighth node, which is pointed 01:12:30.160 |
to by all the previous nodes and itself. And so, that's the structure that our directed graph has, 01:12:36.240 |
or happens to have, in an autoregressive sort of scenario like language modeling. 01:12:40.400 |
But in principle, attention can be applied to any arbitrary directed graph, and it's just a 01:12:44.560 |
communication mechanism between the nodes. The second node is that, notice that there's no notion 01:12:49.680 |
of space. So, attention simply acts over a set of vectors in this graph. And so, by default, 01:12:56.400 |
these nodes have no idea where they are positioned in the space. And that's why we need to encode 01:13:00.480 |
them positionally and sort of give them some information that is anchored to a specific 01:13:04.800 |
position so that they sort of know where they are. And this is different than, for example, 01:13:09.840 |
from convolution, because if you run, for example, a convolution operation over some input, 01:13:13.840 |
there is a very specific sort of layout of the information in space, and the convolutional 01:13:19.280 |
filters sort of act in space. And so, it's not like an attention. An attention is just a set 01:13:25.920 |
of vectors out there in space. They communicate. And if you want them to have a notion of space, 01:13:30.720 |
you need to specifically add it, which is what we've done when we calculated the 01:13:34.480 |
positional encodings and added that information to the vectors. 01:13:40.160 |
The next thing that I hope is very clear is that the elements across the batch dimension, 01:13:44.400 |
which are independent examples, never talk to each other. They're always processed independently. 01:13:48.480 |
And this is a batched matrix multiply that applies basically a matrix multiplication 01:13:52.400 |
kind of in parallel across the batch dimension. So, maybe it would be more accurate to say that 01:13:57.280 |
in this analogy of a directed graph, we really have, because the batch size is four, 01:14:02.000 |
we really have four separate pools of eight nodes, and those eight nodes only talk to each other. 01:14:07.200 |
But in total, there's like 32 nodes that are being processed, but there's sort of four separate 01:14:12.480 |
pools of eight. You can look at it that way. The next note is that here in the case of language 01:14:18.240 |
modeling, we have this specific structure of directed graph where the future tokens will 01:14:24.320 |
not communicate to the past tokens. But this doesn't necessarily have to be the constraint 01:14:28.960 |
in the general case. And in fact, in many cases, you may want to have all of the nodes talk to each 01:14:34.800 |
other fully. So, as an example, if you're doing sentiment analysis or something like that with 01:14:39.200 |
a transformer, you might have a number of tokens and you may want to have them all talk to each 01:14:43.920 |
other fully because later you are predicting, for example, the sentiment of the sentence. 01:14:48.320 |
And so, it's okay for these nodes to talk to each other. And so, in those cases, 01:14:53.440 |
you will use an encoder block of self-attention. And all it means that it's an encoder block is 01:14:59.600 |
that you will delete this line of code, allowing all the nodes to completely talk to each other. 01:15:04.560 |
What we're implementing here is sometimes called a decoder block. And it's called a decoder 01:15:09.360 |
because it is sort of like decoding language. And it's got this autoregressive format where you have 01:15:16.880 |
to mask with the triangular matrix so that nodes from the future never talk to the past. 01:15:23.200 |
Because they would give away the answer. And so, basically, in encoder blocks, you would delete 01:15:27.920 |
this, allow all the nodes to talk. In decoder blocks, this will always be present so that you 01:15:33.040 |
have this triangular structure. But both are allowed and attention doesn't care. Attention 01:15:37.280 |
supports arbitrary connectivity between nodes. The next thing I wanted to comment on is you keep 01:15:41.680 |
hearing me say attention, self-attention, etc. There's actually also something called cross 01:15:46.400 |
attention. What is the difference? Basically, the reason this attention is self-attention 01:15:53.760 |
is because the keys, queries, and the values are all coming from the same source, from x. 01:16:00.240 |
So the same source, x, produces keys, queries, and values. So these nodes are self-attending. 01:16:06.240 |
But in principle, attention is much more general than that. For example, in encoder-decoder 01:16:11.360 |
transformers, you can have a case where the queries are produced from x, but the keys and 01:16:17.120 |
the values come from a whole separate external source. And sometimes from encoder blocks that 01:16:22.640 |
encode some context that we'd like to condition on. And so the keys and the values will actually 01:16:26.960 |
come from a whole separate source. Those are nodes on the side. And here we're just producing 01:16:31.760 |
queries and we're reading off information from the side. So cross attention is used when there's a 01:16:38.080 |
separate source of nodes we'd like to pull information from into our nodes. And it's 01:16:44.080 |
self-attention if we just have nodes that would like to look at each other and talk to each other. 01:16:47.440 |
So this attention here happens to be self-attention. But in principle, 01:16:53.440 |
attention is a lot more general. Okay, and the last note at this stage is if we come to the 01:16:59.520 |
attention is all you need paper here. We've already implemented attention. So given query, 01:17:03.840 |
key and value, we've multiplied the query on the key. We've soft-maxed it. And then we are 01:17:09.680 |
aggregating the values. There's one more thing that we're missing here, which is the dividing 01:17:13.680 |
by one over square root of the head size. The decay here is the head size. Why are they doing 01:17:18.960 |
this? Why is this important? So they call it a scaled attention. And it's kind of like an 01:17:24.880 |
important normalization to basically have. The problem is if you have unit Gaussian inputs, 01:17:30.160 |
so zero mean unit variance, k and q are unit Gaussian. And if you just do weigh naively, 01:17:35.520 |
then you see that your weigh actually will be, the variance will be on the order of head size, 01:17:40.000 |
which in our case is 16. But if you multiply by one over head size square root, so this is square 01:17:46.000 |
root and this is one over, then the variance of weigh will be one. So it will be preserved. 01:17:51.440 |
Now, why is this important? You'll notice that weigh here will feed into soft-max. 01:17:59.520 |
And so it's really important, especially at initialization, that weigh be fairly diffuse. 01:18:03.920 |
So in our case here, we sort of locked out here and weigh had a fairly diffuse numbers here. So 01:18:11.760 |
like this. Now, the problem is that because of soft-max, if weigh takes on very positive and 01:18:18.160 |
very negative numbers inside it, soft-max will actually converge towards one-hot vectors. 01:18:24.160 |
And so I can illustrate that here. Say we are applying soft-max to a tensor of values that 01:18:30.880 |
are very close to zero, then we're going to get a diffuse thing out of soft-max. But the moment I 01:18:36.000 |
take the exact same thing and I start sharpening it and making it bigger by multiplying these 01:18:40.000 |
numbers by eight, for example, you'll see that the soft-max will start to sharpen. And in fact, 01:18:44.560 |
it will sharpen towards the max. So it will sharpen towards whatever number here is the highest. 01:18:50.080 |
And so basically, we don't want these values to be too extreme, especially at initialization. 01:18:54.480 |
Otherwise, soft-max will be way too peaky. And you're basically aggregating information from 01:19:00.800 |
like a single node. Every node just aggregates information from a single other node. That's 01:19:04.720 |
not what we want, especially at initialization. And so the scaling is used just to control the 01:19:09.840 |
variance at initialization. Okay, so having said all that, let's now take our self-attention 01:19:15.040 |
knowledge and let's take it for a spin. So here in the code, I've created this head module and 01:19:21.120 |
implements a single head of self-attention. So you give it a head size, and then here it creates the 01:19:26.720 |
key query and the value linear layers. Typically, people don't use biases in these. So those are 01:19:32.640 |
the linear projections that we're going to apply to all of our nodes. Now here, I'm creating this 01:19:37.760 |
trill variable. Trill is not a parameter of the module. So in sort of PyTorch naming conventions, 01:19:43.360 |
this is called a buffer. It's not a parameter. And you have to assign it to the module using 01:19:48.240 |
a register buffer. So that creates the trill, the lower triangular matrix. And when we're given the 01:19:54.960 |
input x, this should look very familiar now. We calculate the keys, the queries. We calculate the 01:20:00.240 |
attention scores in Sideway. We normalize it. So we're using scaled attention here. 01:20:05.280 |
Then we make sure that future doesn't communicate with the past. So this makes it a decoder block. 01:20:11.920 |
And then softmax, and then aggregate the value and output. Then here in the language model, 01:20:17.520 |
I'm creating a head in the constructor, and I'm calling it self-attention head. 01:20:21.760 |
And the head size, I'm going to keep as the same and embed, just for now. And then here, 01:20:29.040 |
once we've encoded the information with the token embeddings and the position embeddings, 01:20:34.160 |
we're simply going to feed it into the self-attention head. And then the output of 01:20:38.080 |
that is going to go into the decoder language modeling head and create the logits. So this is 01:20:44.720 |
the simplest way to plug in a self-attention component into our network right now. 01:20:50.160 |
I had to make one more change, which is that here in the generate, we have to make sure that our 01:20:58.240 |
IDX that we feed into the model, because now we're using positional embeddings, 01:21:02.800 |
we can never have more than block size coming in. Because if IDX is more than block size, 01:21:08.720 |
then our position embedding table is going to run out of scope, because it only has embeddings for 01:21:12.640 |
up to block size. And so therefore, I added some code here to crop the context that we're going to 01:21:18.560 |
feed into self, so that we never pass in more than block size elements. So those are the changes, 01:21:26.400 |
and let's now train the network. So I also came up to the script here, and I decreased the learning 01:21:31.120 |
rate, because the self-attention can't tolerate very, very high learning rates. And then I also 01:21:36.560 |
increased the number of iterations, because the learning rate is lower. And then I trained it, 01:21:40.080 |
and previously we were only able to get to up to 2.5, and now we are down to 2.4. So we definitely 01:21:45.840 |
see a little bit of an improvement from 2.5 to 2.4, roughly, but the text is still not amazing. 01:21:51.280 |
So clearly, the self-attention head is doing some useful communication, 01:21:56.320 |
but we still have a long way to go. Okay, so now we've implemented the scale.productAttention. 01:22:01.920 |
Now next up, in the attention is all you need paper, there's something called multi-head 01:22:06.000 |
attention. And what is multi-head attention? It's just applying multiple attentions in parallel, 01:22:11.760 |
and concatenating their results. So they have a little bit of diagram here. I don't know if this 01:22:16.720 |
is super clear. It's really just multiple attentions in parallel. So let's implement that. 01:22:23.120 |
Fairly straightforward. If we want a multi-head attention, then we want multiple heads of 01:22:28.160 |
self-attention running in parallel. So in PyTorch, we can do this by simply creating multiple heads. 01:22:34.720 |
So however many heads you want, and then what is the head size of each. And then we run all of them 01:22:43.280 |
in parallel into a list, and simply concatenate all of the outputs. And we're concatenating over 01:22:49.440 |
the channel dimension. So the way this looks now is, we don't have just a single attention 01:22:54.720 |
that has a head size of 32, because remember, an embed is 32. Instead of having one communication 01:23:03.680 |
channel, we now have four communication channels in parallel. And each one of these communication 01:23:09.360 |
channels typically will be smaller correspondingly. So because we have four communication channels, 01:23:16.640 |
we want eight-dimensional self-attention. And so from each communication channel, we're getting 01:23:21.120 |
together eight-dimensional vectors. And then we have four of them, and that concatenates to give 01:23:25.840 |
us 32, which is the original and embed. And so this is kind of similar to, if you're familiar 01:23:31.440 |
with convolutions, this is kind of like a group convolution. Because basically, instead of having 01:23:36.080 |
one large convolution, we do convolution in groups, and that's multi-headed self-attention. 01:23:42.400 |
And so then here, we just use SA heads, self-attention heads, instead. Now, I actually 01:23:48.320 |
ran it, and scrolling down, I ran the same thing, and then we now get down to 2.28, roughly. And 01:23:57.040 |
the output is still, the generation is still not amazing, but clearly the validation loss is 01:24:01.200 |
improving, because we were at 2.4 just now. And so it helps to have multiple communication channels, 01:24:07.120 |
because obviously, these tokens have a lot to talk about. They want to find the consonants, 01:24:11.920 |
the vowels, they want to find the vowels just from certain positions, they want to find any kinds of 01:24:17.200 |
different things. And so it helps to create multiple independent channels of communication, 01:24:21.440 |
gather lots of different types of data, and then decode the output. 01:24:25.920 |
Now, going back to the paper for a second, of course, I didn't explain this figure in full 01:24:29.440 |
detail, but we are starting to see some components of what we've already implemented. We have the 01:24:33.520 |
positional encodings, the token encodings that add, we have the masked multi-headed attention 01:24:38.560 |
implemented. Now, here's another multi-headed attention, which is a cross-attention to an 01:24:43.680 |
encoder, which we haven't, we're not going to implement in this case. I'm going to come back 01:24:47.680 |
to that later. But I want you to notice that there's a feedforward part here, and then this 01:24:52.560 |
is grouped into a block that gets repeated again and again. Now, the feedforward part here is just 01:24:57.280 |
a simple multi-layer perceptron. So the multi-headed, so here position-wise feedforward 01:25:04.960 |
networks is just a simple little MLP. So I want to start basically in a similar fashion, also 01:25:10.560 |
adding computation into the network. And this computation is on a per node level. So I've 01:25:17.440 |
already implemented it, and you can see the diff highlighted on the left here when I've added or 01:25:21.440 |
changed things. Now, before we had the multi-headed self-attention that did the communication, 01:25:26.640 |
but we went way too fast to calculate the logits. So the tokens looked at each other, but didn't 01:25:32.560 |
really have a lot of time to think on what they found from the other tokens. And so what I've 01:25:38.880 |
implemented here is a little feedforward single layer. And this little layer is just a linear 01:25:44.240 |
followed by a ReLU non-linearity, and that's it. So it's just a little layer, and then I call it 01:25:51.120 |
feedforward and embed. And then this feedforward is just called sequentially right after the 01:25:57.680 |
self-attention. So we self-attend, then we feedforward. And you'll notice that the feedforward 01:26:03.040 |
here, when it's applying linear, this is on a per token level. All the tokens do this independently. 01:26:08.560 |
So the self-attention is the communication, and then once they've gathered all the data, 01:26:12.960 |
now they need to think on that data individually. And so that's what feedforward is doing, 01:26:17.920 |
and that's why I've added it here. Now, when I train this, the validation loss actually continues 01:26:22.720 |
to go down, now to 2.24, which is down from 2.28. The output still looks kind of terrible, 01:26:29.440 |
but at least we've improved the situation. And so as a preview, we're going to now start to 01:26:35.120 |
intersperse the communication with the computation. And that's also what the transformer does when it 01:26:42.320 |
has blocks that communicate and then compute, and it groups them and replicates them. 01:26:47.360 |
Okay, so let me show you what we'd like to do. We'd like to do something like this. We have 01:26:52.560 |
a block, and this block is basically this part here, except for the cross-attention. 01:26:57.280 |
Now, the block basically intersperses communication and then computation. 01:27:01.920 |
The communication is done using multi-headed self-attention, and then the computation is 01:27:07.600 |
done using a feedforward network on all the tokens independently. Now, what I've added 01:27:13.840 |
here also is, you'll notice, this takes the number of embeddings in the embedding dimension and the 01:27:19.600 |
number of heads that we would like, which is kind of like group size in group convolution. 01:27:23.280 |
And I'm saying that the number of heads we'd like is four. And so because this is 32, 01:27:28.720 |
we calculate that because this is 32, the number of heads should be four. 01:27:31.920 |
The head size should be eight, so that everything sort of works out channel-wise. 01:27:37.520 |
So this is how the transformer structures the sizes, typically. So the head size will become 01:27:45.120 |
eight, and then this is how we want to intersperse them. And then here, I'm trying to create blocks, 01:27:49.920 |
which is just a sequential application of block, block, block. So then we're interspersing 01:27:54.800 |
communication feedforward many, many times, and then finally we decode. Now, I actually tried to 01:28:00.720 |
run this, and the problem is, this doesn't actually give a very good result. And the reason for that 01:28:07.520 |
is, we're starting to actually get a pretty deep neural net. And deep neural nets suffer 01:28:12.480 |
from optimization issues, and I think that's what we're kind of like slightly starting to run into. 01:28:16.560 |
So we need one more idea that we can borrow from the transformer paper to resolve those 01:28:21.760 |
difficulties. Now, there are two optimizations that dramatically help with the depth of these 01:28:26.400 |
networks and make sure that the networks remain optimizable. Let's talk about the first one. 01:28:31.120 |
The first one in this diagram is, you see this arrow here, and then this arrow and this arrow, 01:28:37.520 |
those are skip connections, or sometimes called residual connections. They come from this paper, 01:28:42.400 |
the Presidual Learning for Image Recognition, from about 2015, that introduced the concept. 01:28:48.960 |
Now, these are basically, what it means is, you transform the data, but then you have a skip 01:28:55.120 |
connection with addition from the previous features. Now, the way I like to visualize it, 01:29:00.720 |
that I prefer, is the following. Here, the computation happens from the top to bottom, 01:29:07.440 |
and basically, you have this residual pathway, and you are free to fork off from the residual 01:29:12.880 |
pathway, perform some computation, and then project back to the residual pathway via addition. 01:29:17.600 |
And so you go from the inputs to the targets only via plus, and plus, and plus. And the reason this 01:29:25.920 |
is useful is because during backpropagation, remember from our micrograd video earlier, 01:29:30.480 |
addition distributes gradients equally to both of its branches that fed us the input. 01:29:36.960 |
And so the supervision, or the gradients from the loss, basically hop through every addition node 01:29:44.640 |
all the way to the input, and then also fork off into the residual blocks. But basically, 01:29:51.680 |
you have this gradient superhighway that goes directly from the supervision all the way to 01:29:56.000 |
the input, unimpeded. And then these residual blocks are usually initialized in the beginning, 01:30:01.040 |
so they contribute very, very little, if anything, to the residual pathway. They are initialized 01:30:05.680 |
that way. So in the beginning, they are almost kind of like not there. But then during the 01:30:10.480 |
optimization, they come online over time, and they start to contribute. But at least at the 01:30:17.040 |
initialization, you can go from directly supervision to the input, gradient is unimpeded and just 01:30:22.080 |
flows, and then the blocks over time kick in. And so that dramatically helps with the optimization. 01:30:28.480 |
So let's implement this. So coming back to our block here, basically what we want to do is 01:30:33.360 |
we want to do x equals x plus self-attention, and x equals x plus self-upfeedforward. 01:30:39.440 |
So this is x, and then we fork off and do some communication and come back. And we fork off, 01:30:46.400 |
and we do some computation and come back. So those are residual connections. And then 01:30:51.440 |
swinging back up here, we also have to introduce this projection. So nn.linear. 01:30:58.480 |
And this is going to be from after we concatenate this. This is the size and embed. So this is the 01:31:05.360 |
output of the self-attention itself. But then we actually want to apply the projection, 01:31:11.040 |
and that's the result. So the projection is just a linear transformation of the outcome of this 01:31:16.800 |
layer. So that's the projection back into the residual pathway. And then here in the feed 01:31:22.560 |
forward, it's going to be the same thing. I could have a self-dot projection here as well. But let 01:31:27.920 |
me just simplify it, and let me couple it inside the same sequential container. And so this is the 01:31:35.360 |
projection layer going back into the residual pathway. And so that's it. So now we can train 01:31:43.280 |
this. So I implemented one more small change. When you look into the paper again, you see that 01:31:49.280 |
the dimensionality of input and output is 512 for them. And they're saying that the inner layer 01:31:54.080 |
here in the feed forward has dimensionality of 2048. So there's a multiplier of 4. And so the 01:32:00.080 |
inner layer of the feed forward network should be multiplied by 4 in terms of channel sizes. 01:32:05.040 |
So I came here, and I multiplied 4 times embed here for the feed forward, and then from 4 times 01:32:10.560 |
nembed coming back down to nembed when we go back to the projection. So adding a bit of computation 01:32:16.560 |
here and growing that layer that is in the residual block on the side of the residual pathway. 01:32:22.960 |
And then I train this, and we actually get down all the way to 2.08 validation loss. And we also 01:32:28.480 |
see that the network is starting to get big enough that our train loss is getting ahead 01:32:32.000 |
of validation loss. So we start to see a little bit of overfitting. And our generations here are 01:32:40.720 |
still not amazing, but at least you see that we can see like is here, this now, grief, sink. 01:32:45.440 |
Like this starts to almost look like English. So yeah, we're starting to really get there. 01:32:51.520 |
Okay, and the second innovation that is very helpful for optimizing very deep neural networks 01:32:55.600 |
is right here. So we have this addition now, that's the residual part. But this norm is referring to 01:33:00.480 |
something called layer norm. So layer norm is implemented in PyTorch. It's a paper that came out 01:33:05.360 |
a while back here. And layer norm is very, very similar to batch norm. So remember back to 01:33:13.600 |
our Make More Series part three, we implemented batch normalization. 01:33:18.560 |
And batch normalization basically just made sure that across the batch dimension, 01:33:24.080 |
any individual neuron had unit Gaussian distribution. So it was zero mean and unit 01:33:31.840 |
standard deviation, one standard deviation output. So what I did here is I'm copy pasting the batch 01:33:38.080 |
norm 1D that we developed in our Make More Series. And see here, we can initialize, for example, 01:33:43.680 |
this module, and we can have a batch of 32 100 dimensional vectors feeding through the batch 01:33:48.800 |
norm layer. So what this does is it guarantees that when we look at just the zeroth column, 01:33:55.520 |
it's a zero mean, one standard deviation. So it's normalizing every single column of this input. 01:34:02.880 |
Now the rows are not going to be normalized by default, because we're just normalizing columns. 01:34:09.600 |
So let's not implement layer norm. It's very complicated. Look, we come here, 01:34:14.960 |
we change this from zero to one. So we don't normalize the columns, we normalize the rows. 01:34:20.880 |
And now we've implemented layer norm. So now the columns are not going to be normalized. 01:34:28.080 |
But the rows are going to be normalized for every individual example, it's 100 dimensional vector 01:34:34.960 |
is normalized in this way. And because our computation now does not span across examples, 01:34:40.640 |
we can delete all of this buffers stuff, because we can always apply this operation, 01:34:48.080 |
and don't need to maintain any running buffers. So we don't need the buffers. 01:34:52.080 |
We don't, there's no distinction between training and test time. 01:34:57.280 |
And we don't need these running buffers. We do keep gamma and beta, we don't need the momentum, 01:35:04.800 |
we don't care if it's training or not. And this is now a layer norm. And it normalizes 01:35:12.080 |
the rows instead of the columns. And this here is identical to basically this here. 01:35:18.240 |
So let's now implement layer norm in our transformer. Before I incorporate the layer norm, 01:35:23.760 |
I just wanted to note that, as I said, very few details about the transformer have changed in the 01:35:28.160 |
last five years. But this is actually something that slightly departs from the original paper. 01:35:32.480 |
You see that the add and norm is applied after the transformation. But now it is a bit more, 01:35:40.400 |
basically, common to apply the layer norm before the transformation. So there's a reshuffling of 01:35:45.360 |
the layer norms. So this is called the pre-norm formulation, and that's the one that we're going 01:35:49.680 |
to implement as well. So slight deviation from the original paper. Basically, we need to layer 01:35:54.320 |
norms. Layer norm one is nn.layernorm, and we tell it how many, what is the embedding dimension. 01:36:02.640 |
And we need the second layer norm. And then here, the layer norms are applied immediately on x. 01:36:08.400 |
So self.layernorm1 applied on x, and self.layernorm2 applied on x, before it goes into 01:36:16.240 |
self-attention and feedforward. And the size of the layer norm here is an embed, so 32. 01:36:22.320 |
So when the layer norm is normalizing our features, it is the normalization here 01:36:30.160 |
happens, the mean and the variance are taken over 32 numbers. So the batch and the time act as batch 01:36:36.320 |
dimensions, both of them. So this is kind of like a per-token transformation that just normalizes 01:36:42.720 |
the features and makes them unit mean, unit Gaussian at initialization. But of course, 01:36:49.680 |
because these layer norms inside it have these gamma and beta trainable parameters, 01:36:55.200 |
the layer norm will eventually create outputs that might not be unit Gaussian, 01:37:00.640 |
but the optimization will determine that. So for now, this is incorporating the layer norms, 01:37:06.720 |
and let's train them up. Okay, so I let it run, and we see that we get down to 2.06, 01:37:11.840 |
which is better than the previous 2.08. So a slight improvement by adding the layer norms. 01:37:16.560 |
And I'd expect that they help even more if we had a bigger and deeper network. 01:37:20.880 |
One more thing I forgot to add is that there should be a layer norm here also typically, 01:37:25.280 |
as at the end of the transformer and right before the final linear layer that decodes into vocabulary. 01:37:32.560 |
So I added that as well. So at this stage, we actually have a pretty complete transformer 01:37:37.360 |
according to the original paper, and it's a decoder-only transformer. I'll talk about that 01:37:42.000 |
in a second. But at this stage, the major pieces are in place, so we can try to scale this up and 01:37:47.040 |
see how well we can push this number. Now, in order to scale up the model, I had to perform 01:37:51.600 |
some cosmetic changes here to make it nicer. So I introduced this variable called n_layer, 01:37:56.640 |
which just specifies how many layers of the blocks we're going to have. I create a bunch 01:38:01.600 |
of blocks, and we have a new variable, number of heads as well. I pulled out the layer norm here, 01:38:07.040 |
and so this is identical. Now, one thing that I did briefly change is I added dropout. So dropout 01:38:14.000 |
is something that you can add right before the residual connection back into the residual pathway. 01:38:20.880 |
So we can drop out that as the last layer here. We can drop out here at the end of the 01:38:27.200 |
multi-headed restriction as well. And we can also drop out here when we calculate the 01:38:32.880 |
basically affinities, and after the softmax, we can drop out some of those. So we can randomly 01:38:39.200 |
prevent some of the nodes from communicating. And so dropout comes from this paper from 2014 or so, 01:38:46.720 |
and basically it takes your neural net, and it randomly, every forward-backward pass, 01:38:54.080 |
shuts off some subset of neurons. So randomly drops them to zero and trains without them. 01:39:01.520 |
And what this does effectively is because the mask of what's being dropped out has changed 01:39:06.880 |
every single forward-backward pass, it ends up kind of training an ensemble of subnetworks. 01:39:12.400 |
And then at test time, everything is fully enabled and kind of all of those subnetworks 01:39:16.720 |
are merged into a single ensemble, if you want to think about it that way. 01:39:20.160 |
So I would read the paper to get the full detail. For now, we're just going to stay on the level of 01:39:25.200 |
this is a regularization technique, and I added it because I'm about to scale up the model quite a 01:39:30.320 |
bit, and I was concerned about overfitting. So now when we scroll up to the top, we'll see that 01:39:36.400 |
I changed a number of hyperparameters here about our neural net. So I made the batch size be much 01:39:41.040 |
larger, now it's 64. I changed the block size to be 256, so previously it was just 8 characters 01:39:47.440 |
of context. Now it is 256 characters of context to predict the 257th. I brought down the learning 01:39:55.440 |
rate a little bit because the neural net is now much bigger, so I brought down the learning rate. 01:40:00.080 |
The embedding dimension is now 384, and there are six heads. So 384 divide 6 means that every head 01:40:07.200 |
is 64-dimensional as a standard. And then there are going to be six layers of that, 01:40:13.120 |
and the dropout will be at 0.2. So every forward-backward pass, 20% of all these 01:40:18.240 |
intermediate calculations are disabled and dropped to zero. 01:40:22.880 |
And then I already trained this and I ran it, so drumroll, how well does it perform? 01:40:29.440 |
So let me just scroll up here. We get a validation loss of 1.48, which is actually quite a bit of an 01:40:37.040 |
improvement on what we had before, which I think was 2.07. So we went from 2.07 all the way down 01:40:42.000 |
to 1.48 just by scaling up this neural net with the code that we have. And this, of course, ran 01:40:47.120 |
for a lot longer. This may be trained for, I want to say, about 15 minutes on my A100 GPU, so that's 01:40:52.960 |
a pretty good GPU. And if you don't have a GPU, you're not going to be able to reproduce this. 01:40:57.600 |
On a CPU, this would be, I would not run this on a CPU or a MacBook or something like that. 01:41:02.640 |
You'll have to break down the number of layers and the embedding dimension and so on. 01:41:07.040 |
But in about 15 minutes, we can get this kind of a result. And I'm printing some of the Shakespeare 01:41:14.640 |
here, but what I did also is I printed 10,000 characters, so a lot more, and I wrote them to 01:41:19.040 |
a file. And so here we see some of the outputs. So it's a lot more recognizable as the input text 01:41:27.360 |
file. So the input text file, just for reference, looked like this. So there's always someone 01:41:32.960 |
speaking in this manner, and our predictions now take on that form. Except, of course, they're 01:41:40.720 |
nonsensical when you actually read them. So it is, "Every crimpty be a house. Oh, those probation." 01:41:57.840 |
Anyway, so you can read through this. It's nonsensical, of course, but this is just a 01:42:05.120 |
transformer trained on the character level for 1 million characters that come from Shakespeare. 01:42:10.080 |
So there's sort of like blabbers on in Shakespeare-like manner, but it doesn't, 01:42:14.160 |
of course, make sense at this scale. But I think it's still a pretty good demonstration 01:42:19.520 |
of what's possible. So now I think that kind of concludes the programming section of this video. 01:42:28.320 |
We basically kind of did a pretty good job of implementing this transformer, but the picture 01:42:34.880 |
doesn't exactly match up to what we've done. So what's going on with all these additional parts 01:42:39.040 |
here? So let me finish explaining this architecture and why it looks so funky. Basically, what's 01:42:44.400 |
happening here is what we implemented here is a decoder-only transformer. So there's no component 01:42:50.640 |
here. This part is called the encoder, and there's no cross-attention block here. Our block only has 01:42:56.720 |
a self-attention and the feedforward, so it is missing this third in-between piece here. This 01:43:02.960 |
piece does cross-attention. So we don't have it, and we don't have the encoder. We just have the 01:43:07.200 |
decoder. And the reason we have a decoder only is because we are just generating text, and it's 01:43:13.360 |
unconditioned on anything. We're just blabbering on according to a given dataset. What makes it a 01:43:19.040 |
decoder is that we are using the triangular mask in our transformer. So it has this autoregressive 01:43:25.040 |
property where we can just go and sample from it. So the fact that it's using the triangular mask 01:43:31.440 |
to mask out the attention makes it a decoder, and it can be used for language modeling. 01:43:36.560 |
Now, the reason that the original paper had an encoder-decoder architecture is because it is a 01:43:41.520 |
machine translation paper. So it is concerned with a different setting in particular. It expects some 01:43:48.240 |
tokens that encode, say for example, French, and then it is expected to decode the translation 01:43:54.560 |
in English. So typically, these here are special tokens. So you are expected to read in this and 01:44:01.760 |
condition on it. And then you start off the generation with a special token called start. 01:44:06.560 |
So this is a special new token that you introduce and always place in the beginning. And then the 01:44:12.800 |
network is expected to output neural networks are awesome, and then a special end token to finish 01:44:18.800 |
the generation. So this part here will be decoded exactly as we've done it. Neural networks are 01:44:26.080 |
awesome will be identical to what we did. But unlike what we did, they want to condition the 01:44:32.240 |
generation on some additional information. And in that case, this additional information is the 01:44:37.680 |
French sentence that they should be translating. So what they do now is they bring the encoder. 01:44:43.680 |
Now the encoder reads this part here. So we're only going to take the part of French, and we're 01:44:50.400 |
going to create tokens from it exactly as we've seen in our video. And we're going to put a 01:44:55.920 |
transformer on it. But there's going to be no triangular mask. And so all the tokens are allowed 01:45:00.640 |
to talk to each other as much as they want. And they're just encoding whatever's the content of 01:45:05.440 |
this French sentence. Once they've encoded it, they've they basically come out in the top here. 01:45:12.480 |
And then what happens here is in our decoder, which does the language modeling, there's an 01:45:18.960 |
additional connection here to the outputs of the encoder. And that is brought in through 01:45:24.640 |
cross attention. So the queries are still generated from x. But now the keys and the values 01:45:30.800 |
are coming from the side, the keys and the values are coming from the top generated by the nodes 01:45:36.320 |
that came outside of the decode the encoder. And those tops, the keys and the values, they're 01:45:41.840 |
the top of it, feed in on a side into every single block of the decoder. And so that's why there's an 01:45:48.400 |
additional cross attention. And really what it's doing is it's conditioning the decoding, 01:45:53.520 |
not just on the past of this current decoding, but also on having seen the full, fully encoded 01:46:01.440 |
French prompt sort of. And so it's an encoder decoder model, which is why we have those two 01:46:07.360 |
transformers and additional block and so on. So we did not do this because we have no we have 01:46:12.480 |
nothing to encode, there's no conditioning, we just have a text file, and we just want to imitate it. 01:46:16.640 |
And that's why we are using a decoder only transformer, exactly as done in GPT. 01:46:21.440 |
Okay, so now I wanted to do a very brief walkthrough of nano GPT, which you can find 01:46:26.560 |
on my GitHub. And now GPT is basically two files of interest. There's trained up by and model by 01:46:33.040 |
trained up by as all the boilerplate code for training the network. It is basically all the 01:46:38.320 |
stuff that we had here is the training loop. It's just that it's a lot more complicated, 01:46:43.280 |
because we're saving and loading checkpoints and pre trained weights. And we are 01:46:47.120 |
decaying the learning rate and compiling the model and using distributed training across 01:46:51.520 |
multiple nodes or GPUs. So the training that pie gets a little bit more hairy, complicated, 01:46:56.720 |
there's more options, etc. But the model that I should look very, very similar to what we've done 01:47:03.520 |
here. In fact, the model is almost identical. So first, here we have the causal self attention 01:47:10.080 |
block. And all of this should look very, very recognizable to you. We're producing queries, 01:47:14.720 |
keys values, we're doing dot products, we're masking, applying softmax, optionally dropping 01:47:20.560 |
out. And here we are pulling the way the values. What is different here is that in our code, 01:47:26.480 |
I have separated out the multi headed attention into just a single individual head. 01:47:32.720 |
And then here, I have multiple heads, and I explicitly concatenate them. 01:47:37.680 |
Whereas here, all of it is implemented in a batched manner inside a single causal self attention. 01:47:43.200 |
And so we don't just have a B and a T and a C dimension, we also end up with a fourth dimension, 01:47:48.080 |
which is the heads. And so it just gets a lot more sort of hairy, because we have four dimensional 01:47:53.360 |
array tensors now, but it is equivalent mathematically. So the exact same thing is 01:47:59.280 |
happening as what we have, it's just it's a bit more efficient, because all the heads are now 01:48:03.600 |
treated as a batch dimension as well. Then we have the multilayer perceptron, it's using the 01:48:08.720 |
Gelu nonlinearity, which is defined here, instead of Relu. And this is done just because OpenAI 01:48:14.640 |
used it, and I want to be able to load their checkpoints. The blocks of the transformer are 01:48:19.520 |
identical, the communicate and the compute phase as we saw, and then the GPT will be identical, 01:48:24.720 |
we have the position encodings, token encodings, the blocks, the layer norm at the end, 01:48:30.080 |
the final linear layer. And this should look all very recognizable. And there's a bit more here, 01:48:36.160 |
because I'm loading checkpoints and stuff like that. I'm separating out the parameters into those 01:48:40.480 |
that should be weight decayed and those that shouldn't. But the generate function should also 01:48:45.280 |
be very, very similar. So a few details are different, but you should definitely be able to 01:48:49.600 |
look at this file and be able to understand a lot of the pieces now. So let's now bring things back 01:48:55.040 |
to chat-gpt. What would it look like if we wanted to train chat-gpt ourselves? And how does it relate 01:49:00.320 |
to what we learned today? Well, to train chat-gpt, there are roughly two stages. First is the pre 01:49:05.920 |
training stage, and then the fine tuning stage. In the pre training stage, we are training on a 01:49:11.840 |
large chunk of internet, and just trying to get a first decoder only transformer to babble text. 01:49:18.560 |
So it's very, very similar to what we've done ourselves. Except we've done like a tiny little 01:49:24.080 |
baby pre training step. And so in our case, this is how you print a number of parameters. I printed 01:49:32.080 |
it and it's about 10 million. So this transformer that I created here to create a little Shakespeare 01:49:37.120 |
transformer was about 10 million parameters. Our data set is roughly 1 million characters, 01:49:44.880 |
so roughly 1 million tokens. But you have to remember that OpenAI uses different vocabulary. 01:49:49.440 |
They're not on the character level. They use these sub word chunks of words. And so they have a 01:49:54.960 |
vocabulary of 50,000 roughly elements. And so their sequences are a bit more condensed. So our data set, 01:50:02.320 |
the Shakespeare data set would be probably around 300,000 tokens in the OpenAI vocabulary, roughly. 01:50:08.160 |
So we trained about 10 million parameter model on roughly 300,000 tokens. Now, when you go to the 01:50:14.800 |
GPT-3 paper and you look at the transformers that they trained, they trained a number of 01:50:22.320 |
transformers of different sizes. But the biggest transformer here has 175 billion parameters. 01:50:27.680 |
So ours is again 10 million. They used this number of layers in a transformer. This is the N embed. 01:50:34.160 |
This is the number of heads. And this is the head size. And then this is the batch size. 01:50:41.040 |
So ours was 65. And the learning rate is similar. Now, when they trained this transformer, 01:50:47.600 |
they trained on 300 billion tokens. So again, remember, ours is about 300,000. So this is 01:50:53.600 |
about a million fold increase. And this number would not be even that large by today's standards. 01:50:59.200 |
It'd be going up 1 trillion and above. So they are training a significantly larger model 01:51:07.200 |
on a good chunk of the internet. And that is the pre-training stage. But otherwise, 01:51:12.400 |
these hyperparameters should be fairly recognizable to you. And the architecture 01:51:15.920 |
is actually nearly identical to what we implemented ourselves. But of course, 01:51:19.760 |
it's a massive infrastructure challenge to train this. You're talking about typically thousands of 01:51:24.560 |
GPUs having to talk to each other to train models of this size. So that's just the pre-training 01:51:30.720 |
stage. Now, after you complete the pre-training stage, you don't get something that responds to 01:51:36.480 |
your questions with answers and is not helpful and et cetera. You get a document completer. 01:51:41.520 |
So it babbles, but it doesn't babble Shakespeare. It babbles internet. It will create arbitrary 01:51:48.000 |
news articles and documents, and it will try to complete documents because that's what it's 01:51:51.760 |
trained for. It's trying to complete the sequence. So when you give it a question, 01:51:55.360 |
it would just potentially just give you more questions. It would follow with more questions. 01:52:00.240 |
It will do whatever it looks like some closed document would do in the training data on the 01:52:05.840 |
internet. And so who knows, you're getting kind of like undefined behavior. It might basically 01:52:10.240 |
answer with two questions with other questions. It might ignore your question. It might just try 01:52:15.120 |
to complete some news article. It's totally on the mind, as we say. So the second fine tuning stage 01:52:21.360 |
is to actually align it to be an assistant. And this is the second stage. And so this chat GPT 01:52:28.240 |
blog post from OpenAI talks a little bit about how this stage is achieved. We basically, 01:52:35.120 |
there's roughly three steps to this stage. So what they do here is they start to collect 01:52:39.920 |
training data that looks specifically like what an assistant would do. So there are documents 01:52:44.800 |
that have the format where the question is on top and then an answer is below. And they have 01:52:49.200 |
a large number of these, but probably not on the order of the internet. This is probably on the 01:52:53.520 |
order of maybe thousands of examples. And so they then fine tune the model to basically only focus 01:53:02.480 |
on documents that look like that. And so you're starting to slowly align it. So it's going to 01:53:06.880 |
expect a question at the top and it's going to expect to complete the answer. And these very, 01:53:11.840 |
very large models are very sample efficient during their fine tuning. So this actually somehow works. 01:53:16.960 |
But that's just step one. That's just fine tuning. So then they actually have more steps where, 01:53:21.760 |
okay, the second step is you let the model respond and then different raters look at the 01:53:26.800 |
different responses and rank them for their preference as to which one is better than the 01:53:31.040 |
other. They use that to train a reward model. So they can predict basically using a different 01:53:36.160 |
network, how much of any candidate response would be desirable. And then once they have a reward 01:53:43.920 |
model, they run PPO, which is a form of policy gradient reinforcement learning optimizer to fine 01:53:51.360 |
tune this sampling policy so that the answers that the chat GPT now generates are expected to score a 01:54:00.000 |
high reward according to the reward model. And so basically there's a whole aligning stage here, 01:54:06.240 |
or fine tuning stage. It's got multiple steps in between there as well. And it takes the model from 01:54:11.840 |
being a document completer to a question answerer. And that's like a whole separate stage. A lot of 01:54:18.560 |
this data is not available publicly. It is internal to OpenAI and it's much harder to replicate this 01:54:24.640 |
stage. And so that's roughly what would give you a chat GPT. And nano-GPT focuses on the pre-training 01:54:31.680 |
stage. Okay. And that's everything that I wanted to cover today. So we trained to summarize a 01:54:38.000 |
decoder only transformer following this famous paper, attention is all you need from 2017. 01:54:43.840 |
And so that's basically a GPT. We trained it on a tiny Shakespeare and got sensible results. 01:54:52.640 |
All of the training code is roughly 200 lines of code. I will be releasing this code base. 01:55:00.320 |
So also it comes with all the Git log commits along the way as we built it up. 01:55:05.200 |
In addition to this code, I'm going to release the notebook, of course, the Google collab. 01:55:11.520 |
And I hope that gave you a sense for how you can train these models, like say GPT-3, that will be 01:55:19.120 |
architecturally basically identical to what we have, but they are somewhere between 10,000 and 01:55:23.280 |
1 million times bigger, depending on how you count. And so that's all I have for now. We did 01:55:30.400 |
not talk about any of the fine tuning stages that would typically go on top of this. So if you're 01:55:34.880 |
interested in something that's not just language modeling, but you actually want to, you know, 01:55:38.320 |
say perform tasks or you want them to be aligned in a specific way, or you want to detect sentiment 01:55:45.280 |
or anything like that, basically anytime you don't want something that's just a document completer, 01:55:49.440 |
you have to complete further stages of fine tuning, which we did not cover. And that could 01:55:54.560 |
be simple supervised fine tuning, or it can be something more fancy, like we see in Chatship-BT, 01:55:58.880 |
where we actually train a reward model and then do rounds of PPO to align it with respect to 01:56:03.840 |
the reward model. So there's a lot more that can be done on top of it. I think for now we're 01:56:08.240 |
starting to get to about two hours mark. So I'm going to kind of finish here. I hope you enjoyed 01:56:14.880 |
the lecture and yeah go forth and transform. See you later.