back to indexLet's reproduce GPT-2 (124M)
Chapters
0:0 intro: Let’s reproduce GPT-2 (124M)
3:39 exploring the GPT-2 (124M) OpenAI checkpoint
13:47 SECTION 1: implementing the GPT-2 nn.Module
28:8 loading the huggingface/GPT-2 parameters
31:0 implementing the forward pass to get logits
33:31 sampling init, prefix tokens, tokenization
37:2 sampling loop
41:47 sample, auto-detect the device
45:50 let’s train: data batches (B,T) → logits (B,T,C)
52:53 cross entropy loss
56:42 optimization loop: overfit a single batch
62:0 data loader lite
66:14 parameter sharing wte and lm_head
73:47 model initialization: std 0.02, residual init
82:18 SECTION 2: Let’s make it fast. GPUs, mixed precision, 1000ms
88:14 Tensor Cores, timing the code, TF32 precision, 333ms
99:38 float16, gradient scalers, bfloat16, 300ms
108:15 torch.compile, Python overhead, kernel fusion, 130ms
120:18 flash attention, 96ms
126:54 nice/ugly numbers. vocab size 50257 → 50304, 93ms
134:55 SECTION 3: hyperpamaters, AdamW, gradient clipping
141:6 learning rate scheduler: warmup + cosine decay
146:21 batch size schedule, weight decay, FusedAdamW, 90ms
154:9 gradient accumulation
166:52 distributed data parallel (DDP)
190:21 datasets used in GPT-2, GPT-3, FineWeb (EDU)
203:10 validation data split, validation loss, sampling revive
208:23 evaluation: HellaSwag, starting the run
223:5 SECTION 4: results in the morning! GPT-2, GPT-3 repro
236:21 shoutout to llm.c, equivalent but faster code in raw C/CUDA
239:39 summary, phew, build-nanogpt github repo
00:00:00.000 |
Hi, everyone. So, today we are going to be continuing our Zero to Hero series, 00:00:04.320 |
and in particular, today we are going to reproduce the GPT-2 model, 00:00:07.840 |
the 124 million version of it. So, when OpenAI released GPT-2, this was 2019, 00:00:15.680 |
and they released it with this blog post. On top of that, they released this paper, 00:00:20.480 |
and on top of that, they released this code on GitHub. So, OpenAI/GPT-2. 00:00:24.560 |
Now, when we talk about reproducing GPT-2, we have to be careful, because in particular, 00:00:29.600 |
in this video, we're going to be reproducing the 124 million parameter model. So, the thing to 00:00:34.800 |
realize is that there's always a miniseries when these releases are made. So, there are the GPT-2 00:00:41.040 |
miniseries made up of models at different sizes, and usually the biggest model is called the GPT-2. 00:00:46.800 |
But basically, the reason we do that is because you can put the model sizes on the x-axis of 00:00:52.160 |
plots like this, and on the y-axis, you put a lot of downstream metrics that you're interested in, 00:00:56.720 |
like translation, summarization, question answering, and so on, and you can chart out 00:01:00.560 |
these scaling laws. So, basically, as the model size increases, you're getting better and better 00:01:05.680 |
at downstream metrics. And so, in particular for GPT-2, if we scroll down in the paper, 00:01:12.000 |
there are four models in the GPT-2 miniseries, starting at 124 million, all the way up to 00:01:18.320 |
1,558 million. Now, the reason my numbers, the way I say them, disagree with this table is that 00:01:24.560 |
this table is wrong. If you actually go to the GPT-2 GitHub repo, they sort of say that there 00:01:32.160 |
was an error in how they added up the parameters. But basically, this is the 124 million parameter 00:01:36.400 |
model, et cetera. So, the 124 million parameter had 12 layers in the transformer, and it had 768 00:01:43.680 |
channels in the transformer, 768 dimensions. And I'm going to be assuming some familiarity with 00:01:49.360 |
what these terms mean, because I covered all of this in my previous video, let's build GPT-2, 00:01:53.520 |
let's build GPT from scratch. So, I covered that in the previous video in this playlist. 00:01:58.400 |
Now, if we do everything correctly and everything works out well, by the end of this video, 00:02:03.360 |
we're going to see something like this, where we're looking at the validation loss, which basically 00:02:08.160 |
measures how good we are at predicting the next token in a sequence on some validation data that 00:02:14.720 |
the model has not seen during training. And we see that we go from doing that task not very well, 00:02:20.320 |
because we're initializing from scratch, all the way to doing that task quite well 00:02:23.760 |
by the end of the training. And hopefully, we're going to beat the GPT-2 124M model. 00:02:29.920 |
Now, previously, when they were working on this, this is already five years ago. 00:02:34.560 |
So, this was probably a fairly complicated optimization at the time, and the GPUs and 00:02:38.560 |
the compute was a lot smaller. Today, you can reproduce this model in roughly an hour, 00:02:44.160 |
or probably less even, and it will cost you about 10 bucks if you want to do this on the cloud 00:02:48.320 |
compute, a sort of computer that you can all rent. And if you pay $10 for that computer, 00:02:54.400 |
you wait about an hour or less, you can actually achieve a model that is as good as 00:02:58.960 |
this model that OpenAI released. And one more thing to mention is, unlike many other models, 00:03:05.200 |
OpenAI did release the weights for GPT-2. So, those weights are all available in this repository. 00:03:11.040 |
But the GPT-2 paper is not always as good with all of the details of the training. 00:03:16.160 |
So, in addition to the GPT-2 paper, we're going to be referencing the GPT-3 paper, 00:03:20.320 |
which is a lot more concrete in a lot of the parameters and optimization settings and so on. 00:03:25.600 |
And it's not a huge departure in the architecture from the GPT-2 version of the model. So, 00:03:31.920 |
we're going to be referencing both GPT-2 and GPT-3 as we try to reproduce GPT-2 124M. So, 00:03:38.080 |
let's go. So, the first thing I would like to do is actually start at the end, or at the target. 00:03:42.960 |
So, in other words, let's load the GPT-2 124M model as it was released by OpenAI, 00:03:48.480 |
and maybe take it for a spin. Let's sample some tokens from it. Now, the issue with that is, 00:03:52.640 |
when you go to the code base of GPT-2 and you go into the source and you click in on the model.py, 00:03:58.160 |
you'll realize that actually this is using TensorFlow. So, the original GPT-2 code here 00:04:02.880 |
was written in TensorFlow, which is, you know, not, let's just say, not used as much anymore. 00:04:10.560 |
So, we'd like to use PyTorch, because it's a lot friendlier, easier, and I just personally like it 00:04:15.680 |
a lot more. The problem with that is the initial code is in TensorFlow. We'd like to use PyTorch. 00:04:20.160 |
So, instead, to get the target, we're going to use the hugging face transformers code, 00:04:26.320 |
which I like a lot more. So, when you go into the transformers, source, transformers, models, 00:04:30.400 |
GPT-2, modeling, GPT-2.py, you will see that they have the GPT-2 implementation of that transformer 00:04:36.640 |
here in this file. And it's, like, medium readable, but not fully readable. But what it does is it did 00:04:46.800 |
all the work of converting all those weights from TensorFlow to PyTorch friendly, and so it's much 00:04:52.480 |
easier to load and work with. So, in particular, we can look at the GPT-2 model here, and we can 00:04:59.120 |
load it using hugging face transformers. So, swinging over, this is what that looks like. 00:05:03.440 |
From transformers, import the GPT-2 LM head model, and then from pre-trained GPT-2. 00:05:11.840 |
Now, one awkward thing about this is that when you do GPT-2 as the model that we're loading, 00:05:17.920 |
this actually is the 124 million parameter model. If you want the actual GPT-2, the 1.5 billion, 00:05:25.440 |
then you actually want to do -XL. So, this is the 124M, our target. Now, what we're doing is, 00:05:32.480 |
when we actually get this, we're initializing the PyTorch NN module as defined here in this class. 00:05:38.160 |
From it, I want to get just the state dict, which is just the raw tensors. So, we just have 00:05:44.800 |
the tensors of that file. And by the way, here, this is a Jupyter notebook, but this is a Jupyter 00:05:51.120 |
notebook running inside VS Code, so I like to work with it all in a single interface, so I like to 00:05:57.200 |
use VS Code, so this is the Jupyter notebook extension inside VS Code. So, when we get the 00:06:05.360 |
state dict, this is just a dict, so we can print the key and the value, which is the tensor, 00:06:11.280 |
and let's just look at the shapes. So, these are sort of the different parameters inside the GPT-2 00:06:18.080 |
model and their shape. So, the W weight for token embedding is of size 50257 by 768. Where this is 00:06:30.400 |
coming from is that we have 50257 tokens in the GPT-2 vocabulary, and the tokens, by the way, 00:06:39.120 |
these are exactly the tokens that we've spoken about in the previous video on my tokenization 00:06:43.520 |
series. So, the previous video, just before this, I go into a ton of detail on tokenization. 00:06:48.400 |
GPT-2 tokenizer happens to have this many tokens. For each token, we have a 768-dimensional 00:06:56.240 |
embedding that is the distributed representation that stands in for that token. So, each token is 00:07:02.880 |
a little string piece, and then these 768 numbers are the vector that represents that token. 00:07:09.760 |
And so, this is just our lookup table for tokens, and then here, we have the lookup table for the 00:07:14.560 |
positions. So, because GPT-2 has a maximum sequence length of 1024, we have up to 1024 positions that 00:07:23.360 |
each token can be attending to in the past, and every one of those positions in GPT-2 has a fixed 00:07:29.920 |
vector of 768 that is learned by optimization. And so, this is the position embedding and the 00:07:37.680 |
token embedding, and then everything here is just the other weights and biases and everything else 00:07:43.760 |
of this transformer. So, when you just take, for example, the positional embeddings and flatten it 00:07:49.760 |
out and take just the 20 elements, you can see that these are just the parameters. These are 00:07:53.760 |
weights, floats, just we can take and we can plot them. So, these are the position embeddings, 00:08:00.640 |
and we get something like this, and you can see that this has structure, and it has structure 00:08:05.280 |
because what we have here really is every row in this visualization is a different position, 00:08:12.560 |
a fixed absolute position in the range from 0 to 1024, and each row here is the representation 00:08:20.640 |
of that position. And so, it has structure because these positional embeddings end up learning these 00:08:26.800 |
sinusoids and cosines that sort of like represent each of these positions, and each row here stands 00:08:34.960 |
in for that position and is processed by the transformer to recover all the relative positions 00:08:40.160 |
and sort of realize which token is where and attend to them depending on their position, 00:08:45.760 |
not just their content. So, when we actually just look into an individual column inside these, 00:08:53.200 |
and I just grabbed three random columns, you'll see that, for example, here we are focusing on 00:08:59.200 |
every single channel, and we're looking at what that channel is doing as a function of position 00:09:09.040 |
from 1, from 0 to 1023, really. And we can see that some of these channels basically like respond 00:09:17.200 |
more or less to different parts of the position spectrum. So, this green channel really likes to 00:09:23.200 |
fire for everything after 200 up to 800, but not less, but a lot less, and has a sharp drop-off 00:09:31.440 |
here near 0. So, who knows what these embeddings are doing and why they are the way they are. 00:09:36.160 |
You can tell, for example, that because they're a bit more jagged and they're kind of noisy, 00:09:39.600 |
you can tell that this model was not fully trained. And the more trained this model was, 00:09:44.640 |
the more you would expect to smooth this out. And so, this is telling you that this is a little bit 00:09:48.560 |
of an under-trained model, but in principle, actually, these curves don't even have to be 00:09:54.160 |
smooth. This should just be totally random noise. And in fact, in the beginning of the optimization, 00:09:58.720 |
it is complete random noise, because this position embedding table is initialized completely at 00:10:03.600 |
random. So, in the beginning, you have jaggedness, and the fact that you end up with something 00:10:08.160 |
smooth is already kind of impressive, that that just falls out of the optimization, 00:10:13.280 |
because in principle, you shouldn't even be able to get any single graph out of this that makes 00:10:17.280 |
sense. But we actually get something that looks a little bit noisy, but for the most part looks 00:10:21.280 |
sinusoidal-like. In the original transformer paper, the attention is all you need paper, 00:10:29.120 |
the positional embeddings are actually initialized and fixed, if I remember correctly, to sinusoids 00:10:34.240 |
and cosines of different frequencies. And that's the positional encoding, and it's fixed. But in 00:10:39.520 |
GPT-2, these are just parameters, and they're trained from scratch, just like any other 00:10:43.440 |
parameter. And that seems to work about as well. And so what they do is they kind of recover these 00:10:48.720 |
sinusoidal-like features during the optimization. We can also look at any of the other matrices 00:10:55.440 |
here. So, here I took the first layer of the transformer, and looking at one of its weights, 00:11:03.280 |
and just the first block of 300 by 300, and you see some structure, but again, who knows what 00:11:10.560 |
any of this is. If you're into mechanistic interpretability, you might get a real kick out 00:11:15.040 |
of trying to figure out what is going on, what is this structure, and what does this all mean, 00:11:19.440 |
but we're not going to be doing that in this video. But we definitely see that there's some 00:11:22.880 |
interesting structure, and that's kind of cool. What we're most interested in is we've loaded 00:11:27.360 |
the weights of this model that was released by OpenAI, and now using the Hugging Face Transformers, 00:11:32.800 |
we can not just get all the raw weights, but we can also get what they call pipeline, 00:11:39.520 |
and sample from it. So, this is the prefix, "Hello, I'm a language model," comma, 00:11:44.480 |
and then we're sampling 30 tokens, and we're getting five sequences, and I ran this, 00:11:51.760 |
and this is what it produced. "Hello, I'm a language model," but what I'm really doing 00:11:57.360 |
is making a human-readable document. There are other languages, but those are dot, dot, dot, 00:12:02.400 |
so you can read through these if you like, but basically, these are five different completions 00:12:05.840 |
of the same prefix from this GPT2124M. Now, if I go here, I took this example from here, 00:12:14.560 |
and sadly, even though we are fixing the seed, we are getting different generations 00:12:20.000 |
from the snippet than what they got, so presumably the code changed, but what we see, though, 00:12:28.240 |
at this stage that's important is that we are getting coherent text, so we've loaded the model 00:12:33.200 |
successfully, we can look at all its parameters, and the keys tell us where in the model these 00:12:38.560 |
come from, and we want to actually write our own GPT2 class so that we have a full understanding 00:12:43.760 |
of what's happening there. We don't want to be working with something like the modeling GPT2.py, 00:12:49.200 |
because it's just too complicated. We want to write this from scratch ourselves, 00:12:52.400 |
so we're going to be implementing the GPT model here in parallel, and as our first task, 00:12:57.120 |
let's load the GPT2124M into the class that we are going to develop here from scratch. 00:13:03.680 |
That's going to give us confidence that we can load the OpenAI model, and therefore, 00:13:09.200 |
there's a setting of weights that exactly is the 124 model, but then, of course, 00:13:13.680 |
what we're going to do is we're going to initialize the model from scratch instead, 00:13:16.880 |
and try to train it ourselves on a bunch of documents that we're going to get, 00:13:21.440 |
and we're going to try to surpass that model, so we're going to get different weights, 00:13:25.520 |
and everything's going to look different, hopefully better even, but we're going to 00:13:30.560 |
have a lot of confidence that because we can load the OpenAI model, we are in the same model family 00:13:34.800 |
and model class, and we just have to rediscover a good setting of the weights, but from scratch. 00:13:39.280 |
So let's now write the GPT2 model, and let's load the weights, and make sure that we can 00:13:45.360 |
also generate text that looks coherent. Okay, so let's now swing over to the 00:13:49.440 |
attention is all you need paper that started everything, and let's scroll over to the model 00:13:53.280 |
architecture, the original transformer. Now, remember that GPT2 is slightly modified from 00:13:58.720 |
the original transformer. In particular, we do not have the encoder. GPT2 is a decoder-only 00:14:05.120 |
transformer, as we call it, so this entire encoder here is missing, and in addition to that, 00:14:10.080 |
this cross-attention here that was using that encoder is also missing, so we delete this entire 00:14:16.240 |
part. Everything else stays almost the same, but there are some differences that we're going to 00:14:21.680 |
sort of look at here. So there are two main differences. When we go to the GPT2 paper under 00:14:29.680 |
2.3.model, we notice that first, there's a reshuffling of the layer norms, so they change 00:14:35.680 |
place, and second, an additional layer normalization was added here to the final 00:14:42.320 |
self-attention block. So basically, all the layer norms here, instead of being after the MLP or 00:14:48.000 |
after the attention, they swing before it, and an additional layer norm gets added here right before 00:14:53.040 |
the final classifier. So now let's implement some of the first sort of skeleton NN modules 00:14:59.200 |
here in our GPT NN module, and in particular, we're going to try to match up this schema here 00:15:05.760 |
that is used by Hugging Face Transformers because that will make it much easier to load these weights 00:15:10.480 |
from this state dict. So we want something that reflects this schema here. So here's what I came 00:15:17.120 |
up with. Basically, we see that the main container here that has all the modules is called transformer, 00:15:25.920 |
so I'm reflecting that with an NN module dict, and this is basically a module that allows you 00:15:30.400 |
to index into the sub-modules using keys, just like a dictionary strings. Within it, we have the 00:15:38.480 |
weights of the token embeddings, WT, and that's an NN embedding, and the weights of the position 00:15:44.320 |
embeddings, which is also just an NN embedding, and if you remember, NN embedding is really just 00:15:48.400 |
a fancy little wrapper module around just a single array of numbers, a single block of 00:15:57.440 |
numbers just like this. It's a single tensor, and NN embedding is a glorified wrapper around a tensor 00:16:04.640 |
that allows you to access its elements by indexing into the rows. Now, in addition to that, we see 00:16:10.960 |
here that we have a .h, and then this is indexed using numbers instead of indexed using strings, 00:16:17.920 |
so there's a .h, .0, 1, 2, etc., all the way up till .h.11, and that's because there are 12 layers 00:16:25.680 |
here in this transformer. So to reflect that, I'm creating also an h, I think that probably 00:16:31.440 |
stands for hidden, and instead of a module dict, this is a model list, so we can index it using 00:16:36.480 |
integers exactly as we see here, .0, .1, 2, etc., and the module list has N layer blocks, and the 00:16:45.680 |
blocks are yet to be defined in a module in a bit. In addition to that, following the GPT-2 paper, 00:16:51.840 |
we need an additional final layer norm that we're going to put in there, and then we have the final 00:16:58.080 |
classifier, the language model head, which projects from 768, the number of embedding 00:17:06.000 |
dimensions in this GPT, all the way to the vocab size, which is 50,257, and GPT-2 uses no bias for 00:17:13.200 |
this final sort of projection. So this is the skeleton, and you can see that it reflects this, 00:17:20.160 |
so the WTE is the token embeddings, here it's called output embedding, but it's really the 00:17:26.000 |
token embeddings. The PE is the positional encodings, those two pieces of information, 00:17:31.440 |
as we saw previously, are going to add, and then go into the transformer. The .h is all the blocks 00:17:37.120 |
in gray, and the LNF is this new layer that gets added here by the GPT-2 model, and LM_head is this 00:17:45.040 |
linear part here. So that's the skeleton of the GPT-2. We now have to implement the block. 00:17:52.480 |
Okay, so let's now recurse to the block itself. So we want to define the block. 00:17:56.400 |
So I'll start putting them here. So the block, I like to write out like this. 00:18:03.680 |
These are some of the initializations, and then this is the actual forward pass of what this block 00:18:09.040 |
computes. And notice here that there's a change from the transformer again, that is mentioned 00:18:14.800 |
in the GPT-2 paper. So here, the layer normalizations are after the application of 00:18:20.480 |
attention, or feedforward. In addition to that note, that the normalizations are inside the 00:18:25.840 |
residual stream. You see how feedforward is applied, and this arrow goes through and through 00:18:31.120 |
the normalization. So that means that your residual pathway has normalizations inside them. 00:18:36.800 |
And this is not very good or desirable. You actually prefer to have a single clean residual 00:18:42.800 |
stream, all the way from supervision, all the way down to the inputs, the tokens. And this is very 00:18:48.400 |
desirable and nice, because the gradients that flow from the top, if you remember from your 00:18:54.240 |
micrograd, addition just distributes gradients during the backward stage to both of its branches 00:19:00.480 |
equally. So addition is a branch in the gradients. And so that means that the gradients from the top 00:19:07.360 |
flow straight to the inputs, the tokens, through the residual pathways unchanged. But then in 00:19:13.280 |
addition to that, the gradient also flows through the blocks, and the blocks, you know, contribute 00:19:17.520 |
their own contribution over time, and kick in and change the optimization over time. But basically, 00:19:22.240 |
clean residual pathway is desirable from an optimization perspective. And then this is the 00:19:28.880 |
pre-normalization version, where you see that Rx first goes through the layer normalization, 00:19:33.920 |
and then the attention, and then goes back out to go to the layer normalization number two, 00:19:39.600 |
and the multilayer perceptron, sometimes also referred to as a feedforward network, or an FFM. 00:19:46.160 |
And then that goes into the residual stream again. And the one more thing that is kind of 00:19:50.800 |
interesting to note is that recall that attention is a communication operation. It is where all the 00:19:56.000 |
tokens, and there's 1024 tokens lined up in a sequence, and this is where the tokens communicate. 00:20:01.760 |
This is where they exchange information. So attention is a aggregation function. It's a 00:20:08.080 |
pooling function. It's a weighted sum function. It is a reduce operation. Whereas MLP, this MLP 00:20:17.360 |
here happens at every single token individually. There's no information being collected or 00:20:21.440 |
exchanged between the tokens. So the attention is the reduce, and the MLP is the map. And what you 00:20:27.680 |
end up with is that the transformer just ends up just being a repeated application of map reduce, 00:20:32.320 |
if you want to think about it that way. So this is where they communicate, and this is where they 00:20:38.000 |
think individually about the information that they gathered. And every one of these blocks 00:20:42.560 |
iteratively refines the representation inside the residual stream. So this is our block, 00:20:49.520 |
slightly modified from this picture. Okay, so let's now move on to the MLP. So the MLP block, 00:20:56.880 |
I implemented it as follows. It is relatively straightforward. We basically have two linear 00:21:02.080 |
projections here that are sandwiched in between the Gelu non-linearity. So nn.gelu approximate 00:21:09.520 |
is 10h. Now when we swing over to the PyTorch documentation, this is nn.gelu, and it has this 00:21:16.880 |
format, and it has two versions, the original version of Gelu, which we'll step into in a bit, 00:21:22.320 |
and the approximate version of Gelu, which we can request using 10h. So as you can see, just as a 00:21:28.080 |
preview here, Gelu is basically like a ReLU, except there's no flat, exactly flat tail here at exactly 00:21:36.400 |
zero. But otherwise it looks very much like a slightly smoother ReLU. It comes from this paper 00:21:42.080 |
here, Gaussian Error Linear Units, and you can step through this paper, and there's some mathematical 00:21:48.080 |
kind of like reasoning that leads to an interpretation that leads to this specific 00:21:51.840 |
formulation. It has to do with stochastic radial risers and the expectation of a modification to 00:21:57.360 |
adaptive dropout, so you can read through all of that if you'd like here. And there's a little bit 00:22:02.320 |
of a history as to why there's an approximate version of Gelu, and that comes from this issue 00:22:07.520 |
here, as far as I can tell. And in this issue, Daniel Hendrix mentions that at the time when 00:22:14.320 |
they developed this non-linearity, the IRF function, which you need to evaluate the exact Gelu, 00:22:20.800 |
was very slow in TensorFlow, so they ended up basically developing this approximation. 00:22:24.720 |
And this approximation then ended up being picked up by BERT and by GPT-2, etc. 00:22:29.200 |
But today there's no real good reason to use the approximate version. You'd prefer to just use the 00:22:33.680 |
exact version, because my expectation is that there's no big difference anymore, and this is 00:22:39.760 |
kind of like a historical kind of quirk. But we are trying to reproduce GPT-2 exactly, and GPT-2 00:22:47.440 |
used the 10H approximate version, so we prefer to stick with that. Now one other reason to actually 00:22:55.200 |
just intuitively use Gelu instead of Relu is, previously in videos in the past, we've spoken 00:23:00.240 |
about the dead Relu neuron problem, where in this tail of a Relu, if it's exactly flat at zero, 00:23:07.360 |
any activations that fall there will get exactly zero gradient. There's no change, 00:23:11.440 |
there's no adaptation, there's no development of the network if any of these activations 00:23:16.160 |
end in this flat region. But the Gelu always contributes a local gradient, and so there's 00:23:21.680 |
always going to be a change, always going to be an adaptation, and sort of smoothing it out 00:23:25.840 |
ends up empirically working better in practice, as demonstrated in this paper, and also as 00:23:30.240 |
demonstrated by it being picked up by the BERT paper, GPT-2 paper, and so on. So for that reason 00:23:35.440 |
we adopt this non-linearity here in the 10 in the GPT-2 reproduction. Now in more modern networks, 00:23:41.760 |
also like Lama3 and so on, this non-linearity also further changes to Swiglu and other variants like 00:23:48.160 |
that, but for GPT-2 they use this approximate Gelu. Okay, and finally we have the attention 00:23:54.320 |
operation. So let me paste in my attention. So I know this is a lot, so I'm gonna go through this 00:24:02.400 |
a bit quickly, a bit slowly, but not too slowly, because we have covered this in the previous video 00:24:07.200 |
and I would just point you there. So this is the attention operation. Now in the previous video 00:24:12.880 |
you will remember this is not just attention, this is multi-headed attention, right? And so in the 00:24:19.520 |
previous video we had this multi-headed attention module, and this implementation made it obvious 00:24:24.960 |
that these heads are not actually that complicated. There's basically, in parallel, inside every 00:24:30.480 |
attention block, there's multiple heads, and they're all functioning in parallel, and their 00:24:36.640 |
outputs are just being concatenated, and that becomes the output of the multi-headed attention. 00:24:41.520 |
So the heads are just kind of like parallel streams, and their outputs get concatenated. 00:24:46.480 |
And so it was very simple and made the head be kind of like fairly straightforward in terms of 00:24:52.640 |
its implementation. What happens here is that instead of having two separate modules, and 00:24:58.320 |
indeed many more modules that get concatenated, all of that is just put into a single self-attention 00:25:04.800 |
module. And instead I'm being very careful and doing a bunch of transpose-split tensor gymnastics 00:25:12.800 |
to make this very efficient in PyTorch, but fundamentally and algorithmically nothing is 00:25:16.720 |
different from the implementation we saw before in this Git repository. So to remind you very 00:25:25.520 |
briefly, and I don't want to go into this in too much time, but we have these tokens lined up in 00:25:32.080 |
a sequence, and there's 1020 of them. And then each token at this stage of the attention emits 00:25:38.320 |
three vectors, the query, key, and the value. And first what happens here is that the queries and 00:25:45.840 |
the keys have to multiply each other to get sort of the attention amount, like how interesting they 00:25:52.720 |
find each other. So they have to interact multiplicatively. So what we're doing here is 00:25:56.720 |
we're calculating the QKV while splitting it, and then there's a bunch of gymnastics as I mentioned 00:26:01.520 |
here. And the way this works is that we're basically making the number of heads, nh, 00:26:07.600 |
into a batch dimension. And so it's a batch dimension just like b, so that in these operations 00:26:13.360 |
that follow, PyTorch treats b and nh as batches, and it applies all the operations on all of them 00:26:20.480 |
in parallel, in both the batch and the heads. And the operations that get applied are, number one, 00:26:26.560 |
the queries and the keys interact to give us our attention. This is the autoregressive mask 00:26:31.840 |
that made sure that the tokens only attend to tokens before them and never to tokens in the 00:26:37.920 |
future. The softmax here normalizes the attention, so it sums to one always. And then recall from 00:26:46.160 |
the previous video that doing the attention matrix multiply with the values is basically 00:26:49.920 |
a way to do a weighted sum of the values of the tokens that we found interesting at every single 00:26:55.200 |
token. And then the final transpose contiguous and view is just reassembling all of that again, 00:27:01.520 |
and this actually performs a concatenation operation. So you can step through this 00:27:06.240 |
slowly if you'd like, but it is equivalent mathematically to our previous implementation, 00:27:12.240 |
it's just more efficient in PyTorch, so that's why I chose this implementation instead. 00:27:15.840 |
Now in addition to that, I'm being careful with how I name my variables. So for example, 00:27:20.880 |
seaten is the same as seaten, and so actually our keys should basically exactly follow the 00:27:26.880 |
schema of the HuggingFaceTransformers code, and that will make it very easy for us to now 00:27:31.040 |
port over all the weights from exactly this sort of naming conventions, because all of our variables 00:27:37.200 |
are named the same thing. But at this point we have finished the GPT-2 implementation, 00:27:42.800 |
and what that allows us to do is we don't have to basically use this file from HuggingFace, 00:27:47.920 |
which is fairly long. This is 2,000 lines of code, instead we just have less than 100 lines of code, 00:27:59.280 |
and this is the complete GPT-2 implementation. So at this stage we should just be able to 00:28:03.840 |
take over all the weights, set them, and then do generation. So let's see what that looks like. 00:28:08.960 |
Okay, so here I've also changed the GPT config so that the numbers here, 00:28:12.640 |
the hybrid parameters, agree with the GPT-2-124M model. So the maximum sequence length, which I 00:28:17.920 |
call block size here, is 124. The number of tokens is 5250257, which if you watch my tokenizer video 00:28:26.400 |
know that this is 50,000 merges, BPE merges, 256 byte tokens, the leaves of the BPE tree, 00:28:34.880 |
and one special end-of-text token that delimits different documents and can start generation as 00:28:39.840 |
well. And there are 12 layers, there are 12 heads in the attention, and the dimension of the 00:28:45.360 |
transformer is 768. So here's how we can now load the parameters from HuggingFace to our code here, 00:28:53.200 |
and initialize the GPT class with those parameters. So let me just copy-paste a bunch of code here. 00:28:58.400 |
And I'm not going to go through this code too slowly, because honestly it's not that interesting, 00:29:07.920 |
it's not that exciting. We're just loading the weights, so it's kind of dry. But as I mentioned, 00:29:11.680 |
there are four models in this mini-series of GPT-2. This is some of the Jupyter code 00:29:16.400 |
that we had here on the right. I'm just porting it over. These are the hyperparameters of the 00:29:22.240 |
GPT-2 models. We're creating the config object and creating our own model. And then what's 00:29:27.600 |
happening here is we're creating the state dict, both for our model and for the HuggingFace model. 00:29:35.360 |
And then what we're doing here is we're going over to HuggingFace model keys, and we're copying over 00:29:41.600 |
those tensors. And in the process, we are kind of ignoring a few of the buffers. They're not 00:29:47.440 |
parameters, they're buffers. So for example, attention.bias, that's just used for the 00:29:51.600 |
autoregressive mask. And so we are ignoring some of those masks, and that's it. And then one 00:29:58.240 |
additional kind of annoyance is that this comes from the TensorFlow repo, and I'm not sure how... 00:30:03.280 |
This is a little bit annoying, but some of the weights are transposed from what PyTorch would 00:30:07.440 |
want. And so manually, I hardcoded the weights that should be transposed, and then we transpose 00:30:13.040 |
them if that is so. And then we return this model. So the from_pretrained is a constructor 00:30:19.920 |
or a class method in Python that returns the GPT object if we just give it the model type, 00:30:27.680 |
which in our case is GPT-2, the smallest model that we're interested in. So this is the code, 00:30:32.800 |
and this is how you would use it. And we can pop open the terminal here 00:30:37.200 |
in VS Code, and we can Python train GPT-2.py, and fingers crossed. 00:30:44.960 |
Okay, so we didn't crash. And so we can load the weights and the biases and everything else 00:30:53.680 |
into our NNModule. But now let's also get additional confidence that this is working, 00:30:58.080 |
and let's try to actually generate from this model. Okay, now before we can actually generate 00:31:02.160 |
from this model, we have to be able to forward it. We didn't actually write that code yet. 00:31:06.080 |
So here's the forward function. So the input to the forward is going to be our indices, 00:31:12.640 |
our token indices. And they are always of shape B by T. And so we have batch dimension of B, 00:31:20.720 |
and then we have the time dimension of up to T. And the T can't be more than the block size. 00:31:27.120 |
The block size is the maximum sequence length. So B by T indices are arranged in sort of like 00:31:32.720 |
a two-dimensional layout. And remember that basically every single row of this is of size 00:31:38.000 |
up to block size. And this is T tokens that are in a sequence. And then we have B independent 00:31:44.720 |
sequences stacked up in a batch so that this is efficient. Now here we are forwarding the position 00:31:51.440 |
embeddings and the token embeddings. And this code should be very recognizable from the previous 00:31:55.360 |
lecture. So we basically use a range, which is kind of like a version of range, but for PyTorch. 00:32:02.080 |
And we're iterating from zero to T and creating this positions sort of indices. 00:32:08.880 |
And then we are making sure that they're on the same device as IDX, because we're not going to 00:32:14.800 |
be training on only CPU. That's going to be too inefficient. We want to be training on GPU, 00:32:18.640 |
and that's going to come in a bit. Then we have the position embeddings and the token embeddings, 00:32:23.680 |
and the addition operation of those two. Now notice that the position embeddings are going 00:32:28.240 |
to be identical for every single row of input. And so there's broadcasting hidden inside this plus, 00:32:36.000 |
where we have to create an additional dimension here, and then these two add up because the same 00:32:40.560 |
position embeddings apply at every single row of our examples stacked up in a batch. 00:32:44.560 |
Then we forward the transformer blocks, and finally the last layer norm and the LMAD. 00:32:51.120 |
So what comes out after forward is the logits. And if the input was B by T indices, 00:32:57.120 |
then at every single B by T, we will calculate the logits for what token comes next in the sequence. 00:33:05.200 |
So what is the token B, T plus one, the one on the right of this token. And vocab size here 00:33:12.880 |
is the number of possible tokens. And so therefore this is the tensor that we're going to obtain. 00:33:18.800 |
And these logits are just a softmax away from becoming probabilities. So this is the forward 00:33:25.040 |
pass of the network, and now we can get logits. And so we're going to be able to generate from 00:33:29.360 |
the model imminently. Okay, so now we're going to try to set up the identical thing on the left here 00:33:35.440 |
that matches hug and face on the right. So here we've sampled from the pipeline, and we sampled 00:33:41.120 |
five times up to 30 tokens with the prefix of hello, I'm a language model. And these are the 00:33:46.560 |
completions that we achieved. So we're going to try to replicate that on the left here. 00:33:49.920 |
So number of turn sequences is five, max length is 30. So the first thing we do, of course, 00:33:54.320 |
is we initialize our model, then we put it into evaluation mode. Now, this is a good practice to 00:33:59.440 |
put the model into eval when you're not going to be training it, you're just going to be using it. 00:34:03.200 |
And I don't actually know if this is doing anything right now for the following reason. 00:34:08.720 |
Our model up above here contains no modules or layers that actually have a different behavior 00:34:14.880 |
at training or evaluation time. So for example, dropout, batch norm, and a bunch of other layers 00:34:19.600 |
have this kind of behavior. But all of these layers that we've used here should be identical 00:34:23.520 |
in both training and evaluation time. So potentially, model.eval does nothing. But 00:34:29.920 |
then I'm not actually sure if this is the case. And maybe PyTorch internals do some clever things, 00:34:34.960 |
depending on the evaluation mode inside here. The next thing we're doing here is we are moving 00:34:40.640 |
the entire model to CUDA. So we're moving all of the tensors to GPU. So I'm SSH'd here to a 00:34:47.520 |
cloud box, and I have a bunch of GPUs on this box. And here, I'm moving the entire model and all of 00:34:53.840 |
its members and all of its tensors and everything like that. Everything gets shipped off to basically 00:34:59.280 |
a whole separate computer that is sitting on the GPU. And the GPU is connected to the CPU, 00:35:05.040 |
and they can communicate, but it's basically a whole separate computer with its own computer 00:35:08.240 |
architecture. And it's really well catered to parallel processing tasks like those of 00:35:12.400 |
running neural networks. So I'm doing this so that the model lives on the GPU, a whole separate 00:35:18.080 |
computer. And it's just going to make our code a lot more efficient, because all of this stuff 00:35:22.720 |
runs a lot more efficiently on the GPUs. So that's the model itself. Now, the next thing we want to 00:35:31.760 |
do is we want to start with this as the prefix when we do the generation. So let's actually 00:35:37.920 |
create those prefix tokens. So here's the code that I've written. We're going to import the 00:35:42.960 |
tick token library from OpenAI, and we're going to get the GPT-2 encoding. So that's the tokenizer 00:35:49.120 |
for GPT-2. And then we're going to encode this string and get a list of integers, which are the 00:35:55.680 |
tokens. Now, these integers here should actually be fairly straightforward, because we can just 00:36:01.440 |
copy/paste this string, and we can sort of inspect what it is in tick tokenizer. So just pasting that 00:36:07.520 |
in, these are the tokens that are going to come out. So this list of integers is what we expect 00:36:12.800 |
tokens to become. And as you recall, if you saw my video, of course, all the tokens, they're just 00:36:18.400 |
little string chunks, right? So this is the truncation of this string into GPT-2 tokens. 00:36:24.560 |
So once we have those tokens, it's a list of integers, we can create a torch tensor out of it. 00:36:31.280 |
In this case, it's eight tokens. And then we're going to replicate these eight tokens for five 00:36:35.920 |
times to get five rows of eight tokens. And that is our initial input X, as I call it here. And 00:36:45.120 |
it lives on the GPU as well. So X now is this IDX that we can pin to forward to get our logits so 00:36:54.080 |
that we know what comes as the sixth token, sorry, as the ninth token in every one of these five 00:37:01.760 |
rows. Okay, and we are now ready to generate. So let me paste in one more code block here. 00:37:06.880 |
So what's happening here in this code block is we have these X, which is of size B by T, right? So 00:37:14.560 |
batch by time. And we're going to be in every iteration of this loop, we're going to be adding 00:37:19.600 |
a column of new indices into each one of these rows, right? And so these are the new indices, 00:37:25.680 |
and we're appending them to the sequence as we're sampling. So with each loop iteration, 00:37:30.560 |
we get one more column into X. And all of the operations happen in the context manager of 00:37:35.920 |
torch.nograd. This is just telling PyTorch that we're not going to be calling that backward on 00:37:40.080 |
any of this. So it doesn't have to cache all the intermediate tensors, it's not going to have to 00:37:44.480 |
prepare in any way for a potential backward later. And this saves a lot of space and also possibly 00:37:50.560 |
some time. So we get our logits, we get the logits at only the last location, we throw away all the 00:37:57.440 |
other logits, we don't need them, we only care about the last columns logits. So this is being 00:38:03.360 |
wasteful. But this is just kind of like an inefficient implementation of sampling. So it's 00:38:09.680 |
correct but inefficient. So we get the last column of logits, pass it through softmax to get our 00:38:15.040 |
probabilities. Then here I'm doing top k sampling of 50. And I'm doing that because this is the 00:38:19.760 |
HuggingFace default. So just looking at the HuggingFace docs here of a pipeline, there's a 00:38:26.880 |
bunch of quarks that go into HuggingFace. And I mean, it's kind of a lot, honestly. But I guess 00:38:35.600 |
the important one that I noticed is that they're using top k by default, which is 50. And what 00:38:40.240 |
that does is that, so that's being used here as well. And what that does is basically we want to 00:38:46.000 |
take our probabilities, and we only want to keep the top 50 probabilities. And anything that is 00:38:51.280 |
lower than the 50th probability, we just clamp to zero and renormalize. And so that way, we are 00:38:57.360 |
never sampling very rare tokens. The tokens we're going to be sampling are always in the top 50 of 00:39:03.280 |
most likely tokens. And this helps keep the model kind of on track, and it doesn't blabber on, and 00:39:08.080 |
it doesn't get lost, and doesn't go off the rails as easily. And it kind of like sticks in the 00:39:13.520 |
vicinity of likely tokens a lot better. So this is the way to do it in PyTorch. And you can step 00:39:18.480 |
through it if you like. I don't think it's super insightful, so I'll speed through it. But roughly 00:39:22.480 |
speaking, we get this new column of tokens. We append them on x. And basically the columns of x 00:39:29.280 |
grow until this while loop gets tripped up. And then finally, we have an entire x of size 00:39:35.760 |
5 by 30, in this case, in this example. And we can just basically print all those individual 00:39:44.240 |
rows. So I'm getting all the rows, I'm getting all the tokens that were sampled, and I'm using 00:39:49.760 |
the decode function from TickTokenizer to get back the string, which we can print. And so terminal, 00:40:01.600 |
Okay. So these are the generations that we're getting. Hello, I'm a language model. Not a 00:40:14.160 |
program. New line, new line, et cetera. Hello, I'm a language model, and one of the main things 00:40:21.120 |
that bothers me when they create languages is how easy it becomes to create something that -- I mean, 00:40:25.760 |
so this will just like blabber on, right, in all these cases. Now, one thing you will notice is 00:40:29.680 |
that these generations are not the generations of HuggingFace here. And I can't find the discrepancy, 00:40:36.640 |
to be honest, and I didn't fully go through all these options, but probably there's something 00:40:40.480 |
else hiding in addition to the top P. So I'm not able to match it up. But just for correctness, 00:40:45.600 |
down here below in the Jupyter Notebook, I'm using the HuggingFace model. 00:40:51.120 |
So this is the HuggingFace model here. I replicated the code, and if I do this and I run that, 00:40:58.160 |
then I am getting the same results. So basically, the model internals are not wrong. It's just I'm 00:41:05.520 |
not 100% sure what the pipeline does in HuggingFace, and that's why we're not able to match 00:41:10.320 |
them up. But otherwise, the code is correct, and we've loaded all the tensors correctly. So we're 00:41:16.720 |
initializing the model correctly, and everything here works. So long story short, we've ported all 00:41:21.600 |
the weights. We initialized the GPT-2. This is the exact opening at GPT-2, and it can generate 00:41:27.360 |
sequences, and they look sensible. And now here, of course, we're initializing with GPT-2 model 00:41:33.040 |
weights. But now we want to initialize from scratch, from random numbers, and we want to 00:41:37.680 |
actually train the model that will give us sequences as good as or better than these ones 00:41:43.760 |
in quality. And so that's what we turn to next. So it turns out that using the random model is 00:41:49.040 |
actually fairly straightforward, because PyTorch already initializes our model randomly and by 00:41:53.520 |
default. So when we create the GPT model in the constructor, all of these layers and modules 00:42:02.880 |
have random initializers that are there by default. So when these linear layers get created and so on, 00:42:08.800 |
there's default constructors, for example, using the Javier initialization that we saw in the past 00:42:13.520 |
to construct the weights of these layers. And so creating a random model instead of a GPT-2 model 00:42:19.920 |
is actually fairly straightforward. And we would just come here, and instead we would create model 00:42:25.120 |
equals GPT, and then we want to use the default config, GPT config. And the default config uses 00:42:32.080 |
the 124m parameters. So this is the random model initialization, and we can run it. 00:42:43.120 |
And we should be able to get results. Now, the results here, of course, are total garbage garble, 00:42:49.760 |
and that's because it's a random model. And so we're just getting all these random token string 00:42:53.680 |
pieces chunked up totally at random. So that's what we have right now. Now, one more thing I 00:42:59.360 |
wanted to point out, by the way, is in case you do not have CUDA available, because you don't have 00:43:03.360 |
a GPU, you can still follow along with what we're doing here to some extent. And probably not to the 00:43:10.320 |
very end, because by the end, we're going to be using multiple GPUs and actually doing a serious 00:43:14.400 |
training run. But for now, you can actually follow along decently okay. So one thing that I like to 00:43:19.440 |
do in PyTorch is I like to auto detect the device that is available to you. So in particular, you 00:43:24.640 |
could do that like this. So here we are trying to detect the device to run on that has the highest 00:43:31.680 |
compute capability. You can think about it that way. So by default, we start with CPU, which of 00:43:36.560 |
course is available everywhere, because every single computer will have a CPU. But then we can 00:43:41.280 |
try to detect, do you have a GPU? You still use a CUDA. And then if you don't have a CUDA, do you 00:43:47.040 |
at least have MPS? MPS is the backend for Apple Silicon. So if you have a MacBook that is fairly 00:43:52.720 |
new, you probably have Apple Silicon on the inside. And then that has a GPU that is actually 00:43:56.960 |
fairly capable, depending on which MacBook you have. And so you can use MPS, which will be 00:44:01.600 |
potentially faster than CPU. And so we can print the device here. Now once we have the device, 00:44:06.640 |
we can actually use it in place of CUDA. So we just swap it in. And notice that here when we 00:44:14.400 |
call model on x, if this x here is on CPU instead of GPU, then it will work fine because here in the 00:44:22.640 |
forward, which is where PyTorch will come, when we create a pose, we are careful to use the device 00:44:28.880 |
of IDX to create this tensor as well. And so there won't be any mismatch where one tensor is on CPU, 00:44:34.880 |
one is on GPU, and that you can't combine those. But here we are carefully initializing on the 00:44:40.960 |
correct device as indicated by the input to this model. So this will auto detect device. For me, 00:44:48.080 |
this will be, of course, GPU. So using device CUDA. But you can also run with, as I mentioned, 00:44:58.800 |
another device. And it's not going to be too much slower. So if I override device here, 00:45:02.480 |
if I override device equals CPU, then we'll still print CUDA, of course, 00:45:11.120 |
but now we're actually using CPU. 1, 2, 3, 4, 5, 6. Okay, about six seconds. And actually, 00:45:21.280 |
we're not using Torch compile and stuff like that, which will speed up everything a lot 00:45:24.640 |
faster as well. But you can follow along even on a CPU, I think, to a decent extent. 00:45:29.280 |
So that's a note on that. Okay, so I do want to loop around eventually into what it means to 00:45:35.440 |
have different devices in PyTorch and what it is exactly that PyTorch does in the background for 00:45:40.080 |
you when you do something like module.to device, or where you take a Torch tensor and do a .to 00:45:46.240 |
device, and what exactly happens and how that works. But for now, I'd like to get to training, 00:45:51.120 |
and I'd like to start training the model. And for now, let's just say the device makes code go fast. 00:45:56.080 |
And let's go into how we can actually train the model. So to train the model, we're going to need 00:46:01.360 |
some data set. And for me, the best debugging simplest data set that I like to use is the 00:46:06.000 |
tiny Shakespeare data set. And it's available at this URL, so you can wget it, or you can just 00:46:11.760 |
search tiny Shakespeare data set. And so I have in my file system is just lsinput.txt. So I already 00:46:20.160 |
downloaded it. And here, I'm reading the data set, getting the first 1000 characters and printing the 00:46:25.600 |
first 100. Now remember that GPT-2 has roughly a compression ratio, the tokenizer has a compression 00:46:32.800 |
ratio of roughly three to one. So 1000 characters is roughly 300 tokens here that will come out of 00:46:38.480 |
this in the slice that we're currently getting. So this is the first few characters. And if you 00:46:46.320 |
want to get a few more statistics on this, we can do word count on input.txt. So we can see that this 00:46:52.160 |
is 40,000 lines, about 200,000 words in this data set, and about 1 million bytes in this file. 00:47:00.320 |
And knowing that this file is only ASCII characters, there's no crazy Unicode here, 00:47:03.840 |
as far as I know. And so every ASCII character is encoded with one byte. And so this is the same 00:47:09.360 |
number, roughly a million characters inside this data set. So that's the data set size by default, 00:47:16.080 |
very small and minimal data set for debugging. To get us off the ground, in order to tokenize 00:47:21.040 |
this data set, we're going to get tick token encoding for GPT-2, encode the data, the first 00:47:29.040 |
1000 characters, and then I'm only going to print the first 24 tokens. So these are the tokens as a 00:47:35.200 |
list of integers. And if you can read GPT-2 tokens, you will see that 198 here, you'll recognize that 00:47:41.360 |
as the slash in character. So that is a new line. And then here, for example, we have two new lines, 00:47:46.080 |
so that's 198 twice here. So this is just a tokenization of the first 24 tokens. So what 00:47:52.560 |
we want to do now is we want to actually process these token sequences and feed them into a 00:47:57.760 |
transformer. And in particular, we want them, we want to rearrange these tokens into this IDX 00:48:04.960 |
variable that we're going to be feeding into the transformer. So we don't want a single very long 00:48:08.880 |
one-dimensional sequence. We want an entire batch where each sequence is up to, is basically T 00:48:16.080 |
tokens, and T cannot be larger than the maximum sequence length. And then we have these T long 00:48:23.680 |
sequences of tokens, and we have B independent examples of sequences. So how can we create a 00:48:29.600 |
B by T tensor that we can feed into the forward out of these one-dimensional sequences? 00:48:33.760 |
So here's my favorite way to achieve this. So if we take Torch, and then we create a tensor object 00:48:41.280 |
out of this list of integers and just the first 24 tokens, my favorite way to do this is basically 00:48:46.240 |
you do a dot view of, for example, 4 by 6, which multiplied to 24. And so it's just a two-dimensional 00:48:54.720 |
rearrangement of these tokens. And you'll notice that when you view this one-dimensional sequence 00:48:58.560 |
as two-dimensional 4 by 6 here, the first six tokens up to here end up being the first row. 00:49:07.280 |
The next six tokens here end up being the second row, and so on. And so basically, it's just going 00:49:13.200 |
to stack up every six tokens, in this case, as independent rows, and it creates a batch of tokens 00:49:21.920 |
in this case. And so for example, if we are token 25, in the transformer, when we feed this in and 00:49:28.160 |
this becomes the IDX, this token is going to see these three tokens and is going to try to predict 00:49:34.000 |
that 198 comes next. So in this way, we are able to create this two-dimensional batch. That's quite 00:49:41.440 |
nice. Now, in terms of the label that we're going to need for the target to calculate the loss 00:49:46.720 |
function, how do we get that? Well, we could write some code inside the forward pass because we know 00:49:51.760 |
that the next token in a sequence, which is the label, is just to the right of us. But you'll 00:49:56.960 |
notice that actually, for this token at the very end, 13, we don't actually have the next correct 00:50:02.960 |
token because we didn't load it. So we actually didn't get enough information here. So I'll show 00:50:09.680 |
you my favorite way of basically getting these batches. And I like to personally have not just 00:50:15.200 |
the input to the transformer, which I like to call X, but I also like to create the labels tensor, 00:50:21.760 |
which is of the exact same size as X but contains the targets at every single position. And so 00:50:27.360 |
here's the way that I like to do that. I like to make sure that I fetch +1 token because we need 00:50:32.720 |
the ground truth for the very last token, for 13. And then when we're creating the input, we take 00:50:39.920 |
everything up to the last token, not including, and view it as 4 by 6. And when we're creating 00:50:45.360 |
targets, we do the buffer, but starting at index 1, not index 0. So we're skipping the first element 00:50:52.720 |
and we view it in the exact same size. And then when I print this, 00:50:58.000 |
here's what happens, where we see that basically as an example for this token 25, 00:51:02.320 |
its target was 198. And that's now just stored at the exact same position in the target tensor, 00:51:08.320 |
which is 198. And also this last token 13 now has its label, which is 198. And that's just because 00:51:16.320 |
we loaded this +1 here. So basically, this is the way I like to do it. You take long sequences, 00:51:23.040 |
you view them in two-dimensional terms so that you get batches of time. And then we make sure 00:51:29.120 |
to load one additional token. So we basically load a buffer of tokens of B times T +1. And then 00:51:36.080 |
we sort of offset things and view them. And then we have two tensors. One of them is the input to 00:51:40.800 |
the transformer, and the other exactly is the labels. And so let's now reorganize this code 00:51:47.120 |
and create a very simple data loader object that tries to basically load these tokens and feed them 00:51:55.040 |
to the transformer and calculate the loss. Okay, so I reshuffled the code here accordingly. 00:51:59.600 |
So as you can see here, I'm temporarily overriding to run on CPU. And importing the token, 00:52:06.400 |
and all of this should look familiar. We're loading 1,000 characters. I'm setting bt to just 00:52:10.720 |
be 4 and 32 right now, just because we're debugging. We just want to have a single batch that's very 00:52:15.680 |
small. And all of this should now look familiar and follows what we did on the right. And then 00:52:21.040 |
here we create the model and get the logits. And so here, as you see, I already ran this. It only 00:52:29.280 |
runs in a few seconds. But because we have a batch of 4 by 32, our logits are now of size 4 by 32 by 00:52:37.440 |
50,257. So those are the logits for what comes next at every position. And now we have the labels, 00:52:44.480 |
which are stored in Y. So now is the time to calculate the loss, and then do the backward 00:52:48.960 |
pass, and then do the optimization. So let's first calculate the loss. Okay, so to calculate the loss, 00:52:55.040 |
we're going to adjust the forward function of this NN module in the model. And in particular, 00:52:59.840 |
we're not just going to be returning logits, but also we're going to return the loss. 00:53:03.040 |
And we're going to not just pass in the input indices, but also the targets in Y. And now we 00:53:10.640 |
will print not logits.shape anymore, we're actually going to print the loss function, 00:53:15.600 |
and then sys.exit of zero so that we skip some of the sampling logic. 00:53:19.200 |
So now let's swing up to the forward function, which gets called there, because now we also 00:53:25.920 |
have these optional targets. And when we get the targets, we can also calculate the loss. And 00:53:32.800 |
remember that we want to basically return logits.loss, and loss by default is none. But 00:53:38.720 |
let's put this here. If targets is not none, then we want to calculate the loss. And Copilot is 00:53:50.320 |
already getting excited here and calculating the what looks to be correct loss. It is using the 00:53:55.600 |
cross-entropy loss as is documented here. So this is a function in PyTorch under the functional. 00:54:03.440 |
Now, what is actually happening here, because it looks a little bit scary, 00:54:07.280 |
basically the FNet cross-entropy does not like multi-dimensional inputs. It can't take a b by t 00:54:13.280 |
by vocab size. So what's happening here is that we are flattening out this three-dimensional tensor 00:54:18.720 |
into just two dimensions. The first dimension is going to be calculated automatically, and it's 00:54:23.120 |
going to be b times t. And then the last dimension is vocab size. So basically, this is flattening 00:54:29.760 |
out this three-dimensional tensor of logits to just be two-dimensional, b times t, all individual 00:54:35.600 |
examples, and vocab size in terms of the length of each row. And then it's also flattening out 00:54:42.640 |
the targets, which are also two-dimensional at this stage, but we're going to just flatten them 00:54:47.120 |
out so they're just a single tensor of b times t. And this can then pass into cross-entropy to 00:54:52.080 |
calculate a loss, which we return. So this should basically, at this point, run, because it's not 00:54:58.000 |
too complicated. So let's run it, and let's see if we should be printing the loss. 00:55:04.320 |
And here we see that we printed 11, roughly. And notice that this is the tensor of a single 00:55:19.280 |
element, which is this number 11. Now, we also want to be able to calculate a reasonable 00:55:23.840 |
kind of starting point for a randomly initialized network. So we covered this in previous videos, 00:55:29.040 |
but our vocabulary size is 50,257. At initialization of the network, you would hope 00:55:35.360 |
that every vocab element is getting roughly a uniform probability, so that we're not favoring, 00:55:42.960 |
at initialization, any token way too much. We're not confidently wrong at initialization. So we're 00:55:49.120 |
hoping is that the probability of any arbitrary token is roughly 1/50,257. And now we can sanity 00:55:56.160 |
check the loss, because remember that the cross-entropy loss is just basically the negative 00:56:00.320 |
log likelihood. So if we now take this probability, and we take it through the natural logarithm, 00:56:07.120 |
and then we do the negative, that is the loss we expect at initialization, and we covered this in 00:56:13.120 |
previous videos. So I would expect something around 10.82, and we're seeing something around 00:56:18.080 |
11. So it's not way off. This is roughly the probability I expect at initialization. 00:56:22.640 |
So that tells me that at initialization, our probability distribution is roughly diffuse, 00:56:27.520 |
it's a good starting point, and we can now perform the optimization and tell the network which 00:56:33.120 |
elements should follow correctly in what order. So at this point, we can do a loss step backward, 00:56:39.120 |
calculate the gradients, and do an optimization. So let's get to that. 00:56:42.720 |
Okay, so let's do the optimization now. So here we have the loss, this is how we get the loss. 00:56:51.120 |
But now basically we want a little for loop here. So for i in range, let's do 50 steps or 00:56:56.560 |
something like that. Let's create an optimizer object in PyTorch. And so here we are using the 00:57:04.320 |
atom optimizer, which is an alternative to the stochastic gradient descent optimizer, SGD, 00:57:10.320 |
that we were using. So SGD is a lot simpler, atom is a bit more involved. And I actually 00:57:14.320 |
specifically like the atom w variation, because in my opinion, it kind of just like fixes a bug. 00:57:19.920 |
So atom w is a bug fix of atom, is what I would say. When we go to the documentation for atom w, 00:57:26.800 |
oh my gosh, we see that it takes a bunch of hyperparameters, and it's a little bit more 00:57:34.080 |
complicated than the SGD we were looking at before. Because in addition to basically updating 00:57:39.120 |
the parameters with the gradient scaled by the learning rate, it keeps these buffers around, 00:57:44.000 |
and it keeps two buffers, the M and the V, which it calls the first and the second moment. 00:57:49.280 |
So something that looks a bit like momentum is something that looks a bit like RMS prop, 00:57:53.120 |
if you're familiar with it. But you don't have to be, it's just kind of like a normalization 00:57:57.120 |
that happens on each gradient element individually, and speeds up the optimization, 00:58:01.600 |
especially for language models. But I'm not going to go into the detail right here. 00:58:05.360 |
We're going to treat this a bit of a black box, and it just optimizes the objective faster than 00:58:12.000 |
SGD, which is what we've seen in the previous lectures. So let's use it as a black box in our 00:58:16.560 |
case. Create the optimizer object, and then go through the optimization. 00:58:23.680 |
The first thing to always make sure, the copilot did not forget to zero the gradients. 00:58:33.440 |
So always remember that you have to start with a zero gradient. Then when you get your loss, 00:58:38.880 |
and you do a dot backward, dot backward adds to gradients. So it deposits gradients. It always 00:58:44.800 |
does a plus equals on whatever the gradients are, which is why you must set them to zero. 00:58:48.640 |
So this accumulates the gradient from this loss, and then we call the step function on the optimizer 00:58:55.280 |
to update the parameters, and to decrease the loss. Then we print the step, and the loss 00:59:03.760 |
dot item is used here, because loss is a tensor with a single element. Dot item will actually 00:59:09.680 |
convert that to a single float, and this float will live on the CPU. So this gets to some of 00:59:16.880 |
the internals, again, of the devices, but loss is a tensor with a single element, and it lives on 00:59:22.800 |
GPU for me, because I'm using GPUs. When you call dot item, PyTorch behind the scenes will take that 00:59:29.120 |
one-dimensional tensor, ship it back to the CPU memory, and convert it into a float that we can 00:59:34.240 |
just print. So this is the optimization, and this should probably just work. Let's see what happens. 00:59:43.680 |
Actually, sorry. Instead of using CPU override, let me delete that so this is a bit faster for me, 00:59:51.600 |
and it runs on CUDA. Oh, expected all tensors to be on the same device, but found at least two 01:00:02.080 |
devices, CUDA0 and CPU. So CUDA0 is the 0th GPU, because I actually have eight GPUs on this box. 01:00:09.280 |
So the 0th GPU on my box, and CPU. And model, we have moved to device, but when I was writing this 01:00:17.440 |
code, I actually introduced a bug, because buff, we never moved to device. And you have to be 01:00:22.640 |
careful, because you can't just do buff dot two of device. It's not stateful. It doesn't convert 01:00:30.240 |
it to be a device. It instead returns a pointer to a new memory, which is on the device. So you 01:00:36.960 |
see how we can just do model dot two of device, but it does not apply to tensors. You have to do 01:00:41.360 |
buff equals buff dot two device, and then this should work. Okay. So what do we expect to see? 01:00:51.920 |
We expect to see a reasonable loss in the beginning, and then we continue to optimize 01:00:55.600 |
just a single batch. And so we want to see that we can overfit this single batch. We can crush 01:01:00.720 |
this little batch, and we can perfectly predict the indices on just this little batch. And in 01:01:05.520 |
these, that is roughly what we're seeing here. So we started off at roughly 10.82, 11 in this case, 01:01:13.440 |
and then as we continue optimizing on this single batch without loading new examples, 01:01:17.040 |
we are making sure that we can overfit a single batch, and we are getting to very, 01:01:20.800 |
very low loss. So the transformer is memorizing this single individual batch. 01:01:25.280 |
And one more thing I didn't mention is the learning rate here is 3E negative 4, 01:01:30.240 |
which is a pretty good default for most optimizations that you want to run at a very 01:01:35.360 |
early debugging stage. So this is our simple inner loop, and we are overfitting a single batch, 01:01:42.560 |
and this looks good. So now what comes next is we don't just want to overfit a single batch, 01:01:47.200 |
we actually want to do an optimization. So we actually need to iterate these XY batches 01:01:52.000 |
and create a little data loader that makes sure that we're always getting a fresh batch, 01:01:56.560 |
and that we're actually optimizing a reasonable objective. So let's do that next. 01:02:00.240 |
Okay, so this is what I came up with, and I wrote a little data loader light. 01:02:03.360 |
So what this data loader does is we're importing the token up here, 01:02:08.320 |
reading the entire text file from this single input.txt, tokenizing it, and then we're just 01:02:14.720 |
printing the number of tokens in total, and the number of batches in a single epoch of iterating 01:02:20.800 |
over this dataset. So how many unique batches do we output before we loop back around at the 01:02:25.840 |
beginning of the document and start reading it again? So we start off at position 0, and then 01:02:31.600 |
we simply walk the document in batches of B times T. So we take chunks of B times T, and then always 01:02:37.920 |
advance by B times T. And it's important to note that we're always advancing our position by exactly 01:02:44.800 |
B times T, but when we're fetching the tokens, we're actually fetching from current position 01:02:49.840 |
to B times T plus 1. And we need that plus 1 because remember, we need the target token for 01:02:56.960 |
the last token in the current batch. And so that way we can do the XY exactly as we did it before. 01:03:04.560 |
And if we are to run out of data, we'll just loop back around to 0. So this is one way to write a 01:03:12.640 |
very, very simple data loader that simply just goes through the file in chunks. And it's good 01:03:18.960 |
enough for us for current purposes. And we're going to complexify it later. And now we'd like 01:03:24.800 |
to come back around here, and we'd like to actually use our data loader. So the import 01:03:28.800 |
tick token has moved up. And actually, all of this is now useless. So instead, we just want a train 01:03:34.960 |
loader for the training data. And we want to use the same hyperparameters for 4. So batch size was 01:03:41.600 |
4 and time was 32. And then here, we need to get the XY for the current batch. So let's see if 01:03:49.120 |
Copilot gets it, because this is simple enough. So we call the next batch. And then we make sure 01:03:56.080 |
that we have to move our tensors from CPU to the device. So here, when I converted the tokens, 01:04:05.040 |
notice that I didn't actually move these tokens to the GPU. I left them on the CPU, which is default. 01:04:11.600 |
And that's just because I'm trying not to waste too much memory on the GPU. In this case, 01:04:16.320 |
this is a tiny data set that it would fit. But it's fine to just ship it to GPU right now for 01:04:22.320 |
our purposes right now. So we get the next batch. We keep the data loader simple CPU class. And then 01:04:28.880 |
here, we actually ship it to the GPU and do all the computation. And let's see if this runs. 01:04:35.920 |
So Python train GPT 2.py. And what do we expect to see before this actually happens? 01:04:42.240 |
What we expect to see is now we're actually getting the next batch. So we expect to not 01:04:46.560 |
overfit a single batch. And so I expect our loss to come down, but not too much. And that's because 01:04:54.240 |
I still expect it to come down because in the 50,257 tokens, many of those tokens never occur 01:05:00.480 |
in our data set. So there are some very easy gains to be made here in the optimization by, 01:05:05.280 |
for example, taking the biases of all the logits that never occur and driving them to negative 01:05:09.920 |
infinity. And that would basically just, it's just that all of these crazy Unicodes or different 01:05:14.800 |
languages, those tokens never occur. So their probability should be very low. And so the gains 01:05:19.360 |
that we should be seeing are along the lines of basically deleting the usage of tokens that never 01:05:24.880 |
occur. That's probably most of the loss gain that we're going to see at this scale right now. 01:05:29.920 |
But we shouldn't come to a zero because we are only doing 50 iterations. And I don't think that's 01:05:37.360 |
enough to do an epoch right now. So let's see what we got. We have 338,000 tokens, which makes sense 01:05:46.320 |
with our three to one compression ratio, because there are 1 million characters. So one epoch with 01:05:52.640 |
the current setting of B and T will take 2,600 batches. And we're only doing 50 batches of 01:05:58.880 |
optimization in here. So we start off in a familiar territory as expected, and then we seem 01:06:05.600 |
to come down to about 6.6. So basically, things seem to be working okay right now with respect 01:06:12.160 |
to our expectations. So that's good. Okay, next, I want to actually fix a bug that we have in our 01:06:17.520 |
code. It's not a major bug, but it is a bug with respect to how GPT-2 training should happen. 01:06:26.240 |
So the bug is the following. We were not being careful enough when we were loading the weights 01:06:30.480 |
from Hug and Face, and we actually missed a little detail. So if we come here, 01:06:34.400 |
notice that the shape of these two tensors is the same. So this one here is the token embedding at 01:06:43.520 |
the bottom of the transformer. And this one here is the language modeling head at the top of the 01:06:50.560 |
transformer. And both of these are basically two-dimensional tensors, and their shape is 01:06:56.080 |
identical. So here, the first one is the output embedding, the token embedding, and the second 01:07:02.240 |
one is this linear layer at the very top, the classifier layer. Both of them are of shape 01:07:07.520 |
50,257 by 768. This one here is giving us our token embeddings at the bottom, and this one 01:07:16.080 |
here is taking the 768 channels of the transformer and trying to upscale that to 50,257 to get the 01:07:23.680 |
logis for the next token. So they're both the same shape, but more than that, actually, if you look 01:07:29.520 |
at comparing their elements, in PyTorch, this is an element-wise equality. So then we use .all, 01:07:37.360 |
and we see that every single element is identical. And more than that, we see that if we actually 01:07:42.880 |
look at the data pointer, this is a way in PyTorch to get the actual pointer to the data 01:07:49.280 |
and the storage, we see that actually the pointer is identical. So not only are these two separate 01:07:54.480 |
tensors that happen to have the same shape and elements, they're actually pointing to the 01:07:58.480 |
identical tensor. So what's happening here is that this is a common wait-time scheme 01:08:04.480 |
that actually comes from the original "Attention is all you need" paper, 01:08:11.360 |
and actually even the reference before it. So if we come here... 01:08:15.040 |
Embeddings in Softmax in the "Attention is all you need" paper, they mention that in our model, 01:08:25.920 |
we shared the same weight matrix between the two embedding layers and the pre-Softmax linear 01:08:30.720 |
transformation similar to 30. So this is an awkward way to phrase that these two are shared 01:08:37.520 |
and they're tied and they're the same matrix. And the 30 reference is this paper. So this came out 01:08:44.480 |
in 2017. And you can read the full paper, but basically it argues for this wait-time scheme. 01:08:50.400 |
And I think intuitively the idea for why you might want to do this 01:08:54.880 |
comes from this paragraph here. And basically, you can observe that 01:09:03.280 |
you actually want these two matrices to behave similar in the following sense. If two tokens 01:09:10.080 |
are very similar semantically, like maybe one of them is all lowercase and the other one is 01:09:14.560 |
all uppercase, or it's the same token in a different language or something like that, 01:09:18.240 |
if you have similarity between two tokens, presumably you would expect that they are 01:09:22.160 |
nearby in the token embedding space. But in the exact same way, you'd expect that if you 01:09:27.520 |
have two tokens that are similar semantically, you'd expect them to get the same probabilities 01:09:33.120 |
at the output of a transformer because they are semantically similar. 01:09:36.080 |
And so both positions in the transformer at the very bottom and at the top have this property 01:09:43.680 |
that similar tokens should have similar embeddings or similar weights. And so this is what motivates 01:09:50.400 |
their exploration here. And they kind of, you know, I don't want to go through the entire paper 01:09:54.400 |
and you can go through it, but this is what they observe. They also observe that if you look at the 01:10:00.240 |
output embeddings, they also behave like word embeddings. If you just kind of try to use those 01:10:06.880 |
weights as word embeddings. So they kind of observe this similarity, they try to tie them, 01:10:12.800 |
and they observe that they can get much better performance in that way. And so this was adopted 01:10:18.160 |
in the attention is only paper, and then it was used again in GPT-2 as well. So I couldn't find 01:10:25.760 |
it in the transformers implementation. I'm not sure where they tie those embeddings, 01:10:30.080 |
but I can find it in the original GPT-2 code introduced by OpenAI. So this is OpenAI GPT-2 01:10:37.840 |
source model. And here where they are forwarding this model, and this is in TensorFlow, but 01:10:42.880 |
that's okay. We see that they get the WTE token embeddings. And then here is the encoder of the 01:10:50.480 |
token embeddings and the position. And then here at the bottom, they use the WTE again to do the 01:10:57.040 |
logits. So when they get the logits, it's a matmul of this output from the transformer and the WTE 01:11:03.920 |
tensor is reused. And so the WTE tensor basically is used twice on the bottom of the transformer 01:11:10.880 |
and on the top of the transformer. And in the backward pass, we'll get gradients contributions 01:11:16.640 |
from both branches, right? And these gradients will add up on the WTE tensor. So we'll get a 01:11:23.760 |
contribution from the classifier layer. And then at the very end of the transformer, we'll get a 01:11:27.760 |
contribution at the bottom of it, flowing again into the WTE tensor. So we are currently not 01:11:36.960 |
sharing WTE in our code, but we want to do that. So weight sharing scheme. And one way to do this, 01:11:48.720 |
let's see if Copilot gets it. Oh, it does. Okay. So this is one way to do it. Basically, 01:11:56.720 |
relatively straightforward. What we're doing here is we're taking the WTE.weight and we're simply 01:12:03.920 |
redirecting it to point to the LM head. So this basically copies the data pointer, 01:12:12.240 |
right? It copies the reference. And now the WTE.weight becomes orphaned, the old value of it, 01:12:18.800 |
and PyTorch will clean it up. Python will clean it up. And so we are only left with a single 01:12:25.440 |
tensor, and it's going to be used twice in the forward pass. And this is, to my knowledge, 01:12:32.720 |
all that's required. So we should be able to use this, and this should probably train. 01:12:37.440 |
We're just going to basically be using this exact same tensor twice. And we weren't being careful 01:12:45.920 |
with tracking the likelihoods, but according to the paper and according to the results, 01:12:50.240 |
you'd actually expect slightly better results doing this. And in addition to that, one other 01:12:54.480 |
reason that this is very, very nice for us is that this is a ton of parameters, right? 01:12:59.600 |
What is the size of here? It's 768 times 50,257. So this is 40 million parameters. 01:13:07.200 |
And this is a 124 million parameter model. So 40 divide 124. So this is like 30% of the 01:13:14.000 |
parameters are being saved using this weight tying scheme. And so this might be one of the 01:13:19.440 |
reasons that this is working slightly better. If you're not training the model long enough, 01:13:23.120 |
because of the weight tying, you don't have to train as many parameters. And so you become more 01:13:27.760 |
efficient in terms of the training process, because you have fewer parameters and you're 01:13:33.200 |
putting in this inductive bias that these two embeddings should share similarities between 01:13:38.560 |
tokens. So this is the weight tying scheme, and we've saved a ton of parameters. And we expect 01:13:44.400 |
our model to work slightly better because of this scheme. Okay, next, I would like us to be a bit 01:13:48.480 |
more careful with the initialization and to try to follow the way GPT-2 initialized their model. 01:13:53.680 |
Now, unfortunately, the GPT-2 paper and the GPT-3 paper are not very explicit about 01:13:58.640 |
initialization. So we kind of have to read between the lines. And instead of going to the paper, 01:14:03.280 |
which is quite vague, there's a bit of information in the code that OpenAI released. So when we go 01:14:09.040 |
to the model.py, we see that when they initialize their weights, they are using the standard 01:14:14.560 |
deviation of 0.02. And that's how they, so this is a normal distribution for the weights, 01:14:21.600 |
and the standard deviation is 0.02. For the bias, they initialize that with zero. 01:14:26.560 |
And then when we scroll down here, why is this not scrolling? The token embeddings are 01:14:35.920 |
initialized at 0.02, and position embeddings at 0.01 for some reason. So those are the 01:14:42.320 |
initializations, and we'd like to mirror that in GPT-2 in our module here. So here's a snippet of 01:14:48.720 |
code that I sort of came up with very quickly. So what's happening here is at the end of our 01:14:56.400 |
initializer for the GPT module, we're calling the apply function of NNModule, and that iterates all 01:15:01.920 |
the sub-modules of this module, and applies init_weights function on them. And so what's 01:15:09.440 |
happening here is that we're iterating all the modules here, and if they are an NN.linear module, 01:15:16.560 |
then we're going to make sure to initialize the weight using a normal with a standard deviation 01:15:20.640 |
of 0.02. If there's a bias in this layer, we will make sure to initialize that to zero. Note that 01:15:27.200 |
zero initialization for the bias is not actually the PyTorch default. 01:15:30.160 |
By default, the bias here is initialized with a uniform, so that's interesting. So we make sure 01:15:37.520 |
to use zero. And for the embedding, we're just going to use 0.02 and keep it the same. So we're 01:15:44.080 |
not going to change it to 0.01 for positional, because it's about the same. And then if you look 01:15:48.960 |
through our model, the only other layer that requires initialization, and that has parameters, 01:15:53.520 |
is the layer norm. And the PyTorch default initialization sets the scale in the layer norm 01:15:58.240 |
to be one, and the offset in the layer norm to be zero. So that's exactly what we want, 01:16:02.560 |
and so we're just going to keep it that way. And so this is the default initialization 01:16:08.480 |
if we are following the, where is it, the GPT-2 source code that they released. 01:16:16.240 |
I would like to point out, by the way, that typically the standard deviation here on this 01:16:21.760 |
initialization, if you follow the Javier initialization, would be one over the square 01:16:25.440 |
root of the number of features that are incoming into this layer. But if you'll notice, actually, 01:16:31.120 |
0.02 is basically consistent with that, because the d model sizes inside these transformers for 01:16:36.800 |
GPT-2 are roughly 768, 1600, etc. So one over the square root of, for example, 768 gives us 0.03. 01:16:44.720 |
If we plug in 1600, we get 0.02. If we plug in three times that, 0.014, etc. So basically 0.02 01:16:55.520 |
is roughly in the vicinity of reasonable values for these initializations anyway. So it's not 01:17:04.640 |
completely crazy to be hard coding 0.02 here, but you'd like typically something that grows with the 01:17:12.160 |
model size instead. But we will keep this because that is the GPT-2 initialization per their source 01:17:17.120 |
code. But we are not fully done yet on initialization, because there's one more caveat 01:17:21.120 |
here. So here, a modified initialization which accounts for the accumulation on the residual 01:17:27.520 |
path with model depth is used. We scale the weight of residual layers of initialization 01:17:32.160 |
by a factor of one over square root of n, where n is the number of residual layers. 01:17:35.520 |
So this is what GPT-2 paper says. So we have not implemented that yet, and we can do so now. 01:17:41.760 |
Now, I'd like to actually kind of like motivate a little bit what they mean here, I think. So 01:17:47.440 |
here's roughly what they mean. If you start out with zeros in your residual stream, remember that 01:17:54.320 |
each residual stream is of this form, where we continue adding to it. x is x plus something, 01:18:01.520 |
some kind of contribution. So every single block of the residual network contributes some 01:18:07.040 |
amount, and it gets added. And so what ends up happening is that the variance of the activations 01:18:15.600 |
in the residual stream grows. So here's a small example. If we start at zero, and then we for 100 01:18:21.840 |
times, we have sort of this residual stream of 768 zeros. And then 100 times, we add random, 01:18:31.040 |
which is a normal distribution, zero mean, one standard deviation. If we add to it, then by the 01:18:37.120 |
end, the residual stream has grown to have standard deviation of 10. And that's just because we're 01:18:43.600 |
always adding these numbers. And so this scaling factor that they use here exactly compensates for 01:18:52.080 |
that growth. So if we take n, and we basically scale down every one of these contributions into 01:18:59.600 |
the residual stream by one over the square root of n. So one over the square root of n is n to the 01:19:05.200 |
negative 0.5, right? Because n to the 0.5 is the square root, and then one over the square root is 01:19:12.800 |
n negative 0.5. If we scale it in this way, then we see that we actually get one. So this is a way 01:19:23.120 |
to control the growth of activations inside the residual stream in the forward pass. And so we'd 01:19:29.040 |
like to initialize in the same way, where these weights that are at the end of each block, so this 01:19:34.960 |
CPROJ layer, the GPT paper proposes to scale down those weights by one over the square root of the 01:19:42.240 |
number of residual layers. So one crude way to implement this is the following. I don't know if 01:19:48.240 |
this is PyTorch-sanctioned, but it works for me, is we all do in the initialization, see that 01:19:56.320 |
special nano-GPT scale in it is one. So we're setting kind of like a flag for this module. 01:20:06.880 |
There must be a better way than PyTorch, right? But I don't know. Okay, so we're basically 01:20:13.200 |
attaching this flag and trying to make sure that it doesn't conflict with anything previously. 01:20:17.920 |
And then when we come down here, this STD should be 0.02 by default. 01:20:24.880 |
But then if it has at her module of this thing, then STD times equals. 01:20:33.600 |
Cobalt is not guessing correctly. So we want one over the square root of the number of layers. 01:20:42.320 |
So the number of residual layers here is twice times self.config layers, and then this times 01:20:54.560 |
negative 0.5. So we want to scale down that standard deviation, and this should be correct 01:21:01.360 |
and implement that. I should clarify, by the way, that the two times number of layers comes from the 01:21:06.000 |
fact that every single one of our layers in the transformer actually has two blocks that add to 01:21:10.800 |
the residual pathway, right? We have the attention and then the MLP. So that's where the two times 01:21:15.280 |
comes from. And the other thing to mention is that what's slightly awkward, but we're not going to 01:21:21.520 |
fix it, is that because we are weight sharing the WTE and the LMHead, in this iteration of our old 01:21:29.600 |
submodules, we're going to actually come around to that tensor twice. So we're going to first 01:21:33.920 |
initialize it as an embedding with 0.02, and then we're going to come back around it again in the 01:21:38.800 |
linear and initialize it again using 0.02. And it's going to be 0.02 because the LMHead is, of 01:21:44.720 |
course, not scaled. So it's not going to come here. It's just it's going to be basically initialized 01:21:49.680 |
twice using the identical same initialization, but that's okay. And then scrolling over here, 01:21:55.920 |
I added some code here so that we have reproducibility to set the seeds. And now 01:22:03.680 |
we should be able to Python train GPT2.py and let this running. And as far as I know, 01:22:09.360 |
this is the GPT2 initialization in the way we've implemented it right now. So this looks 01:22:17.360 |
reasonable to me. Okay. So at this point, we have the GPT2 model. We have some confidence that it's 01:22:22.640 |
correctly implemented. We've initialized it properly. And we have a data loader that's 01:22:26.400 |
iterating through data batches, and we can train. So now comes the fun part. I'd like us to speed 01:22:31.440 |
up the training by a lot. So we're getting our money's worth with respect to the hardware that 01:22:35.360 |
we are using here. And we're going to speed up the training by quite a bit. Now, you always want to 01:22:42.400 |
start with what hardware do you have? What does it offer? And are you fully utilizing it? So in my 01:22:47.520 |
case, if we go to NVIDIA SMI, we can see that I have eight GPUs. And each one of those GPUs is an 01:22:57.760 |
A100 SXM 80 gigabytes. So this is the GPU that I have available to me in this box. Now, when I use 01:23:07.680 |
to spin up these kinds of boxes, by the way, my favorite place to go to is Lambda Labs. They do 01:23:14.160 |
sponsor my development and that of my projects. But this is my favorite place to go. And this is 01:23:20.320 |
where you can spin up one of these machines, and you pay per hour, and it's very, very simple. 01:23:24.080 |
So I like to spin them up and then connect VS Code to it, and that's how I develop. 01:23:28.320 |
Now, when we look at the A100s that are available here, A100 80 gigabyte SXM is the 01:23:36.880 |
GPU that I have here. And we have a bunch of numbers here for how many calculations you can 01:23:42.720 |
expect out of this GPU. So when I come over here and I break in right after here. So Python. So 01:23:52.480 |
I'm breaking in right after we calculate the logits and the loss. And the interesting thing I'd like 01:23:57.680 |
you to note is when I do logits.dtype, this prints a torch.float32. So by default in PyTorch, when 01:24:06.960 |
you create tensors, and this is the case for all the activations and for the parameters of the 01:24:11.280 |
network and so on, by default, everything is in float32. That means that every single number, 01:24:17.680 |
activation or weight and so on, is using a float representation that has 32 bits. 01:24:24.560 |
And that's actually quite a bit of memory. And it turns out empirically that for deep learning 01:24:28.400 |
as a computational workload, this is way too much. And deep learning and the training of 01:24:33.040 |
these networks can tolerate significantly lower precisions. Not all computational workflows 01:24:38.960 |
can tolerate small precision. So for example, if we go back to the data sheet, you'll see that 01:24:46.000 |
actually these GPUs support up to FB64. And this is quite useful, I understand, for a lot of 01:24:52.080 |
scientific computing applications. And there they really need this. But we don't need that much 01:24:56.720 |
precision for deep learning training. So currently we are here, FP32. And with this code as it is 01:25:03.920 |
right now, we expect to get at most 19.5 teraflops of performance. That means we're doing 19.5 01:25:11.680 |
trillion operations, floating point operations. So this is floating point multiply, add, most 01:25:19.520 |
likely. And so these are the floating point operations. Now notice that if we are willing 01:25:27.360 |
to go down in precision, so TF32 is a lower precision format we're going to see in a second, 01:25:33.360 |
you can actually get an 8x improvement here. And if you're willing to go down to float16 or 01:25:38.160 |
bfloat16, you can actually get times 16x performance, all the way to 312 teraflops. 01:25:45.520 |
You see here that NVIDIA likes to cite numbers that have an asterisk here. 01:25:49.120 |
This asterisk says with sparsity. But we are not going to be using sparsity in our code. And I 01:25:55.120 |
don't know that this is very widely used in the industry right now. So most people look at this 01:25:59.840 |
number here without sparsity. And you'll notice that we could have got even more here. But this 01:26:05.600 |
is int8. And int8 is used for inference, not for training. Because int8 has a, 01:26:12.880 |
it basically has uniform spacing. And we actually require a float so that we get a better match 01:26:24.320 |
to the normal distributions that occur during training of neural networks, where both 01:26:30.160 |
activations and weights are distributed as a normal distribution. And so floating points are 01:26:35.280 |
really important to match that representation. So we're not typically using int8 for training, 01:26:42.160 |
but we are using it for inference. And if we bring down the precision, we can get a lot more 01:26:47.680 |
teraflops out of the tensor cores available in the GPUs. We'll talk about that in a second. 01:26:52.400 |
But in addition to that, if all of these numbers have fewer bits of representation, it's going to 01:26:58.080 |
be much easier to move them around. And that's where we start to get into the memory bandwidth 01:27:02.640 |
and the memory of the model. So not only do we have a finite capacity of the number of bits that 01:27:07.760 |
our GPU can store, but in addition to that, there's a speed with which you can access this memory. 01:27:13.440 |
And you have a certain memory bandwidth. It's a very precious resource. And in fact, many of the 01:27:20.080 |
deep learning workloads for training are memory bound. And what that means is actually that the 01:27:25.520 |
tensor cores that do all these extremely fast multiplications, most of the time they're waiting 01:27:30.400 |
around, they're idle, because we can't feed them with data fast enough. We can't load the data 01:27:36.880 |
fast enough for memory. So typical utilizations of your hardware, if you're getting 60% utilization, 01:27:42.800 |
you're actually doing extremely well. So half of the time in a well-tuned application, 01:27:48.400 |
your tensor cores are not doing multiplies because the data is not available. 01:27:51.840 |
So the memory bandwidth here is extremely important as well. And if we come down in 01:27:56.320 |
the precision for all the floats, all the numbers, weights, and activations suddenly require less 01:28:01.920 |
memory. So we can store more and we can access it faster. So everything speeds up and it's amazing. 01:28:08.320 |
And now let's reap the benefits of it. And let's first look at the tensor float 32 format. 01:28:14.080 |
Okay, so first of all, what are tensor cores? Well, tensor core is just an instruction 01:28:20.800 |
in the A100 architecture, right? So what it does is it does basically a little 4x4 matrix multiply. 01:28:28.000 |
So this is just matrix multiplication here of 4x4 matrices. And there are multiple configurations 01:28:37.200 |
as to what precision any of these matrices are, in what precision the internal accumulate happens, 01:28:43.520 |
and then what is the output precision, input precision, etc. So there's a few switches, 01:28:48.560 |
but it's basically a 4x4 multiply. And then any time we have any operations that require matrix 01:28:54.000 |
multiplication, they get broken up into this instruction of a little 4x4 multiply. And so 01:29:01.040 |
everything gets broken up into this instruction because it's the fastest way to multiply matrices. 01:29:05.120 |
And it turns out that most of the computational work that we're doing up above, 01:29:08.960 |
all of it really is matrix multiplication. Most of the work computationally happens in 01:29:14.080 |
the linear layers, linear, linear, etc. There's a few things sandwiched in between. So there's 01:29:22.400 |
some additions in residuals, there's some Gelud nonlinearities, there's some layer norms, etc. 01:29:27.760 |
But if you just time them, you'll see that these are nothing. Like basically, 01:29:31.760 |
the entire transformer is just a bunch of matrix multiplications, really. 01:29:35.680 |
And especially at this small scale, 124 million parameter model, actually the biggest matrix 01:29:42.720 |
multiplication by far is the classifier layer at the top. That is a massive matrix multiply 01:29:47.840 |
of going from 768 to 50,257. And that matrix multiply dominates anything else that happens 01:29:54.880 |
in that network, roughly speaking. So it's matrix multiplies that become a lot faster, 01:30:00.720 |
which are hidden inside our linear layers. And they're accelerated through tensor cores. 01:30:05.680 |
Now, the best reference I would say for tensor cores is basically just go to the 01:30:11.680 |
A100 architecture whitepaper. And then it's pretty detailed. But I think people, 01:30:17.760 |
it's like relatively readable mostly, if you half understand what's happening. 01:30:21.280 |
So figure 9 tensor float 32. So this is the explanation basically for TF32 and what happens 01:30:29.760 |
here. And you see that there's many configuration options here available. So the input operands, 01:30:35.840 |
and what precisions are they in, the accumulator, and what basically the 01:30:41.680 |
internal representation within the instruction when you do the accumulate of this matrix 01:30:47.840 |
multiplication. So the intermediate plus equals of the intermediate little vector multiplies here, 01:30:54.880 |
that all happens in FP32. And then this is an 8x improvement, as I mentioned, to the 01:31:01.520 |
ops that we got. So TF32 specifically, we're looking at this row here. And the way this works is 01:31:06.960 |
normally FP32 has 32 bits. TF32 is the exact same bits. We have one sine bit, 01:31:19.040 |
we have eight exponent bits, except the mantissa bits get cropped in the float. 01:31:25.120 |
And so basically, we end up with just 19 bits, instead of 32 bits, because the last 13 bits 01:31:32.640 |
get truncated, they get dropped. And all this is internal to the instruction. So none of it is 01:31:39.120 |
visible to anything in our PyTorch. None of our PyTorch code will change, all the numbers will 01:31:44.400 |
look identical. It's just that when you call the tensor core instruction, internally in the hardware, 01:31:51.280 |
it will crop out these 13 bits. And that allows it to calculate this little matrix multiply 01:31:58.320 |
significantly faster, 8x faster. Now, of course, this speed up comes at a cost. And the cost is 01:32:04.800 |
that we are reducing the precision. Our accumulate is still in FP32, our output is FP32, our inputs 01:32:10.880 |
are FP32. But internally, things get truncated in the operands to perform the operation faster. 01:32:17.440 |
And so our results are starting to be a bit more approximate. But empirically, when you actually 01:32:21.440 |
train with this, you basically can't tell the difference. So the reason I like TF32 is because 01:32:26.080 |
if you can tolerate a little bit of a precision fudge, then this is free, like none of your code 01:32:33.120 |
sees this, it's fully internal to the operation, and the operation to you just go 8x faster, 01:32:39.120 |
and it's a bit more approximate. And so it's a pretty sweet spot, I would say in optimization. 01:32:44.480 |
And let's see what that looks like first. So I've set up our codes to just time the iterations. 01:32:50.720 |
So import time, I changed the hyper parameters so that we have something a bit more that reflects 01:32:56.240 |
a kind of workload that we want to run, because we want to do a fairly large run at the end of this. 01:33:01.040 |
So let's use batch size 16. And let's now use the actual GPT-2 maximum sequence length of 1024 01:33:08.000 |
tokens. So this is the configuration. And then for 50 iterations, I'm just doing something very lazy 01:33:16.640 |
here. I'm doing time.time to get the current time. And then this is the optimization loop. 01:33:22.240 |
And now I want to time how long this takes. Now, one issue with working with GPUs is that 01:33:31.040 |
as your CPU-- when your CPU runs, it's just scheduling work on GPU. It's ordering some work, 01:33:38.640 |
right? And so it sends a request, and then it continues running. And so it can happen sometimes 01:33:44.400 |
that we sort of speed through this, and we queue up a lot of kernels to run on the GPU, 01:33:51.360 |
and then the CPU sort of gets here and takes time.time. But actually, the GPU is still running, 01:33:56.160 |
because it takes it time to actually work through the work that was scheduled to run. 01:34:00.560 |
And so you're just building up a queue for the GPU. And so actually, if you need to, 01:34:06.080 |
you want to wait, torchatku.data.synchronize. And this will wait for the GPU to finish all the work 01:34:11.680 |
that was scheduled to run up above here. And then we can actually take the time. So basically, 01:34:17.840 |
we're waiting for the GPU to stop this iteration, take the time, and then we're going to just print 01:34:23.040 |
it. So here, I'm going to run the training loop. And here on the right, I'm watching NVIDIA SMI. 01:34:30.400 |
So we start off at 0. We're not using the GPU. And then by default, PyTorch will use GPU 0, 01:34:36.640 |
so we see that it gets filled up. And we're using 35 gigabytes out of 80 gigabytes available. 01:34:42.160 |
And then here on the left, we see that because we've cranked up the batch size, 01:34:49.840 |
now it's only 20 batches to do a single epoch on our tiny Shakespeare. 01:34:53.120 |
And we see that we're seeing roughly 1,000 milliseconds per iteration here, right? 01:34:58.160 |
So the first iteration sometimes is slower. And that's because PyTorch might be doing a lot of 01:35:05.200 |
initializations here on the very first iteration. And so it's probably initializing all these 01:35:09.760 |
tensors and buffers to hold all the gradients. And I'm not 100% sure all the work that happens here, 01:35:14.960 |
but this could be a slower iteration. When you're timing your logic, you always want 01:35:18.960 |
to be careful with that. But basically, we're seeing 1,000 milliseconds per iteration. 01:35:23.360 |
And so this will run for roughly 50 seconds as we have it right now. 01:35:28.000 |
So that's our baseline in float32. One more thing I wanted to mention is that 01:35:33.760 |
if this doesn't fit into your GPU and you're getting out of memory errors, 01:35:38.080 |
then start decreasing your batch size until things fit. So instead of 16, try 8 or 4 01:35:43.120 |
or whatever you need to fit the batch into your GPU. And if you have a bigger GPU, you can actually 01:35:49.040 |
potentially get away with 32 and so on. By default, you want to basically max out the batch size that 01:35:55.600 |
fits on your GPU. And you want to keep it nice numbers. So use numbers that have lots of powers 01:36:01.680 |
of 2 in them. So 16 is a good number. 8, 24, 32, 48, these are nice numbers. But don't use something 01:36:09.600 |
like 17, because that will run very inefficiently on the GPU. And we're going to see that a bit 01:36:14.560 |
later as well. So for now, let's just stick with 16, 1,024. And the one thing that I added also 01:36:22.320 |
here, and I ran it again, is I'm calculating tokens per second throughput during training. 01:36:28.960 |
Because we might end up changing the batch size around over time. But tokens per second is the 01:36:34.720 |
objective measure that we actually really care about. How many tokens of data are we training 01:36:38.960 |
on? And what is the throughput of tokens that we're getting in our optimization? So right now, 01:36:43.280 |
we're processing and training on 163,000 tokens per second, roughly. And that's a bit more 01:36:49.360 |
objective metric. Okay, so let's now enable TF32. Now, luckily, PyTorch makes this fairly easy for 01:36:56.240 |
us. And to enable TF32, you just need to do a single line. And it's this. And when we go to the 01:37:03.600 |
PyTorch documentation here for this function, basically, this tells PyTorch what kind of 01:37:08.000 |
kernels to run. And by default, I believe it is highest. Highest precision for matmul. And that 01:37:15.520 |
means that everything happens in float32, just like it did before. But if we set it to high, 01:37:20.480 |
as we do right now, matrix multiplications will now use TensorFlow 32 when it's available. 01:37:26.480 |
My GPU is the A100. So it's an ampere series. And therefore, TF32 is available. If you have an older 01:37:34.640 |
GPU, this might not be available for you. But for my GPU, it's available. And so what I expect 01:37:40.080 |
PyTorch to do is that every single place where we see an nn.linear, inside there, there's a 01:37:45.040 |
matrix multiplication. And I expect that matrix multiplication now to be running on TensorCourse, 01:37:51.200 |
utilizing the TF32 precision. So this is the single line of change that is, I believe, necessary. And 01:37:59.200 |
let's rerun this. Now, we saw that in terms of the throughput that is promised to us, we're supposed 01:38:05.680 |
to be getting 8x, roughly. So let's see what happens. And that 8x came from here, right? 8x. 01:38:16.640 |
And it also came from looking at it here, 156 tflops instead of 19.5. Okay, so what actually 01:38:26.240 |
happened? So we're seeing that our throughput, roughly 3x, not 8x. So we are going from 1000 01:38:35.600 |
milliseconds, we're going down to 300 milliseconds, and our throughput is now about 50,000 tokens per 01:38:40.560 |
second. So we have a roughly 3x instead of 8x. So what happened? And basically, what's happening 01:38:46.240 |
here is, again, a lot of these workloads are memory bound. And so even though the TF32 offers, 01:38:54.000 |
in principle, a lot faster throughput, all of these numbers everywhere are still float32s. 01:39:01.280 |
And it's float32 numbers that are being shipped all over the place through the memory system. 01:39:05.840 |
And it's just costing us way too much time to shuttle around all this data. And so even though 01:39:10.160 |
we've made the multiply itself much faster, we are memory bound, and we're not actually seeing 01:39:14.480 |
the full benefit that would come from this napkin math here. That said, we are getting 01:39:21.200 |
3x faster throughput. And this is free. Single line of code in PyTorch. All your variables are 01:39:28.640 |
still float32 everywhere. It just runs faster. And it's slightly more approximate, but we're 01:39:33.760 |
not going to notice it, basically. So that's TF32. Okay, so let's now continue. So we've exercised 01:39:41.920 |
this row. And we saw that we can crop out some of the precision inside the operation itself. 01:39:48.560 |
But we saw that we're still memory bound. We're still moving around all these floats, 01:39:51.840 |
right? Otherwise. And we're paying that cost because of this. So let's now decrease the 01:39:56.400 |
amount of stuff that we're going to be moving around. And we're going to do that by dropping 01:40:00.960 |
down to bfloat16. So we're only going to be maintaining 16 bits per float. And we're going 01:40:07.440 |
to use the bfloat16. And I'll explain in a bit FP16 difference. And we're going to be in this 01:40:12.880 |
row. So when we go back to the documentation here for the A100, we see here the precisions 01:40:23.120 |
that are available. And this is the original FP32. The TF32 crops out the precision. And then here 01:40:29.760 |
in bfloat16, you see that it is very similar to TF32. But it's even more aggressive in cropping 01:40:36.640 |
off the precision, the mantissa, of this float. So the important thing with bfloat16 is that the 01:40:42.480 |
exponent bits and the sign bit, of course, remain unchanged. So if you're familiar with your float 01:40:48.000 |
numbers, and I think this should probably be an entire video by itself, the exponent sets the 01:40:55.200 |
range that you can represent of your numbers. And the precision is how much precision you have 01:41:00.880 |
for your numbers. And so the range of numbers is identical. But we have fewer 01:41:07.520 |
possibilities within that range, because we are truncating the mantissa. So we have less precision 01:41:13.680 |
in that range. What that means is that things are actually fairly nice, because we have the 01:41:19.600 |
original range of numbers that are representable in float. But we just have less precision for it. 01:41:25.920 |
And the difference with FP16 is that they actually touch and change the range. So FP16 cannot 01:41:32.640 |
represent the full range of FP32. It has a reduced range. And that's where you start to 01:41:38.240 |
actually run into issues, because now you need these gradient scalers and things like that. 01:41:43.120 |
And I'm not going to go into the detail of that in this video, because that's a whole video by 01:41:48.240 |
itself. But FP16 actually historically came first. That was available in the Volta series before 01:41:54.400 |
Ampere. And so FP16 came first, and everyone started to train in FP16. But everyone had to 01:41:59.680 |
use all these gradient scaling operations, which are kind of annoying. And it's an additional source 01:42:04.240 |
of state and complexity. And the reason for that was because the exponent range was reduced in FP16. 01:42:10.160 |
So that's the IEEE FP16's spec. And then they came out with BF16 and the Ampere. And they made it 01:42:17.680 |
much simpler, because we're just truncating mantissa, we have the exact same range, and we do 01:42:21.680 |
not need gradient scalers. So everything is much, much simpler. Now, when we do use BF16, though, 01:42:27.680 |
we are impacting the numbers that we might be seeing in our PyTorch code. 01:42:32.080 |
This change is not just local to the operation itself. So let's see how that works. 01:42:38.160 |
There's some documentation here that-- so I think this is probably the best page to explain how to 01:42:45.600 |
use mixed precision in PyTorch. Because there are many other tutorials and so on, even within 01:42:51.680 |
PyTorch documentation, that are a lot more confusing. And so I recommend specifically 01:42:55.920 |
this one. Because there's five other copies that I would not recommend. And then when we come here, 01:43:01.680 |
ignore everything about everything. Ignore everything about gradient scalers. 01:43:06.560 |
And only look at torch.autocast. And basically, also, this comes to a single line of code at 01:43:14.640 |
the end. So this is the context manager that we want. And we want to use that in our network. 01:43:22.400 |
When you click into the torch.autocast, autocasting, it has a few more-- a bit more 01:43:29.440 |
guideline for you. So it's telling you, do not call BFloat16 on any of your tensors. 01:43:35.200 |
Just use autocast. And only surround the forward pass of the model and the loss calculation. 01:43:41.600 |
And that's the only two things that you should be surrounding. Leave the backward and the 01:43:45.280 |
optimizer step alone. So that's the guidance that comes from the PyTorch team. So we're going to 01:43:50.400 |
follow that guidance. And for us, because the loss calculation is inside of the model forward pass for 01:43:55.600 |
us, we are going to be doing this. And then we don't want to be using torch.float16. Because if 01:44:01.840 |
we do that, we need to start using gradient scalers as well. So we are going to be using BFloat16. 01:44:07.120 |
This is only possible to do in Ampere. But this means that the changes are extremely minimal. 01:44:12.720 |
Well, it's basically just this one line of code. Let me first break in to here, before we actually 01:44:21.440 |
run this. So right after logits. I'd like to show you that, different from the TF32 that we saw, 01:44:28.720 |
this is actually going to impact our tensors. So this logits tensor, if we now look at this, 01:44:36.720 |
and we look at the D type, we suddenly see that this is now BFloat16. It's not float32 anymore. 01:44:43.360 |
So our activations have been changed. The activations tensor is now BFloat16. But not 01:44:48.640 |
everything has changed. So model.transformer.wte. This is the weight token embedding table. It has 01:44:59.120 |
a dot weight inside it. And the D type of this weight, this parameter, is still torch.float32. 01:45:06.160 |
So our parameters seem to still be in float32, but our activations, the logits, are now in BFloat16. 01:45:11.760 |
So clearly, this is why we get the mixed precision. Some things PyTorch is keeping 01:45:17.040 |
in float32. Some things PyTorch is converting to lower precision. And what gets converted, 01:45:24.640 |
at what point, is not super clear. I remember scrolling down. Is it here? 01:45:31.920 |
Okay, I can't find it. I thought it was here. Okay, there we go. So there are a few docs on 01:45:44.160 |
when you're using this autocast, what gets converted to BFloat16 and when. So for example, 01:45:49.760 |
only these matrix multiply-like operations get converted to BFloat16. But a lot of operations 01:45:55.360 |
remain in float32. So in particular, a lot of normalizations, like layer norms and things like 01:45:59.920 |
that, not all of those layers might be converted. So only some layers selectively would be running 01:46:06.400 |
BFloat16. But things like softmax, layer norms, log softmax, so loss function calculations, 01:46:14.800 |
a lot of those things might remain in float32 because they are more susceptible to precision 01:46:19.040 |
changes. Matrix multiplies are fairly robust to precision changes. So some parts of the network 01:46:27.520 |
are impacted more or less by the precision change. So basically only some parts of the model are 01:46:35.520 |
running in reduced precision. Let's take it for a spin and let's actually see what kind of 01:46:41.680 |
improvement we achieve here. Okay, so we used to be 333 milliseconds. We're now at 300. 01:46:52.880 |
And we used to be somewhere around 50,000 tokens per second. We're now at 55. 01:46:56.960 |
So we're definitely running faster, but maybe not a lot faster. And that's because there are 01:47:02.960 |
still many, many bottlenecks in our GPT-2. We're just getting started. But we have dropped down 01:47:07.840 |
the precision as far as we can with my current GPU, which is A100. We're using PyTorch Autocast. 01:47:14.320 |
Unfortunately, I don't actually exactly know what PyTorch Autocast does. I don't actually know 01:47:20.000 |
exactly what's in BFloat16, what's in float32. We could go in and we could start to scrutinize it. 01:47:25.120 |
But these are the kinds of rules that PyTorch has internally. And unfortunately, they don't 01:47:30.400 |
document it very well. So we're not going to go into that in too much detail. But for now, 01:47:36.560 |
we are training in BFloat16. We do not need a gradient scaler. And the reason things are 01:47:41.200 |
running faster is because we are able to run TensorCourse in BFloat16 now. That means we are 01:47:49.360 |
in this row. But we are also paying in precision for this. So we expect slightly less accurate 01:47:57.840 |
results with respect to the original FP32. But empirically, in many cases, this is a worth it 01:48:04.080 |
trade-off because it allows you to run faster. And you could, for example, train longer and make 01:48:08.800 |
up for that precision decrease. So that's BFloat16 for now. Okay. So as we can see, 01:48:17.120 |
we are currently at about 300 milliseconds per iteration. And we're now going to reach for some 01:48:21.840 |
really heavy weapons in the PyTorch arsenal. And in particular, we're going to introduce 01:48:25.920 |
Torch.compile. So Torch.compile is really quite incredible infrastructure from the PyTorch team. 01:48:31.840 |
And it's basically a compiler for neural networks. It's almost like GCC for C and C++ code. This is 01:48:38.240 |
just the GCC of neural nets. So it came out a while ago and extremely simple to use. The way 01:48:46.800 |
to use Torch.compile is to do this. It's a single line of code to compile your model and return it. 01:48:53.200 |
Now, this line of code will cost you compilation time. But as you might guess, it's going to make 01:48:58.000 |
the code a lot faster. So let's actually run that. Because this will take some time to run. 01:49:02.960 |
But currently, remember, we're at 300 milliseconds. And we'll see what happens. 01:49:06.240 |
Now, while this is running, I'd like to explain a little bit of what Torch.compile does under 01:49:11.600 |
the hood. So feel free to read this page of PyTorch. But basically, there's no real good 01:49:17.680 |
reason for you to not use Torch.compile in your PyTorch. I kind of feel like you should be using 01:49:22.640 |
it almost by default unless you're debugging and you want your code to run really fast. 01:49:27.920 |
And there's one line here in Torch.compile that I found that actually kind of gets to 01:49:31.760 |
why this is faster. Speed up mainly comes from reducing Python overhead and GPU read/writes. 01:49:38.240 |
So let me unpack that a little bit. Okay. Here we are. Okay. So we went from 300 milliseconds. 01:49:44.400 |
We're now running at 129 milliseconds. So this is 300 divided by 129, about 2.3x 01:49:52.080 |
improvement from a single line of code in PyTorch. So quite incredible. So what is happening? What's 01:49:57.520 |
happening under the hood? Well, when you pass the model to Torch.compile, what we have here in this 01:50:03.840 |
NN module, this is really just the algorithmic description of what we'd like to happen in our 01:50:09.360 |
network. And Torch.compile will analyze the entire thing. And it will look at what operations you'd 01:50:15.840 |
like to use. And with the benefit of knowing exactly what's going to happen, it doesn't 01:50:21.040 |
have to run in what's called the eager mode. It doesn't have to just kind of like go layer by 01:50:25.680 |
layer. Like the Python interpreter normally would start at the forward. And the Python interpreter 01:50:33.280 |
will go, okay, let's do this operation. And then let's do that operation. And it kind of materializes 01:50:38.640 |
all the operations as it goes through. So these calculations are dispatched and run in this order. 01:50:45.360 |
And the Python interpreter and this code doesn't know what kind of operations are going to happen 01:50:49.680 |
later. But Torch.compile sees your entire code at the same time. And it's able to know what 01:50:55.200 |
operations you intend to run. And it will kind of optimize that process. The first thing it will do 01:51:00.640 |
is it will take out the Python interpreter from the forward pass entirely. And it will kind of 01:51:05.120 |
compile this entire neural net as a single object with no Python interpreter involved. So it knows 01:51:10.800 |
exactly what's going to run. It will just run that. And it's all going to be running in efficient code. 01:51:16.640 |
The second thing that happens is this read/write that they mentioned very briefly. So a good 01:51:23.120 |
example of that, I think, is the Gelu nonlinearity that we've been looking at. 01:51:26.240 |
So here we use the nngelu. Now, this here is me basically just breaking up the nngelu, 01:51:34.000 |
which you remember has this formula. So this here is the equivalent implementation to what's 01:51:39.840 |
happening inside Gelu. Algorithmically, it's identical. Now, by default, if we just were 01:51:46.240 |
using this instead of nngelu here, what would happen without Torch.compile? Well, the Python 01:51:51.840 |
interpreter would make its way here. And then it would be, okay, well, there's an input. Well, 01:51:56.160 |
let me first let me raise this input to the third power. And it's going to dispatch a kernel that 01:52:01.840 |
takes your input and raises it to the third power. And that kernel will run. And when this kernel 01:52:08.400 |
runs, what ends up happening is this input is stored in the memory of the GPU. So here's a 01:52:14.080 |
helpful example of the layout of what's happening, right? You have your CPU. This is in every single 01:52:20.240 |
computer. There's a few cores in there. And you have your RAM, your memory. And the CPU can talk 01:52:27.040 |
to the memory. And this is all well known. But now we've added the GPU. And the GPU is a slightly 01:52:31.840 |
different architecture, of course. They can communicate. And it's different in that it's got 01:52:35.920 |
a lot more cores than a CPU. All of those cores are individually a lot simpler, too. 01:52:41.200 |
But it also has memory, right? This high bandwidth memory. Sorry if I'm botching it. 01:52:48.480 |
HBM. I don't even know what that stands for. I'm just realizing now. But this is the memory. And 01:52:54.480 |
it's very equivalent to RAM, basically, in the computer. And what's happening is that input is 01:53:00.640 |
living in the memory. And when you do input cubed, this has to travel to the GPU, to the cores, 01:53:10.240 |
and to all the caches and registers on the actual chip of this GPU. And it has to calculate all 01:53:17.680 |
the elements of the third. And then it saves the result back to the memory. And it's this travel 01:53:23.840 |
time that actually causes a lot of issues. So here, remember this memory bandwidth? We can 01:53:29.760 |
communicate about 2 terabytes per second, which is a lot. But also, we have to traverse this link, 01:53:35.760 |
and it's very slow. So here on the GPU, we're on chip, and everything is super fast within the 01:53:40.400 |
chip. But going to the memory is extremely expensive. It takes an extremely long amount of 01:53:44.400 |
time. And so we load the input, do the calculations, and load back the output. And this round trip 01:53:51.520 |
takes a lot of time. And now right after we do that, we multiply by this constant. So what happens 01:53:57.840 |
then is we dispatch another kernel. And then the result travels back. All the elements get 01:54:03.280 |
multiplied by a constant. And then the results travel back to the memory. And then we take the 01:54:08.960 |
result, and we add back input. And so this entire thing, again, travels to the GPU, adds the inputs, 01:54:16.240 |
and gets written back. So we're making all these round trips from the memory to actually where the 01:54:22.000 |
computation happens. Because all the tensor cores and the ALUs and everything like that is all 01:54:26.720 |
stored on the chip and the GPU. So we're doing a ton of round trips. And PyTorch, without using 01:54:32.480 |
Torch Compile, doesn't know to optimize this, because it doesn't know what kind of operations 01:54:37.200 |
you're running later. You're just telling it, raise the power to the third, then do this, then 01:54:42.160 |
do that. And it will just do that in that sequence. But Torch Compile sees your entire code. 01:54:46.480 |
It will come here, and it will realize, wait, all of these are element-wise operations. And actually, 01:54:51.440 |
what I'm going to do is I'm going to do a single trip of input to the GPU. 01:54:56.320 |
Then for every single element, I'm going to do all of these operations while that memory is on the 01:55:01.680 |
GPU, or chunks of it, rather. And then I'm going to write back a single time. So we're not going 01:55:07.520 |
to have these round trips. And that's one example of what's called kernel fusion, and is a major way 01:55:12.560 |
in which everything is sped up. So basically, if you have your benefit of handset, and you know 01:55:16.560 |
exactly what you're going to compute, you can optimize your round trips to the memory. And 01:55:21.600 |
you're not going to pay the memory bandwidth cost. And that's fundamentally what makes some of these 01:55:25.520 |
operations a lot faster, and what they mean by read/writes here. So let me erase this, because 01:55:32.320 |
we are not using it. And yeah, we should be using Torch Compile. And our code is now significantly 01:55:39.760 |
faster. And we're doing about 125,000 tokens per second. But we still have a long way to go. 01:55:46.160 |
Before we move on, I wanted to supplement the discussion a little bit with a few more figures. 01:55:50.640 |
Because this is a complicated topic, but it's worth understanding on a high level 01:55:55.040 |
what's happening here. And I could probably spend an entire video of like two hours on this, but 01:55:59.520 |
just a preview of that basically. So this chip here, that is the GPU, this chip is where all 01:56:06.640 |
the calculations happen mostly. But this chip also does have some memory in it. But most of 01:56:13.200 |
the memory by far is here in the high bandwidth memory, HBM, and is connected, they're connected. 01:56:20.240 |
But these are two separate chips, basically. Now, here, this is a zoom in of kind of this 01:56:26.800 |
cartoon diagram of a GPU. And we're seeing here is number one, you see this HBM, I realize it's 01:56:33.680 |
probably very small for you. But on the sides here, it says HBM. And so that's the links to the HBM. 01:56:39.360 |
Now the HBM is, again, off chip. On the chip, there are a large number of these streaming 01:56:45.840 |
multiprocessors. Every one of these is an SM, there's 120 of them in total. And this is where 01:56:52.080 |
a lot of the calculations happen. And this is a zoom in of a single individual SM. It has these 01:56:57.440 |
four quadrants. And see, for example, tensor core, this is where a lot of the matrix multiply stuff 01:57:01.600 |
happens. But there's all these other units to do all different kinds of calculations for FB64, 01:57:07.120 |
FB32, and for integers, and so on. Now, so we have all this logic here to the calculations. 01:57:14.480 |
But in addition to that, on the chip, there is memory sprinkled throughout the chip. 01:57:18.880 |
So L2 cache is some amount of memory that lives on the chip. And then on the SMs themselves, 01:57:25.520 |
there's L1 cache. I realize it's probably very small for you, but this blue bar is L1. 01:57:30.400 |
And there's also registers. And so there is memory stored here. But the way this memory is stored is 01:57:37.840 |
very different from the way memory is stored in HBM. This is a very different implementation using 01:57:44.400 |
just in terms of like what the silicon looks like, it's a very different implementation. 01:57:49.520 |
So here, you would be using transistors and capacitors. And here, it's a very different 01:57:55.280 |
implementation with SRAM and what that looks like. But long story short is there is memory 01:58:04.240 |
inside the chip, but it's not a lot of memory. That's the critical point. So this is an example 01:58:10.320 |
diagram of a slightly different GPU, just like here, where it shows that, for example, typical 01:58:15.760 |
numbers for CPU DRAM memory, which is this thing here, you might have one terabyte of disk, right? 01:58:22.640 |
But it would be extremely expensive to access, especially for a GPU. You have to go through the 01:58:26.320 |
CPU here. Now, next, we have the HBM. So we have tens of gigabytes of HBM memory on a typical GPU 01:58:32.960 |
here, but it's, as I mentioned, very expensive to access. And then on the chip itself, everything is 01:58:39.840 |
extremely fast within the chip, but we only have a couple of 10 megabytes of memory collectively 01:58:45.760 |
throughout the chip. And so there's just not enough space because the memory is very expensive 01:58:50.800 |
on the chip. And so there's not a lot of it, but it is lightning fast to access in relative terms. 01:58:56.000 |
And so basically, whenever we have these kernels, the more accurate picture of what's happening here 01:59:02.240 |
is that we take these inputs, which live by default on the global memory, and now we need 01:59:08.080 |
to perform some calculation. So we start streaming the data from the global memory to the chip. 01:59:15.680 |
We perform the calculations on the chip and then stream it back and store it back to the 01:59:19.840 |
global memory, right? And so if we don't have Torch Compile, we are streaming the data through 01:59:25.600 |
the chip doing the calculations and saving to the memory, and we're doing those round trips many, 01:59:29.520 |
many times. But if it's Torch Compiled, then we start streaming the memory as before, but then 01:59:35.600 |
while we're on the chip, we have a chunk of the data that we're trying to process. So that chunk 01:59:42.880 |
now lives on the chip. While it's on the chip, it's extremely fast to operate on. So if we have 01:59:47.360 |
kernel fusion, we can do all the operations right there in an element-wise fashion, and those are 01:59:52.560 |
very cheap. And then we do a single round trip back to the global memory. So operator fusion 01:59:58.960 |
basically allows you to keep your chunk of data on the chip and do lots of calculations on it 02:00:03.600 |
before you write it back, and that gives huge savings. And that's why Torch Compile ends up 02:00:09.280 |
being a lot faster, or that's one of the major reasons. So again, just a very brief intro to 02:00:15.120 |
the memory hierarchy and roughly what Torch Compile does for you. Now, Torch Compile is 02:00:19.920 |
amazing, but there are operations that Torch Compile will not find. And an amazing example 02:00:25.520 |
of that is FlashAttention, to which we turn next. So FlashAttention comes from this paper from 02:00:30.960 |
Stanford in 2022, and it's this incredible algorithm for performing attention and running 02:00:39.520 |
it a lot faster. So FlashAttention will come here, and we will take out these four lines, 02:00:45.920 |
and FlashAttention implements these four lines really, really quickly. And how does it do that? 02:00:52.880 |
Well, FlashAttention is a kernel fusion operation. So you see here we have in this diagram, 02:01:00.080 |
they're showing PyTorch, and you have these four operations. They're including dropout, 02:01:05.600 |
but we are not using dropout here. So we just have these four lines of code here, 02:01:09.680 |
and instead of those, we are fusing them into a single fused kernel of FlashAttention. 02:01:15.200 |
So it's a kernel fusion algorithm, but it's a kernel fusion that Torch Compile cannot find. 02:01:22.480 |
And the reason that it cannot find it is that it requires an algorithmic rewrite of how attention 02:01:28.160 |
is actually implemented here in this case. And what's remarkable about it is that FlashAttention, 02:01:33.920 |
actually, if you just count the number of flops, FlashAttention does more flops than this attention 02:01:40.080 |
here. But FlashAttention is actually significantly faster. In fact, they cite 7.6 times faster, 02:01:47.120 |
potentially. And that's because it is very mindful of the memory hierarchy, as I described it just 02:01:53.680 |
now. And so it's very mindful about what's in high bandwidth memory, what's in the shared memory, 02:01:59.280 |
and it is very careful with how it orchestrates the computation, such that we have fewer reads 02:02:05.120 |
and writes to the high bandwidth memory. And so even though we're doing more flops, 02:02:09.120 |
the expensive part is their load and store into HBM, and that's what they avoid. And so in 02:02:14.000 |
particular, they do not ever materialize this end-by-end attention matrix, this ATT here. 02:02:20.560 |
FlashAttention is designed such that this matrix never gets materialized at any point, 02:02:25.760 |
and it never gets read or written to the HBM. And this is a very large matrix, right? 02:02:30.960 |
So because this is where all the queries and keys interact, and we're sort of getting 02:02:35.280 |
for each head, for each batch element, we're getting a T-by-T matrix of attention, 02:02:43.120 |
which is a million numbers, even for a single head at a single batch index. 02:02:47.600 |
So basically, this is a ton of memory, and this is never materialized. And the way that 02:02:54.080 |
this is achieved is that basically the fundamental algorithmic rewrite here relies on this online 02:03:00.720 |
softmax trick, which was proposed previously, and I'll show you the paper in a bit. 02:03:04.240 |
And the online softmax trick, coming from a previous paper, shows how you can incrementally 02:03:11.360 |
evaluate a softmax without having to sort of realize all of the inputs to the softmax 02:03:17.360 |
to do the normalization. And you do that by having these intermediate variables m and l, 02:03:22.080 |
and there's an update to them that allows you to evaluate the softmax in an online manner. 02:03:26.560 |
Now FlashAttention, actually, so recently FlashAttention2 came out as well, so I have 02:03:33.120 |
that paper up here as well, that has additional gains to how it calculates FlashAttention. 02:03:38.480 |
And the original paper that this is based on, basically, is this online normalizing calculation 02:03:43.200 |
for softmax. And remarkably, it came out of NVIDIA, and it came out of it like really early, 02:03:48.240 |
2018. So this is four years before FlashAttention. And this paper says that we propose a way to 02:03:55.840 |
compute the classical softmax with fewer memory accesses and hypothesize that this reduction in 02:04:00.240 |
memory accesses should improve softmax performance on actual hardware. And so they are extremely 02:04:06.480 |
correct in this hypothesis, but it's really fascinating to me that they're from NVIDIA, 02:04:11.600 |
and that they had this realization, but they didn't actually take it to the actual FlashAttention 02:04:16.240 |
that had to come four years later from Stanford. So I don't fully understand the historical, 02:04:21.840 |
how this happened historically, but they do basically propose this online update to the softmax 02:04:27.280 |
right here. And this is fundamentally what they reuse here to calculate the softmax in a streaming 02:04:33.360 |
manner. And then they realized that they can actually fuse all the other operations 02:04:36.880 |
with the online softmax calculation into a single fused kernel, FlashAttention, 02:04:42.240 |
and that's what we are about to use. So a great example, I think, of being aware of memory 02:04:48.160 |
hierarchy, the fact that flops don't matter, the entire memory access pattern matters, 02:04:52.960 |
and that TorchCompile is amazing, but there are many optimizations that are still available to us 02:04:57.200 |
that potentially TorchCompile cannot find. Maybe one day it could, but right now it seems like a 02:05:03.200 |
lot to ask. So here's what we're going to do. We're going to use FlashAttention, and the way 02:05:09.200 |
to do that basically in PyTorch is we are going to comment out these four lines, and we're going 02:05:15.280 |
to replace them with a single line. And here we are calling this compound operation in PyTorch 02:05:20.960 |
called scale.productAttention. And PyTorch will call FlashAttention when you use it in this way. 02:05:30.880 |
I'm not actually 100% sure why TorchCompile doesn't realize that these four lines should 02:05:34.720 |
just call FlashAttention in this exact way. We have to do it again for it, which in my opinion 02:05:40.240 |
is a little bit odd, but here we are. So you have to use this compound op, and let's wait for a few 02:05:50.480 |
moments before TorchCompile gets around to it. And then let's remember that we achieved 6.05661. 02:05:58.400 |
I have it here. That's the loss we were expecting to see, and we took 130 milliseconds before this 02:06:04.640 |
change. So we're expecting to see the exact same result by iteration 49, but we expect to see 02:06:11.120 |
faster runtime because FlashAttention is just an algorithmic rewrite, and it's a faster kernel, 02:06:16.240 |
but it doesn't actually change any of the computation, and we should have the exact 02:06:19.120 |
same optimization. So okay, so we're a lot faster. We're at about 95 milliseconds, 02:06:24.720 |
and we achieved 6.058. Okay, so they're basically identical up to a floating-point 02:06:32.640 |
fudge factor. So it's the identical computation, but it's significantly faster going from 130 to 02:06:39.120 |
roughly 96, and so this is 96 divide 130-ish, so this is maybe 27-ish percent improvement. 02:06:50.240 |
So really interesting, and that is FlashAttention. Okay, we are now getting to one of my favorite 02:06:56.560 |
optimizations, and it is simultaneously the dumbest and the most brilliant optimization, 02:07:02.000 |
and it's always a little bit surprising to me. Anyway, so basically I mentioned a few minutes 02:07:07.760 |
ago that there are some numbers that are nice and some numbers that are ugly. So 64 is a beautiful 02:07:15.120 |
nice number. 128 is even nicer. 256 is beautiful. What makes these numbers beautiful is that there 02:07:21.120 |
are many powers of two inside them. You can divide by two many times, and examples of ugly numbers 02:07:27.360 |
are like 13 and 17 and something like that, prime numbers, numbers that are not even, and so on, 02:07:32.560 |
and so pretty much you always want to use nice numbers in all of your code that deals with neural 02:07:36.880 |
networks or CUDA because everything in CUDA works in sort of like powers of two, and lots of kernels 02:07:43.840 |
are written in terms of powers of two, and there are lots of blocks of sizes 16 and 64 and so on. 02:07:50.240 |
So everything is written in those terms, and you always have special case handling for all kinds of 02:07:54.880 |
logic that when your inputs are not made of nice numbers. So let's see what that looks like. 02:08:01.680 |
Basically, scan your code and look for ugly numbers is roughly the heuristic. So three times 02:08:08.000 |
is kind of ugly. I'm not 100% sure maybe this can be improved, but this is ugly and not ideal. 02:08:14.640 |
Four times is nice. So that's nice. 1024 is very nice. That's a power of two. 02:08:24.960 |
12 is a little bit suspicious. Not too many powers of two. 768 is great. 50,257 is a really, 02:08:32.800 |
really ugly number. First of all, it's odd, and there's not too many powers of two in there. 02:08:40.560 |
So this is a very ugly number, and it's highly suspicious. And then when we scroll down, 02:08:46.080 |
all these numbers are nice, and then here we have mostly nice numbers except for 25. 02:08:52.720 |
So in this configuration of GPT-2XL, the number of heads is 25. That's a really ugly number. That's 02:08:57.840 |
an odd number. Actually, this did cause a lot of headaches for us recently when we were trying to 02:09:02.960 |
optimize some kernels to run this fast and required a bunch of special case handling. 02:09:08.320 |
So basically, we have some ugly numbers, and some of them are easier to fix than others. 02:09:14.560 |
In particular, the vocab size being 50,257, that's a very ugly number, very suspicious, 02:09:19.360 |
and we want to fix it. Now, when you fix these things, one of the easy ways to do that is you 02:09:24.560 |
basically increase the number until it's the nearest power of two that you like. So here's 02:09:31.280 |
a much nicer number. It's 50,304. And why is that? Because 50,304 can be divided by 8, or by 16, 02:09:40.000 |
or by 32, 64. It can even be divided by 128, I think. Yeah. So it's a very nice number. 02:09:48.400 |
So what we're going to do here is this is the GPT config, and you see that we initialize 02:09:53.360 |
vocab size to 50,257. Let's override just that element to be 50,304. 02:10:00.960 |
So everything else stays the same. We're just increasing our vocabulary size. 02:10:07.600 |
So it's almost like we're adding fake tokens. So that vocab size has powers of two inside it. 02:10:14.560 |
Now, actually, what I'm doing here, by the way, is I'm increasing the amount of computation 02:10:19.200 |
that our network will be doing. If you just count the flops on like, do the math of how many flops 02:10:23.840 |
we're doing, we're going to be doing more flops. And we still have to think through whether this 02:10:29.680 |
doesn't break anything. But if I just run this, let's see what we get. Currently, this ran and 02:10:35.760 |
maybe 96.5 milliseconds per step. I'm just kind of like eyeballing it. And let's see what kind of 02:10:43.840 |
result we're going to get. While this is compiling, let's think through whether our code actually 02:10:51.280 |
works okay when we increase the vocab size like this. Let's look at where vocab size is actually 02:10:56.000 |
used. So we swing up to the init, and we see that it's used inside the embedding table, of course, 02:11:02.560 |
so all the way at the bottom of the transformer. And it's used at the classifier layer, all the 02:11:06.400 |
way at the top of the transformer, so in two places. And let's take a look. And we're running 02:11:11.440 |
at 93. So 93 milliseconds instead of 96.5. So we are seeing a roughly 4% improvement here 02:11:20.400 |
by doing more calculations. And the reason for this is we've made an ugly number into a nice 02:11:28.400 |
number. I'm going to come into the explanation for that a little bit again. But for now, let's 02:11:33.920 |
just convince ourselves that we're not breaking anything when we do this. So first of all, we've 02:11:37.600 |
made the WTE, the embedding table for the tokens, we've made it larger. It's almost like we 02:11:43.120 |
introduced more tokens at the bottom. And these tokens are never used because the GPT tokenizer 02:11:49.600 |
only has tokens up to 50,256. And so we'll never index into the rows that we've added. So we're 02:11:57.120 |
wasting a little bit of space here by creating memory that's never going to be accessed, never 02:12:01.360 |
going to be used, etc. Now, that's not fully correct, because this WTE weight ends up being 02:12:07.440 |
shared and ends up being used in the classifier here at the end. So what is that doing to the 02:12:12.320 |
classifier right here? Well, what that's doing is we're predicting additional dimensions of the 02:12:17.440 |
classifier now. And we're predicting probabilities for tokens that will, of course, never be present 02:12:22.160 |
in the training set. And so therefore, the network has to learn that these probabilities have to be 02:12:28.880 |
driven to zero. And so the logits that the network produces have to drive those dimensions of the 02:12:34.320 |
output to negative infinity. But that's no different from all the other tokens that are already in our 02:12:39.760 |
data set, or rather that are not in our data set. So Shakespeare only probably uses, let's say 1000 02:12:46.400 |
tokens out of 50,257 tokens. So most of the tokens are already being driven to zero probability by 02:12:52.160 |
the optimization, we've just introduced a few more tokens now that in a similar manner will never be 02:12:57.120 |
used and have to be driven to zero in probability. So functionally, though, nothing breaks, we're 02:13:03.920 |
using a bit more extra memory. But otherwise, this is a harmless operation, as far as I can tell. 02:13:10.400 |
But and we're adding calculation, but it's running faster. And it's running faster, 02:13:15.280 |
because as I mentioned, in CUDA, so many kernels use block tiles, and these block tiles are usually 02:13:22.560 |
nice numbers. So powers of two, so calculations are done in like chunks of 64, or chunks of 32. 02:13:28.720 |
And when you're when your desired calculation doesn't neatly fit into those block tiles, 02:13:34.720 |
there are all kinds of boundary kernels that can kick in to like, do the last part. So basically, 02:13:42.720 |
in a lot of kernels, they will truncate up your input, and they will do the nice part first, 02:13:47.280 |
and then they have a whole second second phase, where they come back to anything that like remains. 02:13:52.640 |
And then they process the remaining part. And the kernels for that can be very inefficient. 02:13:58.000 |
And so you're basically spinning up all this extra compute, and it's extremely inefficient. 02:14:03.680 |
So you might as well pad your inputs and make it fit nicely. And usually that empirically ends up 02:14:10.160 |
actually running faster. So this is another example of a 4% improvement that we've added. 02:14:17.760 |
And this is something that also Torch Compile did not find for us. You would hope that Torch Compile 02:14:22.640 |
at some point could figure an optimization like this out. But for now, this is it. And I also have 02:14:28.320 |
to point out that we're using PyTorch nightly. So that's why we're only seeing 4%. If you're using 02:14:32.880 |
PyTorch 2.3.1, or earlier, you would actually see something like 30% improvement just from this 02:14:38.800 |
change, from changing it from 50,000 to 57,000 to 5,304. So again, one of my favorite examples also 02:14:48.000 |
of having to understand the under the hood and how it all works, and to know what kinds of things to 02:14:52.320 |
tinker with to push the performance of your code. Okay, so at this point, we have improved the 02:14:56.960 |
performance by about 11x, right? Because we started at about 1000 milliseconds per step, 02:15:01.840 |
and we're now down to like 93 milliseconds. So that's quite good. And we're doing a much better 02:15:08.240 |
job of utilizing our GPU resources. So I'm going to now turn to more algorithmic changes and 02:15:14.720 |
improvements to the actual optimization itself. And what we would like to do is we'd like to 02:15:18.640 |
follow the hyperparameters that are mentioned in the GPT-2 or GPT-3 paper. Now, sadly, GPT-2 02:15:25.360 |
doesn't actually say too much. It's very nice of them that they released the model weights and the 02:15:31.120 |
code, but the paper itself is extremely vague as to the optimization details. The code itself that 02:15:36.320 |
they released as well, the code we've been looking at, this is just the inference code. So there's no 02:15:41.520 |
training code here and very few hyperparameters. So this doesn't also tell us too much. So for that, 02:15:46.480 |
we have to turn to the GPT-3 paper. And in the appendix of the GPT-3 paper, they have a lot 02:15:54.400 |
more hyperparameters here for us to use. And the GPT-3 paper in general is a lot more detailed as 02:16:00.240 |
to all the small details that go into the model training, but GPT-3 models were never released. 02:16:08.000 |
So GPT-2, we have the weights, but no details, and GPT-3, we have lots of details, but no weights. 02:16:12.880 |
But roughly speaking, GPT-2 and GPT-3 architectures are very, very similar. 02:16:18.560 |
And basically, there are very few changes. The context length was expanded from 1024 to 2048, 02:16:25.760 |
and that's kind of like the major change. And some of the hyperparameters around the 02:16:29.520 |
transformer have changed. But otherwise, they're pretty much the same model. It's 02:16:32.880 |
just that GPT-3 was trained for a lot longer on a bigger dataset and has a lot more thorough 02:16:37.920 |
evaluations. And the GPT-3 model is 175 billion instead of 1.6 billion in the GPT-2. 02:16:47.440 |
So long story short, we're going to go to GPT-3 paper to follow along some of the hyperparameters. 02:16:52.160 |
So to train all the versions of GPT-3, we use Atom with beta 1, beta 2 of 0.9 and 0.95. 02:16:59.440 |
So let's swing over here and make sure that the betas parameter, which you can see here defaults 02:17:04.880 |
to 0.9 and 0.999, is actually set to 0.9 and 0.95. And then the epsilon parameter, 02:17:13.680 |
you can see is the default is 1 and negative 8, and this is also 1 and negative 8. Let's just 02:17:19.680 |
put it in so that we're explicit. Now, next up, they say we clip the global norm of the gradient 02:17:27.120 |
at 1.0. So what this is referring to is that once we calculate the gradients right after loss dot 02:17:33.120 |
backward, we basically have the gradients at all the parameter tensors. And what people like to do 02:17:40.160 |
is basically clip them to have some kind of a maximum norm. So in PyTorch, this is fairly easy 02:17:46.160 |
to do. It's one line of code here that we have to insert right after we calculate the gradients. 02:17:51.440 |
And what this utility function is doing is it's calculating the global norm of the parameters. 02:17:59.040 |
So every single gradient on all the parameters, you square it and you add it all up and you take 02:18:05.120 |
a big square root of that. And that's the norm of the parameter vector, basically. It's the 02:18:12.800 |
length of it, if you if you'd like to look at it that way. And we are basically making sure that 02:18:17.040 |
its length is no more than 1.0. And we're going to clip it. And the reason that people like to 02:18:22.880 |
use this is that sometimes you can get unlucky during the optimization. Maybe it's a bad data 02:18:27.760 |
batch or something like that. And if you get very unlucky in the batch, you might get really high 02:18:32.560 |
loss and really high loss could lead to a really high gradient. And this could basically shock your 02:18:38.880 |
model and shock the optimization. So people like to use a gradient norm clipping to prevent the 02:18:45.680 |
model from basically getting too big of shocks in terms of the gradient magnitude and the upper 02:18:52.880 |
bounded in this way. It's a bit of a hacky solution. It's got like a patch on top of like 02:18:57.840 |
deeper issues, but people still do it fairly frequently. Now, the clip grad norm returns 02:19:04.400 |
the norm of the gradient, which I like to always visualize because it is useful information. And 02:19:11.440 |
sometimes you can look at the norm of the gradient and if it's well behaved, things are good. If it's 02:19:16.560 |
climbing, things are bad and they're destabilizing during training. Sometimes you could get a spike 02:19:21.040 |
in the norm, and that means there's some kind of an issue or instability. So the norm here will be 02:19:27.360 |
a norm. And let's do a 0.4F or something like that. And I believe this is just a float. 02:19:36.720 |
And so we should be able to print that. So that's global gradient clipping. 02:19:43.840 |
Now they go into the details of the learning rate scheduler. So they don't just use a fixed 02:19:50.560 |
learning rate like we do here for 3E negative 4, but there's actually basically a cosine decay 02:19:56.640 |
learning rate schedule. It's got a warmup and it's got a cosine decay to 10% over some horizon. 02:20:04.800 |
And so we're going to implement this in a second. I just like to see the norm printed here. Okay, 02:20:12.720 |
there we go. So what happened here is the norm is actually really high in the beginning, 02:20:17.760 |
30 or so. And you see that as we continue training, it kind of like stabilizes 02:20:24.800 |
at values below one. And this is not that crazy uncommon for the norm to be high in the very first 02:20:31.440 |
few stages. Basically what's happening here is the model is completely random. And so there's 02:20:35.520 |
a ton of learning happening very early in the network, but that learning is kind of like, 02:20:39.520 |
you know, it's mostly learning the biases of the output tokens. And so it's a bit of an unstable 02:20:45.920 |
time, but the network usually stabilizes in the very few iterations. So this looks relatively 02:20:50.880 |
reasonable to me, except usually I would expect this looks a little bit funky that we go from 02:20:55.360 |
28 to 6 to 2 and then to 10. It's not completely insane, but it's just kind of a little bit funky. 02:21:02.080 |
Okay, so let's now get to the learning rate scheduler. So the learning rate schedule 02:21:07.760 |
that's used here in GPT-3 is what's called a cosine decay learning schedule with warmup. 02:21:14.240 |
And the way this looks is that the learning rate is basically starts right at around zero, 02:21:19.840 |
linearly ramps up over some amount of time, and then comes down with this cosine sort of form and 02:21:27.120 |
comes down to some kind of a minimum learning rate that's up to you. So here the minimum learning 02:21:30.720 |
rate is zero. But here in the paper, they said that they use cosine decay for learning rate down 02:21:36.640 |
to 10% of its value over the first 260 billion tokens. And then training continues 10% after. 02:21:44.800 |
And there's a linear warmup over the first 375 million tokens. So that's about the learning 02:21:50.560 |
rate. So let's now implement this. So I already implemented it here. And the way this works is, 02:21:56.560 |
let me scroll down first here. I changed our training loop a little bit. So this was a 4i 02:22:02.400 |
in max steps. I just change it to step now so that we have the notion of a step as a single 02:22:07.360 |
optimization step in the for loop. And then here, I get the LR for this step of the optimization 02:22:14.960 |
using a new function I call getLR. And then in PyTorch to set the learning rate, I think this 02:22:20.400 |
is the way to set the learning rate. It's a little bit gnarly. Because you have to basically there's 02:22:25.120 |
a notion of different parameter groups that could exist in the optimizer. And so you actually have 02:22:29.200 |
to iterate over them, even though we currently have a single param group only. And you have 02:22:34.080 |
to set the LR in this for loop kind of style, is my impression right now. So we have this local 02:22:40.400 |
LR, we set the learning rate, and then on the bottom, I'm also printing it. So that's all the 02:22:46.240 |
changes I made to this loop. And then of course, the getLR is my scheduler. Now it's worth pointing 02:22:51.200 |
out that PyTorch actually has learning rate schedulers, and you can use them. And I believe 02:22:55.680 |
there's a cosine learning rate schedule in PyTorch. I just don't really love using that 02:23:00.880 |
code, because honestly, it's like five lines of code. And I fully understand what's happening 02:23:06.720 |
inside these lines. So I don't love to use abstractions where they're kind of inscrutable, 02:23:12.160 |
and then I don't know what they're doing. So personal style. So the max learning rate here 02:23:17.360 |
is let's say 3e negative 4. But we're going to see that in GPT-3 here, they have a table of what the 02:23:25.280 |
maximum learning rate is for every model size. So for this one, basically 12 layer 768 GPT-3. 02:23:36.560 |
So the GPT-3 small is roughly like a GPT-2 124M. We see that here they use a learning rate of 6e 02:23:43.280 |
negative 4. So we could actually go higher. In fact, we may want to try to follow that and just 02:23:48.000 |
set the maximal R here at 6. Then that's the maximum learning rate. The min learning rate is 02:23:54.400 |
10% of that per description in the paper, some number of steps that we're going to warm up over, 02:24:01.600 |
and then the maximum steps of the optimization, which I now use also in the for loop down here. 02:24:06.080 |
And then you can go over this code if you like. It's not terribly inside floor interesting. 02:24:12.240 |
I'm just modulating based on the iteration number which learning rate there should be. 02:24:17.600 |
So this is the warmup region. This is the region after the optimization. And then this is the 02:24:24.240 |
region sort of in between. And this is where I calculate the cosine learning rate schedule. 02:24:28.960 |
And you can step through this in detail if you'd like. But this is basically implementing this 02:24:33.040 |
curve. And I ran this already. And this is what that looks like. So when we now run, we start at 02:24:45.600 |
some very low number. Now, note that we don't start exactly at zero because that would be not 02:24:49.760 |
useful to update with a learning rate of zero. That's why there's an it plus one, so that on 02:24:54.320 |
the zeroth iteration, we are not using exactly zero. We're using something very, very low. 02:24:58.640 |
Then we linearly warm up to maximum learning rate, which in this case was 3e negative 4 when I ran 02:25:04.320 |
it. But now it would be 6e negative 4. And then it starts to decay all the way down to 3e negative 5, 02:25:12.640 |
which was at the time 10% of the original learning rate. Now, one thing we are not 02:25:16.880 |
following exactly is that they mentioned that -- let me see if I can find it again. 02:25:22.640 |
We're not exactly following what they did because 02:25:26.960 |
they mentioned that their training horizon is 300 billion tokens. And they come down to 10% 02:25:34.400 |
of the initial learning rate at 260 billion. And then they train after 260 with 10%. So basically, 02:25:41.440 |
their decay time is less than the max steps time, whereas for us, they're exactly equal. 02:25:46.560 |
So it's not exactly faithful, but it's an okay -- this is okay for us and for our purposes right 02:25:54.240 |
now. And we're just going to use this ourselves. I don't think it makes too big of a difference, 02:26:00.400 |
honestly. I should point out that what learning rate schedule you use is totally up to you. 02:26:05.440 |
There's many different types. Cosine learning rate has been popularized a lot by GPT-2 and 02:26:10.640 |
GPT-3, but people have come up with all kinds of other learning rate schedules. And this is 02:26:15.600 |
kind of like an active area of research as to which one is the most effective at training 02:26:20.320 |
these networks. Okay, next up, the paper talks about the gradual batch size increase. So there's 02:26:27.040 |
a ramp on the batch size that is linear. And you start with very small batch size, and you ramp up 02:26:32.400 |
to a big batch size over time. We're going to actually skip this, and we're not going to work 02:26:36.880 |
with it. And the reason I don't love to use it is that it complicates a lot of the arithmetic, 02:26:41.760 |
because you are changing the number of tokens that you're processing at every single step of 02:26:45.200 |
the optimization. And I like to keep that math very, very simple. Also, my understanding is 02:26:50.080 |
that this is not a major improvement. And also, my understanding is that this is not an algorithmic 02:26:56.560 |
optimization improvement. It's more of a systems and speed improvement. And roughly speaking, 02:27:01.360 |
this is because in the early stages of the optimization, again, the model is in a very 02:27:08.080 |
atypical setting. And mostly what you're learning is that you're mostly learning to ignore the 02:27:14.000 |
tokens that don't come up in your training set very often. You're learning very simple biases and 02:27:19.280 |
that kind of a thing. And so every single example that you put through your network 02:27:25.520 |
is basically just telling you, use these tokens and don't use these tokens. And so the gradients 02:27:30.080 |
from every single example are actually extremely highly correlated. They all look roughly the same 02:27:35.040 |
in the original parts of the optimization, because they're all just telling you that 02:27:39.040 |
these tokens don't appear and these tokens do appear. And so because the gradients are all 02:27:44.240 |
very similar, and they're highly correlated, then why are you doing batch sizes of millions, 02:27:49.280 |
when if you do a batch size of 32k, you're basically getting the exact same gradient 02:27:53.760 |
early on in the training. And then later in the optimization, once you've learned all the simple 02:27:59.040 |
stuff, that's where the actual work starts. And that's where the gradients become more 02:28:02.800 |
decorrelated per examples. And that's where they actually offer you sort of statistical power, 02:28:07.760 |
in some sense. So we're going to skip this just because it kind of complicates things. 02:28:13.200 |
And we're going to go to data are sampled without replacement during training. 02:28:19.520 |
So until an epoch boundary is reached. So without replacement means that they're not sampling from 02:28:25.600 |
some fixed pool, and then take a sequence, train on it, but then also like return to sequence the 02:28:32.480 |
pool, they are exhausting a pool. So when they draw a sequence, it's it's gone until the next 02:28:37.920 |
epoch of training. So we're already doing that because our data loader iterates over chunks of 02:28:44.560 |
data. So there's no replacement, they don't become eligible to be drawn again until the next epoch. 02:28:50.640 |
So we're basically already doing that. All models use a weight decay of point one to 02:28:57.280 |
provide a small amount of regularization. So let's implement a weight decay. And you see 02:29:02.400 |
here that I've already kind of made the changes. And in particular, instead of creating the 02:29:06.400 |
optimizer right here, I'm creating a new configure optimizers function inside the model. And I'm 02:29:13.680 |
passing in some of the hyper parameters instead. So let's look at the configure optimizers, 02:29:18.080 |
which is supposed to return the optimizer object. Okay, so it looks complicated, but it's actually 02:29:27.680 |
really simple. And it's just, we're just being very careful. And there's a few settings here 02:29:32.400 |
to go through. The most important thing with respect to this line is that you see there's 02:29:36.640 |
a weight decay parameter here. And I'm passing that into well, I'm passing that into something 02:29:44.960 |
called optim_groups that eventually ends up going into the add_mw optimizer. And the weight decay 02:29:50.720 |
that's by default used in add_mw here is 0.01. So it's 10 times lower than what's used in GPT-3 02:29:58.160 |
paper here. So the weight decay basically ends up making its way into the add_mw3 optimizer groups. 02:30:05.200 |
Now what else is going on here in this function? So the two things that are happening here that 02:30:09.600 |
are important is that I'm splitting up the parameters into those that should be weight 02:30:13.920 |
decayed and those that should not be weight decayed. So in particular, it is common to not 02:30:18.960 |
weight decay biases and any other sort of one-dimensional tensors. So the one-dimensional 02:30:25.840 |
tensors are in the node decay parameters. And these are also things like layer norm, 02:30:31.760 |
scales, and biases. It doesn't really make sense to weight decay those. You mostly want to weight 02:30:36.240 |
decay the weights that participate in matrix multiplications. And you want to potentially 02:30:41.920 |
weight decay the embeddings. And we've covered in a previous video why it makes sense to decay 02:30:47.760 |
the weights, because you can sort of think of it as a regularization. Because when you're pulling 02:30:52.000 |
down on all the weights, you're forcing the optimization to use more of the weights. And 02:30:57.200 |
you're not allowing any one of the weights individually to be way too large. You're 02:31:02.560 |
forcing the network to kind of distribute the work across more channels, because there's sort 02:31:07.600 |
of like a pull of gravity on the weights themselves. So that's why we are separating it 02:31:14.400 |
in those ways here. We're only decaying the embeddings and the matmul participating weights. 02:31:20.000 |
We're printing the number of parameters that we're decaying and not. Most of the parameters 02:31:25.120 |
will be decayed. And then one more thing that we're doing here is I'm doing another optimization here. 02:31:31.920 |
And previous AdamW did not have this option, but later parts of PyTorch introduced it. 02:31:37.840 |
And that's why I'm guarding it with an inspect.signature, which is basically checking 02:31:42.320 |
if this fused quark is present inside AdamW. And then if it is present, I'm going to end up using 02:31:50.560 |
it and passing it in here. Because some earlier versions do not have fused equals. So here's 02:31:57.520 |
AdamW fused equals. It did not used to exist and it was added later. And there's some docs here for 02:32:03.440 |
what's happening. And basically they say that by default, they do not use fused because it is 02:32:09.440 |
relatively new and we want to give it sufficient bake time. So by default, they don't use fused, 02:32:13.600 |
but fused is a lot faster when it is available and when you're running on CUDA. And what that 02:32:18.720 |
does is instead of iterating in a for loop over all the parameter tensors and updating them, 02:32:25.040 |
that would launch a lot of kernels, right? And so fused just means that all those kernels are 02:32:30.640 |
fused into a single kernel. You get rid of a lot of overhead and you a single time on all the 02:32:36.160 |
parameters call a kernel that updates them. And so it's just basically a kernel fusion for the 02:32:43.440 |
AdamW update instead of iterating over all the tensors. So that's the configure optimizers 02:32:49.680 |
function that I like to use. And we can rerun and we're not going to see any major differences from 02:32:55.120 |
what we saw before, but we are going to see some prints coming from here. So let's just take a look 02:33:00.400 |
at what they look like. So we see that number of decay tensors is 50 and it's most of the primers 02:33:07.440 |
and number of non-decay tensors is 98. And these are the biases and the layer norm parameters 02:33:11.840 |
mostly. And that's, there's only a hundred thousand of those. So most of it is decayed. 02:33:17.520 |
And then we are using the fused implementation of AdamW, which will be a lot faster. So if you 02:33:22.480 |
have it available, I would advise you to use it. I'm not actually a hundred percent sure why they 02:33:26.720 |
don't default to it. It seems fairly benign and harmless. And also because we are using the fused 02:33:31.920 |
implementation, I think this is why we have dropped, notice that the running time used to 02:33:38.000 |
be 93 milliseconds per step. And we're now down to 90 milliseconds per step because of using the 02:33:43.200 |
fused AdamW optimizer. So in a single commit here, we are introducing fused Adam, getting 02:33:50.240 |
improvements on the time, and we're adding or changing the weight decay, but we're only weight 02:33:56.080 |
decaying the two-dimensional parameters, the embeddings, and the matrices that participate 02:34:00.320 |
in the linear. So that is this, and we can take this out. And yeah, that is it for this line. 02:34:09.440 |
One more quick note before we continue here. I just want to point out that the relationship between 02:34:14.000 |
weight decay, learning rate, batch size, the Adam parameters, beta1, beta2, the epsilon, and so on, 02:34:19.920 |
these are very complicated mathematical relationships in the optimization literature. 02:34:25.120 |
And for the most part, in this video, I'm just trying to copy paste the settings that OpenAI 02:34:30.720 |
used. But this is a complicated topic, quite deep. And yeah, in this video, I just want to copy the 02:34:36.880 |
parameters because it's a whole different video to really talk about that in detail and give it 02:34:41.040 |
a proper justice instead of just high-level intuitions. Now, the next thing that I want 02:34:45.600 |
to move on to is that this paragraph here, by the way, we're going to turn back around to when we 02:34:51.360 |
improve our data loader. For now, I want to swing back around to this table, 02:35:02.000 |
where you will notice that for different models, we, of course, have different hyperparameters 02:35:07.760 |
for the transformer that dictate the size of the transformer network. We also have a different 02:35:12.160 |
learning rate. So we're seeing the pattern that the bigger networks are trained with slightly 02:35:15.600 |
lower learning rates. And we also see this batch size, where in the small networks, 02:35:21.840 |
they use a smaller batch size, and in the bigger networks, they use a bigger batch size. 02:35:26.240 |
Now, the problem for us is we can't just use 0.5 million batch size, 02:35:30.960 |
because if I just try to come in here and I try to set this B, where's my B? 02:35:41.840 |
Where do I call the data loader? Okay, B equals 16. If I try to set... 02:35:48.720 |
Well, we have to be careful. It's not 0.5 million, because this is the batch size 02:35:54.960 |
in the number of tokens. Every single one of our rows is 1,024 tokens. So 0.5E6, 02:36:01.840 |
1 million divide 1,024. This would need about a 488 batch size. So the problem is I can't 02:36:09.440 |
come in here and set this to 488, because my GPU would explode. This would not fit for sure. 02:36:16.400 |
But we still want to use this batch size, because again, as I mentioned, the batch size is correlated 02:36:23.600 |
with all the other optimization hyperparameters and the learning rates and so on. So we want to 02:36:28.160 |
have a faithful representation of all the hyperparameters, and therefore we need to 02:36:31.760 |
use a batch size of 0.5 million, roughly. But the question is, how do we use 0.5 million if 02:36:38.720 |
we only have a small GPU? Well, for that, we need to use what's called gradient accumulation. 02:36:43.120 |
So we're going to turn to that next, and it allows us to simulate in a serial way 02:36:48.400 |
any arbitrary batch size that we set. And so we can do a batch size of 0.5 million. We just have 02:36:54.400 |
to run longer, and we have to process multiple sequences and basically add up all the gradients 02:37:00.720 |
from them to simulate a batch size of 0.5 million. So let's turn to that next. 02:37:05.440 |
Okay, so I started the implementation right here just by adding these lines of code. 02:37:08.800 |
And basically what I did is first I set the total batch size that we desire. So this is 02:37:14.880 |
exactly 0.5 million, and I used a nice number, a power of 2, because 2 to the 19 is 524288, 02:37:22.560 |
so it's roughly 0.5 million. It's a nice number. Now, our micro-batch size, as we call it now, 02:37:28.000 |
is 16. So this is going to be -- we still have B by T indices that go into the transformer and do 02:37:35.120 |
forward-backward, but we're not going to do an update, right? We're going to do many forward- 02:37:38.960 |
backwards. We're going to -- and those gradients are all going to plus equals on the parameter 02:37:43.920 |
gradients. They're all going to add up. So we're going to do forward-backward grad-accum-steps 02:37:48.960 |
number of times, and then we're going to do a single update once all that is accumulated. 02:37:53.200 |
So in particular, our micro-batch size is just now controlling how many tokens, 02:37:58.960 |
how many rows we're processing in a single go of a forward-backward. 02:38:01.840 |
So here we are doing 16 times 124. We're doing 16384 02:38:11.440 |
tokens per forward-backward, and we are supposed to be doing 02:38:15.120 |
2 to the 19 -- whoops, what am I doing -- 2 to the 19 in total, so the grad-accum will be 32. 02:38:24.000 |
So therefore, grad-accum here will work out to 32, and we have to do 32 forward-backward 02:38:32.560 |
and then a single update. Now, we see that we have about 100 milliseconds for a single 02:38:38.320 |
forward-backward, so doing 32 of them will be -- will make every step roughly three seconds, 02:38:43.680 |
just napkin math. So that's grad-accum-steps, but now we actually have to implement that. 02:38:50.080 |
So we're going to swing over to our training loop, because now this part here 02:38:55.920 |
and this part here, the forward and the backward, we have to now repeat this 02:39:01.200 |
32 times before we do everything else that follows. So let's see how we can implement that. 02:39:07.520 |
So let's come over here, and actually, we do have to load a new batch every single time, 02:39:11.520 |
so let me move that over here, and now this is where we have the inner loop. 02:39:15.280 |
So for micro-step in range grad-accum-steps, we do this. And remember that last-step-backward 02:39:24.720 |
always deposits gradients, so we're doing -- inside last-step-backward, there's always a 02:39:28.320 |
plus-equals on the gradients. So in every single last-step-backward, gradients will add up on the 02:39:33.840 |
gradient tensors. So we last-step-backward, and then we get all the gradients over there, 02:39:41.120 |
and then we normalize, and everything else should just follow. So we're very close, 02:39:47.840 |
but actually, there's a subtle and deep issue here, and this is actually incorrect. 02:39:53.280 |
So I invite you to think about why this is not yet sufficient, and let me fix it then. 02:39:59.520 |
Okay, so I brought back the Jupyter Notebook, so we can think about this carefully 02:40:03.280 |
in a simple toy setting and see what's happening. So let's create a very simple 02:40:07.840 |
neural net that takes a 16 -- vector of 16 numbers and returns a single number. 02:40:12.320 |
And then here, I'm creating some random examples x and some targets y, and then we are using the 02:40:21.520 |
mean-squared loss here to calculate the loss. So basically, what this is, is four individual 02:40:28.640 |
examples, and we're just doing simple regression with the mean-squared loss over those four 02:40:33.600 |
examples. Now, when we calculate the loss and we last-step-backward and look at the gradient, 02:40:39.360 |
this is the gradient that we achieve. Now, the loss objective here -- notice that in MSC loss, 02:40:45.600 |
the default for the loss function is reduction is mean. So we're calculating the average mean loss 02:40:52.800 |
here over the four examples. So this is the exact loss objective, and this is the average, 02:41:01.920 |
the 1/4, because there are four independent examples here. And then we have the four 02:41:07.280 |
examples and their mean-squared error -- the squared error, and then this makes it the 02:41:11.760 |
mean-squared error. So therefore, we calculate the squared error and then we normalize it to 02:41:18.240 |
make it the mean over the examples, and there's four examples here. So now, when we come to the 02:41:22.880 |
gradient accumulation version of it, this here is the gradient accumulation version of it, 02:41:31.200 |
where we have grad-account steps of four, and I reset the gradient, with grad-account steps of 02:41:36.320 |
four, and now I'm evaluating all the examples individually instead and calling last-step-backward 02:41:41.280 |
on them many times, and then we're looking at the gradient that we achieve from that. 02:41:44.640 |
So basically, now we forward our function, calculate the exact same loss, do a backward, 02:41:50.640 |
and we do that four times, and when we look at the gradient, you'll notice that the gradients 02:41:56.000 |
don't match. So here we did a single batch of four, and here we did four gradient accumulation 02:42:03.600 |
steps of batch size one, and the gradients are not the same. And basically, the reason that they're 02:42:09.600 |
not the same is exactly because this mean squared error gets lost. This one quarter in this loss 02:42:15.600 |
gets lost, because what happens here is the loss objective for every one of the loops 02:42:21.760 |
is just a mean squared error, which in this case, because there's only a single example, 02:42:26.800 |
is just this term here. So that was the loss in the zeroth iteration, the same in the first, 02:42:31.520 |
third, and so on. And then when you do the last-step-backward, we're accumulating gradients, 02:42:37.200 |
and what happens is that accumulation in the gradient is basically equivalent 02:42:42.160 |
to doing a sum in the loss. So our loss actually here is this without the factor of one quarter 02:42:51.840 |
outside of it. So we're missing the normalizer, and therefore our gradients are off. And so the 02:42:57.360 |
way to fix this, or one of them, is basically we can actually come here and we can say loss equals 02:43:02.480 |
loss divide four. And what happens now is that we're scaling our loss, we're introducing a one 02:43:10.720 |
quarter in front of all of these places. So all the individual losses are now scaled by one quarter, 02:43:18.160 |
and then when we backward, all of these accumulate with a sum, but now there's a one quarter inside 02:43:24.880 |
every one of these components, and now our losses will be equivalent. So when I run this, 02:43:31.680 |
you see that the gradients are now identical. So long story short, with this simple example, 02:43:37.600 |
when you step through it, you can see that basically the reason that this is not correct 02:43:42.160 |
is because in the same way as here in the MSC loss, the loss that we're calculating here 02:43:48.720 |
in the model is using a reduction of mean as well. So where is the loss? F dot cross entropy. 02:43:59.200 |
And by default, the reduction here in cross entropy is also, I don't know why they don't 02:44:04.000 |
show it, but it's the mean loss at all the b by t elements, right? So there's a reduction 02:44:12.880 |
by mean in there, and if we're just doing this gradient accumulation here, we're missing that. 02:44:17.040 |
And so the way to fix this is to simply compensate for the number of gradient accumulation steps, 02:44:22.480 |
and we can in the same way divide this loss. So in particular here, the number of steps that 02:44:26.720 |
we're doing is loss equals loss divided gradient accumulation steps. So even Copilot gets the 02:44:34.800 |
modification. But in the same way exactly, we are scaling down the loss so that when we do 02:44:40.160 |
loss step backward, which basically corresponds to a sum in the objective, we are summing up 02:44:45.040 |
the already normalized loss. And therefore, when we sum up the losses divided by grad-accum steps, 02:44:52.160 |
we are recovering the additional normalizer. And so now these two will be, now this will be 02:44:58.960 |
equivalent to the original sort of optimization, because the gradient will come out the same. 02:45:03.680 |
Okay, so I had to do a few more touch-ups, and I launched the optimization here. So in particular, 02:45:09.840 |
one thing we want to do, because we want to print things nicely, is, well, first of all, we need to 02:45:14.880 |
create like an accumulator over the loss. We can't just print the loss, because we'd be printing only 02:45:18.960 |
the final loss at the final microstep. So instead, we have loss-accum, which I initialized at zero, 02:45:24.800 |
and then I accumulate the loss into it. And I'm using detach so that I'm detaching the tensor 02:45:32.560 |
from the graph, and I'm just trying to keep track of the values. So I'm making these leaf nodes when 02:45:38.720 |
I add them. So that's loss-accum, and then we're printing that here instead of loss. 02:45:43.840 |
And then in addition to that, I had to account for the grad-accum steps inside the tokens processed, 02:45:49.120 |
because now the tokens processed per step is b times t times gradient accumulation. 02:45:53.920 |
So long story short, here we have the optimization. It looks reasonable, right? We're starting at a 02:46:00.800 |
good spot. We calculated the grad-accum steps to be 32, and we're getting about three seconds here, 02:46:07.280 |
right? And so this looks pretty good. Now, if you'd like to verify that your optimization and 02:46:17.440 |
the implementation here is correct and you're working on a side, well, now because we have 02:46:21.120 |
the total batch size and the gradient accumulation steps, our setting of b is purely a performance 02:46:26.640 |
optimization kind of setting. So if you have a big GPU, you can actually increase this to 32, 02:46:32.080 |
and you'll probably go a bit faster. If you have a very small GPU, you can try 8 or 4. 02:46:36.800 |
But in any case, you should be getting the exact same optimization and the same answers 02:46:40.720 |
up to a floating point error, because the gradient accumulation kicks in and can handle everything 02:46:47.360 |
serially as necessary. So that's it for gradient accumulation, I think. 02:46:52.880 |
Okay, so now is the time to bring out the heavy weapons. You've noticed that so far, 02:46:57.040 |
we've only been using a single GPU for training. But actually, I am paying for eight GPUs here, 02:47:02.560 |
and so we should be putting all of them to work. And in particular, they're all going to collaborate 02:47:07.360 |
and optimize over tokens at the same time and communicate so that they're all kind of 02:47:15.520 |
collaborating on the optimization. For this, we are going to be using the distributed data parallel 02:47:20.160 |
from PyTorch. There's also a legacy data parallel, which I recommend you not use, 02:47:24.320 |
and that's kind of like legacy. Distributed data parallel works in a very simple way. 02:47:29.680 |
We have eight GPUs, so we're going to launch eight processes, and each process is going to be 02:47:36.240 |
assigned a GPU. And for each process, the training loop and everything we've worked on so far is 02:47:41.520 |
going to look pretty much the same. Each GPU, as far as it's concerned, is just working on exactly 02:47:46.480 |
what we've built so far. But now secretly, there's eight of them, and they're all going to be 02:47:51.280 |
processing slightly different parts of the data. And we're going to add one more part, where once 02:47:57.440 |
they all calculate their gradients, there's one more part where we do an average of those gradients. 02:48:02.720 |
And so that's how they're going to be collaborating on the computational workload here. 02:48:08.720 |
So to use all eight of them, we're not going to be launching our script anymore with just 02:48:14.000 |
PyTorch-train-gpt2.py. We're going to be running it with a special command called "torch run" 02:48:20.720 |
in PyTorch. We'll see that in a bit. And torch run, when it runs our Python script, 02:48:26.720 |
will actually make sure to run eight of them in parallel. And it creates these environmental 02:48:33.120 |
variables where each of these processes can look up basically which one of the processes it is. 02:48:40.400 |
So for example, torch run will set rank, local rank, and world size, environmental variables. 02:48:46.320 |
And so this is a bad way to detect whether DDP is running. So if we're using torch run, 02:48:52.960 |
if DDP is running, then we have to make sure that CUDA is available, 02:48:57.920 |
because I don't know that you can run this on CPU anymore, or that that makes sense to do. 02:49:02.320 |
This is some setup code here. The important part is that there's a world size, which for us will 02:49:10.800 |
be eight. That's the total number of processes running. There's a rank, which is each process 02:49:16.960 |
will basically run the exact same code at the exact same time, roughly. But the only difference 02:49:23.920 |
between these processes is that they all have a different DDP rank. So the GPU 0 will have DDP 02:49:30.960 |
rank of 0, GPU 1 will have rank of 1, etc. So otherwise, they're all running the exact same 02:49:37.440 |
script. It's just that DDP rank will be a slightly different integer. And that is the way for us to 02:49:42.800 |
coordinate that they don't, for example, run on the same data. We want them to run on different 02:49:47.360 |
parts of the data, and so on. Now, local rank is something that is only used in a multi-node 02:49:53.600 |
setting. We only have a single node with eight GPUs. And so local rank is the rank of the GPU 02:49:59.920 |
on a single node. So from 0 to 7, as an example. But for us, we're mostly going to be running on 02:50:06.640 |
a single box. So the things we care about are rank and world size. This is 8, and this will be 02:50:12.400 |
whatever it is, depending on the GPU, that this particular instantiation of the script runs on. 02:50:18.400 |
Now, here, we make sure that according to the local rank, we are setting the device 02:50:27.200 |
to be CUDA colon. And colon indicates which GPU to use if there are more than one GPUs. So 02:50:34.640 |
depending on the local rank of this process, it's going to use just the appropriate GPU. So there's 02:50:40.640 |
no collisions on which GPU is being used by which process. And finally, there's a boolean variable 02:50:45.840 |
that I like to create, which is the DDP rank equals equals zero. So the master process is 02:50:51.840 |
arbitrarily process number zero, and it does a lot of the printing, logging, checkpointing, etc. 02:50:57.040 |
And the other processes are thought of mostly as compute processes that are assisting. 02:51:01.280 |
And so master process zero will have some additional work to do. All the other processes 02:51:05.680 |
will almost just be doing forward and backwards. And if we're not using DDP, and none of these 02:51:10.640 |
variables are set, we revert back to single GPU training. So that means that we only have rank 02:51:15.440 |
zero, the world size is just one. And we are the master process. And we try to auto detect the 02:51:22.400 |
device. And this is world as normal. So so far, all we've done is we've initialized DDP. 02:51:28.480 |
And in the case where we're running with Torch run, which we'll see in a bit, 02:51:34.000 |
there's going to be eight copies running in parallel, each one of them will have a different 02:51:37.520 |
rank. And now we have to make sure that everything happens correctly afterwards. 02:51:42.880 |
So the tricky thing with running multiple processes is you always have to imagine that there's going 02:51:48.400 |
to be eight processes running in parallel. So as you read the code, now you have to imagine there's 02:51:54.240 |
eight, you know, eight Python interpreters running down these lines of code. And the only difference 02:52:00.160 |
between them is that they have a different DDP rank. So they all come here, they all pick the 02:52:05.120 |
exact same seed, they all make all of these calculations, completely unaware of the other 02:52:10.160 |
copies running, roughly speaking, right. So they all make the exact same calculations. And now we 02:52:15.680 |
have to adjust these calculations to take into account that there's actually like a certain 02:52:20.960 |
world size and certain ranks. So in particular, these micro batches and sequence links, these are 02:52:27.280 |
all just per GPU, right. So now there's going to be num processes of them running in parallel. 02:52:33.600 |
So we have to adjust this, right, because the Gradacom steps now is going to be total batch 02:52:39.040 |
size divided by B times T times DDP world size, because each process will do B times T, and 02:52:49.600 |
there's this many of them. And so in addition to that, we want to make sure that this fits nicely 02:52:56.800 |
into total batch size, which for us it will because 16 times 124 times eight GPUs is 131K. 02:53:04.480 |
And so 524288, this means that our Gradacom will be four with the current settings, right. So 02:53:13.680 |
there's going to be 16 times 124 processes in each GPU, and then there's eight GPUs. So we're 02:53:18.960 |
going to be doing 131,000 tokens in a single forward backward on the eight GPUs. 02:53:26.240 |
So we want to make sure that this fits nicely so that we can derive a nice gradient accumulation 02:53:31.520 |
steps. And yeah, let's just adjust the comments here times DDP world size. Okay. So each GPU 02:53:43.600 |
calculates this. Now this is where we start to get run into issues, right. So we are each process 02:53:49.280 |
going to come by a print, and they're all going to print. So we're going to have eight copies 02:53:54.240 |
of these prints. So one way to deal with this is exactly this master process variable that we have. 02:53:59.360 |
So if master process, then guard this. And that's just so that we just print this a single time, 02:54:05.760 |
because otherwise, all the processes would have computed the exact same variables. And there's 02:54:09.600 |
no need to print this eight times. Before getting into the data loader, and we're going to have to 02:54:15.520 |
refactor it, obviously, maybe at this point is we should do some prints and just take it out for a 02:54:22.560 |
spin and exit at this point. So import sys and sys.exit and print IMGPU DDP rank. IMGPU DDP rank 02:54:40.960 |
and print by. So now let's try to run this and just see how this works. So let's take it for a 02:54:51.600 |
spin just so we see what it looks like. So normally we used to launch Python train gpt2.py like this. 02:54:57.440 |
Now we're going to run with Torch run, and this is what it looks like. So Torch run standalone, 02:55:02.400 |
number of processes, for example, is eight for us because we have eight GPUs, 02:55:05.840 |
and then train gpt2.py. So this is what the command would look like. And Torch run again, 02:55:12.480 |
we'll run eight of these. So let's just see what happens. So first, it gets a little busy. So 02:55:19.600 |
there's a lot going on here. So first of all, there's some warnings from distributed, and I 02:55:24.000 |
don't actually know that these mean anything. I think this is just like, the code is setting up 02:55:28.640 |
and the processes are coming online. And we're seeing some preliminary failure to collect while 02:55:33.440 |
the processes come up. I'm not 100% sure about that. But we start to then get into actual prints. 02:55:39.920 |
So all the processes went down. And then the first print actually comes from process five, 02:55:48.720 |
just by chance. And then it printed. So process five basically got here first, 02:55:53.360 |
it said on process on GPU five by, and then this these prints come from the master process. 02:56:00.800 |
So process five just finished first, for whatever reason, it just depends on how 02:56:05.760 |
the operating system scheduled the processes to run. Then GPU zero ended, then GPU three and two. 02:56:11.600 |
And then probably process five or something like that has exited. 02:56:18.080 |
And DDP really doesn't like that, because we didn't properly dispose of the multi GPUs setting. 02:56:26.240 |
And so process group has not been destroyed before we destruct. So it really doesn't like that. And 02:56:32.640 |
in an actual application, we would want to call destroy process group, so that we clean up DDP 02:56:38.320 |
properly. And so it doesn't like that too much. And then the rest of the GPUs finish. And that's 02:56:44.320 |
it. So basically, we can't guarantee when these processes are running is totally arbitrary, 02:56:48.640 |
but they are running in parallel, we don't want that to be printing. And next up, let's erase this. 02:56:56.080 |
Next up, we want to make sure that when we create data loader light, we need to now make it aware 02:57:02.160 |
of this multi process setting, because we don't want all the processes to be loading the exact 02:57:08.800 |
same data, we want every process to get its own chunk of data, so that they're all working on 02:57:13.360 |
different parts of the data set, of course. So let's adjust that. So one particularly simple 02:57:19.040 |
and a naive way to do this is we have to make sure that we pass in the rank and the size 02:57:23.840 |
to the data loader. And then we come up here, we see that we now take rank and processes and we 02:57:30.160 |
save them. Now, the current position will not be zero. Because what we want is we want to stride 02:57:37.200 |
out all the processes. So one way to do this is we basically take self.b times self.t, and then 02:57:44.080 |
multiply it by the process rank. So process rank zero will start at zero, but process rank one 02:57:51.520 |
now starts at b times t process rank two is starts at two times b times t, etc. So that is the 02:57:58.080 |
initialization. Now we still they still do this identically. But now when we advance, we don't 02:58:05.200 |
advance by b times t, we advance by b times t times number of processes. Right? So basically, 02:58:12.400 |
the total number of tokens that we're consuming is b times t times numProcesses. And they all go 02:58:20.240 |
off to a different rank. And the position has to advance by the entire chunk. And then here at b 02:58:28.240 |
times t times self.numProcesses plus one would be to exceed number of tokens, then we're going to 02:58:35.120 |
loop. And when we loop, we want to of course, loop in the exact same way. So we sort of like reset 02:58:40.800 |
back. So this is the simplest change that I can find for kind of a very simple distributed data 02:58:47.920 |
loader like. And you can notice that if process rank is zero, and numProcesses is one, then the 02:58:54.320 |
whole thing will be identical to what we had before. But now we can have actually multiple 02:58:58.080 |
processes running and this should work fine. So that's the data loader. Okay, so next up, 02:59:06.320 |
once they've all initialized the data loader, they come here and they all create a GPT model. 02:59:11.280 |
So we create eight GPT models on eight processes. But because the seeds are fixed here, 02:59:17.360 |
they all create the same identical model. They all move it to the device of their rank, 02:59:22.560 |
and they all compile the model. And because the models are identical, there are eight identical 02:59:27.200 |
compilations happening in parallel, but that's okay. Now, none of this changes because that is 02:59:33.040 |
on a per step basis. And we're currently working kind of within step because we need to just all 02:59:39.520 |
the all the changes we're making are kind of like a within step changes. Now, the important thing 02:59:44.400 |
here is when we construct the model, we actually have a bit of work to do here. GetLogits is 02:59:49.680 |
deprecated. So create model. We need to actually wrap the model into the distributed data parallel 02:59:57.600 |
container. So this is how we wrap the model into the DDP container. And these are the docs for 03:00:05.440 |
DDP. And they're quite extensive. And there's a lot of caveats and a lot of things to be careful 03:00:10.000 |
with because everything complexifies times 10 when multiple processes are involved. But roughly 03:00:15.920 |
speaking, this device IDs I believe has to be passed in. Now, unfortunately, the docs for what 03:00:20.320 |
device IDs is, is extremely unclear. So when you actually like come here, this comment for what 03:00:27.360 |
device IDs is, is roughly nonsensical. But I'm pretty sure it's supposed to be the DDP local 03:00:34.400 |
rank. So not the DDP rank, the local rank. So this is what you pass in here. This wraps the model. 03:00:41.840 |
And in particular, what DDP does for you is in a forward pass, it actually behaves identically. So 03:00:46.880 |
my understanding of it is nothing should be changed in the forward pass. 03:00:51.120 |
But in the backward pass, as you are doing the backward pass, in the simplest setting, 03:00:56.720 |
once the backward pass is over on each independent GPU, each independent GPU has the gradient for 03:01:03.200 |
all the parameters. And what DDP does for you is once the backward pass is over, it will call 03:01:09.040 |
what's called all reduce. And it basically does an average across all the ranks of their gradients. 03:01:16.720 |
And then it will deposit that average on every single rank. So every single rank will end up 03:01:23.040 |
with the average on it. And so basically, that's the communication, it just synchronizes and 03:01:27.920 |
averages the gradients. And that's what DDP offers you. Now, DDP actually is a little bit more, 03:01:32.320 |
is a little bit more involved than that, because as you are doing the backward pass through the 03:01:38.240 |
layers of the transformer, it actually can dispatch communications for the gradient while 03:01:43.360 |
the backward pass is still happening. So there's overlap of the communication of the gradients and 03:01:48.160 |
the synchronization of them and the backward pass. And this is just more efficient to do it that way. 03:01:55.680 |
So that's what DDP does for you. Forward is unchanged, and backward is mostly unchanged. 03:02:02.000 |
And we're tacking on this average, as we'll see in a bit. Okay, so now let's go to the optimization. 03:02:09.120 |
Nothing here changes. Let's go to the optimization here, the inner loop, and think through the 03:02:13.680 |
synchronization of these gradients in the DDP. So basically, by default, what happens, as I 03:02:18.800 |
mentioned, is when you do loss dot backward here, it will do the backward pass, and then it will 03:02:23.760 |
synchronize the gradients. The problem here is because of the gradient accumulation steps loop 03:02:30.320 |
here, we don't actually want to do the synchronization after every single loss dot backward, 03:02:36.560 |
because we are just depositing gradients, and we're doing that serially, and we just want them 03:02:40.800 |
adding up, and we don't want to synchronize every single time. That would be extremely wasteful. 03:02:44.960 |
So basically, we want to add them up, and then on the very last — it's only on the very last step, 03:02:50.880 |
when microstep becomes grad-accum steps minus one, only at that last step do we want to actually do 03:02:57.840 |
the all-reduce to average up the gradients. So to do that, we come here, and the official 03:03:05.440 |
sanctioned way, by the way, is to do this no-sync context manager. So PyTorch says this is a context 03:03:12.560 |
manager to disable gradient synchronization across DDP processes. So within this context, 03:03:17.280 |
gradients will be accumulated, and basically, when you do no-sync, there will be no communication. 03:03:23.760 |
So they are telling us to do, with DDP no-sync, do the gradient accumulation, accumulate grads, 03:03:30.080 |
and then they are asking us to do DDP again with another input and dot backward. And I just really 03:03:35.600 |
don't love this. I just really don't like it, the fact that you have to copy-paste your code here 03:03:40.480 |
and use a context manager, and this is just super ugly. So when I went to the source code here, 03:03:44.720 |
you can see that when you enter, you simply toggle this variable, 03:03:50.880 |
this require backward grad sync, and this is being toggled around and changed. And this is the 03:03:58.560 |
variable that basically, if you step through it, is being toggled to determine if the gradient is 03:04:05.200 |
going to be synchronized. So I actually just kind of like to use that directly. So instead, what I 03:04:10.960 |
like to do is the following. Right here, before the last dot backward, if we are using DDP, then 03:04:19.680 |
we only want to synchronize, we only want this variable to be true when it is the final iteration. 03:04:29.600 |
In all the other iterations inside the microsteps, we want it to be false. So I just toggle it like 03:04:35.440 |
this. So require backward grad sync should only turn on when the microstep is the last step. 03:04:40.720 |
And so I'm toggling this variable directly, and I hope that that impacts last dot backward, 03:04:48.080 |
and this is a naughty thing to do because they could probably change the DDP and this variable 03:04:52.640 |
will go away. But for now, I believe this works, and it allows me to avoid the use of context 03:04:58.160 |
managers and code duplication. I'm just toggling the variable, and then last dot backward will not 03:05:02.560 |
synchronize most of the steps, and it will synchronize the very last step. And so once 03:05:06.480 |
this is over, and we come out, every single rank will suddenly magically have the average 03:05:17.120 |
of all the gradients that were stored on all the ranks. So now we have to think through whether 03:05:22.800 |
that is what we want, and also if this suffices, and how it works with the loss, 03:05:30.160 |
and what is loss_accum. So let's think through that now. And the problem I'm getting at is that 03:05:35.440 |
we've averaged the gradients, which is great, but the loss_accum has not been impacted yet, 03:05:41.200 |
and this is outside of the DDP container, so that is not being averaged. And so here, 03:05:47.360 |
when we are printing loss_accum, well, presumably we're only going to be printing on the master 03:05:51.680 |
process, rank 0, and it's just going to be printing the losses that it saw on its process. 03:05:56.480 |
But instead, we want it to print the loss over all the processes and the average of that loss, 03:06:02.240 |
because we did average of gradients, so we want the average of loss as well. 03:06:06.240 |
So simply here, after this, this is the code that I've used in the past, 03:06:11.200 |
and instead of loss_eff, we want loss_accum. So if DDP, again, then dist is a PyTorch distributed. 03:06:22.320 |
I import it. Where do I import it? Oh, gosh. So this file is starting to get out of control, huh? 03:06:32.080 |
So import torch.dist, so dist.all_reduce, and we're doing the average on loss_accum. 03:06:41.920 |
And so this loss_accum tensor exists on all the ranks. When we call all_reduce of average, 03:06:46.960 |
it creates the average of those numbers, and it deposits that average on all the ranks. 03:06:52.000 |
So all the ranks after this call will now contain loss_accum averaged up. 03:06:58.880 |
And so when we print here on the master process, 03:07:01.120 |
the loss_accum is identical in all the other ranks as well. 03:07:03.280 |
So here, if master_process, oops, we want to print like this. 03:07:09.360 |
Okay, and finally, we have to be careful, because we're not processing even more tokens. 03:07:14.880 |
So times DDP_world_size, that's the number of tokens that we've processed up above. 03:07:27.760 |
The only other thing to be careful with is, as I mentioned, you want to destroy the process group 03:07:32.160 |
so that we are nice to nickel, and it's not going to DDP, 03:07:36.160 |
and it's not going to complain to us when we exit here. 03:07:39.600 |
So that should be it. Let's try to take it for a spin. 03:07:43.840 |
Okay, so I launched the script, and it should be printing here imminently. 03:07:47.840 |
We're now training with eight GPUs at the same time. 03:07:50.320 |
So the gradient accumulation steps is not 32, it is now divide 8, and it's just 4. 03:07:56.000 |
So otherwise, this is what the optimization now looks like. 03:08:02.320 |
So we're processing 1.5 million tokens per second now. 03:08:10.720 |
And the tiny Shakespeare dataset is so tiny that we're just doing like 03:08:17.840 |
One thing that I had to fix, by the way, is that this was model.configure_optimizers, 03:08:24.320 |
which now doesn't work because model now is a DDP model. 03:08:27.440 |
So instead, this has to become raw_model.configure_optimizers, 03:08:35.120 |
So right after I wrap the model into DDP, I have to create the raw_model, 03:08:43.280 |
is where it stores the raw NM module of GPT-2 as we have it, 03:08:48.320 |
which contains the configure_optimizers function that we want to call. 03:08:55.920 |
Now, one thing you'll notice is that when you actually compare this run 03:08:59.200 |
and the numbers in it to just running a single GPU, 03:09:01.920 |
you'll notice that this is a single GPU run with 32 Gradacom. 03:09:08.560 |
And it's kind of a boring reason for why that happens. 03:09:12.800 |
The reason for that is that in the data loader, 03:09:15.360 |
we're basically just iterating through batches in a slightly different way, 03:09:18.320 |
because now we're looking for an entire page of data. 03:09:24.240 |
if that chunk exceeds the number of tokens, we just loop. 03:09:27.920 |
And so actually the single GPU and the GPU process 03:09:31.280 |
will end up resetting in a slightly different manner. 03:09:38.640 |
But one way to convince yourself that this is okay 03:09:42.480 |
is just make the total batch size much smaller and the B and a T. 03:09:46.240 |
And then so I think I used 4 times 124 times 8. 03:10:01.600 |
And then you're reducing the boundary effects of the data loader. 03:10:07.120 |
So long story short, we're now going really, really fast. 03:10:10.560 |
The optimization is mostly consistent with GPT-2 and 3 hybrid parameters. 03:10:14.960 |
And we have outgrown our tiny Shakespeare file. 03:10:21.760 |
So let's now take a look at what data sets were used by GPT-2 and GPT-3. 03:10:25.040 |
So GPT-2 used this web text data set that was never released. 03:10:29.520 |
There's an attempt at reproducing it called open web text. 03:10:33.840 |
So basically, roughly speaking, what they say here in the paper is that 03:10:42.560 |
And that was kind of like their starting point. 03:10:44.160 |
And they collected all the web pages and all the text in them. 03:10:49.760 |
And this ended up being 40 gigabytes of text. 03:10:51.600 |
So that's roughly what GPT-2 says about its data set. 03:10:57.200 |
So it's basically outbound links from Reddit. 03:10:59.040 |
Now, when we go over to GPT-3, there's a training data set section. 03:11:03.360 |
And that's where they start to talk about Common Crawl, which is a lot more used. 03:11:08.720 |
Actually, I think even GPT-2 talked about Common Crawl. 03:11:13.600 |
But basically, it's not a very high quality data set all by itself, 03:11:18.000 |
This is a completely random subset of the internet. 03:11:22.000 |
So people go into great lengths to filter Common Crawl, 03:11:26.480 |
But most of it is just like ad spam, random tables and numbers and stock tickers. 03:11:34.320 |
So that's why people like to train on these data mixtures 03:11:42.880 |
So a large chunk of these data mixtures typically will be Common Crawl. 03:11:46.400 |
Like, for example, 50% of the tokens will be Common Crawl. 03:11:49.440 |
But then here in GPT-3, they're also using WebText2 from before. 03:11:57.840 |
There's many other things you can decide to add. 03:11:59.840 |
Now, this data set for GPT-3 was also never released. 03:12:03.680 |
So today, some of the data sets that I'm familiar with that are quite good 03:12:06.640 |
and would be representative of something along these lines 03:12:12.400 |
Or more specifically, for example, the Slim Pajama subset of the Red Pajama data set, 03:12:17.520 |
which is a cleaned and deduplicated version of it. 03:12:20.160 |
And just to give you a sense, again, it's a bunch of Common Crawl. 03:12:23.440 |
C4, which is also, as far as I know, more Common Crawl, but processed differently. 03:12:29.200 |
And then we have GitHub, Books, Archive, Wikipedia, StackExchange. 03:12:33.600 |
These are the kinds of data sets that would go into these data mixtures. 03:12:36.160 |
Now, specifically the one that I like that came out recently 03:12:41.520 |
So this is an attempt to basically collect really high-quality Common Crawl data 03:12:47.760 |
and filter it, in this case, to 15 trillion tokens. 03:12:50.320 |
And then in addition to that, more recently, Hugging Face released this 03:12:54.080 |
FineWebEDU subset, which is 1.3 trillion of educational 03:12:59.120 |
and 5.4 trillion of high-educational content. 03:13:02.080 |
So basically, they're trying to filter Common Crawl 03:13:14.240 |
And they go into a ton of detail about how they process the data, 03:13:16.960 |
which is really fascinating reading, by the way. 03:13:19.040 |
And I would definitely recommend, if you're interested 03:13:20.880 |
into data mixtures and so on, and how data gets processed at these scales, 03:13:25.760 |
And more specifically, we'll be working with the FineWebEDU, I think. 03:13:29.920 |
And it's basically educational content from the internet. 03:13:35.200 |
They show that training on educational content in their metrics 03:13:41.920 |
And we're going to use this sample 10 billion tokens subsample of it. 03:13:48.640 |
Because we're not going to be training on trillions of tokens. 03:13:51.040 |
We're just going to train on a 10 billion sample of the FineWebEDU. 03:13:56.160 |
Because empirically, in my previous few experiments, 03:13:58.800 |
this actually suffices to really get close to GPT-2 performance. 03:14:07.600 |
So our goal will be to download it, process it, 03:14:11.440 |
and make sure that our data loader can work with it. 03:14:19.440 |
that will basically download FineWebEDU from Hugging Face datasets. 03:14:23.920 |
It will pre-process and pre-tokenize all of the data. 03:14:27.360 |
And it will save data shards to a folder on a local disk. 03:14:34.160 |
And so while this is running, I just wanted to briefly mention 03:14:39.040 |
that you can kind of look through the dataset viewer here 03:14:44.400 |
I mean, it basically looks like it's working fairly well. 03:14:48.080 |
Like it's talking about nuclear energy in France. 03:14:50.080 |
It's talking about Mexican America, some Mac Pi Js, et cetera. 03:14:57.200 |
So actually, it seems like their filters are working pretty well. 03:14:59.840 |
The filters here, by the way, were applied automatically 03:15:06.320 |
And so basically, LLMs are judging which content is educational 03:15:10.800 |
and that ends up making it through the filter. 03:15:17.840 |
because it's not as interesting and not as LLM-centric. 03:15:21.520 |
But when you run this, basically, number one, 03:15:25.520 |
which this is all Hugging Face code running this. 03:15:28.160 |
You're going to need to pip install datasets. 03:15:33.200 |
Then it is tokenizing all of the documents inside this dataset. 03:15:39.600 |
you'll notice that to tokenize a single document, 03:15:43.200 |
we first start the tokens with the end of text token. 03:15:48.480 |
And this is a special token in the GPT-2 tokenizer, as you know. 03:15:58.640 |
But this is the first token that begins a document. 03:16:01.520 |
Then we extend with all of the tokens of that document. 03:16:07.840 |
We make sure that all the tokens are between... 03:16:16.720 |
It just had to do with me using a float division in Python. 03:16:20.960 |
so that this is an int and everything is nice. 03:16:25.520 |
Okay, but basically, the tokenization here is relatively straightforward. 03:16:31.840 |
We're using un.16 to save a little bit of space 03:16:39.120 |
So the GPT-2 max token ID is well below that. 03:16:42.080 |
And then here, there's a bunch of multiprocessing code. 03:16:48.320 |
But we're loading the dataset, we're tokenizing it, 03:16:58.640 |
which is very, very similar to Torch tensors. 03:17:02.000 |
And the first shard, 000, is a validation shard. 03:17:07.120 |
And all the other shards are training shards. 03:17:10.080 |
And as I mentioned, they all have 100 million tokens in them exactly. 03:17:13.760 |
And that just makes it easier to work with, to shard the files. 03:17:20.320 |
Because if we just have a single massive file, 03:17:22.080 |
sometimes they can be hard to work with on the disk. 03:17:24.720 |
And so sharding it is just kind of a messier from that perspective. 03:17:36.400 |
And then we're gonna come back to actually train on this data. 03:17:39.040 |
And we're gonna be actually doing some legit pre-training in this case. 03:17:48.080 |
And so we're actually gonna be doing a serious training run. 03:17:54.000 |
So if we ls edu find_web, we see that there's now 100 shards in it. 03:17:59.760 |
And that makes sense because each shard is 100 million tokens. 03:18:04.400 |
So 100 shards of that is 10 billion tokens in total. 03:18:09.760 |
I made some adjustments to our data loader again. 03:18:12.480 |
And that's because we're not running with Shakespeare anymore. 03:18:18.640 |
And so you'll see some code here that additionally basically can load these shards. 03:18:29.520 |
which is what a lot of the layers up top expect by default. 03:18:33.040 |
And then here, we're just enumerating all the shards. 03:18:38.800 |
So we can load the split train, but also the split val, the zero split. 03:18:46.080 |
And then here, we also have not just the current position now, 03:18:53.680 |
And then when we run out of tokens in a single shard, 03:18:57.120 |
we first advance the shard and loop if we need to. 03:19:00.560 |
And then we get the tokens and readjust the position. 03:19:03.360 |
So this data loader will now iterate all the shards as well. 03:19:09.600 |
And then the other thing that I did while the data was processing 03:19:13.520 |
is our train loader now has split train, of course. 03:19:33.440 |
because that's how many unique tokens we have. 03:19:36.400 |
So if we did 10 billion tokens, then divide that by 2 to the 19, 03:19:44.560 |
And then the GPT-3 paper says that they warm up the learning rate over 375 million tokens. 03:19:50.800 |
So I came here and 375E6 tokens divide 2 to the 19 is 715 steps. 03:20:03.040 |
So this will exactly match the warm up schedule that GPT-3 used. 03:20:11.600 |
And this could be made significantly more aggressive. 03:20:17.520 |
Let's leave it for now so that we have the exact hyperparameters of GPT-3. 03:20:32.000 |
And actually, sorry, let me do one more thing. 03:20:39.360 |
For my GPU, I can actually fit more batch size. 03:20:44.240 |
And I believe I can fit 64 on my GPU as a micro-batch size. 03:21:01.280 |
So that means we would not even be doing gradient accumulation if this fits. 03:21:05.040 |
Because this just multiplies out to the full total batch size. 03:21:12.080 |
And that would run pretty quickly if that fits. 03:21:27.700 |
I mean, if this works, then this is basically a serious pre-training run. 03:21:37.680 |
So it's not-- we haven't crossed our Ts and dotted our Is. 03:21:41.040 |
But if we let this run for a while, we're going to actually get a pretty good model. 03:21:46.400 |
And the model that might even be on par with or better than GPT-124M. 03:21:55.360 |
We're processing 1.5 million tokens per second. 03:22:06.880 |
And we have to do a total of-- where are we printing that? 03:22:12.800 |
So 19073 times 0.33 is this many seconds, this many minutes. 03:22:28.400 |
And we don't even have to use gradient accumulation, which is nice. 03:22:31.760 |
And you might not have that luxury in your GPU. 03:22:34.160 |
In that case, just start decreasing the batch size until things fit. 03:22:42.240 |
We're currently warming up the learning rate. 03:22:46.640 |
So this will ramp up over the next few steps all the way to 6e-4 here. 03:22:55.280 |
So now what I'd like to do is let's cross the Ts and dot our Is. 03:23:00.880 |
And let's try to figure out how we can run evals, how we can do logging, 03:23:04.640 |
how we can visualize our losses, and all the good stuff. 03:23:07.840 |
So let's get to that before we actually do the run. 03:23:10.960 |
OK, so I've adjusted the code so that we're evaluating on the validation split. 03:23:14.720 |
So creating the val loader just by passing in split equals val, 03:23:18.320 |
that will basically create a data loader just for the validation shard. 03:23:21.920 |
The other thing I did is in the data loader, I introduced a new function reset, 03:23:31.600 |
And that is very useful because when we come to the main training loop now-- 03:23:37.920 |
And basically, every 100th iteration, including the 0th iteration, 03:23:50.400 |
We're going to basically accumulate the gradients over, say, 20 steps. 03:23:56.720 |
And then average it all up and print out the validation loss. 03:23:59.200 |
And so that basically is the exact same logic as the training loop, roughly. 03:24:14.320 |
And so this will print the validation loss every 100th iteration, 03:24:22.160 |
That will tell us a little bit about how much we're overfitting. 03:24:29.760 |
So we're mostly expecting our train and val loss to be about the same. 03:24:32.880 |
But the other reason I'm interested in this is because we can take the GPT-2-124M 03:24:40.800 |
And we can basically see what kind of loss it achieves on the validation loss as well. 03:24:44.480 |
And that gives us an indication as to how much that model would generalize to 124M. 03:24:49.840 |
But it's not-- sorry, to fine web EDU validation split. 03:24:54.080 |
That said, it's not a super fair comparison to GPT-2 03:24:56.640 |
because it was trained on a very different data distribution. 03:24:58.960 |
But it's still kind of like an interesting data point. 03:25:00.960 |
And in any case, you would always want to have a validation split in a training run like this 03:25:06.800 |
so that you can make sure that you are not overfitting. 03:25:10.960 |
And this is especially a concern if we were to make more epochs in our training data. 03:25:15.440 |
So for example, right now, we're just doing a single epoch. 03:25:19.120 |
But if we get to a point where we want to train on 10 epochs or something like that, 03:25:22.640 |
we would be really careful with-- maybe we are memorizing that data too much 03:25:28.560 |
And our validation split would be one way to tell whether that is happening. 03:25:32.800 |
OK, and in addition to that, if you remember, at the bottom of our script, 03:25:36.000 |
we had all of this orphaned code for sampling from way back when. 03:25:46.640 |
Once in a while, we sample-- we generate samples. 03:25:57.360 |
And I've been running this for 1,000 iterations. 03:26:10.640 |
And languages file you're learning about here is-- or is the beginning of a computer. 03:26:14.480 |
OK, so this is all pretty-- there's still a garble. 03:26:22.800 |
And we've only just barely reached the maximum learning rate. 03:26:27.280 |
We're about to get some more samples coming up in 1,100. 03:26:31.280 |
OK, this is-- the model is still a young baby. 03:26:39.680 |
OK, so basically, all of this sampling code that I've put here, 03:26:45.280 |
everything should be familiar to you and came from before. 03:26:48.080 |
The only thing that I did is I created a generator object in PyTorch 03:26:51.840 |
so that I have a direct control over the sampling of the random numbers. 03:26:55.920 |
Because I don't want to impact the RNG state of the random number generator 03:27:02.800 |
I want this to be completely outside of the training loop. 03:27:08.240 |
And then I make sure to seed it, that every single rank has a different seed. 03:27:13.280 |
And then I pass in here, where we sort of consume random numbers in multinomial, 03:27:19.760 |
I make sure to pass in the generator object there. 03:27:23.440 |
Now, the other thing is you'll notice that we're running a bit slower. 03:27:28.160 |
That's because I actually had to disable torch.compile to get this to sample. 03:27:35.120 |
So for some reason, it works with no torch.compile. 03:27:37.040 |
But when I torch.compile my model, I get a really scary error from PyTorch. 03:27:40.800 |
And I have no idea how to resolve it right now. 03:27:42.800 |
So probably by the time you see this code released or something like that, maybe it's fixed. 03:27:56.720 |
By the way, I will be releasing all this code. 03:28:00.480 |
And actually, I've been very careful about making git commits every time we add something. 03:28:04.960 |
And so I'm going to release the entire repo that starts completely from scratch, 03:28:12.560 |
And so everything should be exactly documented in the git commit history. 03:28:18.720 |
So hopefully, by the time you go to GitHub, this is removed and it's working. 03:28:25.600 |
And it's stepping and we're on step 6,000 or so. 03:28:30.960 |
Now, while this is training, I would like to introduce one evaluation 03:28:34.000 |
that we're going to use to supplement the validation set. 03:28:39.840 |
So Hellaswag comes from this paper back in 2019. 03:28:44.800 |
And the way Hellaswag works is there's basically a sentence completion data set. 03:28:51.520 |
For every one of these questions, we have basically a shared context, 03:28:56.080 |
like a woman is outside with a bucket and a dog. 03:28:58.960 |
The dog is running around trying to avoid bath. 03:29:01.600 |
She, A, rinses the bucket off with soap and blow dry the dog's head. 03:29:07.360 |
B, uses a hose to keep it from getting soapy. 03:29:15.920 |
And so basically, the idea is that these multiple choice are constructed 03:29:21.040 |
so that one of them is a natural continuation of the sentence and the others are not. 03:29:29.120 |
And the others might not make sense, like uses the hose to keep it from getting soapy. 03:29:35.840 |
And so what happens is that models that are not trained very well 03:29:41.520 |
But models that have a lot of world knowledge and can tell a lot about the world 03:29:50.960 |
And these sentences are sourced from ActivityNet and from Wikihow. 03:29:55.760 |
And at the bottom of the paper, there's kind of like a cool chart 03:30:05.280 |
So there's a lot of sentences from computers and electronics and homes and garden. 03:30:09.840 |
And it has kind of a broad coverage of the kinds of things 03:30:13.040 |
you need to know about the world in order to find the most likely completion 03:30:21.120 |
One more thing that's kind of interesting about Hellaswag is the way it was constructed 03:30:25.920 |
is that the incorrect options are deliberately adversarially sourced. 03:30:36.400 |
They're actually sentences generated by language models. 03:30:39.120 |
And they're generated in such a way that language models basically find them difficult, 03:30:45.040 |
And so they mentioned that humans have a 95% accuracy on this set. 03:30:48.800 |
But at the time, the state-of-the-art language models had only 48%. 03:30:52.000 |
And so at the time, this was a good benchmark. 03:30:54.800 |
Now, you can read the details of this paper to learn more. 03:30:59.200 |
The thing to point out, though, is that this is five years ago. 03:31:02.880 |
And since then, what happened to Hellaswag is that it's been totally just solved. 03:31:13.040 |
So basically, the last 4% is probably errors in the data set, 03:31:19.360 |
And so basically, this data set is kind of crushed with respect to language models. 03:31:22.560 |
But back then, the best language model was only at about 50%. 03:31:29.440 |
But still, the reason people like Hellaswag, and it's not used, by the way, in GPT-2, 03:31:40.480 |
And so for GPT-3, we have results here that are cited. 03:31:47.440 |
So we know what percent accuracies GPT-3 attains 03:31:51.200 |
at all these different model checkpoints for Hellaswag eval. 03:31:54.320 |
And the reason people like it is because Hellaswag is a smooth eval. 03:31:59.360 |
And it is an eval that offers, quote, unquote, early signal. 03:32:02.640 |
So early signal means that even small language models 03:32:06.320 |
are going to start at the random chance of 25%. 03:32:11.120 |
And you're going to see 25, 26, 27, et cetera. 03:32:13.920 |
And you can see slow improvement, even when the models are very small, and it's very early. 03:32:29.440 |
Now, the way that we're going to evaluate this is as follows. 03:32:38.240 |
And this is kind of like a multiple choice task. 03:32:40.560 |
But instead of giving the model a multiple choice question 03:32:43.280 |
and asking it for A, B, C, or D, we can't do that. 03:32:46.960 |
Because these models, when they are so small, as we are seeing here, 03:32:50.320 |
the models can't actually do multiple choice. 03:32:52.320 |
They don't understand the concept of associating a label 03:32:59.120 |
So we have to give it to them in a native form. 03:33:05.520 |
We construct a batch of four rows and T tokens, whatever that T happens to be. 03:33:11.440 |
Then the shared context, that is basically the context for the four choices, 03:33:16.400 |
the tokens of that are shared across all of the rows. 03:33:33.040 |
Now, these options might be of different lengths. 03:33:37.200 |
So what we do is we sort of like take the longest length. 03:33:42.400 |
And then some of these here are going to be padded dimensions. 03:33:52.560 |
And we need a mask that tells us which tokens are active. 03:33:56.480 |
And the mask is then 0 for these padded areas. 03:34:02.960 |
And then in order to get the language model to predict A, B, C, or D, 03:34:07.280 |
the way this works is basically we're just going to look at the tokens, 03:34:12.320 |
And we're going to pick the option that gets the lowest 03:34:16.080 |
or the highest average probability for the token. 03:34:21.120 |
So for the tokens, because that is the most likely completion 03:34:27.200 |
So we're just going to look at the probabilities here 03:34:35.040 |
and pick the one with the highest probability, roughly speaking. 03:34:40.960 |
And this is, I believe, also how GPT-3 did it. 03:34:51.680 |
But you should note that some of the other evals 03:34:53.920 |
where you might see Hellaswag may not do it this way. 03:34:58.560 |
where you sort of give the context a single time 03:35:03.680 |
And so the model is able to see all the four options 03:35:08.480 |
And that's actually an easier task for a model 03:35:10.880 |
because you get to see the other options when you're picking your choice. 03:35:14.240 |
But unfortunately, models at our size can't do that. 03:35:17.760 |
Only models at a bigger size are able to do that. 03:35:20.640 |
And so our models are actually slightly handicapped in this way 03:35:24.240 |
that they are not going to see the other options. 03:35:26.160 |
They're only going to see one option at a time 03:35:30.880 |
and the correct option has to win out in this metric. 03:35:33.200 |
All right, so let's now implement this very briefly 03:35:38.400 |
Okay, so what I've done here is I've introduced a new file 03:35:41.600 |
called Hellaswag.py that you can take a look into. 03:35:50.880 |
It's kind of like a little bit tedious, honestly, 03:35:52.880 |
because what's happening is I'm downloading Hellaswag from GitHub 03:36:02.640 |
And so here at the end of this render example function, 03:36:12.080 |
The tokens of this four by T array of tokens, 03:36:17.920 |
the mask, which tells us which parts are the options 03:36:25.360 |
And so that allows us to then iterate the examples 03:36:36.480 |
And basically just calculates, just as I described, 03:36:48.080 |
is we can basically evaluate the cross entropy loss. 03:36:59.200 |
And that's the option that we pick as the prediction. 03:37:04.960 |
And then we do some stats and prints and stuff like that. 03:37:09.840 |
Now, if you go up here, I'm showing that for GPT-2-124m, 03:37:15.520 |
you're going to see that Hellaswag gets 29.55%. 03:37:40.960 |
And then there's one more thing called Eleuther Harness, 03:37:44.000 |
which is a very common piece of infrastructure 03:37:50.160 |
And I'm not 100% sure what the discrepancy is for these. 03:37:52.720 |
It could be that they actually do the multiple choice 03:38:07.920 |
And so that is the number that we'd like to beat 03:38:10.000 |
if we were training AGP-2124M from scratch in ourselves. 03:38:13.600 |
So now I'm going to go into actually incorporating 03:38:23.600 |
And basically, because we want to evaluate it 03:38:28.880 |
so that we can track Hellaswag and how it evolves over time 03:38:32.000 |
and see when and if we cross this 29.55 region. 03:38:39.840 |
So let's now walk through some of the changes 03:38:44.960 |
is I actually made useCompile optional, kind of. 03:38:54.080 |
is that unfortunately, it does make our code faster. 03:39:01.840 |
So hopefully, by the time you get to the code base 03:39:07.360 |
But for now, I'm running without TorchCompile, 03:39:19.520 |
which will record the train loss, validation loss, 03:39:31.520 |
I created a simple variable that helps tell us 03:39:37.920 |
And then basically, periodically inside this loop, 03:40:07.920 |
And so you should recognize this as our ancient code 03:40:14.080 |
And then finally here, these are, if we're not, 03:40:18.720 |
after we validate, sample, and evaluate Hellaswag, 03:40:27.760 |
And you should be pretty familiar with all of what this does. 03:40:30.320 |
And at the end here, once we get our training loss, 03:40:34.480 |
So the only thing that changed that I really added 03:40:42.240 |
all the GPUs to collaborate on the Hellaswag. 03:40:46.640 |
And then each process only picks the examples 03:40:53.200 |
So we sort of take i and mod it by the world size. 03:40:58.480 |
And then we render an example, put it on a GPU. 03:41:03.920 |
Then I create a helper function that helps us basically 03:41:11.120 |
And then if it's correct, we sort of keep count. 03:41:13.840 |
And then if multiple processes were collaborating 03:41:17.040 |
on all of this, then we need to synchronize their stats. 03:41:19.920 |
And so the one way to do that is to package up 03:41:25.600 |
Which we can then call this.allReduceOn and sum. 03:41:28.640 |
And then here we sort of unwrap them from tensors 03:41:52.240 |
And this is step 10,000 out of about 20,000, right? 03:41:56.800 |
And these are the kinds of samples that we are getting 03:42:03.840 |
So I'd like to use it to generate some kinds of output. 03:42:18.080 |
But certainly, the predictions are getting less and less random. 03:42:20.960 |
It seems like the model is a little bit more self-aware 03:42:24.000 |
and using language that is a bit more specific to it 03:42:32.560 |
And like the model, I'm going to use it to generate 03:42:45.200 |
And like how the language is used to communicate, 03:42:48.560 |
And I'm going to be speaking English and German. 03:42:52.960 |
So let's just wait until this optimization finishes. 03:42:57.760 |
And we're also going to look at the train, val, 03:43:03.120 |
and see how we're doing with respect to GPT-2. 03:43:07.840 |
So focusing for a moment on the Jupyter Notebook 03:43:12.000 |
I created a new cell that basically allows us 03:43:14.480 |
to visualize the train, val, and the hellosquare. 03:43:20.960 |
It basically like parses the log file that we are writing. 03:43:23.520 |
And a lot of this is just like boring matplotlib code. 03:43:27.840 |
But basically, this is what our optimization looks like. 03:43:42.160 |
which is one epoch of the sample 10B of FineWebEDU. 03:43:55.440 |
we have the opening IGPT-2 124M model checkpoint, 03:43:59.360 |
when it's just evaluated on the validation set 03:44:09.200 |
So we're surpassing the validation set of this dataset. 03:44:12.560 |
And like I mentioned, the dataset distribution 03:44:14.640 |
is very different from what GPT-2 trained on. 03:44:25.200 |
that is withheld and comparable and somewhat standard. 03:44:32.320 |
And so on here, we see the helloswag progress 03:44:38.080 |
In red, we see the opening IGPT-2 124M model in red. 03:44:49.360 |
which was trained on 300 billion tokens, achieves green. 03:44:57.280 |
the GPT-2 124M model right here, which is really nice. 03:45:09.280 |
while GPT-2 was trained on 100 billion tokens. 03:45:15.360 |
with significantly fewer tokens for training. 03:45:17.760 |
There are many possibilities as to why we could match 03:45:21.440 |
or surpass this accuracy with only 10 billion training. 03:45:29.840 |
was trained on a much wider data distribution. 03:45:40.000 |
And so math and code and multilingual could have been 03:45:44.240 |
stealing capacity from the original GPT-2 model. 03:45:48.080 |
And basically, that could be partially the reason 03:45:55.120 |
So for example, the helloswag eval is fairly old, 03:45:59.920 |
It is possible that aspects of helloswag in some way, 03:46:03.200 |
or even identically, have made it into the training set. 03:46:09.680 |
But if that was the case, then we are basically looking 03:46:11.600 |
at the training curve instead of the validation curve. 03:46:13.440 |
So long story short, this is not a perfect eval. 03:46:19.920 |
that we're not doing something completely wrong. 03:46:21.840 |
And it's probably the case that when people try 03:46:26.960 |
to create these data sets, they try to make sure 03:46:32.720 |
For example, when HuggingFace created the FindWebEDU, 03:46:37.120 |
So I would hope that they make sure that they deduplicate 03:46:40.080 |
and that there's no helloswag in the training set. 03:46:43.680 |
The other thing I wanted to address briefly is, 03:46:52.960 |
And I suspect it's because the 10 billion sample 03:47:17.200 |
I think we're inheriting some of the ordering 03:47:24.800 |
But hopefully by the time you get to this repo, 03:47:34.800 |
And right now it looks a little ugly and preliminary. 03:47:37.840 |
So hopefully by the time you get here, it's nicer. 03:47:42.960 |
And I'm going to talk about some of the things 03:47:47.200 |
And I expect that we will have fixed the small issue. 03:47:50.160 |
But for now, basically, this shows that our training 03:47:55.840 |
And it shows that we're able to surpass the accuracy 03:48:01.440 |
And possibly it could be also that the data set 03:48:10.880 |
It's possible that not a lot of care and attention 03:48:18.160 |
on good practices around deduplication, filtering, 03:48:23.360 |
And it's possible that the data set we're training on 03:48:47.360 |
instead of one epoch, which took roughly two hours, 03:48:51.280 |
So that that would take eight hours while I was sleeping. 03:48:53.680 |
And so we did four epochs or roughly 40 billion 03:48:58.320 |
And I was trying to see how far we could get. 03:49:03.920 |
And when I point and read the log file at the 40B, 03:49:13.520 |
we are seeing this issue here with the periodicity 03:49:16.560 |
through the different epochs and something really weird 03:49:22.720 |
But otherwise, we are seeing that the Hellaswag 03:49:28.640 |
And we almost made it to the GPT-3 124M accuracy up here, 03:49:36.320 |
So it's too bad that I didn't sleep slightly longer. 03:49:45.680 |
Now, one thing to point out is that if you're 03:49:57.360 |
in exactly the same format and exactly the same order. 03:50:04.880 |
where you actually permute the data randomly. 03:50:07.920 |
You permute the documents around in every single shard 03:50:10.960 |
on every single new epoch and potentially even permute 03:50:15.840 |
And that would go a long way into decreasing the periodicity. 03:50:21.600 |
so that you're not seeing things in the identical format. 03:50:24.960 |
And you're introducing some of the randomness 03:50:29.600 |
Because you have to remember that in every single row, 03:50:36.240 |
So the documents are currently glued together 03:50:41.280 |
But we actually want to break up the documents 03:50:45.040 |
Because the order of the documents shouldn't matter. 03:50:52.720 |
And so our data letter is not currently doing that. 03:50:55.440 |
And that's one improvement you could think of making. 03:50:57.920 |
The other thing to point out is we're almost matching 03:51:07.840 |
So again, we're seeing about a 10x improvement here 03:51:13.360 |
The other thing I wanted to-- and I don't actually 03:51:16.880 |
know exactly what to attribute this to other than some 03:51:19.120 |
of the things that I already mentioned previously 03:51:27.600 |
I saw some people already play with this a little bit 03:51:32.000 |
And it turns out that you can actually almost 3x this. 03:51:36.000 |
So it's possible that the maximum learning rate 03:51:38.800 |
And for some reason, the GPT-3 hyperparameters 03:51:43.760 |
And you can actually get away with a higher learning rate. 03:51:47.360 |
So a lot of these hyperparameters are quite tunable. 03:51:52.800 |
And they're probably not set precisely correctly. 03:52:00.800 |
And if you wanted to exactly be faithful to GPT-3, 03:52:05.200 |
you would also want to make the following difference. 03:52:16.240 |
So you would come here, change this to 2,048 for t. 03:52:19.600 |
And then if you want the exact same number of tokens, 03:52:29.440 |
So that would give your model sequence length 03:52:38.800 |
would be roughly identical as far as I'm aware. 03:52:42.560 |
Because again, GPT-2 and GPT-3 are very, very similar models. 03:52:46.240 |
Now, we can also look at some of the samples here 03:52:54.640 |
And you see that here, we stepped all the way 03:52:59.440 |
And these are-- the Hellas spike we achieved was 33.24. 03:53:04.400 |
And these are some of the samples from the model. 03:53:07.680 |
And you can see that if you read through this 03:53:14.160 |
So-- and they're actually addressing the fact 03:53:25.920 |
I'm a language model, not a programming language. 03:53:39.200 |
If you pause this and look at it and then compare it 03:53:41.280 |
to the one-- to the model that was only trained 03:53:48.240 |
One more thing I added to the code, by the way, 03:53:52.720 |
So basically, right after we evaluate the validation loss, 03:53:58.640 |
to logging the validation loss, every 5,000 steps, 03:54:03.440 |
which is really just the state dictionary of the model. 03:54:16.880 |
we have to also save the optimizer state dict. 03:54:21.680 |
has a few additional buffers because of atom. 03:54:26.400 |
And you need to also resume the optimizer properly. 03:54:30.240 |
You have to be careful with the RNG seeds, random number 03:54:34.080 |
So if you wanted to exactly be able to resume optimization, 03:54:37.040 |
you have to think through the state of the training process. 03:54:43.440 |
And one nice reason why you might want to do this 03:54:46.160 |
is because you may want to evaluate the model a lot more 03:54:49.460 |
So here, we are only kind of like winging the LSWG eval. 03:54:57.040 |
like, for example, the Luther evaluation hardness. 03:55:05.600 |
So this is a way to also evaluate language models. 03:55:13.840 |
want to use basically different infrastructure 03:55:21.120 |
to the OpenAI GPT-2 model on many other tasks, 03:55:28.320 |
So this is a nice functionality to have as well. 03:55:44.720 |
You can't talk to it like you can talk to chat GPT. 03:55:50.480 |
we have to fine-tune it into the chat format. 03:55:55.120 |
If you're looking at supervised fine-tuning or SFT, 03:55:58.000 |
really what that means is we're just swapping out the data set 03:56:00.720 |
into a data set that is a lot more conversational. 03:56:02.960 |
And there's a user-assistant, user-assistant kind 03:56:07.200 |
And then we basically fill in the user tokens, 03:56:17.360 |
But for now, we're going to stop at pre-training. 03:56:20.720 |
One more thing that I wanted to briefly show you 03:56:23.040 |
is that, of course, what we've built up today 03:56:30.080 |
But also, there's actually another NanoGPT implementation, 03:56:44.080 |
And it just directly uses CUDA and is written as C CUDA. 03:56:49.040 |
Now, the NanoGPT here acts as reference code in PyTorch 03:56:59.440 |
and, of course, currently that seems to be the case 03:57:01.280 |
because it is a direct optimized implementation. 03:57:04.400 |
So train gpt2.py in llm.c is basically the NanoGPT. 03:57:12.000 |
you'll find a lot of things that very much look like 03:57:26.080 |
So there's a lot of MPI, NICL, GPU, CUDA, C, C++. 03:57:39.440 |
and they're going to produce the exact same results, 03:57:44.480 |
So on the left, I have PyTorch, a NanoGPT looking thing. 03:57:55.040 |
Both of these are gonna be running on a single GPU. 03:58:12.240 |
So basically, meanwhile, PyTorch is still compiling 03:58:25.280 |
And so this program has already started running 03:58:27.520 |
and we're still waiting here for Torch compile. 03:58:30.400 |
Now, of course, this is a very specific implementation 03:58:35.040 |
PyTorch is a very general neural network framework, 03:58:38.960 |
But if you're only interested in training GPT-2 and 3, 03:58:45.520 |
it's faster to start, and it's faster per step. 03:59:00.480 |
So quite a bit slower, but I don't have full confidence 03:59:16.880 |
that are printed between these two are identical. 03:59:21.920 |
and on the right, this C code implementation, 03:59:24.640 |
and they're the same, except this one runs faster. 03:59:35.600 |
to play with or look at, and it's kind of interesting. 03:59:39.520 |
Okay, so at this point, I should probably start 03:59:54.560 |
We were looking at how you set up these training runs, 04:00:04.160 |
of either a two-hour training run or an overnight run, 04:00:07.280 |
we can actually match the 124 million parameter checkpoints 04:00:18.480 |
if you have the patience or the computing resources, 04:00:21.120 |
and so you could potentially think about training 04:00:36.400 |
It currently breaks generation and Hellaswag. 04:00:40.240 |
In the data loader, we should probably be permuting our data 04:00:46.400 |
and I expect to be documenting some of those over time 04:00:51.920 |
which I'm going to be releasing with this video. 04:00:57.200 |
or would like to talk about anything that we covered, 04:00:59.600 |
please go to discussions tab so we can talk here, 04:01:08.720 |
or also have a look at the Zero2Hero Discord, 04:01:13.120 |
and I'm going to be hanging out here on NanoGPT. 04:01:18.000 |
Otherwise, for now, I'm pretty happy about where we got,