back to index

Let's reproduce GPT-2 (124M)


Chapters

0:0 intro: Let’s reproduce GPT-2 (124M)
3:39 exploring the GPT-2 (124M) OpenAI checkpoint
13:47 SECTION 1: implementing the GPT-2 nn.Module
28:8 loading the huggingface/GPT-2 parameters
31:0 implementing the forward pass to get logits
33:31 sampling init, prefix tokens, tokenization
37:2 sampling loop
41:47 sample, auto-detect the device
45:50 let’s train: data batches (B,T) → logits (B,T,C)
52:53 cross entropy loss
56:42 optimization loop: overfit a single batch
62:0 data loader lite
66:14 parameter sharing wte and lm_head
73:47 model initialization: std 0.02, residual init
82:18 SECTION 2: Let’s make it fast. GPUs, mixed precision, 1000ms
88:14 Tensor Cores, timing the code, TF32 precision, 333ms
99:38 float16, gradient scalers, bfloat16, 300ms
108:15 torch.compile, Python overhead, kernel fusion, 130ms
120:18 flash attention, 96ms
126:54 nice/ugly numbers. vocab size 50257 → 50304, 93ms
134:55 SECTION 3: hyperpamaters, AdamW, gradient clipping
141:6 learning rate scheduler: warmup + cosine decay
146:21 batch size schedule, weight decay, FusedAdamW, 90ms
154:9 gradient accumulation
166:52 distributed data parallel (DDP)
190:21 datasets used in GPT-2, GPT-3, FineWeb (EDU)
203:10 validation data split, validation loss, sampling revive
208:23 evaluation: HellaSwag, starting the run
223:5 SECTION 4: results in the morning! GPT-2, GPT-3 repro
236:21 shoutout to llm.c, equivalent but faster code in raw C/CUDA
239:39 summary, phew, build-nanogpt github repo

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi, everyone. So, today we are going to be continuing our Zero to Hero series,
00:00:04.320 | and in particular, today we are going to reproduce the GPT-2 model,
00:00:07.840 | the 124 million version of it. So, when OpenAI released GPT-2, this was 2019,
00:00:15.680 | and they released it with this blog post. On top of that, they released this paper,
00:00:20.480 | and on top of that, they released this code on GitHub. So, OpenAI/GPT-2.
00:00:24.560 | Now, when we talk about reproducing GPT-2, we have to be careful, because in particular,
00:00:29.600 | in this video, we're going to be reproducing the 124 million parameter model. So, the thing to
00:00:34.800 | realize is that there's always a miniseries when these releases are made. So, there are the GPT-2
00:00:41.040 | miniseries made up of models at different sizes, and usually the biggest model is called the GPT-2.
00:00:46.800 | But basically, the reason we do that is because you can put the model sizes on the x-axis of
00:00:52.160 | plots like this, and on the y-axis, you put a lot of downstream metrics that you're interested in,
00:00:56.720 | like translation, summarization, question answering, and so on, and you can chart out
00:01:00.560 | these scaling laws. So, basically, as the model size increases, you're getting better and better
00:01:05.680 | at downstream metrics. And so, in particular for GPT-2, if we scroll down in the paper,
00:01:12.000 | there are four models in the GPT-2 miniseries, starting at 124 million, all the way up to
00:01:18.320 | 1,558 million. Now, the reason my numbers, the way I say them, disagree with this table is that
00:01:24.560 | this table is wrong. If you actually go to the GPT-2 GitHub repo, they sort of say that there
00:01:32.160 | was an error in how they added up the parameters. But basically, this is the 124 million parameter
00:01:36.400 | model, et cetera. So, the 124 million parameter had 12 layers in the transformer, and it had 768
00:01:43.680 | channels in the transformer, 768 dimensions. And I'm going to be assuming some familiarity with
00:01:49.360 | what these terms mean, because I covered all of this in my previous video, let's build GPT-2,
00:01:53.520 | let's build GPT from scratch. So, I covered that in the previous video in this playlist.
00:01:58.400 | Now, if we do everything correctly and everything works out well, by the end of this video,
00:02:03.360 | we're going to see something like this, where we're looking at the validation loss, which basically
00:02:08.160 | measures how good we are at predicting the next token in a sequence on some validation data that
00:02:14.720 | the model has not seen during training. And we see that we go from doing that task not very well,
00:02:20.320 | because we're initializing from scratch, all the way to doing that task quite well
00:02:23.760 | by the end of the training. And hopefully, we're going to beat the GPT-2 124M model.
00:02:29.920 | Now, previously, when they were working on this, this is already five years ago.
00:02:34.560 | So, this was probably a fairly complicated optimization at the time, and the GPUs and
00:02:38.560 | the compute was a lot smaller. Today, you can reproduce this model in roughly an hour,
00:02:44.160 | or probably less even, and it will cost you about 10 bucks if you want to do this on the cloud
00:02:48.320 | compute, a sort of computer that you can all rent. And if you pay $10 for that computer,
00:02:54.400 | you wait about an hour or less, you can actually achieve a model that is as good as
00:02:58.960 | this model that OpenAI released. And one more thing to mention is, unlike many other models,
00:03:05.200 | OpenAI did release the weights for GPT-2. So, those weights are all available in this repository.
00:03:11.040 | But the GPT-2 paper is not always as good with all of the details of the training.
00:03:16.160 | So, in addition to the GPT-2 paper, we're going to be referencing the GPT-3 paper,
00:03:20.320 | which is a lot more concrete in a lot of the parameters and optimization settings and so on.
00:03:25.600 | And it's not a huge departure in the architecture from the GPT-2 version of the model. So,
00:03:31.920 | we're going to be referencing both GPT-2 and GPT-3 as we try to reproduce GPT-2 124M. So,
00:03:38.080 | let's go. So, the first thing I would like to do is actually start at the end, or at the target.
00:03:42.960 | So, in other words, let's load the GPT-2 124M model as it was released by OpenAI,
00:03:48.480 | and maybe take it for a spin. Let's sample some tokens from it. Now, the issue with that is,
00:03:52.640 | when you go to the code base of GPT-2 and you go into the source and you click in on the model.py,
00:03:58.160 | you'll realize that actually this is using TensorFlow. So, the original GPT-2 code here
00:04:02.880 | was written in TensorFlow, which is, you know, not, let's just say, not used as much anymore.
00:04:10.560 | So, we'd like to use PyTorch, because it's a lot friendlier, easier, and I just personally like it
00:04:15.680 | a lot more. The problem with that is the initial code is in TensorFlow. We'd like to use PyTorch.
00:04:20.160 | So, instead, to get the target, we're going to use the hugging face transformers code,
00:04:26.320 | which I like a lot more. So, when you go into the transformers, source, transformers, models,
00:04:30.400 | GPT-2, modeling, GPT-2.py, you will see that they have the GPT-2 implementation of that transformer
00:04:36.640 | here in this file. And it's, like, medium readable, but not fully readable. But what it does is it did
00:04:46.800 | all the work of converting all those weights from TensorFlow to PyTorch friendly, and so it's much
00:04:52.480 | easier to load and work with. So, in particular, we can look at the GPT-2 model here, and we can
00:04:59.120 | load it using hugging face transformers. So, swinging over, this is what that looks like.
00:05:03.440 | From transformers, import the GPT-2 LM head model, and then from pre-trained GPT-2.
00:05:11.840 | Now, one awkward thing about this is that when you do GPT-2 as the model that we're loading,
00:05:17.920 | this actually is the 124 million parameter model. If you want the actual GPT-2, the 1.5 billion,
00:05:25.440 | then you actually want to do -XL. So, this is the 124M, our target. Now, what we're doing is,
00:05:32.480 | when we actually get this, we're initializing the PyTorch NN module as defined here in this class.
00:05:38.160 | From it, I want to get just the state dict, which is just the raw tensors. So, we just have
00:05:44.800 | the tensors of that file. And by the way, here, this is a Jupyter notebook, but this is a Jupyter
00:05:51.120 | notebook running inside VS Code, so I like to work with it all in a single interface, so I like to
00:05:57.200 | use VS Code, so this is the Jupyter notebook extension inside VS Code. So, when we get the
00:06:05.360 | state dict, this is just a dict, so we can print the key and the value, which is the tensor,
00:06:11.280 | and let's just look at the shapes. So, these are sort of the different parameters inside the GPT-2
00:06:18.080 | model and their shape. So, the W weight for token embedding is of size 50257 by 768. Where this is
00:06:30.400 | coming from is that we have 50257 tokens in the GPT-2 vocabulary, and the tokens, by the way,
00:06:39.120 | these are exactly the tokens that we've spoken about in the previous video on my tokenization
00:06:43.520 | series. So, the previous video, just before this, I go into a ton of detail on tokenization.
00:06:48.400 | GPT-2 tokenizer happens to have this many tokens. For each token, we have a 768-dimensional
00:06:56.240 | embedding that is the distributed representation that stands in for that token. So, each token is
00:07:02.880 | a little string piece, and then these 768 numbers are the vector that represents that token.
00:07:09.760 | And so, this is just our lookup table for tokens, and then here, we have the lookup table for the
00:07:14.560 | positions. So, because GPT-2 has a maximum sequence length of 1024, we have up to 1024 positions that
00:07:23.360 | each token can be attending to in the past, and every one of those positions in GPT-2 has a fixed
00:07:29.920 | vector of 768 that is learned by optimization. And so, this is the position embedding and the
00:07:37.680 | token embedding, and then everything here is just the other weights and biases and everything else
00:07:43.760 | of this transformer. So, when you just take, for example, the positional embeddings and flatten it
00:07:49.760 | out and take just the 20 elements, you can see that these are just the parameters. These are
00:07:53.760 | weights, floats, just we can take and we can plot them. So, these are the position embeddings,
00:08:00.640 | and we get something like this, and you can see that this has structure, and it has structure
00:08:05.280 | because what we have here really is every row in this visualization is a different position,
00:08:12.560 | a fixed absolute position in the range from 0 to 1024, and each row here is the representation
00:08:20.640 | of that position. And so, it has structure because these positional embeddings end up learning these
00:08:26.800 | sinusoids and cosines that sort of like represent each of these positions, and each row here stands
00:08:34.960 | in for that position and is processed by the transformer to recover all the relative positions
00:08:40.160 | and sort of realize which token is where and attend to them depending on their position,
00:08:45.760 | not just their content. So, when we actually just look into an individual column inside these,
00:08:53.200 | and I just grabbed three random columns, you'll see that, for example, here we are focusing on
00:08:59.200 | every single channel, and we're looking at what that channel is doing as a function of position
00:09:09.040 | from 1, from 0 to 1023, really. And we can see that some of these channels basically like respond
00:09:17.200 | more or less to different parts of the position spectrum. So, this green channel really likes to
00:09:23.200 | fire for everything after 200 up to 800, but not less, but a lot less, and has a sharp drop-off
00:09:31.440 | here near 0. So, who knows what these embeddings are doing and why they are the way they are.
00:09:36.160 | You can tell, for example, that because they're a bit more jagged and they're kind of noisy,
00:09:39.600 | you can tell that this model was not fully trained. And the more trained this model was,
00:09:44.640 | the more you would expect to smooth this out. And so, this is telling you that this is a little bit
00:09:48.560 | of an under-trained model, but in principle, actually, these curves don't even have to be
00:09:54.160 | smooth. This should just be totally random noise. And in fact, in the beginning of the optimization,
00:09:58.720 | it is complete random noise, because this position embedding table is initialized completely at
00:10:03.600 | random. So, in the beginning, you have jaggedness, and the fact that you end up with something
00:10:08.160 | smooth is already kind of impressive, that that just falls out of the optimization,
00:10:13.280 | because in principle, you shouldn't even be able to get any single graph out of this that makes
00:10:17.280 | sense. But we actually get something that looks a little bit noisy, but for the most part looks
00:10:21.280 | sinusoidal-like. In the original transformer paper, the attention is all you need paper,
00:10:29.120 | the positional embeddings are actually initialized and fixed, if I remember correctly, to sinusoids
00:10:34.240 | and cosines of different frequencies. And that's the positional encoding, and it's fixed. But in
00:10:39.520 | GPT-2, these are just parameters, and they're trained from scratch, just like any other
00:10:43.440 | parameter. And that seems to work about as well. And so what they do is they kind of recover these
00:10:48.720 | sinusoidal-like features during the optimization. We can also look at any of the other matrices
00:10:55.440 | here. So, here I took the first layer of the transformer, and looking at one of its weights,
00:11:03.280 | and just the first block of 300 by 300, and you see some structure, but again, who knows what
00:11:10.560 | any of this is. If you're into mechanistic interpretability, you might get a real kick out
00:11:15.040 | of trying to figure out what is going on, what is this structure, and what does this all mean,
00:11:19.440 | but we're not going to be doing that in this video. But we definitely see that there's some
00:11:22.880 | interesting structure, and that's kind of cool. What we're most interested in is we've loaded
00:11:27.360 | the weights of this model that was released by OpenAI, and now using the Hugging Face Transformers,
00:11:32.800 | we can not just get all the raw weights, but we can also get what they call pipeline,
00:11:39.520 | and sample from it. So, this is the prefix, "Hello, I'm a language model," comma,
00:11:44.480 | and then we're sampling 30 tokens, and we're getting five sequences, and I ran this,
00:11:51.760 | and this is what it produced. "Hello, I'm a language model," but what I'm really doing
00:11:57.360 | is making a human-readable document. There are other languages, but those are dot, dot, dot,
00:12:02.400 | so you can read through these if you like, but basically, these are five different completions
00:12:05.840 | of the same prefix from this GPT2124M. Now, if I go here, I took this example from here,
00:12:14.560 | and sadly, even though we are fixing the seed, we are getting different generations
00:12:20.000 | from the snippet than what they got, so presumably the code changed, but what we see, though,
00:12:28.240 | at this stage that's important is that we are getting coherent text, so we've loaded the model
00:12:33.200 | successfully, we can look at all its parameters, and the keys tell us where in the model these
00:12:38.560 | come from, and we want to actually write our own GPT2 class so that we have a full understanding
00:12:43.760 | of what's happening there. We don't want to be working with something like the modeling GPT2.py,
00:12:49.200 | because it's just too complicated. We want to write this from scratch ourselves,
00:12:52.400 | so we're going to be implementing the GPT model here in parallel, and as our first task,
00:12:57.120 | let's load the GPT2124M into the class that we are going to develop here from scratch.
00:13:03.680 | That's going to give us confidence that we can load the OpenAI model, and therefore,
00:13:09.200 | there's a setting of weights that exactly is the 124 model, but then, of course,
00:13:13.680 | what we're going to do is we're going to initialize the model from scratch instead,
00:13:16.880 | and try to train it ourselves on a bunch of documents that we're going to get,
00:13:21.440 | and we're going to try to surpass that model, so we're going to get different weights,
00:13:25.520 | and everything's going to look different, hopefully better even, but we're going to
00:13:30.560 | have a lot of confidence that because we can load the OpenAI model, we are in the same model family
00:13:34.800 | and model class, and we just have to rediscover a good setting of the weights, but from scratch.
00:13:39.280 | So let's now write the GPT2 model, and let's load the weights, and make sure that we can
00:13:45.360 | also generate text that looks coherent. Okay, so let's now swing over to the
00:13:49.440 | attention is all you need paper that started everything, and let's scroll over to the model
00:13:53.280 | architecture, the original transformer. Now, remember that GPT2 is slightly modified from
00:13:58.720 | the original transformer. In particular, we do not have the encoder. GPT2 is a decoder-only
00:14:05.120 | transformer, as we call it, so this entire encoder here is missing, and in addition to that,
00:14:10.080 | this cross-attention here that was using that encoder is also missing, so we delete this entire
00:14:16.240 | part. Everything else stays almost the same, but there are some differences that we're going to
00:14:21.680 | sort of look at here. So there are two main differences. When we go to the GPT2 paper under
00:14:29.680 | 2.3.model, we notice that first, there's a reshuffling of the layer norms, so they change
00:14:35.680 | place, and second, an additional layer normalization was added here to the final
00:14:42.320 | self-attention block. So basically, all the layer norms here, instead of being after the MLP or
00:14:48.000 | after the attention, they swing before it, and an additional layer norm gets added here right before
00:14:53.040 | the final classifier. So now let's implement some of the first sort of skeleton NN modules
00:14:59.200 | here in our GPT NN module, and in particular, we're going to try to match up this schema here
00:15:05.760 | that is used by Hugging Face Transformers because that will make it much easier to load these weights
00:15:10.480 | from this state dict. So we want something that reflects this schema here. So here's what I came
00:15:17.120 | up with. Basically, we see that the main container here that has all the modules is called transformer,
00:15:25.920 | so I'm reflecting that with an NN module dict, and this is basically a module that allows you
00:15:30.400 | to index into the sub-modules using keys, just like a dictionary strings. Within it, we have the
00:15:38.480 | weights of the token embeddings, WT, and that's an NN embedding, and the weights of the position
00:15:44.320 | embeddings, which is also just an NN embedding, and if you remember, NN embedding is really just
00:15:48.400 | a fancy little wrapper module around just a single array of numbers, a single block of
00:15:57.440 | numbers just like this. It's a single tensor, and NN embedding is a glorified wrapper around a tensor
00:16:04.640 | that allows you to access its elements by indexing into the rows. Now, in addition to that, we see
00:16:10.960 | here that we have a .h, and then this is indexed using numbers instead of indexed using strings,
00:16:17.920 | so there's a .h, .0, 1, 2, etc., all the way up till .h.11, and that's because there are 12 layers
00:16:25.680 | here in this transformer. So to reflect that, I'm creating also an h, I think that probably
00:16:31.440 | stands for hidden, and instead of a module dict, this is a model list, so we can index it using
00:16:36.480 | integers exactly as we see here, .0, .1, 2, etc., and the module list has N layer blocks, and the
00:16:45.680 | blocks are yet to be defined in a module in a bit. In addition to that, following the GPT-2 paper,
00:16:51.840 | we need an additional final layer norm that we're going to put in there, and then we have the final
00:16:58.080 | classifier, the language model head, which projects from 768, the number of embedding
00:17:06.000 | dimensions in this GPT, all the way to the vocab size, which is 50,257, and GPT-2 uses no bias for
00:17:13.200 | this final sort of projection. So this is the skeleton, and you can see that it reflects this,
00:17:20.160 | so the WTE is the token embeddings, here it's called output embedding, but it's really the
00:17:26.000 | token embeddings. The PE is the positional encodings, those two pieces of information,
00:17:31.440 | as we saw previously, are going to add, and then go into the transformer. The .h is all the blocks
00:17:37.120 | in gray, and the LNF is this new layer that gets added here by the GPT-2 model, and LM_head is this
00:17:45.040 | linear part here. So that's the skeleton of the GPT-2. We now have to implement the block.
00:17:52.480 | Okay, so let's now recurse to the block itself. So we want to define the block.
00:17:56.400 | So I'll start putting them here. So the block, I like to write out like this.
00:18:03.680 | These are some of the initializations, and then this is the actual forward pass of what this block
00:18:09.040 | computes. And notice here that there's a change from the transformer again, that is mentioned
00:18:14.800 | in the GPT-2 paper. So here, the layer normalizations are after the application of
00:18:20.480 | attention, or feedforward. In addition to that note, that the normalizations are inside the
00:18:25.840 | residual stream. You see how feedforward is applied, and this arrow goes through and through
00:18:31.120 | the normalization. So that means that your residual pathway has normalizations inside them.
00:18:36.800 | And this is not very good or desirable. You actually prefer to have a single clean residual
00:18:42.800 | stream, all the way from supervision, all the way down to the inputs, the tokens. And this is very
00:18:48.400 | desirable and nice, because the gradients that flow from the top, if you remember from your
00:18:54.240 | micrograd, addition just distributes gradients during the backward stage to both of its branches
00:19:00.480 | equally. So addition is a branch in the gradients. And so that means that the gradients from the top
00:19:07.360 | flow straight to the inputs, the tokens, through the residual pathways unchanged. But then in
00:19:13.280 | addition to that, the gradient also flows through the blocks, and the blocks, you know, contribute
00:19:17.520 | their own contribution over time, and kick in and change the optimization over time. But basically,
00:19:22.240 | clean residual pathway is desirable from an optimization perspective. And then this is the
00:19:28.880 | pre-normalization version, where you see that Rx first goes through the layer normalization,
00:19:33.920 | and then the attention, and then goes back out to go to the layer normalization number two,
00:19:39.600 | and the multilayer perceptron, sometimes also referred to as a feedforward network, or an FFM.
00:19:46.160 | And then that goes into the residual stream again. And the one more thing that is kind of
00:19:50.800 | interesting to note is that recall that attention is a communication operation. It is where all the
00:19:56.000 | tokens, and there's 1024 tokens lined up in a sequence, and this is where the tokens communicate.
00:20:01.760 | This is where they exchange information. So attention is a aggregation function. It's a
00:20:08.080 | pooling function. It's a weighted sum function. It is a reduce operation. Whereas MLP, this MLP
00:20:17.360 | here happens at every single token individually. There's no information being collected or
00:20:21.440 | exchanged between the tokens. So the attention is the reduce, and the MLP is the map. And what you
00:20:27.680 | end up with is that the transformer just ends up just being a repeated application of map reduce,
00:20:32.320 | if you want to think about it that way. So this is where they communicate, and this is where they
00:20:38.000 | think individually about the information that they gathered. And every one of these blocks
00:20:42.560 | iteratively refines the representation inside the residual stream. So this is our block,
00:20:49.520 | slightly modified from this picture. Okay, so let's now move on to the MLP. So the MLP block,
00:20:56.880 | I implemented it as follows. It is relatively straightforward. We basically have two linear
00:21:02.080 | projections here that are sandwiched in between the Gelu non-linearity. So nn.gelu approximate
00:21:09.520 | is 10h. Now when we swing over to the PyTorch documentation, this is nn.gelu, and it has this
00:21:16.880 | format, and it has two versions, the original version of Gelu, which we'll step into in a bit,
00:21:22.320 | and the approximate version of Gelu, which we can request using 10h. So as you can see, just as a
00:21:28.080 | preview here, Gelu is basically like a ReLU, except there's no flat, exactly flat tail here at exactly
00:21:36.400 | zero. But otherwise it looks very much like a slightly smoother ReLU. It comes from this paper
00:21:42.080 | here, Gaussian Error Linear Units, and you can step through this paper, and there's some mathematical
00:21:48.080 | kind of like reasoning that leads to an interpretation that leads to this specific
00:21:51.840 | formulation. It has to do with stochastic radial risers and the expectation of a modification to
00:21:57.360 | adaptive dropout, so you can read through all of that if you'd like here. And there's a little bit
00:22:02.320 | of a history as to why there's an approximate version of Gelu, and that comes from this issue
00:22:07.520 | here, as far as I can tell. And in this issue, Daniel Hendrix mentions that at the time when
00:22:14.320 | they developed this non-linearity, the IRF function, which you need to evaluate the exact Gelu,
00:22:20.800 | was very slow in TensorFlow, so they ended up basically developing this approximation.
00:22:24.720 | And this approximation then ended up being picked up by BERT and by GPT-2, etc.
00:22:29.200 | But today there's no real good reason to use the approximate version. You'd prefer to just use the
00:22:33.680 | exact version, because my expectation is that there's no big difference anymore, and this is
00:22:39.760 | kind of like a historical kind of quirk. But we are trying to reproduce GPT-2 exactly, and GPT-2
00:22:47.440 | used the 10H approximate version, so we prefer to stick with that. Now one other reason to actually
00:22:55.200 | just intuitively use Gelu instead of Relu is, previously in videos in the past, we've spoken
00:23:00.240 | about the dead Relu neuron problem, where in this tail of a Relu, if it's exactly flat at zero,
00:23:07.360 | any activations that fall there will get exactly zero gradient. There's no change,
00:23:11.440 | there's no adaptation, there's no development of the network if any of these activations
00:23:16.160 | end in this flat region. But the Gelu always contributes a local gradient, and so there's
00:23:21.680 | always going to be a change, always going to be an adaptation, and sort of smoothing it out
00:23:25.840 | ends up empirically working better in practice, as demonstrated in this paper, and also as
00:23:30.240 | demonstrated by it being picked up by the BERT paper, GPT-2 paper, and so on. So for that reason
00:23:35.440 | we adopt this non-linearity here in the 10 in the GPT-2 reproduction. Now in more modern networks,
00:23:41.760 | also like Lama3 and so on, this non-linearity also further changes to Swiglu and other variants like
00:23:48.160 | that, but for GPT-2 they use this approximate Gelu. Okay, and finally we have the attention
00:23:54.320 | operation. So let me paste in my attention. So I know this is a lot, so I'm gonna go through this
00:24:02.400 | a bit quickly, a bit slowly, but not too slowly, because we have covered this in the previous video
00:24:07.200 | and I would just point you there. So this is the attention operation. Now in the previous video
00:24:12.880 | you will remember this is not just attention, this is multi-headed attention, right? And so in the
00:24:19.520 | previous video we had this multi-headed attention module, and this implementation made it obvious
00:24:24.960 | that these heads are not actually that complicated. There's basically, in parallel, inside every
00:24:30.480 | attention block, there's multiple heads, and they're all functioning in parallel, and their
00:24:36.640 | outputs are just being concatenated, and that becomes the output of the multi-headed attention.
00:24:41.520 | So the heads are just kind of like parallel streams, and their outputs get concatenated.
00:24:46.480 | And so it was very simple and made the head be kind of like fairly straightforward in terms of
00:24:52.640 | its implementation. What happens here is that instead of having two separate modules, and
00:24:58.320 | indeed many more modules that get concatenated, all of that is just put into a single self-attention
00:25:04.800 | module. And instead I'm being very careful and doing a bunch of transpose-split tensor gymnastics
00:25:12.800 | to make this very efficient in PyTorch, but fundamentally and algorithmically nothing is
00:25:16.720 | different from the implementation we saw before in this Git repository. So to remind you very
00:25:25.520 | briefly, and I don't want to go into this in too much time, but we have these tokens lined up in
00:25:32.080 | a sequence, and there's 1020 of them. And then each token at this stage of the attention emits
00:25:38.320 | three vectors, the query, key, and the value. And first what happens here is that the queries and
00:25:45.840 | the keys have to multiply each other to get sort of the attention amount, like how interesting they
00:25:52.720 | find each other. So they have to interact multiplicatively. So what we're doing here is
00:25:56.720 | we're calculating the QKV while splitting it, and then there's a bunch of gymnastics as I mentioned
00:26:01.520 | here. And the way this works is that we're basically making the number of heads, nh,
00:26:07.600 | into a batch dimension. And so it's a batch dimension just like b, so that in these operations
00:26:13.360 | that follow, PyTorch treats b and nh as batches, and it applies all the operations on all of them
00:26:20.480 | in parallel, in both the batch and the heads. And the operations that get applied are, number one,
00:26:26.560 | the queries and the keys interact to give us our attention. This is the autoregressive mask
00:26:31.840 | that made sure that the tokens only attend to tokens before them and never to tokens in the
00:26:37.920 | future. The softmax here normalizes the attention, so it sums to one always. And then recall from
00:26:46.160 | the previous video that doing the attention matrix multiply with the values is basically
00:26:49.920 | a way to do a weighted sum of the values of the tokens that we found interesting at every single
00:26:55.200 | token. And then the final transpose contiguous and view is just reassembling all of that again,
00:27:01.520 | and this actually performs a concatenation operation. So you can step through this
00:27:06.240 | slowly if you'd like, but it is equivalent mathematically to our previous implementation,
00:27:12.240 | it's just more efficient in PyTorch, so that's why I chose this implementation instead.
00:27:15.840 | Now in addition to that, I'm being careful with how I name my variables. So for example,
00:27:20.880 | seaten is the same as seaten, and so actually our keys should basically exactly follow the
00:27:26.880 | schema of the HuggingFaceTransformers code, and that will make it very easy for us to now
00:27:31.040 | port over all the weights from exactly this sort of naming conventions, because all of our variables
00:27:37.200 | are named the same thing. But at this point we have finished the GPT-2 implementation,
00:27:42.800 | and what that allows us to do is we don't have to basically use this file from HuggingFace,
00:27:47.920 | which is fairly long. This is 2,000 lines of code, instead we just have less than 100 lines of code,
00:27:59.280 | and this is the complete GPT-2 implementation. So at this stage we should just be able to
00:28:03.840 | take over all the weights, set them, and then do generation. So let's see what that looks like.
00:28:08.960 | Okay, so here I've also changed the GPT config so that the numbers here,
00:28:12.640 | the hybrid parameters, agree with the GPT-2-124M model. So the maximum sequence length, which I
00:28:17.920 | call block size here, is 124. The number of tokens is 5250257, which if you watch my tokenizer video
00:28:26.400 | know that this is 50,000 merges, BPE merges, 256 byte tokens, the leaves of the BPE tree,
00:28:34.880 | and one special end-of-text token that delimits different documents and can start generation as
00:28:39.840 | well. And there are 12 layers, there are 12 heads in the attention, and the dimension of the
00:28:45.360 | transformer is 768. So here's how we can now load the parameters from HuggingFace to our code here,
00:28:53.200 | and initialize the GPT class with those parameters. So let me just copy-paste a bunch of code here.
00:28:58.400 | And I'm not going to go through this code too slowly, because honestly it's not that interesting,
00:29:07.920 | it's not that exciting. We're just loading the weights, so it's kind of dry. But as I mentioned,
00:29:11.680 | there are four models in this mini-series of GPT-2. This is some of the Jupyter code
00:29:16.400 | that we had here on the right. I'm just porting it over. These are the hyperparameters of the
00:29:22.240 | GPT-2 models. We're creating the config object and creating our own model. And then what's
00:29:27.600 | happening here is we're creating the state dict, both for our model and for the HuggingFace model.
00:29:35.360 | And then what we're doing here is we're going over to HuggingFace model keys, and we're copying over
00:29:41.600 | those tensors. And in the process, we are kind of ignoring a few of the buffers. They're not
00:29:47.440 | parameters, they're buffers. So for example, attention.bias, that's just used for the
00:29:51.600 | autoregressive mask. And so we are ignoring some of those masks, and that's it. And then one
00:29:58.240 | additional kind of annoyance is that this comes from the TensorFlow repo, and I'm not sure how...
00:30:03.280 | This is a little bit annoying, but some of the weights are transposed from what PyTorch would
00:30:07.440 | want. And so manually, I hardcoded the weights that should be transposed, and then we transpose
00:30:13.040 | them if that is so. And then we return this model. So the from_pretrained is a constructor
00:30:19.920 | or a class method in Python that returns the GPT object if we just give it the model type,
00:30:27.680 | which in our case is GPT-2, the smallest model that we're interested in. So this is the code,
00:30:32.800 | and this is how you would use it. And we can pop open the terminal here
00:30:37.200 | in VS Code, and we can Python train GPT-2.py, and fingers crossed.
00:30:44.960 | Okay, so we didn't crash. And so we can load the weights and the biases and everything else
00:30:53.680 | into our NNModule. But now let's also get additional confidence that this is working,
00:30:58.080 | and let's try to actually generate from this model. Okay, now before we can actually generate
00:31:02.160 | from this model, we have to be able to forward it. We didn't actually write that code yet.
00:31:06.080 | So here's the forward function. So the input to the forward is going to be our indices,
00:31:12.640 | our token indices. And they are always of shape B by T. And so we have batch dimension of B,
00:31:20.720 | and then we have the time dimension of up to T. And the T can't be more than the block size.
00:31:27.120 | The block size is the maximum sequence length. So B by T indices are arranged in sort of like
00:31:32.720 | a two-dimensional layout. And remember that basically every single row of this is of size
00:31:38.000 | up to block size. And this is T tokens that are in a sequence. And then we have B independent
00:31:44.720 | sequences stacked up in a batch so that this is efficient. Now here we are forwarding the position
00:31:51.440 | embeddings and the token embeddings. And this code should be very recognizable from the previous
00:31:55.360 | lecture. So we basically use a range, which is kind of like a version of range, but for PyTorch.
00:32:02.080 | And we're iterating from zero to T and creating this positions sort of indices.
00:32:08.880 | And then we are making sure that they're on the same device as IDX, because we're not going to
00:32:14.800 | be training on only CPU. That's going to be too inefficient. We want to be training on GPU,
00:32:18.640 | and that's going to come in a bit. Then we have the position embeddings and the token embeddings,
00:32:23.680 | and the addition operation of those two. Now notice that the position embeddings are going
00:32:28.240 | to be identical for every single row of input. And so there's broadcasting hidden inside this plus,
00:32:36.000 | where we have to create an additional dimension here, and then these two add up because the same
00:32:40.560 | position embeddings apply at every single row of our examples stacked up in a batch.
00:32:44.560 | Then we forward the transformer blocks, and finally the last layer norm and the LMAD.
00:32:51.120 | So what comes out after forward is the logits. And if the input was B by T indices,
00:32:57.120 | then at every single B by T, we will calculate the logits for what token comes next in the sequence.
00:33:05.200 | So what is the token B, T plus one, the one on the right of this token. And vocab size here
00:33:12.880 | is the number of possible tokens. And so therefore this is the tensor that we're going to obtain.
00:33:18.800 | And these logits are just a softmax away from becoming probabilities. So this is the forward
00:33:25.040 | pass of the network, and now we can get logits. And so we're going to be able to generate from
00:33:29.360 | the model imminently. Okay, so now we're going to try to set up the identical thing on the left here
00:33:35.440 | that matches hug and face on the right. So here we've sampled from the pipeline, and we sampled
00:33:41.120 | five times up to 30 tokens with the prefix of hello, I'm a language model. And these are the
00:33:46.560 | completions that we achieved. So we're going to try to replicate that on the left here.
00:33:49.920 | So number of turn sequences is five, max length is 30. So the first thing we do, of course,
00:33:54.320 | is we initialize our model, then we put it into evaluation mode. Now, this is a good practice to
00:33:59.440 | put the model into eval when you're not going to be training it, you're just going to be using it.
00:34:03.200 | And I don't actually know if this is doing anything right now for the following reason.
00:34:08.720 | Our model up above here contains no modules or layers that actually have a different behavior
00:34:14.880 | at training or evaluation time. So for example, dropout, batch norm, and a bunch of other layers
00:34:19.600 | have this kind of behavior. But all of these layers that we've used here should be identical
00:34:23.520 | in both training and evaluation time. So potentially, model.eval does nothing. But
00:34:29.920 | then I'm not actually sure if this is the case. And maybe PyTorch internals do some clever things,
00:34:34.960 | depending on the evaluation mode inside here. The next thing we're doing here is we are moving
00:34:40.640 | the entire model to CUDA. So we're moving all of the tensors to GPU. So I'm SSH'd here to a
00:34:47.520 | cloud box, and I have a bunch of GPUs on this box. And here, I'm moving the entire model and all of
00:34:53.840 | its members and all of its tensors and everything like that. Everything gets shipped off to basically
00:34:59.280 | a whole separate computer that is sitting on the GPU. And the GPU is connected to the CPU,
00:35:05.040 | and they can communicate, but it's basically a whole separate computer with its own computer
00:35:08.240 | architecture. And it's really well catered to parallel processing tasks like those of
00:35:12.400 | running neural networks. So I'm doing this so that the model lives on the GPU, a whole separate
00:35:18.080 | computer. And it's just going to make our code a lot more efficient, because all of this stuff
00:35:22.720 | runs a lot more efficiently on the GPUs. So that's the model itself. Now, the next thing we want to
00:35:31.760 | do is we want to start with this as the prefix when we do the generation. So let's actually
00:35:37.920 | create those prefix tokens. So here's the code that I've written. We're going to import the
00:35:42.960 | tick token library from OpenAI, and we're going to get the GPT-2 encoding. So that's the tokenizer
00:35:49.120 | for GPT-2. And then we're going to encode this string and get a list of integers, which are the
00:35:55.680 | tokens. Now, these integers here should actually be fairly straightforward, because we can just
00:36:01.440 | copy/paste this string, and we can sort of inspect what it is in tick tokenizer. So just pasting that
00:36:07.520 | in, these are the tokens that are going to come out. So this list of integers is what we expect
00:36:12.800 | tokens to become. And as you recall, if you saw my video, of course, all the tokens, they're just
00:36:18.400 | little string chunks, right? So this is the truncation of this string into GPT-2 tokens.
00:36:24.560 | So once we have those tokens, it's a list of integers, we can create a torch tensor out of it.
00:36:31.280 | In this case, it's eight tokens. And then we're going to replicate these eight tokens for five
00:36:35.920 | times to get five rows of eight tokens. And that is our initial input X, as I call it here. And
00:36:45.120 | it lives on the GPU as well. So X now is this IDX that we can pin to forward to get our logits so
00:36:54.080 | that we know what comes as the sixth token, sorry, as the ninth token in every one of these five
00:37:01.760 | rows. Okay, and we are now ready to generate. So let me paste in one more code block here.
00:37:06.880 | So what's happening here in this code block is we have these X, which is of size B by T, right? So
00:37:14.560 | batch by time. And we're going to be in every iteration of this loop, we're going to be adding
00:37:19.600 | a column of new indices into each one of these rows, right? And so these are the new indices,
00:37:25.680 | and we're appending them to the sequence as we're sampling. So with each loop iteration,
00:37:30.560 | we get one more column into X. And all of the operations happen in the context manager of
00:37:35.920 | torch.nograd. This is just telling PyTorch that we're not going to be calling that backward on
00:37:40.080 | any of this. So it doesn't have to cache all the intermediate tensors, it's not going to have to
00:37:44.480 | prepare in any way for a potential backward later. And this saves a lot of space and also possibly
00:37:50.560 | some time. So we get our logits, we get the logits at only the last location, we throw away all the
00:37:57.440 | other logits, we don't need them, we only care about the last columns logits. So this is being
00:38:03.360 | wasteful. But this is just kind of like an inefficient implementation of sampling. So it's
00:38:09.680 | correct but inefficient. So we get the last column of logits, pass it through softmax to get our
00:38:15.040 | probabilities. Then here I'm doing top k sampling of 50. And I'm doing that because this is the
00:38:19.760 | HuggingFace default. So just looking at the HuggingFace docs here of a pipeline, there's a
00:38:26.880 | bunch of quarks that go into HuggingFace. And I mean, it's kind of a lot, honestly. But I guess
00:38:35.600 | the important one that I noticed is that they're using top k by default, which is 50. And what
00:38:40.240 | that does is that, so that's being used here as well. And what that does is basically we want to
00:38:46.000 | take our probabilities, and we only want to keep the top 50 probabilities. And anything that is
00:38:51.280 | lower than the 50th probability, we just clamp to zero and renormalize. And so that way, we are
00:38:57.360 | never sampling very rare tokens. The tokens we're going to be sampling are always in the top 50 of
00:39:03.280 | most likely tokens. And this helps keep the model kind of on track, and it doesn't blabber on, and
00:39:08.080 | it doesn't get lost, and doesn't go off the rails as easily. And it kind of like sticks in the
00:39:13.520 | vicinity of likely tokens a lot better. So this is the way to do it in PyTorch. And you can step
00:39:18.480 | through it if you like. I don't think it's super insightful, so I'll speed through it. But roughly
00:39:22.480 | speaking, we get this new column of tokens. We append them on x. And basically the columns of x
00:39:29.280 | grow until this while loop gets tripped up. And then finally, we have an entire x of size
00:39:35.760 | 5 by 30, in this case, in this example. And we can just basically print all those individual
00:39:44.240 | rows. So I'm getting all the rows, I'm getting all the tokens that were sampled, and I'm using
00:39:49.760 | the decode function from TickTokenizer to get back the string, which we can print. And so terminal,
00:39:56.480 | new terminal. And let me Python train GPT2.
00:40:01.600 | Okay. So these are the generations that we're getting. Hello, I'm a language model. Not a
00:40:14.160 | program. New line, new line, et cetera. Hello, I'm a language model, and one of the main things
00:40:21.120 | that bothers me when they create languages is how easy it becomes to create something that -- I mean,
00:40:25.760 | so this will just like blabber on, right, in all these cases. Now, one thing you will notice is
00:40:29.680 | that these generations are not the generations of HuggingFace here. And I can't find the discrepancy,
00:40:36.640 | to be honest, and I didn't fully go through all these options, but probably there's something
00:40:40.480 | else hiding in addition to the top P. So I'm not able to match it up. But just for correctness,
00:40:45.600 | down here below in the Jupyter Notebook, I'm using the HuggingFace model.
00:40:51.120 | So this is the HuggingFace model here. I replicated the code, and if I do this and I run that,
00:40:58.160 | then I am getting the same results. So basically, the model internals are not wrong. It's just I'm
00:41:05.520 | not 100% sure what the pipeline does in HuggingFace, and that's why we're not able to match
00:41:10.320 | them up. But otherwise, the code is correct, and we've loaded all the tensors correctly. So we're
00:41:16.720 | initializing the model correctly, and everything here works. So long story short, we've ported all
00:41:21.600 | the weights. We initialized the GPT-2. This is the exact opening at GPT-2, and it can generate
00:41:27.360 | sequences, and they look sensible. And now here, of course, we're initializing with GPT-2 model
00:41:33.040 | weights. But now we want to initialize from scratch, from random numbers, and we want to
00:41:37.680 | actually train the model that will give us sequences as good as or better than these ones
00:41:43.760 | in quality. And so that's what we turn to next. So it turns out that using the random model is
00:41:49.040 | actually fairly straightforward, because PyTorch already initializes our model randomly and by
00:41:53.520 | default. So when we create the GPT model in the constructor, all of these layers and modules
00:42:02.880 | have random initializers that are there by default. So when these linear layers get created and so on,
00:42:08.800 | there's default constructors, for example, using the Javier initialization that we saw in the past
00:42:13.520 | to construct the weights of these layers. And so creating a random model instead of a GPT-2 model
00:42:19.920 | is actually fairly straightforward. And we would just come here, and instead we would create model
00:42:25.120 | equals GPT, and then we want to use the default config, GPT config. And the default config uses
00:42:32.080 | the 124m parameters. So this is the random model initialization, and we can run it.
00:42:43.120 | And we should be able to get results. Now, the results here, of course, are total garbage garble,
00:42:49.760 | and that's because it's a random model. And so we're just getting all these random token string
00:42:53.680 | pieces chunked up totally at random. So that's what we have right now. Now, one more thing I
00:42:59.360 | wanted to point out, by the way, is in case you do not have CUDA available, because you don't have
00:43:03.360 | a GPU, you can still follow along with what we're doing here to some extent. And probably not to the
00:43:10.320 | very end, because by the end, we're going to be using multiple GPUs and actually doing a serious
00:43:14.400 | training run. But for now, you can actually follow along decently okay. So one thing that I like to
00:43:19.440 | do in PyTorch is I like to auto detect the device that is available to you. So in particular, you
00:43:24.640 | could do that like this. So here we are trying to detect the device to run on that has the highest
00:43:31.680 | compute capability. You can think about it that way. So by default, we start with CPU, which of
00:43:36.560 | course is available everywhere, because every single computer will have a CPU. But then we can
00:43:41.280 | try to detect, do you have a GPU? You still use a CUDA. And then if you don't have a CUDA, do you
00:43:47.040 | at least have MPS? MPS is the backend for Apple Silicon. So if you have a MacBook that is fairly
00:43:52.720 | new, you probably have Apple Silicon on the inside. And then that has a GPU that is actually
00:43:56.960 | fairly capable, depending on which MacBook you have. And so you can use MPS, which will be
00:44:01.600 | potentially faster than CPU. And so we can print the device here. Now once we have the device,
00:44:06.640 | we can actually use it in place of CUDA. So we just swap it in. And notice that here when we
00:44:14.400 | call model on x, if this x here is on CPU instead of GPU, then it will work fine because here in the
00:44:22.640 | forward, which is where PyTorch will come, when we create a pose, we are careful to use the device
00:44:28.880 | of IDX to create this tensor as well. And so there won't be any mismatch where one tensor is on CPU,
00:44:34.880 | one is on GPU, and that you can't combine those. But here we are carefully initializing on the
00:44:40.960 | correct device as indicated by the input to this model. So this will auto detect device. For me,
00:44:48.080 | this will be, of course, GPU. So using device CUDA. But you can also run with, as I mentioned,
00:44:58.800 | another device. And it's not going to be too much slower. So if I override device here,
00:45:02.480 | if I override device equals CPU, then we'll still print CUDA, of course,
00:45:11.120 | but now we're actually using CPU. 1, 2, 3, 4, 5, 6. Okay, about six seconds. And actually,
00:45:21.280 | we're not using Torch compile and stuff like that, which will speed up everything a lot
00:45:24.640 | faster as well. But you can follow along even on a CPU, I think, to a decent extent.
00:45:29.280 | So that's a note on that. Okay, so I do want to loop around eventually into what it means to
00:45:35.440 | have different devices in PyTorch and what it is exactly that PyTorch does in the background for
00:45:40.080 | you when you do something like module.to device, or where you take a Torch tensor and do a .to
00:45:46.240 | device, and what exactly happens and how that works. But for now, I'd like to get to training,
00:45:51.120 | and I'd like to start training the model. And for now, let's just say the device makes code go fast.
00:45:56.080 | And let's go into how we can actually train the model. So to train the model, we're going to need
00:46:01.360 | some data set. And for me, the best debugging simplest data set that I like to use is the
00:46:06.000 | tiny Shakespeare data set. And it's available at this URL, so you can wget it, or you can just
00:46:11.760 | search tiny Shakespeare data set. And so I have in my file system is just lsinput.txt. So I already
00:46:20.160 | downloaded it. And here, I'm reading the data set, getting the first 1000 characters and printing the
00:46:25.600 | first 100. Now remember that GPT-2 has roughly a compression ratio, the tokenizer has a compression
00:46:32.800 | ratio of roughly three to one. So 1000 characters is roughly 300 tokens here that will come out of
00:46:38.480 | this in the slice that we're currently getting. So this is the first few characters. And if you
00:46:46.320 | want to get a few more statistics on this, we can do word count on input.txt. So we can see that this
00:46:52.160 | is 40,000 lines, about 200,000 words in this data set, and about 1 million bytes in this file.
00:47:00.320 | And knowing that this file is only ASCII characters, there's no crazy Unicode here,
00:47:03.840 | as far as I know. And so every ASCII character is encoded with one byte. And so this is the same
00:47:09.360 | number, roughly a million characters inside this data set. So that's the data set size by default,
00:47:16.080 | very small and minimal data set for debugging. To get us off the ground, in order to tokenize
00:47:21.040 | this data set, we're going to get tick token encoding for GPT-2, encode the data, the first
00:47:29.040 | 1000 characters, and then I'm only going to print the first 24 tokens. So these are the tokens as a
00:47:35.200 | list of integers. And if you can read GPT-2 tokens, you will see that 198 here, you'll recognize that
00:47:41.360 | as the slash in character. So that is a new line. And then here, for example, we have two new lines,
00:47:46.080 | so that's 198 twice here. So this is just a tokenization of the first 24 tokens. So what
00:47:52.560 | we want to do now is we want to actually process these token sequences and feed them into a
00:47:57.760 | transformer. And in particular, we want them, we want to rearrange these tokens into this IDX
00:48:04.960 | variable that we're going to be feeding into the transformer. So we don't want a single very long
00:48:08.880 | one-dimensional sequence. We want an entire batch where each sequence is up to, is basically T
00:48:16.080 | tokens, and T cannot be larger than the maximum sequence length. And then we have these T long
00:48:23.680 | sequences of tokens, and we have B independent examples of sequences. So how can we create a
00:48:29.600 | B by T tensor that we can feed into the forward out of these one-dimensional sequences?
00:48:33.760 | So here's my favorite way to achieve this. So if we take Torch, and then we create a tensor object
00:48:41.280 | out of this list of integers and just the first 24 tokens, my favorite way to do this is basically
00:48:46.240 | you do a dot view of, for example, 4 by 6, which multiplied to 24. And so it's just a two-dimensional
00:48:54.720 | rearrangement of these tokens. And you'll notice that when you view this one-dimensional sequence
00:48:58.560 | as two-dimensional 4 by 6 here, the first six tokens up to here end up being the first row.
00:49:07.280 | The next six tokens here end up being the second row, and so on. And so basically, it's just going
00:49:13.200 | to stack up every six tokens, in this case, as independent rows, and it creates a batch of tokens
00:49:21.920 | in this case. And so for example, if we are token 25, in the transformer, when we feed this in and
00:49:28.160 | this becomes the IDX, this token is going to see these three tokens and is going to try to predict
00:49:34.000 | that 198 comes next. So in this way, we are able to create this two-dimensional batch. That's quite
00:49:41.440 | nice. Now, in terms of the label that we're going to need for the target to calculate the loss
00:49:46.720 | function, how do we get that? Well, we could write some code inside the forward pass because we know
00:49:51.760 | that the next token in a sequence, which is the label, is just to the right of us. But you'll
00:49:56.960 | notice that actually, for this token at the very end, 13, we don't actually have the next correct
00:50:02.960 | token because we didn't load it. So we actually didn't get enough information here. So I'll show
00:50:09.680 | you my favorite way of basically getting these batches. And I like to personally have not just
00:50:15.200 | the input to the transformer, which I like to call X, but I also like to create the labels tensor,
00:50:21.760 | which is of the exact same size as X but contains the targets at every single position. And so
00:50:27.360 | here's the way that I like to do that. I like to make sure that I fetch +1 token because we need
00:50:32.720 | the ground truth for the very last token, for 13. And then when we're creating the input, we take
00:50:39.920 | everything up to the last token, not including, and view it as 4 by 6. And when we're creating
00:50:45.360 | targets, we do the buffer, but starting at index 1, not index 0. So we're skipping the first element
00:50:52.720 | and we view it in the exact same size. And then when I print this,
00:50:58.000 | here's what happens, where we see that basically as an example for this token 25,
00:51:02.320 | its target was 198. And that's now just stored at the exact same position in the target tensor,
00:51:08.320 | which is 198. And also this last token 13 now has its label, which is 198. And that's just because
00:51:16.320 | we loaded this +1 here. So basically, this is the way I like to do it. You take long sequences,
00:51:23.040 | you view them in two-dimensional terms so that you get batches of time. And then we make sure
00:51:29.120 | to load one additional token. So we basically load a buffer of tokens of B times T +1. And then
00:51:36.080 | we sort of offset things and view them. And then we have two tensors. One of them is the input to
00:51:40.800 | the transformer, and the other exactly is the labels. And so let's now reorganize this code
00:51:47.120 | and create a very simple data loader object that tries to basically load these tokens and feed them
00:51:55.040 | to the transformer and calculate the loss. Okay, so I reshuffled the code here accordingly.
00:51:59.600 | So as you can see here, I'm temporarily overriding to run on CPU. And importing the token,
00:52:06.400 | and all of this should look familiar. We're loading 1,000 characters. I'm setting bt to just
00:52:10.720 | be 4 and 32 right now, just because we're debugging. We just want to have a single batch that's very
00:52:15.680 | small. And all of this should now look familiar and follows what we did on the right. And then
00:52:21.040 | here we create the model and get the logits. And so here, as you see, I already ran this. It only
00:52:29.280 | runs in a few seconds. But because we have a batch of 4 by 32, our logits are now of size 4 by 32 by
00:52:37.440 | 50,257. So those are the logits for what comes next at every position. And now we have the labels,
00:52:44.480 | which are stored in Y. So now is the time to calculate the loss, and then do the backward
00:52:48.960 | pass, and then do the optimization. So let's first calculate the loss. Okay, so to calculate the loss,
00:52:55.040 | we're going to adjust the forward function of this NN module in the model. And in particular,
00:52:59.840 | we're not just going to be returning logits, but also we're going to return the loss.
00:53:03.040 | And we're going to not just pass in the input indices, but also the targets in Y. And now we
00:53:10.640 | will print not logits.shape anymore, we're actually going to print the loss function,
00:53:15.600 | and then sys.exit of zero so that we skip some of the sampling logic.
00:53:19.200 | So now let's swing up to the forward function, which gets called there, because now we also
00:53:25.920 | have these optional targets. And when we get the targets, we can also calculate the loss. And
00:53:32.800 | remember that we want to basically return logits.loss, and loss by default is none. But
00:53:38.720 | let's put this here. If targets is not none, then we want to calculate the loss. And Copilot is
00:53:50.320 | already getting excited here and calculating the what looks to be correct loss. It is using the
00:53:55.600 | cross-entropy loss as is documented here. So this is a function in PyTorch under the functional.
00:54:03.440 | Now, what is actually happening here, because it looks a little bit scary,
00:54:07.280 | basically the FNet cross-entropy does not like multi-dimensional inputs. It can't take a b by t
00:54:13.280 | by vocab size. So what's happening here is that we are flattening out this three-dimensional tensor
00:54:18.720 | into just two dimensions. The first dimension is going to be calculated automatically, and it's
00:54:23.120 | going to be b times t. And then the last dimension is vocab size. So basically, this is flattening
00:54:29.760 | out this three-dimensional tensor of logits to just be two-dimensional, b times t, all individual
00:54:35.600 | examples, and vocab size in terms of the length of each row. And then it's also flattening out
00:54:42.640 | the targets, which are also two-dimensional at this stage, but we're going to just flatten them
00:54:47.120 | out so they're just a single tensor of b times t. And this can then pass into cross-entropy to
00:54:52.080 | calculate a loss, which we return. So this should basically, at this point, run, because it's not
00:54:58.000 | too complicated. So let's run it, and let's see if we should be printing the loss.
00:55:04.320 | And here we see that we printed 11, roughly. And notice that this is the tensor of a single
00:55:19.280 | element, which is this number 11. Now, we also want to be able to calculate a reasonable
00:55:23.840 | kind of starting point for a randomly initialized network. So we covered this in previous videos,
00:55:29.040 | but our vocabulary size is 50,257. At initialization of the network, you would hope
00:55:35.360 | that every vocab element is getting roughly a uniform probability, so that we're not favoring,
00:55:42.960 | at initialization, any token way too much. We're not confidently wrong at initialization. So we're
00:55:49.120 | hoping is that the probability of any arbitrary token is roughly 1/50,257. And now we can sanity
00:55:56.160 | check the loss, because remember that the cross-entropy loss is just basically the negative
00:56:00.320 | log likelihood. So if we now take this probability, and we take it through the natural logarithm,
00:56:07.120 | and then we do the negative, that is the loss we expect at initialization, and we covered this in
00:56:13.120 | previous videos. So I would expect something around 10.82, and we're seeing something around
00:56:18.080 | 11. So it's not way off. This is roughly the probability I expect at initialization.
00:56:22.640 | So that tells me that at initialization, our probability distribution is roughly diffuse,
00:56:27.520 | it's a good starting point, and we can now perform the optimization and tell the network which
00:56:33.120 | elements should follow correctly in what order. So at this point, we can do a loss step backward,
00:56:39.120 | calculate the gradients, and do an optimization. So let's get to that.
00:56:42.720 | Okay, so let's do the optimization now. So here we have the loss, this is how we get the loss.
00:56:51.120 | But now basically we want a little for loop here. So for i in range, let's do 50 steps or
00:56:56.560 | something like that. Let's create an optimizer object in PyTorch. And so here we are using the
00:57:04.320 | atom optimizer, which is an alternative to the stochastic gradient descent optimizer, SGD,
00:57:10.320 | that we were using. So SGD is a lot simpler, atom is a bit more involved. And I actually
00:57:14.320 | specifically like the atom w variation, because in my opinion, it kind of just like fixes a bug.
00:57:19.920 | So atom w is a bug fix of atom, is what I would say. When we go to the documentation for atom w,
00:57:26.800 | oh my gosh, we see that it takes a bunch of hyperparameters, and it's a little bit more
00:57:34.080 | complicated than the SGD we were looking at before. Because in addition to basically updating
00:57:39.120 | the parameters with the gradient scaled by the learning rate, it keeps these buffers around,
00:57:44.000 | and it keeps two buffers, the M and the V, which it calls the first and the second moment.
00:57:49.280 | So something that looks a bit like momentum is something that looks a bit like RMS prop,
00:57:53.120 | if you're familiar with it. But you don't have to be, it's just kind of like a normalization
00:57:57.120 | that happens on each gradient element individually, and speeds up the optimization,
00:58:01.600 | especially for language models. But I'm not going to go into the detail right here.
00:58:05.360 | We're going to treat this a bit of a black box, and it just optimizes the objective faster than
00:58:12.000 | SGD, which is what we've seen in the previous lectures. So let's use it as a black box in our
00:58:16.560 | case. Create the optimizer object, and then go through the optimization.
00:58:23.680 | The first thing to always make sure, the copilot did not forget to zero the gradients.
00:58:33.440 | So always remember that you have to start with a zero gradient. Then when you get your loss,
00:58:38.880 | and you do a dot backward, dot backward adds to gradients. So it deposits gradients. It always
00:58:44.800 | does a plus equals on whatever the gradients are, which is why you must set them to zero.
00:58:48.640 | So this accumulates the gradient from this loss, and then we call the step function on the optimizer
00:58:55.280 | to update the parameters, and to decrease the loss. Then we print the step, and the loss
00:59:03.760 | dot item is used here, because loss is a tensor with a single element. Dot item will actually
00:59:09.680 | convert that to a single float, and this float will live on the CPU. So this gets to some of
00:59:16.880 | the internals, again, of the devices, but loss is a tensor with a single element, and it lives on
00:59:22.800 | GPU for me, because I'm using GPUs. When you call dot item, PyTorch behind the scenes will take that
00:59:29.120 | one-dimensional tensor, ship it back to the CPU memory, and convert it into a float that we can
00:59:34.240 | just print. So this is the optimization, and this should probably just work. Let's see what happens.
00:59:43.680 | Actually, sorry. Instead of using CPU override, let me delete that so this is a bit faster for me,
00:59:51.600 | and it runs on CUDA. Oh, expected all tensors to be on the same device, but found at least two
01:00:02.080 | devices, CUDA0 and CPU. So CUDA0 is the 0th GPU, because I actually have eight GPUs on this box.
01:00:09.280 | So the 0th GPU on my box, and CPU. And model, we have moved to device, but when I was writing this
01:00:17.440 | code, I actually introduced a bug, because buff, we never moved to device. And you have to be
01:00:22.640 | careful, because you can't just do buff dot two of device. It's not stateful. It doesn't convert
01:00:30.240 | it to be a device. It instead returns a pointer to a new memory, which is on the device. So you
01:00:36.960 | see how we can just do model dot two of device, but it does not apply to tensors. You have to do
01:00:41.360 | buff equals buff dot two device, and then this should work. Okay. So what do we expect to see?
01:00:51.920 | We expect to see a reasonable loss in the beginning, and then we continue to optimize
01:00:55.600 | just a single batch. And so we want to see that we can overfit this single batch. We can crush
01:01:00.720 | this little batch, and we can perfectly predict the indices on just this little batch. And in
01:01:05.520 | these, that is roughly what we're seeing here. So we started off at roughly 10.82, 11 in this case,
01:01:13.440 | and then as we continue optimizing on this single batch without loading new examples,
01:01:17.040 | we are making sure that we can overfit a single batch, and we are getting to very,
01:01:20.800 | very low loss. So the transformer is memorizing this single individual batch.
01:01:25.280 | And one more thing I didn't mention is the learning rate here is 3E negative 4,
01:01:30.240 | which is a pretty good default for most optimizations that you want to run at a very
01:01:35.360 | early debugging stage. So this is our simple inner loop, and we are overfitting a single batch,
01:01:42.560 | and this looks good. So now what comes next is we don't just want to overfit a single batch,
01:01:47.200 | we actually want to do an optimization. So we actually need to iterate these XY batches
01:01:52.000 | and create a little data loader that makes sure that we're always getting a fresh batch,
01:01:56.560 | and that we're actually optimizing a reasonable objective. So let's do that next.
01:02:00.240 | Okay, so this is what I came up with, and I wrote a little data loader light.
01:02:03.360 | So what this data loader does is we're importing the token up here,
01:02:08.320 | reading the entire text file from this single input.txt, tokenizing it, and then we're just
01:02:14.720 | printing the number of tokens in total, and the number of batches in a single epoch of iterating
01:02:20.800 | over this dataset. So how many unique batches do we output before we loop back around at the
01:02:25.840 | beginning of the document and start reading it again? So we start off at position 0, and then
01:02:31.600 | we simply walk the document in batches of B times T. So we take chunks of B times T, and then always
01:02:37.920 | advance by B times T. And it's important to note that we're always advancing our position by exactly
01:02:44.800 | B times T, but when we're fetching the tokens, we're actually fetching from current position
01:02:49.840 | to B times T plus 1. And we need that plus 1 because remember, we need the target token for
01:02:56.960 | the last token in the current batch. And so that way we can do the XY exactly as we did it before.
01:03:04.560 | And if we are to run out of data, we'll just loop back around to 0. So this is one way to write a
01:03:12.640 | very, very simple data loader that simply just goes through the file in chunks. And it's good
01:03:18.960 | enough for us for current purposes. And we're going to complexify it later. And now we'd like
01:03:24.800 | to come back around here, and we'd like to actually use our data loader. So the import
01:03:28.800 | tick token has moved up. And actually, all of this is now useless. So instead, we just want a train
01:03:34.960 | loader for the training data. And we want to use the same hyperparameters for 4. So batch size was
01:03:41.600 | 4 and time was 32. And then here, we need to get the XY for the current batch. So let's see if
01:03:49.120 | Copilot gets it, because this is simple enough. So we call the next batch. And then we make sure
01:03:56.080 | that we have to move our tensors from CPU to the device. So here, when I converted the tokens,
01:04:05.040 | notice that I didn't actually move these tokens to the GPU. I left them on the CPU, which is default.
01:04:11.600 | And that's just because I'm trying not to waste too much memory on the GPU. In this case,
01:04:16.320 | this is a tiny data set that it would fit. But it's fine to just ship it to GPU right now for
01:04:22.320 | our purposes right now. So we get the next batch. We keep the data loader simple CPU class. And then
01:04:28.880 | here, we actually ship it to the GPU and do all the computation. And let's see if this runs.
01:04:35.920 | So Python train GPT 2.py. And what do we expect to see before this actually happens?
01:04:42.240 | What we expect to see is now we're actually getting the next batch. So we expect to not
01:04:46.560 | overfit a single batch. And so I expect our loss to come down, but not too much. And that's because
01:04:54.240 | I still expect it to come down because in the 50,257 tokens, many of those tokens never occur
01:05:00.480 | in our data set. So there are some very easy gains to be made here in the optimization by,
01:05:05.280 | for example, taking the biases of all the logits that never occur and driving them to negative
01:05:09.920 | infinity. And that would basically just, it's just that all of these crazy Unicodes or different
01:05:14.800 | languages, those tokens never occur. So their probability should be very low. And so the gains
01:05:19.360 | that we should be seeing are along the lines of basically deleting the usage of tokens that never
01:05:24.880 | occur. That's probably most of the loss gain that we're going to see at this scale right now.
01:05:29.920 | But we shouldn't come to a zero because we are only doing 50 iterations. And I don't think that's
01:05:37.360 | enough to do an epoch right now. So let's see what we got. We have 338,000 tokens, which makes sense
01:05:46.320 | with our three to one compression ratio, because there are 1 million characters. So one epoch with
01:05:52.640 | the current setting of B and T will take 2,600 batches. And we're only doing 50 batches of
01:05:58.880 | optimization in here. So we start off in a familiar territory as expected, and then we seem
01:06:05.600 | to come down to about 6.6. So basically, things seem to be working okay right now with respect
01:06:12.160 | to our expectations. So that's good. Okay, next, I want to actually fix a bug that we have in our
01:06:17.520 | code. It's not a major bug, but it is a bug with respect to how GPT-2 training should happen.
01:06:26.240 | So the bug is the following. We were not being careful enough when we were loading the weights
01:06:30.480 | from Hug and Face, and we actually missed a little detail. So if we come here,
01:06:34.400 | notice that the shape of these two tensors is the same. So this one here is the token embedding at
01:06:43.520 | the bottom of the transformer. And this one here is the language modeling head at the top of the
01:06:50.560 | transformer. And both of these are basically two-dimensional tensors, and their shape is
01:06:56.080 | identical. So here, the first one is the output embedding, the token embedding, and the second
01:07:02.240 | one is this linear layer at the very top, the classifier layer. Both of them are of shape
01:07:07.520 | 50,257 by 768. This one here is giving us our token embeddings at the bottom, and this one
01:07:16.080 | here is taking the 768 channels of the transformer and trying to upscale that to 50,257 to get the
01:07:23.680 | logis for the next token. So they're both the same shape, but more than that, actually, if you look
01:07:29.520 | at comparing their elements, in PyTorch, this is an element-wise equality. So then we use .all,
01:07:37.360 | and we see that every single element is identical. And more than that, we see that if we actually
01:07:42.880 | look at the data pointer, this is a way in PyTorch to get the actual pointer to the data
01:07:49.280 | and the storage, we see that actually the pointer is identical. So not only are these two separate
01:07:54.480 | tensors that happen to have the same shape and elements, they're actually pointing to the
01:07:58.480 | identical tensor. So what's happening here is that this is a common wait-time scheme
01:08:04.480 | that actually comes from the original "Attention is all you need" paper,
01:08:11.360 | and actually even the reference before it. So if we come here...
01:08:15.040 | Embeddings in Softmax in the "Attention is all you need" paper, they mention that in our model,
01:08:25.920 | we shared the same weight matrix between the two embedding layers and the pre-Softmax linear
01:08:30.720 | transformation similar to 30. So this is an awkward way to phrase that these two are shared
01:08:37.520 | and they're tied and they're the same matrix. And the 30 reference is this paper. So this came out
01:08:44.480 | in 2017. And you can read the full paper, but basically it argues for this wait-time scheme.
01:08:50.400 | And I think intuitively the idea for why you might want to do this
01:08:54.880 | comes from this paragraph here. And basically, you can observe that
01:09:03.280 | you actually want these two matrices to behave similar in the following sense. If two tokens
01:09:10.080 | are very similar semantically, like maybe one of them is all lowercase and the other one is
01:09:14.560 | all uppercase, or it's the same token in a different language or something like that,
01:09:18.240 | if you have similarity between two tokens, presumably you would expect that they are
01:09:22.160 | nearby in the token embedding space. But in the exact same way, you'd expect that if you
01:09:27.520 | have two tokens that are similar semantically, you'd expect them to get the same probabilities
01:09:33.120 | at the output of a transformer because they are semantically similar.
01:09:36.080 | And so both positions in the transformer at the very bottom and at the top have this property
01:09:43.680 | that similar tokens should have similar embeddings or similar weights. And so this is what motivates
01:09:50.400 | their exploration here. And they kind of, you know, I don't want to go through the entire paper
01:09:54.400 | and you can go through it, but this is what they observe. They also observe that if you look at the
01:10:00.240 | output embeddings, they also behave like word embeddings. If you just kind of try to use those
01:10:06.880 | weights as word embeddings. So they kind of observe this similarity, they try to tie them,
01:10:12.800 | and they observe that they can get much better performance in that way. And so this was adopted
01:10:18.160 | in the attention is only paper, and then it was used again in GPT-2 as well. So I couldn't find
01:10:25.760 | it in the transformers implementation. I'm not sure where they tie those embeddings,
01:10:30.080 | but I can find it in the original GPT-2 code introduced by OpenAI. So this is OpenAI GPT-2
01:10:37.840 | source model. And here where they are forwarding this model, and this is in TensorFlow, but
01:10:42.880 | that's okay. We see that they get the WTE token embeddings. And then here is the encoder of the
01:10:50.480 | token embeddings and the position. And then here at the bottom, they use the WTE again to do the
01:10:57.040 | logits. So when they get the logits, it's a matmul of this output from the transformer and the WTE
01:11:03.920 | tensor is reused. And so the WTE tensor basically is used twice on the bottom of the transformer
01:11:10.880 | and on the top of the transformer. And in the backward pass, we'll get gradients contributions
01:11:16.640 | from both branches, right? And these gradients will add up on the WTE tensor. So we'll get a
01:11:23.760 | contribution from the classifier layer. And then at the very end of the transformer, we'll get a
01:11:27.760 | contribution at the bottom of it, flowing again into the WTE tensor. So we are currently not
01:11:36.960 | sharing WTE in our code, but we want to do that. So weight sharing scheme. And one way to do this,
01:11:48.720 | let's see if Copilot gets it. Oh, it does. Okay. So this is one way to do it. Basically,
01:11:56.720 | relatively straightforward. What we're doing here is we're taking the WTE.weight and we're simply
01:12:03.920 | redirecting it to point to the LM head. So this basically copies the data pointer,
01:12:12.240 | right? It copies the reference. And now the WTE.weight becomes orphaned, the old value of it,
01:12:18.800 | and PyTorch will clean it up. Python will clean it up. And so we are only left with a single
01:12:25.440 | tensor, and it's going to be used twice in the forward pass. And this is, to my knowledge,
01:12:32.720 | all that's required. So we should be able to use this, and this should probably train.
01:12:37.440 | We're just going to basically be using this exact same tensor twice. And we weren't being careful
01:12:45.920 | with tracking the likelihoods, but according to the paper and according to the results,
01:12:50.240 | you'd actually expect slightly better results doing this. And in addition to that, one other
01:12:54.480 | reason that this is very, very nice for us is that this is a ton of parameters, right?
01:12:59.600 | What is the size of here? It's 768 times 50,257. So this is 40 million parameters.
01:13:07.200 | And this is a 124 million parameter model. So 40 divide 124. So this is like 30% of the
01:13:14.000 | parameters are being saved using this weight tying scheme. And so this might be one of the
01:13:19.440 | reasons that this is working slightly better. If you're not training the model long enough,
01:13:23.120 | because of the weight tying, you don't have to train as many parameters. And so you become more
01:13:27.760 | efficient in terms of the training process, because you have fewer parameters and you're
01:13:33.200 | putting in this inductive bias that these two embeddings should share similarities between
01:13:38.560 | tokens. So this is the weight tying scheme, and we've saved a ton of parameters. And we expect
01:13:44.400 | our model to work slightly better because of this scheme. Okay, next, I would like us to be a bit
01:13:48.480 | more careful with the initialization and to try to follow the way GPT-2 initialized their model.
01:13:53.680 | Now, unfortunately, the GPT-2 paper and the GPT-3 paper are not very explicit about
01:13:58.640 | initialization. So we kind of have to read between the lines. And instead of going to the paper,
01:14:03.280 | which is quite vague, there's a bit of information in the code that OpenAI released. So when we go
01:14:09.040 | to the model.py, we see that when they initialize their weights, they are using the standard
01:14:14.560 | deviation of 0.02. And that's how they, so this is a normal distribution for the weights,
01:14:21.600 | and the standard deviation is 0.02. For the bias, they initialize that with zero.
01:14:26.560 | And then when we scroll down here, why is this not scrolling? The token embeddings are
01:14:35.920 | initialized at 0.02, and position embeddings at 0.01 for some reason. So those are the
01:14:42.320 | initializations, and we'd like to mirror that in GPT-2 in our module here. So here's a snippet of
01:14:48.720 | code that I sort of came up with very quickly. So what's happening here is at the end of our
01:14:56.400 | initializer for the GPT module, we're calling the apply function of NNModule, and that iterates all
01:15:01.920 | the sub-modules of this module, and applies init_weights function on them. And so what's
01:15:09.440 | happening here is that we're iterating all the modules here, and if they are an NN.linear module,
01:15:16.560 | then we're going to make sure to initialize the weight using a normal with a standard deviation
01:15:20.640 | of 0.02. If there's a bias in this layer, we will make sure to initialize that to zero. Note that
01:15:27.200 | zero initialization for the bias is not actually the PyTorch default.
01:15:30.160 | By default, the bias here is initialized with a uniform, so that's interesting. So we make sure
01:15:37.520 | to use zero. And for the embedding, we're just going to use 0.02 and keep it the same. So we're
01:15:44.080 | not going to change it to 0.01 for positional, because it's about the same. And then if you look
01:15:48.960 | through our model, the only other layer that requires initialization, and that has parameters,
01:15:53.520 | is the layer norm. And the PyTorch default initialization sets the scale in the layer norm
01:15:58.240 | to be one, and the offset in the layer norm to be zero. So that's exactly what we want,
01:16:02.560 | and so we're just going to keep it that way. And so this is the default initialization
01:16:08.480 | if we are following the, where is it, the GPT-2 source code that they released.
01:16:16.240 | I would like to point out, by the way, that typically the standard deviation here on this
01:16:21.760 | initialization, if you follow the Javier initialization, would be one over the square
01:16:25.440 | root of the number of features that are incoming into this layer. But if you'll notice, actually,
01:16:31.120 | 0.02 is basically consistent with that, because the d model sizes inside these transformers for
01:16:36.800 | GPT-2 are roughly 768, 1600, etc. So one over the square root of, for example, 768 gives us 0.03.
01:16:44.720 | If we plug in 1600, we get 0.02. If we plug in three times that, 0.014, etc. So basically 0.02
01:16:55.520 | is roughly in the vicinity of reasonable values for these initializations anyway. So it's not
01:17:04.640 | completely crazy to be hard coding 0.02 here, but you'd like typically something that grows with the
01:17:12.160 | model size instead. But we will keep this because that is the GPT-2 initialization per their source
01:17:17.120 | code. But we are not fully done yet on initialization, because there's one more caveat
01:17:21.120 | here. So here, a modified initialization which accounts for the accumulation on the residual
01:17:27.520 | path with model depth is used. We scale the weight of residual layers of initialization
01:17:32.160 | by a factor of one over square root of n, where n is the number of residual layers.
01:17:35.520 | So this is what GPT-2 paper says. So we have not implemented that yet, and we can do so now.
01:17:41.760 | Now, I'd like to actually kind of like motivate a little bit what they mean here, I think. So
01:17:47.440 | here's roughly what they mean. If you start out with zeros in your residual stream, remember that
01:17:54.320 | each residual stream is of this form, where we continue adding to it. x is x plus something,
01:18:01.520 | some kind of contribution. So every single block of the residual network contributes some
01:18:07.040 | amount, and it gets added. And so what ends up happening is that the variance of the activations
01:18:15.600 | in the residual stream grows. So here's a small example. If we start at zero, and then we for 100
01:18:21.840 | times, we have sort of this residual stream of 768 zeros. And then 100 times, we add random,
01:18:31.040 | which is a normal distribution, zero mean, one standard deviation. If we add to it, then by the
01:18:37.120 | end, the residual stream has grown to have standard deviation of 10. And that's just because we're
01:18:43.600 | always adding these numbers. And so this scaling factor that they use here exactly compensates for
01:18:52.080 | that growth. So if we take n, and we basically scale down every one of these contributions into
01:18:59.600 | the residual stream by one over the square root of n. So one over the square root of n is n to the
01:19:05.200 | negative 0.5, right? Because n to the 0.5 is the square root, and then one over the square root is
01:19:12.800 | n negative 0.5. If we scale it in this way, then we see that we actually get one. So this is a way
01:19:23.120 | to control the growth of activations inside the residual stream in the forward pass. And so we'd
01:19:29.040 | like to initialize in the same way, where these weights that are at the end of each block, so this
01:19:34.960 | CPROJ layer, the GPT paper proposes to scale down those weights by one over the square root of the
01:19:42.240 | number of residual layers. So one crude way to implement this is the following. I don't know if
01:19:48.240 | this is PyTorch-sanctioned, but it works for me, is we all do in the initialization, see that
01:19:56.320 | special nano-GPT scale in it is one. So we're setting kind of like a flag for this module.
01:20:06.880 | There must be a better way than PyTorch, right? But I don't know. Okay, so we're basically
01:20:13.200 | attaching this flag and trying to make sure that it doesn't conflict with anything previously.
01:20:17.920 | And then when we come down here, this STD should be 0.02 by default.
01:20:24.880 | But then if it has at her module of this thing, then STD times equals.
01:20:33.600 | Cobalt is not guessing correctly. So we want one over the square root of the number of layers.
01:20:42.320 | So the number of residual layers here is twice times self.config layers, and then this times
01:20:54.560 | negative 0.5. So we want to scale down that standard deviation, and this should be correct
01:21:01.360 | and implement that. I should clarify, by the way, that the two times number of layers comes from the
01:21:06.000 | fact that every single one of our layers in the transformer actually has two blocks that add to
01:21:10.800 | the residual pathway, right? We have the attention and then the MLP. So that's where the two times
01:21:15.280 | comes from. And the other thing to mention is that what's slightly awkward, but we're not going to
01:21:21.520 | fix it, is that because we are weight sharing the WTE and the LMHead, in this iteration of our old
01:21:29.600 | submodules, we're going to actually come around to that tensor twice. So we're going to first
01:21:33.920 | initialize it as an embedding with 0.02, and then we're going to come back around it again in the
01:21:38.800 | linear and initialize it again using 0.02. And it's going to be 0.02 because the LMHead is, of
01:21:44.720 | course, not scaled. So it's not going to come here. It's just it's going to be basically initialized
01:21:49.680 | twice using the identical same initialization, but that's okay. And then scrolling over here,
01:21:55.920 | I added some code here so that we have reproducibility to set the seeds. And now
01:22:03.680 | we should be able to Python train GPT2.py and let this running. And as far as I know,
01:22:09.360 | this is the GPT2 initialization in the way we've implemented it right now. So this looks
01:22:17.360 | reasonable to me. Okay. So at this point, we have the GPT2 model. We have some confidence that it's
01:22:22.640 | correctly implemented. We've initialized it properly. And we have a data loader that's
01:22:26.400 | iterating through data batches, and we can train. So now comes the fun part. I'd like us to speed
01:22:31.440 | up the training by a lot. So we're getting our money's worth with respect to the hardware that
01:22:35.360 | we are using here. And we're going to speed up the training by quite a bit. Now, you always want to
01:22:42.400 | start with what hardware do you have? What does it offer? And are you fully utilizing it? So in my
01:22:47.520 | case, if we go to NVIDIA SMI, we can see that I have eight GPUs. And each one of those GPUs is an
01:22:57.760 | A100 SXM 80 gigabytes. So this is the GPU that I have available to me in this box. Now, when I use
01:23:07.680 | to spin up these kinds of boxes, by the way, my favorite place to go to is Lambda Labs. They do
01:23:14.160 | sponsor my development and that of my projects. But this is my favorite place to go. And this is
01:23:20.320 | where you can spin up one of these machines, and you pay per hour, and it's very, very simple.
01:23:24.080 | So I like to spin them up and then connect VS Code to it, and that's how I develop.
01:23:28.320 | Now, when we look at the A100s that are available here, A100 80 gigabyte SXM is the
01:23:36.880 | GPU that I have here. And we have a bunch of numbers here for how many calculations you can
01:23:42.720 | expect out of this GPU. So when I come over here and I break in right after here. So Python. So
01:23:52.480 | I'm breaking in right after we calculate the logits and the loss. And the interesting thing I'd like
01:23:57.680 | you to note is when I do logits.dtype, this prints a torch.float32. So by default in PyTorch, when
01:24:06.960 | you create tensors, and this is the case for all the activations and for the parameters of the
01:24:11.280 | network and so on, by default, everything is in float32. That means that every single number,
01:24:17.680 | activation or weight and so on, is using a float representation that has 32 bits.
01:24:24.560 | And that's actually quite a bit of memory. And it turns out empirically that for deep learning
01:24:28.400 | as a computational workload, this is way too much. And deep learning and the training of
01:24:33.040 | these networks can tolerate significantly lower precisions. Not all computational workflows
01:24:38.960 | can tolerate small precision. So for example, if we go back to the data sheet, you'll see that
01:24:46.000 | actually these GPUs support up to FB64. And this is quite useful, I understand, for a lot of
01:24:52.080 | scientific computing applications. And there they really need this. But we don't need that much
01:24:56.720 | precision for deep learning training. So currently we are here, FP32. And with this code as it is
01:25:03.920 | right now, we expect to get at most 19.5 teraflops of performance. That means we're doing 19.5
01:25:11.680 | trillion operations, floating point operations. So this is floating point multiply, add, most
01:25:19.520 | likely. And so these are the floating point operations. Now notice that if we are willing
01:25:27.360 | to go down in precision, so TF32 is a lower precision format we're going to see in a second,
01:25:33.360 | you can actually get an 8x improvement here. And if you're willing to go down to float16 or
01:25:38.160 | bfloat16, you can actually get times 16x performance, all the way to 312 teraflops.
01:25:45.520 | You see here that NVIDIA likes to cite numbers that have an asterisk here.
01:25:49.120 | This asterisk says with sparsity. But we are not going to be using sparsity in our code. And I
01:25:55.120 | don't know that this is very widely used in the industry right now. So most people look at this
01:25:59.840 | number here without sparsity. And you'll notice that we could have got even more here. But this
01:26:05.600 | is int8. And int8 is used for inference, not for training. Because int8 has a,
01:26:12.880 | it basically has uniform spacing. And we actually require a float so that we get a better match
01:26:24.320 | to the normal distributions that occur during training of neural networks, where both
01:26:30.160 | activations and weights are distributed as a normal distribution. And so floating points are
01:26:35.280 | really important to match that representation. So we're not typically using int8 for training,
01:26:42.160 | but we are using it for inference. And if we bring down the precision, we can get a lot more
01:26:47.680 | teraflops out of the tensor cores available in the GPUs. We'll talk about that in a second.
01:26:52.400 | But in addition to that, if all of these numbers have fewer bits of representation, it's going to
01:26:58.080 | be much easier to move them around. And that's where we start to get into the memory bandwidth
01:27:02.640 | and the memory of the model. So not only do we have a finite capacity of the number of bits that
01:27:07.760 | our GPU can store, but in addition to that, there's a speed with which you can access this memory.
01:27:13.440 | And you have a certain memory bandwidth. It's a very precious resource. And in fact, many of the
01:27:20.080 | deep learning workloads for training are memory bound. And what that means is actually that the
01:27:25.520 | tensor cores that do all these extremely fast multiplications, most of the time they're waiting
01:27:30.400 | around, they're idle, because we can't feed them with data fast enough. We can't load the data
01:27:36.880 | fast enough for memory. So typical utilizations of your hardware, if you're getting 60% utilization,
01:27:42.800 | you're actually doing extremely well. So half of the time in a well-tuned application,
01:27:48.400 | your tensor cores are not doing multiplies because the data is not available.
01:27:51.840 | So the memory bandwidth here is extremely important as well. And if we come down in
01:27:56.320 | the precision for all the floats, all the numbers, weights, and activations suddenly require less
01:28:01.920 | memory. So we can store more and we can access it faster. So everything speeds up and it's amazing.
01:28:08.320 | And now let's reap the benefits of it. And let's first look at the tensor float 32 format.
01:28:14.080 | Okay, so first of all, what are tensor cores? Well, tensor core is just an instruction
01:28:20.800 | in the A100 architecture, right? So what it does is it does basically a little 4x4 matrix multiply.
01:28:28.000 | So this is just matrix multiplication here of 4x4 matrices. And there are multiple configurations
01:28:37.200 | as to what precision any of these matrices are, in what precision the internal accumulate happens,
01:28:43.520 | and then what is the output precision, input precision, etc. So there's a few switches,
01:28:48.560 | but it's basically a 4x4 multiply. And then any time we have any operations that require matrix
01:28:54.000 | multiplication, they get broken up into this instruction of a little 4x4 multiply. And so
01:29:01.040 | everything gets broken up into this instruction because it's the fastest way to multiply matrices.
01:29:05.120 | And it turns out that most of the computational work that we're doing up above,
01:29:08.960 | all of it really is matrix multiplication. Most of the work computationally happens in
01:29:14.080 | the linear layers, linear, linear, etc. There's a few things sandwiched in between. So there's
01:29:22.400 | some additions in residuals, there's some Gelud nonlinearities, there's some layer norms, etc.
01:29:27.760 | But if you just time them, you'll see that these are nothing. Like basically,
01:29:31.760 | the entire transformer is just a bunch of matrix multiplications, really.
01:29:35.680 | And especially at this small scale, 124 million parameter model, actually the biggest matrix
01:29:42.720 | multiplication by far is the classifier layer at the top. That is a massive matrix multiply
01:29:47.840 | of going from 768 to 50,257. And that matrix multiply dominates anything else that happens
01:29:54.880 | in that network, roughly speaking. So it's matrix multiplies that become a lot faster,
01:30:00.720 | which are hidden inside our linear layers. And they're accelerated through tensor cores.
01:30:05.680 | Now, the best reference I would say for tensor cores is basically just go to the
01:30:11.680 | A100 architecture whitepaper. And then it's pretty detailed. But I think people,
01:30:17.760 | it's like relatively readable mostly, if you half understand what's happening.
01:30:21.280 | So figure 9 tensor float 32. So this is the explanation basically for TF32 and what happens
01:30:29.760 | here. And you see that there's many configuration options here available. So the input operands,
01:30:35.840 | and what precisions are they in, the accumulator, and what basically the
01:30:41.680 | internal representation within the instruction when you do the accumulate of this matrix
01:30:47.840 | multiplication. So the intermediate plus equals of the intermediate little vector multiplies here,
01:30:54.880 | that all happens in FP32. And then this is an 8x improvement, as I mentioned, to the
01:31:01.520 | ops that we got. So TF32 specifically, we're looking at this row here. And the way this works is
01:31:06.960 | normally FP32 has 32 bits. TF32 is the exact same bits. We have one sine bit,
01:31:19.040 | we have eight exponent bits, except the mantissa bits get cropped in the float.
01:31:25.120 | And so basically, we end up with just 19 bits, instead of 32 bits, because the last 13 bits
01:31:32.640 | get truncated, they get dropped. And all this is internal to the instruction. So none of it is
01:31:39.120 | visible to anything in our PyTorch. None of our PyTorch code will change, all the numbers will
01:31:44.400 | look identical. It's just that when you call the tensor core instruction, internally in the hardware,
01:31:51.280 | it will crop out these 13 bits. And that allows it to calculate this little matrix multiply
01:31:58.320 | significantly faster, 8x faster. Now, of course, this speed up comes at a cost. And the cost is
01:32:04.800 | that we are reducing the precision. Our accumulate is still in FP32, our output is FP32, our inputs
01:32:10.880 | are FP32. But internally, things get truncated in the operands to perform the operation faster.
01:32:17.440 | And so our results are starting to be a bit more approximate. But empirically, when you actually
01:32:21.440 | train with this, you basically can't tell the difference. So the reason I like TF32 is because
01:32:26.080 | if you can tolerate a little bit of a precision fudge, then this is free, like none of your code
01:32:33.120 | sees this, it's fully internal to the operation, and the operation to you just go 8x faster,
01:32:39.120 | and it's a bit more approximate. And so it's a pretty sweet spot, I would say in optimization.
01:32:44.480 | And let's see what that looks like first. So I've set up our codes to just time the iterations.
01:32:50.720 | So import time, I changed the hyper parameters so that we have something a bit more that reflects
01:32:56.240 | a kind of workload that we want to run, because we want to do a fairly large run at the end of this.
01:33:01.040 | So let's use batch size 16. And let's now use the actual GPT-2 maximum sequence length of 1024
01:33:08.000 | tokens. So this is the configuration. And then for 50 iterations, I'm just doing something very lazy
01:33:16.640 | here. I'm doing time.time to get the current time. And then this is the optimization loop.
01:33:22.240 | And now I want to time how long this takes. Now, one issue with working with GPUs is that
01:33:31.040 | as your CPU-- when your CPU runs, it's just scheduling work on GPU. It's ordering some work,
01:33:38.640 | right? And so it sends a request, and then it continues running. And so it can happen sometimes
01:33:44.400 | that we sort of speed through this, and we queue up a lot of kernels to run on the GPU,
01:33:51.360 | and then the CPU sort of gets here and takes time.time. But actually, the GPU is still running,
01:33:56.160 | because it takes it time to actually work through the work that was scheduled to run.
01:34:00.560 | And so you're just building up a queue for the GPU. And so actually, if you need to,
01:34:06.080 | you want to wait, torchatku.data.synchronize. And this will wait for the GPU to finish all the work
01:34:11.680 | that was scheduled to run up above here. And then we can actually take the time. So basically,
01:34:17.840 | we're waiting for the GPU to stop this iteration, take the time, and then we're going to just print
01:34:23.040 | it. So here, I'm going to run the training loop. And here on the right, I'm watching NVIDIA SMI.
01:34:30.400 | So we start off at 0. We're not using the GPU. And then by default, PyTorch will use GPU 0,
01:34:36.640 | so we see that it gets filled up. And we're using 35 gigabytes out of 80 gigabytes available.
01:34:42.160 | And then here on the left, we see that because we've cranked up the batch size,
01:34:49.840 | now it's only 20 batches to do a single epoch on our tiny Shakespeare.
01:34:53.120 | And we see that we're seeing roughly 1,000 milliseconds per iteration here, right?
01:34:58.160 | So the first iteration sometimes is slower. And that's because PyTorch might be doing a lot of
01:35:05.200 | initializations here on the very first iteration. And so it's probably initializing all these
01:35:09.760 | tensors and buffers to hold all the gradients. And I'm not 100% sure all the work that happens here,
01:35:14.960 | but this could be a slower iteration. When you're timing your logic, you always want
01:35:18.960 | to be careful with that. But basically, we're seeing 1,000 milliseconds per iteration.
01:35:23.360 | And so this will run for roughly 50 seconds as we have it right now.
01:35:28.000 | So that's our baseline in float32. One more thing I wanted to mention is that
01:35:33.760 | if this doesn't fit into your GPU and you're getting out of memory errors,
01:35:38.080 | then start decreasing your batch size until things fit. So instead of 16, try 8 or 4
01:35:43.120 | or whatever you need to fit the batch into your GPU. And if you have a bigger GPU, you can actually
01:35:49.040 | potentially get away with 32 and so on. By default, you want to basically max out the batch size that
01:35:55.600 | fits on your GPU. And you want to keep it nice numbers. So use numbers that have lots of powers
01:36:01.680 | of 2 in them. So 16 is a good number. 8, 24, 32, 48, these are nice numbers. But don't use something
01:36:09.600 | like 17, because that will run very inefficiently on the GPU. And we're going to see that a bit
01:36:14.560 | later as well. So for now, let's just stick with 16, 1,024. And the one thing that I added also
01:36:22.320 | here, and I ran it again, is I'm calculating tokens per second throughput during training.
01:36:28.960 | Because we might end up changing the batch size around over time. But tokens per second is the
01:36:34.720 | objective measure that we actually really care about. How many tokens of data are we training
01:36:38.960 | on? And what is the throughput of tokens that we're getting in our optimization? So right now,
01:36:43.280 | we're processing and training on 163,000 tokens per second, roughly. And that's a bit more
01:36:49.360 | objective metric. Okay, so let's now enable TF32. Now, luckily, PyTorch makes this fairly easy for
01:36:56.240 | us. And to enable TF32, you just need to do a single line. And it's this. And when we go to the
01:37:03.600 | PyTorch documentation here for this function, basically, this tells PyTorch what kind of
01:37:08.000 | kernels to run. And by default, I believe it is highest. Highest precision for matmul. And that
01:37:15.520 | means that everything happens in float32, just like it did before. But if we set it to high,
01:37:20.480 | as we do right now, matrix multiplications will now use TensorFlow 32 when it's available.
01:37:26.480 | My GPU is the A100. So it's an ampere series. And therefore, TF32 is available. If you have an older
01:37:34.640 | GPU, this might not be available for you. But for my GPU, it's available. And so what I expect
01:37:40.080 | PyTorch to do is that every single place where we see an nn.linear, inside there, there's a
01:37:45.040 | matrix multiplication. And I expect that matrix multiplication now to be running on TensorCourse,
01:37:51.200 | utilizing the TF32 precision. So this is the single line of change that is, I believe, necessary. And
01:37:59.200 | let's rerun this. Now, we saw that in terms of the throughput that is promised to us, we're supposed
01:38:05.680 | to be getting 8x, roughly. So let's see what happens. And that 8x came from here, right? 8x.
01:38:16.640 | And it also came from looking at it here, 156 tflops instead of 19.5. Okay, so what actually
01:38:26.240 | happened? So we're seeing that our throughput, roughly 3x, not 8x. So we are going from 1000
01:38:35.600 | milliseconds, we're going down to 300 milliseconds, and our throughput is now about 50,000 tokens per
01:38:40.560 | second. So we have a roughly 3x instead of 8x. So what happened? And basically, what's happening
01:38:46.240 | here is, again, a lot of these workloads are memory bound. And so even though the TF32 offers,
01:38:54.000 | in principle, a lot faster throughput, all of these numbers everywhere are still float32s.
01:39:01.280 | And it's float32 numbers that are being shipped all over the place through the memory system.
01:39:05.840 | And it's just costing us way too much time to shuttle around all this data. And so even though
01:39:10.160 | we've made the multiply itself much faster, we are memory bound, and we're not actually seeing
01:39:14.480 | the full benefit that would come from this napkin math here. That said, we are getting
01:39:21.200 | 3x faster throughput. And this is free. Single line of code in PyTorch. All your variables are
01:39:28.640 | still float32 everywhere. It just runs faster. And it's slightly more approximate, but we're
01:39:33.760 | not going to notice it, basically. So that's TF32. Okay, so let's now continue. So we've exercised
01:39:41.920 | this row. And we saw that we can crop out some of the precision inside the operation itself.
01:39:48.560 | But we saw that we're still memory bound. We're still moving around all these floats,
01:39:51.840 | right? Otherwise. And we're paying that cost because of this. So let's now decrease the
01:39:56.400 | amount of stuff that we're going to be moving around. And we're going to do that by dropping
01:40:00.960 | down to bfloat16. So we're only going to be maintaining 16 bits per float. And we're going
01:40:07.440 | to use the bfloat16. And I'll explain in a bit FP16 difference. And we're going to be in this
01:40:12.880 | row. So when we go back to the documentation here for the A100, we see here the precisions
01:40:23.120 | that are available. And this is the original FP32. The TF32 crops out the precision. And then here
01:40:29.760 | in bfloat16, you see that it is very similar to TF32. But it's even more aggressive in cropping
01:40:36.640 | off the precision, the mantissa, of this float. So the important thing with bfloat16 is that the
01:40:42.480 | exponent bits and the sign bit, of course, remain unchanged. So if you're familiar with your float
01:40:48.000 | numbers, and I think this should probably be an entire video by itself, the exponent sets the
01:40:55.200 | range that you can represent of your numbers. And the precision is how much precision you have
01:41:00.880 | for your numbers. And so the range of numbers is identical. But we have fewer
01:41:07.520 | possibilities within that range, because we are truncating the mantissa. So we have less precision
01:41:13.680 | in that range. What that means is that things are actually fairly nice, because we have the
01:41:19.600 | original range of numbers that are representable in float. But we just have less precision for it.
01:41:25.920 | And the difference with FP16 is that they actually touch and change the range. So FP16 cannot
01:41:32.640 | represent the full range of FP32. It has a reduced range. And that's where you start to
01:41:38.240 | actually run into issues, because now you need these gradient scalers and things like that.
01:41:43.120 | And I'm not going to go into the detail of that in this video, because that's a whole video by
01:41:48.240 | itself. But FP16 actually historically came first. That was available in the Volta series before
01:41:54.400 | Ampere. And so FP16 came first, and everyone started to train in FP16. But everyone had to
01:41:59.680 | use all these gradient scaling operations, which are kind of annoying. And it's an additional source
01:42:04.240 | of state and complexity. And the reason for that was because the exponent range was reduced in FP16.
01:42:10.160 | So that's the IEEE FP16's spec. And then they came out with BF16 and the Ampere. And they made it
01:42:17.680 | much simpler, because we're just truncating mantissa, we have the exact same range, and we do
01:42:21.680 | not need gradient scalers. So everything is much, much simpler. Now, when we do use BF16, though,
01:42:27.680 | we are impacting the numbers that we might be seeing in our PyTorch code.
01:42:32.080 | This change is not just local to the operation itself. So let's see how that works.
01:42:38.160 | There's some documentation here that-- so I think this is probably the best page to explain how to
01:42:45.600 | use mixed precision in PyTorch. Because there are many other tutorials and so on, even within
01:42:51.680 | PyTorch documentation, that are a lot more confusing. And so I recommend specifically
01:42:55.920 | this one. Because there's five other copies that I would not recommend. And then when we come here,
01:43:01.680 | ignore everything about everything. Ignore everything about gradient scalers.
01:43:06.560 | And only look at torch.autocast. And basically, also, this comes to a single line of code at
01:43:14.640 | the end. So this is the context manager that we want. And we want to use that in our network.
01:43:22.400 | When you click into the torch.autocast, autocasting, it has a few more-- a bit more
01:43:29.440 | guideline for you. So it's telling you, do not call BFloat16 on any of your tensors.
01:43:35.200 | Just use autocast. And only surround the forward pass of the model and the loss calculation.
01:43:41.600 | And that's the only two things that you should be surrounding. Leave the backward and the
01:43:45.280 | optimizer step alone. So that's the guidance that comes from the PyTorch team. So we're going to
01:43:50.400 | follow that guidance. And for us, because the loss calculation is inside of the model forward pass for
01:43:55.600 | us, we are going to be doing this. And then we don't want to be using torch.float16. Because if
01:44:01.840 | we do that, we need to start using gradient scalers as well. So we are going to be using BFloat16.
01:44:07.120 | This is only possible to do in Ampere. But this means that the changes are extremely minimal.
01:44:12.720 | Well, it's basically just this one line of code. Let me first break in to here, before we actually
01:44:21.440 | run this. So right after logits. I'd like to show you that, different from the TF32 that we saw,
01:44:28.720 | this is actually going to impact our tensors. So this logits tensor, if we now look at this,
01:44:36.720 | and we look at the D type, we suddenly see that this is now BFloat16. It's not float32 anymore.
01:44:43.360 | So our activations have been changed. The activations tensor is now BFloat16. But not
01:44:48.640 | everything has changed. So model.transformer.wte. This is the weight token embedding table. It has
01:44:59.120 | a dot weight inside it. And the D type of this weight, this parameter, is still torch.float32.
01:45:06.160 | So our parameters seem to still be in float32, but our activations, the logits, are now in BFloat16.
01:45:11.760 | So clearly, this is why we get the mixed precision. Some things PyTorch is keeping
01:45:17.040 | in float32. Some things PyTorch is converting to lower precision. And what gets converted,
01:45:24.640 | at what point, is not super clear. I remember scrolling down. Is it here?
01:45:31.920 | Okay, I can't find it. I thought it was here. Okay, there we go. So there are a few docs on
01:45:44.160 | when you're using this autocast, what gets converted to BFloat16 and when. So for example,
01:45:49.760 | only these matrix multiply-like operations get converted to BFloat16. But a lot of operations
01:45:55.360 | remain in float32. So in particular, a lot of normalizations, like layer norms and things like
01:45:59.920 | that, not all of those layers might be converted. So only some layers selectively would be running
01:46:06.400 | BFloat16. But things like softmax, layer norms, log softmax, so loss function calculations,
01:46:14.800 | a lot of those things might remain in float32 because they are more susceptible to precision
01:46:19.040 | changes. Matrix multiplies are fairly robust to precision changes. So some parts of the network
01:46:27.520 | are impacted more or less by the precision change. So basically only some parts of the model are
01:46:35.520 | running in reduced precision. Let's take it for a spin and let's actually see what kind of
01:46:41.680 | improvement we achieve here. Okay, so we used to be 333 milliseconds. We're now at 300.
01:46:52.880 | And we used to be somewhere around 50,000 tokens per second. We're now at 55.
01:46:56.960 | So we're definitely running faster, but maybe not a lot faster. And that's because there are
01:47:02.960 | still many, many bottlenecks in our GPT-2. We're just getting started. But we have dropped down
01:47:07.840 | the precision as far as we can with my current GPU, which is A100. We're using PyTorch Autocast.
01:47:14.320 | Unfortunately, I don't actually exactly know what PyTorch Autocast does. I don't actually know
01:47:20.000 | exactly what's in BFloat16, what's in float32. We could go in and we could start to scrutinize it.
01:47:25.120 | But these are the kinds of rules that PyTorch has internally. And unfortunately, they don't
01:47:30.400 | document it very well. So we're not going to go into that in too much detail. But for now,
01:47:36.560 | we are training in BFloat16. We do not need a gradient scaler. And the reason things are
01:47:41.200 | running faster is because we are able to run TensorCourse in BFloat16 now. That means we are
01:47:49.360 | in this row. But we are also paying in precision for this. So we expect slightly less accurate
01:47:57.840 | results with respect to the original FP32. But empirically, in many cases, this is a worth it
01:48:04.080 | trade-off because it allows you to run faster. And you could, for example, train longer and make
01:48:08.800 | up for that precision decrease. So that's BFloat16 for now. Okay. So as we can see,
01:48:17.120 | we are currently at about 300 milliseconds per iteration. And we're now going to reach for some
01:48:21.840 | really heavy weapons in the PyTorch arsenal. And in particular, we're going to introduce
01:48:25.920 | Torch.compile. So Torch.compile is really quite incredible infrastructure from the PyTorch team.
01:48:31.840 | And it's basically a compiler for neural networks. It's almost like GCC for C and C++ code. This is
01:48:38.240 | just the GCC of neural nets. So it came out a while ago and extremely simple to use. The way
01:48:46.800 | to use Torch.compile is to do this. It's a single line of code to compile your model and return it.
01:48:53.200 | Now, this line of code will cost you compilation time. But as you might guess, it's going to make
01:48:58.000 | the code a lot faster. So let's actually run that. Because this will take some time to run.
01:49:02.960 | But currently, remember, we're at 300 milliseconds. And we'll see what happens.
01:49:06.240 | Now, while this is running, I'd like to explain a little bit of what Torch.compile does under
01:49:11.600 | the hood. So feel free to read this page of PyTorch. But basically, there's no real good
01:49:17.680 | reason for you to not use Torch.compile in your PyTorch. I kind of feel like you should be using
01:49:22.640 | it almost by default unless you're debugging and you want your code to run really fast.
01:49:27.920 | And there's one line here in Torch.compile that I found that actually kind of gets to
01:49:31.760 | why this is faster. Speed up mainly comes from reducing Python overhead and GPU read/writes.
01:49:38.240 | So let me unpack that a little bit. Okay. Here we are. Okay. So we went from 300 milliseconds.
01:49:44.400 | We're now running at 129 milliseconds. So this is 300 divided by 129, about 2.3x
01:49:52.080 | improvement from a single line of code in PyTorch. So quite incredible. So what is happening? What's
01:49:57.520 | happening under the hood? Well, when you pass the model to Torch.compile, what we have here in this
01:50:03.840 | NN module, this is really just the algorithmic description of what we'd like to happen in our
01:50:09.360 | network. And Torch.compile will analyze the entire thing. And it will look at what operations you'd
01:50:15.840 | like to use. And with the benefit of knowing exactly what's going to happen, it doesn't
01:50:21.040 | have to run in what's called the eager mode. It doesn't have to just kind of like go layer by
01:50:25.680 | layer. Like the Python interpreter normally would start at the forward. And the Python interpreter
01:50:33.280 | will go, okay, let's do this operation. And then let's do that operation. And it kind of materializes
01:50:38.640 | all the operations as it goes through. So these calculations are dispatched and run in this order.
01:50:45.360 | And the Python interpreter and this code doesn't know what kind of operations are going to happen
01:50:49.680 | later. But Torch.compile sees your entire code at the same time. And it's able to know what
01:50:55.200 | operations you intend to run. And it will kind of optimize that process. The first thing it will do
01:51:00.640 | is it will take out the Python interpreter from the forward pass entirely. And it will kind of
01:51:05.120 | compile this entire neural net as a single object with no Python interpreter involved. So it knows
01:51:10.800 | exactly what's going to run. It will just run that. And it's all going to be running in efficient code.
01:51:16.640 | The second thing that happens is this read/write that they mentioned very briefly. So a good
01:51:23.120 | example of that, I think, is the Gelu nonlinearity that we've been looking at.
01:51:26.240 | So here we use the nngelu. Now, this here is me basically just breaking up the nngelu,
01:51:34.000 | which you remember has this formula. So this here is the equivalent implementation to what's
01:51:39.840 | happening inside Gelu. Algorithmically, it's identical. Now, by default, if we just were
01:51:46.240 | using this instead of nngelu here, what would happen without Torch.compile? Well, the Python
01:51:51.840 | interpreter would make its way here. And then it would be, okay, well, there's an input. Well,
01:51:56.160 | let me first let me raise this input to the third power. And it's going to dispatch a kernel that
01:52:01.840 | takes your input and raises it to the third power. And that kernel will run. And when this kernel
01:52:08.400 | runs, what ends up happening is this input is stored in the memory of the GPU. So here's a
01:52:14.080 | helpful example of the layout of what's happening, right? You have your CPU. This is in every single
01:52:20.240 | computer. There's a few cores in there. And you have your RAM, your memory. And the CPU can talk
01:52:27.040 | to the memory. And this is all well known. But now we've added the GPU. And the GPU is a slightly
01:52:31.840 | different architecture, of course. They can communicate. And it's different in that it's got
01:52:35.920 | a lot more cores than a CPU. All of those cores are individually a lot simpler, too.
01:52:41.200 | But it also has memory, right? This high bandwidth memory. Sorry if I'm botching it.
01:52:48.480 | HBM. I don't even know what that stands for. I'm just realizing now. But this is the memory. And
01:52:54.480 | it's very equivalent to RAM, basically, in the computer. And what's happening is that input is
01:53:00.640 | living in the memory. And when you do input cubed, this has to travel to the GPU, to the cores,
01:53:10.240 | and to all the caches and registers on the actual chip of this GPU. And it has to calculate all
01:53:17.680 | the elements of the third. And then it saves the result back to the memory. And it's this travel
01:53:23.840 | time that actually causes a lot of issues. So here, remember this memory bandwidth? We can
01:53:29.760 | communicate about 2 terabytes per second, which is a lot. But also, we have to traverse this link,
01:53:35.760 | and it's very slow. So here on the GPU, we're on chip, and everything is super fast within the
01:53:40.400 | chip. But going to the memory is extremely expensive. It takes an extremely long amount of
01:53:44.400 | time. And so we load the input, do the calculations, and load back the output. And this round trip
01:53:51.520 | takes a lot of time. And now right after we do that, we multiply by this constant. So what happens
01:53:57.840 | then is we dispatch another kernel. And then the result travels back. All the elements get
01:54:03.280 | multiplied by a constant. And then the results travel back to the memory. And then we take the
01:54:08.960 | result, and we add back input. And so this entire thing, again, travels to the GPU, adds the inputs,
01:54:16.240 | and gets written back. So we're making all these round trips from the memory to actually where the
01:54:22.000 | computation happens. Because all the tensor cores and the ALUs and everything like that is all
01:54:26.720 | stored on the chip and the GPU. So we're doing a ton of round trips. And PyTorch, without using
01:54:32.480 | Torch Compile, doesn't know to optimize this, because it doesn't know what kind of operations
01:54:37.200 | you're running later. You're just telling it, raise the power to the third, then do this, then
01:54:42.160 | do that. And it will just do that in that sequence. But Torch Compile sees your entire code.
01:54:46.480 | It will come here, and it will realize, wait, all of these are element-wise operations. And actually,
01:54:51.440 | what I'm going to do is I'm going to do a single trip of input to the GPU.
01:54:56.320 | Then for every single element, I'm going to do all of these operations while that memory is on the
01:55:01.680 | GPU, or chunks of it, rather. And then I'm going to write back a single time. So we're not going
01:55:07.520 | to have these round trips. And that's one example of what's called kernel fusion, and is a major way
01:55:12.560 | in which everything is sped up. So basically, if you have your benefit of handset, and you know
01:55:16.560 | exactly what you're going to compute, you can optimize your round trips to the memory. And
01:55:21.600 | you're not going to pay the memory bandwidth cost. And that's fundamentally what makes some of these
01:55:25.520 | operations a lot faster, and what they mean by read/writes here. So let me erase this, because
01:55:32.320 | we are not using it. And yeah, we should be using Torch Compile. And our code is now significantly
01:55:39.760 | faster. And we're doing about 125,000 tokens per second. But we still have a long way to go.
01:55:46.160 | Before we move on, I wanted to supplement the discussion a little bit with a few more figures.
01:55:50.640 | Because this is a complicated topic, but it's worth understanding on a high level
01:55:55.040 | what's happening here. And I could probably spend an entire video of like two hours on this, but
01:55:59.520 | just a preview of that basically. So this chip here, that is the GPU, this chip is where all
01:56:06.640 | the calculations happen mostly. But this chip also does have some memory in it. But most of
01:56:13.200 | the memory by far is here in the high bandwidth memory, HBM, and is connected, they're connected.
01:56:20.240 | But these are two separate chips, basically. Now, here, this is a zoom in of kind of this
01:56:26.800 | cartoon diagram of a GPU. And we're seeing here is number one, you see this HBM, I realize it's
01:56:33.680 | probably very small for you. But on the sides here, it says HBM. And so that's the links to the HBM.
01:56:39.360 | Now the HBM is, again, off chip. On the chip, there are a large number of these streaming
01:56:45.840 | multiprocessors. Every one of these is an SM, there's 120 of them in total. And this is where
01:56:52.080 | a lot of the calculations happen. And this is a zoom in of a single individual SM. It has these
01:56:57.440 | four quadrants. And see, for example, tensor core, this is where a lot of the matrix multiply stuff
01:57:01.600 | happens. But there's all these other units to do all different kinds of calculations for FB64,
01:57:07.120 | FB32, and for integers, and so on. Now, so we have all this logic here to the calculations.
01:57:14.480 | But in addition to that, on the chip, there is memory sprinkled throughout the chip.
01:57:18.880 | So L2 cache is some amount of memory that lives on the chip. And then on the SMs themselves,
01:57:25.520 | there's L1 cache. I realize it's probably very small for you, but this blue bar is L1.
01:57:30.400 | And there's also registers. And so there is memory stored here. But the way this memory is stored is
01:57:37.840 | very different from the way memory is stored in HBM. This is a very different implementation using
01:57:44.400 | just in terms of like what the silicon looks like, it's a very different implementation.
01:57:49.520 | So here, you would be using transistors and capacitors. And here, it's a very different
01:57:55.280 | implementation with SRAM and what that looks like. But long story short is there is memory
01:58:04.240 | inside the chip, but it's not a lot of memory. That's the critical point. So this is an example
01:58:10.320 | diagram of a slightly different GPU, just like here, where it shows that, for example, typical
01:58:15.760 | numbers for CPU DRAM memory, which is this thing here, you might have one terabyte of disk, right?
01:58:22.640 | But it would be extremely expensive to access, especially for a GPU. You have to go through the
01:58:26.320 | CPU here. Now, next, we have the HBM. So we have tens of gigabytes of HBM memory on a typical GPU
01:58:32.960 | here, but it's, as I mentioned, very expensive to access. And then on the chip itself, everything is
01:58:39.840 | extremely fast within the chip, but we only have a couple of 10 megabytes of memory collectively
01:58:45.760 | throughout the chip. And so there's just not enough space because the memory is very expensive
01:58:50.800 | on the chip. And so there's not a lot of it, but it is lightning fast to access in relative terms.
01:58:56.000 | And so basically, whenever we have these kernels, the more accurate picture of what's happening here
01:59:02.240 | is that we take these inputs, which live by default on the global memory, and now we need
01:59:08.080 | to perform some calculation. So we start streaming the data from the global memory to the chip.
01:59:15.680 | We perform the calculations on the chip and then stream it back and store it back to the
01:59:19.840 | global memory, right? And so if we don't have Torch Compile, we are streaming the data through
01:59:25.600 | the chip doing the calculations and saving to the memory, and we're doing those round trips many,
01:59:29.520 | many times. But if it's Torch Compiled, then we start streaming the memory as before, but then
01:59:35.600 | while we're on the chip, we have a chunk of the data that we're trying to process. So that chunk
01:59:42.880 | now lives on the chip. While it's on the chip, it's extremely fast to operate on. So if we have
01:59:47.360 | kernel fusion, we can do all the operations right there in an element-wise fashion, and those are
01:59:52.560 | very cheap. And then we do a single round trip back to the global memory. So operator fusion
01:59:58.960 | basically allows you to keep your chunk of data on the chip and do lots of calculations on it
02:00:03.600 | before you write it back, and that gives huge savings. And that's why Torch Compile ends up
02:00:09.280 | being a lot faster, or that's one of the major reasons. So again, just a very brief intro to
02:00:15.120 | the memory hierarchy and roughly what Torch Compile does for you. Now, Torch Compile is
02:00:19.920 | amazing, but there are operations that Torch Compile will not find. And an amazing example
02:00:25.520 | of that is FlashAttention, to which we turn next. So FlashAttention comes from this paper from
02:00:30.960 | Stanford in 2022, and it's this incredible algorithm for performing attention and running
02:00:39.520 | it a lot faster. So FlashAttention will come here, and we will take out these four lines,
02:00:45.920 | and FlashAttention implements these four lines really, really quickly. And how does it do that?
02:00:52.880 | Well, FlashAttention is a kernel fusion operation. So you see here we have in this diagram,
02:01:00.080 | they're showing PyTorch, and you have these four operations. They're including dropout,
02:01:05.600 | but we are not using dropout here. So we just have these four lines of code here,
02:01:09.680 | and instead of those, we are fusing them into a single fused kernel of FlashAttention.
02:01:15.200 | So it's a kernel fusion algorithm, but it's a kernel fusion that Torch Compile cannot find.
02:01:22.480 | And the reason that it cannot find it is that it requires an algorithmic rewrite of how attention
02:01:28.160 | is actually implemented here in this case. And what's remarkable about it is that FlashAttention,
02:01:33.920 | actually, if you just count the number of flops, FlashAttention does more flops than this attention
02:01:40.080 | here. But FlashAttention is actually significantly faster. In fact, they cite 7.6 times faster,
02:01:47.120 | potentially. And that's because it is very mindful of the memory hierarchy, as I described it just
02:01:53.680 | now. And so it's very mindful about what's in high bandwidth memory, what's in the shared memory,
02:01:59.280 | and it is very careful with how it orchestrates the computation, such that we have fewer reads
02:02:05.120 | and writes to the high bandwidth memory. And so even though we're doing more flops,
02:02:09.120 | the expensive part is their load and store into HBM, and that's what they avoid. And so in
02:02:14.000 | particular, they do not ever materialize this end-by-end attention matrix, this ATT here.
02:02:20.560 | FlashAttention is designed such that this matrix never gets materialized at any point,
02:02:25.760 | and it never gets read or written to the HBM. And this is a very large matrix, right?
02:02:30.960 | So because this is where all the queries and keys interact, and we're sort of getting
02:02:35.280 | for each head, for each batch element, we're getting a T-by-T matrix of attention,
02:02:43.120 | which is a million numbers, even for a single head at a single batch index.
02:02:47.600 | So basically, this is a ton of memory, and this is never materialized. And the way that
02:02:54.080 | this is achieved is that basically the fundamental algorithmic rewrite here relies on this online
02:03:00.720 | softmax trick, which was proposed previously, and I'll show you the paper in a bit.
02:03:04.240 | And the online softmax trick, coming from a previous paper, shows how you can incrementally
02:03:11.360 | evaluate a softmax without having to sort of realize all of the inputs to the softmax
02:03:17.360 | to do the normalization. And you do that by having these intermediate variables m and l,
02:03:22.080 | and there's an update to them that allows you to evaluate the softmax in an online manner.
02:03:26.560 | Now FlashAttention, actually, so recently FlashAttention2 came out as well, so I have
02:03:33.120 | that paper up here as well, that has additional gains to how it calculates FlashAttention.
02:03:38.480 | And the original paper that this is based on, basically, is this online normalizing calculation
02:03:43.200 | for softmax. And remarkably, it came out of NVIDIA, and it came out of it like really early,
02:03:48.240 | 2018. So this is four years before FlashAttention. And this paper says that we propose a way to
02:03:55.840 | compute the classical softmax with fewer memory accesses and hypothesize that this reduction in
02:04:00.240 | memory accesses should improve softmax performance on actual hardware. And so they are extremely
02:04:06.480 | correct in this hypothesis, but it's really fascinating to me that they're from NVIDIA,
02:04:11.600 | and that they had this realization, but they didn't actually take it to the actual FlashAttention
02:04:16.240 | that had to come four years later from Stanford. So I don't fully understand the historical,
02:04:21.840 | how this happened historically, but they do basically propose this online update to the softmax
02:04:27.280 | right here. And this is fundamentally what they reuse here to calculate the softmax in a streaming
02:04:33.360 | manner. And then they realized that they can actually fuse all the other operations
02:04:36.880 | with the online softmax calculation into a single fused kernel, FlashAttention,
02:04:42.240 | and that's what we are about to use. So a great example, I think, of being aware of memory
02:04:48.160 | hierarchy, the fact that flops don't matter, the entire memory access pattern matters,
02:04:52.960 | and that TorchCompile is amazing, but there are many optimizations that are still available to us
02:04:57.200 | that potentially TorchCompile cannot find. Maybe one day it could, but right now it seems like a
02:05:03.200 | lot to ask. So here's what we're going to do. We're going to use FlashAttention, and the way
02:05:09.200 | to do that basically in PyTorch is we are going to comment out these four lines, and we're going
02:05:15.280 | to replace them with a single line. And here we are calling this compound operation in PyTorch
02:05:20.960 | called scale.productAttention. And PyTorch will call FlashAttention when you use it in this way.
02:05:30.880 | I'm not actually 100% sure why TorchCompile doesn't realize that these four lines should
02:05:34.720 | just call FlashAttention in this exact way. We have to do it again for it, which in my opinion
02:05:40.240 | is a little bit odd, but here we are. So you have to use this compound op, and let's wait for a few
02:05:50.480 | moments before TorchCompile gets around to it. And then let's remember that we achieved 6.05661.
02:05:58.400 | I have it here. That's the loss we were expecting to see, and we took 130 milliseconds before this
02:06:04.640 | change. So we're expecting to see the exact same result by iteration 49, but we expect to see
02:06:11.120 | faster runtime because FlashAttention is just an algorithmic rewrite, and it's a faster kernel,
02:06:16.240 | but it doesn't actually change any of the computation, and we should have the exact
02:06:19.120 | same optimization. So okay, so we're a lot faster. We're at about 95 milliseconds,
02:06:24.720 | and we achieved 6.058. Okay, so they're basically identical up to a floating-point
02:06:32.640 | fudge factor. So it's the identical computation, but it's significantly faster going from 130 to
02:06:39.120 | roughly 96, and so this is 96 divide 130-ish, so this is maybe 27-ish percent improvement.
02:06:50.240 | So really interesting, and that is FlashAttention. Okay, we are now getting to one of my favorite
02:06:56.560 | optimizations, and it is simultaneously the dumbest and the most brilliant optimization,
02:07:02.000 | and it's always a little bit surprising to me. Anyway, so basically I mentioned a few minutes
02:07:07.760 | ago that there are some numbers that are nice and some numbers that are ugly. So 64 is a beautiful
02:07:15.120 | nice number. 128 is even nicer. 256 is beautiful. What makes these numbers beautiful is that there
02:07:21.120 | are many powers of two inside them. You can divide by two many times, and examples of ugly numbers
02:07:27.360 | are like 13 and 17 and something like that, prime numbers, numbers that are not even, and so on,
02:07:32.560 | and so pretty much you always want to use nice numbers in all of your code that deals with neural
02:07:36.880 | networks or CUDA because everything in CUDA works in sort of like powers of two, and lots of kernels
02:07:43.840 | are written in terms of powers of two, and there are lots of blocks of sizes 16 and 64 and so on.
02:07:50.240 | So everything is written in those terms, and you always have special case handling for all kinds of
02:07:54.880 | logic that when your inputs are not made of nice numbers. So let's see what that looks like.
02:08:01.680 | Basically, scan your code and look for ugly numbers is roughly the heuristic. So three times
02:08:08.000 | is kind of ugly. I'm not 100% sure maybe this can be improved, but this is ugly and not ideal.
02:08:14.640 | Four times is nice. So that's nice. 1024 is very nice. That's a power of two.
02:08:24.960 | 12 is a little bit suspicious. Not too many powers of two. 768 is great. 50,257 is a really,
02:08:32.800 | really ugly number. First of all, it's odd, and there's not too many powers of two in there.
02:08:40.560 | So this is a very ugly number, and it's highly suspicious. And then when we scroll down,
02:08:46.080 | all these numbers are nice, and then here we have mostly nice numbers except for 25.
02:08:52.720 | So in this configuration of GPT-2XL, the number of heads is 25. That's a really ugly number. That's
02:08:57.840 | an odd number. Actually, this did cause a lot of headaches for us recently when we were trying to
02:09:02.960 | optimize some kernels to run this fast and required a bunch of special case handling.
02:09:08.320 | So basically, we have some ugly numbers, and some of them are easier to fix than others.
02:09:14.560 | In particular, the vocab size being 50,257, that's a very ugly number, very suspicious,
02:09:19.360 | and we want to fix it. Now, when you fix these things, one of the easy ways to do that is you
02:09:24.560 | basically increase the number until it's the nearest power of two that you like. So here's
02:09:31.280 | a much nicer number. It's 50,304. And why is that? Because 50,304 can be divided by 8, or by 16,
02:09:40.000 | or by 32, 64. It can even be divided by 128, I think. Yeah. So it's a very nice number.
02:09:48.400 | So what we're going to do here is this is the GPT config, and you see that we initialize
02:09:53.360 | vocab size to 50,257. Let's override just that element to be 50,304.
02:10:00.960 | So everything else stays the same. We're just increasing our vocabulary size.
02:10:07.600 | So it's almost like we're adding fake tokens. So that vocab size has powers of two inside it.
02:10:14.560 | Now, actually, what I'm doing here, by the way, is I'm increasing the amount of computation
02:10:19.200 | that our network will be doing. If you just count the flops on like, do the math of how many flops
02:10:23.840 | we're doing, we're going to be doing more flops. And we still have to think through whether this
02:10:29.680 | doesn't break anything. But if I just run this, let's see what we get. Currently, this ran and
02:10:35.760 | maybe 96.5 milliseconds per step. I'm just kind of like eyeballing it. And let's see what kind of
02:10:43.840 | result we're going to get. While this is compiling, let's think through whether our code actually
02:10:51.280 | works okay when we increase the vocab size like this. Let's look at where vocab size is actually
02:10:56.000 | used. So we swing up to the init, and we see that it's used inside the embedding table, of course,
02:11:02.560 | so all the way at the bottom of the transformer. And it's used at the classifier layer, all the
02:11:06.400 | way at the top of the transformer, so in two places. And let's take a look. And we're running
02:11:11.440 | at 93. So 93 milliseconds instead of 96.5. So we are seeing a roughly 4% improvement here
02:11:20.400 | by doing more calculations. And the reason for this is we've made an ugly number into a nice
02:11:28.400 | number. I'm going to come into the explanation for that a little bit again. But for now, let's
02:11:33.920 | just convince ourselves that we're not breaking anything when we do this. So first of all, we've
02:11:37.600 | made the WTE, the embedding table for the tokens, we've made it larger. It's almost like we
02:11:43.120 | introduced more tokens at the bottom. And these tokens are never used because the GPT tokenizer
02:11:49.600 | only has tokens up to 50,256. And so we'll never index into the rows that we've added. So we're
02:11:57.120 | wasting a little bit of space here by creating memory that's never going to be accessed, never
02:12:01.360 | going to be used, etc. Now, that's not fully correct, because this WTE weight ends up being
02:12:07.440 | shared and ends up being used in the classifier here at the end. So what is that doing to the
02:12:12.320 | classifier right here? Well, what that's doing is we're predicting additional dimensions of the
02:12:17.440 | classifier now. And we're predicting probabilities for tokens that will, of course, never be present
02:12:22.160 | in the training set. And so therefore, the network has to learn that these probabilities have to be
02:12:28.880 | driven to zero. And so the logits that the network produces have to drive those dimensions of the
02:12:34.320 | output to negative infinity. But that's no different from all the other tokens that are already in our
02:12:39.760 | data set, or rather that are not in our data set. So Shakespeare only probably uses, let's say 1000
02:12:46.400 | tokens out of 50,257 tokens. So most of the tokens are already being driven to zero probability by
02:12:52.160 | the optimization, we've just introduced a few more tokens now that in a similar manner will never be
02:12:57.120 | used and have to be driven to zero in probability. So functionally, though, nothing breaks, we're
02:13:03.920 | using a bit more extra memory. But otherwise, this is a harmless operation, as far as I can tell.
02:13:10.400 | But and we're adding calculation, but it's running faster. And it's running faster,
02:13:15.280 | because as I mentioned, in CUDA, so many kernels use block tiles, and these block tiles are usually
02:13:22.560 | nice numbers. So powers of two, so calculations are done in like chunks of 64, or chunks of 32.
02:13:28.720 | And when you're when your desired calculation doesn't neatly fit into those block tiles,
02:13:34.720 | there are all kinds of boundary kernels that can kick in to like, do the last part. So basically,
02:13:42.720 | in a lot of kernels, they will truncate up your input, and they will do the nice part first,
02:13:47.280 | and then they have a whole second second phase, where they come back to anything that like remains.
02:13:52.640 | And then they process the remaining part. And the kernels for that can be very inefficient.
02:13:58.000 | And so you're basically spinning up all this extra compute, and it's extremely inefficient.
02:14:03.680 | So you might as well pad your inputs and make it fit nicely. And usually that empirically ends up
02:14:10.160 | actually running faster. So this is another example of a 4% improvement that we've added.
02:14:17.760 | And this is something that also Torch Compile did not find for us. You would hope that Torch Compile
02:14:22.640 | at some point could figure an optimization like this out. But for now, this is it. And I also have
02:14:28.320 | to point out that we're using PyTorch nightly. So that's why we're only seeing 4%. If you're using
02:14:32.880 | PyTorch 2.3.1, or earlier, you would actually see something like 30% improvement just from this
02:14:38.800 | change, from changing it from 50,000 to 57,000 to 5,304. So again, one of my favorite examples also
02:14:48.000 | of having to understand the under the hood and how it all works, and to know what kinds of things to
02:14:52.320 | tinker with to push the performance of your code. Okay, so at this point, we have improved the
02:14:56.960 | performance by about 11x, right? Because we started at about 1000 milliseconds per step,
02:15:01.840 | and we're now down to like 93 milliseconds. So that's quite good. And we're doing a much better
02:15:08.240 | job of utilizing our GPU resources. So I'm going to now turn to more algorithmic changes and
02:15:14.720 | improvements to the actual optimization itself. And what we would like to do is we'd like to
02:15:18.640 | follow the hyperparameters that are mentioned in the GPT-2 or GPT-3 paper. Now, sadly, GPT-2
02:15:25.360 | doesn't actually say too much. It's very nice of them that they released the model weights and the
02:15:31.120 | code, but the paper itself is extremely vague as to the optimization details. The code itself that
02:15:36.320 | they released as well, the code we've been looking at, this is just the inference code. So there's no
02:15:41.520 | training code here and very few hyperparameters. So this doesn't also tell us too much. So for that,
02:15:46.480 | we have to turn to the GPT-3 paper. And in the appendix of the GPT-3 paper, they have a lot
02:15:54.400 | more hyperparameters here for us to use. And the GPT-3 paper in general is a lot more detailed as
02:16:00.240 | to all the small details that go into the model training, but GPT-3 models were never released.
02:16:08.000 | So GPT-2, we have the weights, but no details, and GPT-3, we have lots of details, but no weights.
02:16:12.880 | But roughly speaking, GPT-2 and GPT-3 architectures are very, very similar.
02:16:18.560 | And basically, there are very few changes. The context length was expanded from 1024 to 2048,
02:16:25.760 | and that's kind of like the major change. And some of the hyperparameters around the
02:16:29.520 | transformer have changed. But otherwise, they're pretty much the same model. It's
02:16:32.880 | just that GPT-3 was trained for a lot longer on a bigger dataset and has a lot more thorough
02:16:37.920 | evaluations. And the GPT-3 model is 175 billion instead of 1.6 billion in the GPT-2.
02:16:47.440 | So long story short, we're going to go to GPT-3 paper to follow along some of the hyperparameters.
02:16:52.160 | So to train all the versions of GPT-3, we use Atom with beta 1, beta 2 of 0.9 and 0.95.
02:16:59.440 | So let's swing over here and make sure that the betas parameter, which you can see here defaults
02:17:04.880 | to 0.9 and 0.999, is actually set to 0.9 and 0.95. And then the epsilon parameter,
02:17:13.680 | you can see is the default is 1 and negative 8, and this is also 1 and negative 8. Let's just
02:17:19.680 | put it in so that we're explicit. Now, next up, they say we clip the global norm of the gradient
02:17:27.120 | at 1.0. So what this is referring to is that once we calculate the gradients right after loss dot
02:17:33.120 | backward, we basically have the gradients at all the parameter tensors. And what people like to do
02:17:40.160 | is basically clip them to have some kind of a maximum norm. So in PyTorch, this is fairly easy
02:17:46.160 | to do. It's one line of code here that we have to insert right after we calculate the gradients.
02:17:51.440 | And what this utility function is doing is it's calculating the global norm of the parameters.
02:17:59.040 | So every single gradient on all the parameters, you square it and you add it all up and you take
02:18:05.120 | a big square root of that. And that's the norm of the parameter vector, basically. It's the
02:18:12.800 | length of it, if you if you'd like to look at it that way. And we are basically making sure that
02:18:17.040 | its length is no more than 1.0. And we're going to clip it. And the reason that people like to
02:18:22.880 | use this is that sometimes you can get unlucky during the optimization. Maybe it's a bad data
02:18:27.760 | batch or something like that. And if you get very unlucky in the batch, you might get really high
02:18:32.560 | loss and really high loss could lead to a really high gradient. And this could basically shock your
02:18:38.880 | model and shock the optimization. So people like to use a gradient norm clipping to prevent the
02:18:45.680 | model from basically getting too big of shocks in terms of the gradient magnitude and the upper
02:18:52.880 | bounded in this way. It's a bit of a hacky solution. It's got like a patch on top of like
02:18:57.840 | deeper issues, but people still do it fairly frequently. Now, the clip grad norm returns
02:19:04.400 | the norm of the gradient, which I like to always visualize because it is useful information. And
02:19:11.440 | sometimes you can look at the norm of the gradient and if it's well behaved, things are good. If it's
02:19:16.560 | climbing, things are bad and they're destabilizing during training. Sometimes you could get a spike
02:19:21.040 | in the norm, and that means there's some kind of an issue or instability. So the norm here will be
02:19:27.360 | a norm. And let's do a 0.4F or something like that. And I believe this is just a float.
02:19:36.720 | And so we should be able to print that. So that's global gradient clipping.
02:19:43.840 | Now they go into the details of the learning rate scheduler. So they don't just use a fixed
02:19:50.560 | learning rate like we do here for 3E negative 4, but there's actually basically a cosine decay
02:19:56.640 | learning rate schedule. It's got a warmup and it's got a cosine decay to 10% over some horizon.
02:20:04.800 | And so we're going to implement this in a second. I just like to see the norm printed here. Okay,
02:20:12.720 | there we go. So what happened here is the norm is actually really high in the beginning,
02:20:17.760 | 30 or so. And you see that as we continue training, it kind of like stabilizes
02:20:24.800 | at values below one. And this is not that crazy uncommon for the norm to be high in the very first
02:20:31.440 | few stages. Basically what's happening here is the model is completely random. And so there's
02:20:35.520 | a ton of learning happening very early in the network, but that learning is kind of like,
02:20:39.520 | you know, it's mostly learning the biases of the output tokens. And so it's a bit of an unstable
02:20:45.920 | time, but the network usually stabilizes in the very few iterations. So this looks relatively
02:20:50.880 | reasonable to me, except usually I would expect this looks a little bit funky that we go from
02:20:55.360 | 28 to 6 to 2 and then to 10. It's not completely insane, but it's just kind of a little bit funky.
02:21:02.080 | Okay, so let's now get to the learning rate scheduler. So the learning rate schedule
02:21:07.760 | that's used here in GPT-3 is what's called a cosine decay learning schedule with warmup.
02:21:14.240 | And the way this looks is that the learning rate is basically starts right at around zero,
02:21:19.840 | linearly ramps up over some amount of time, and then comes down with this cosine sort of form and
02:21:27.120 | comes down to some kind of a minimum learning rate that's up to you. So here the minimum learning
02:21:30.720 | rate is zero. But here in the paper, they said that they use cosine decay for learning rate down
02:21:36.640 | to 10% of its value over the first 260 billion tokens. And then training continues 10% after.
02:21:44.800 | And there's a linear warmup over the first 375 million tokens. So that's about the learning
02:21:50.560 | rate. So let's now implement this. So I already implemented it here. And the way this works is,
02:21:56.560 | let me scroll down first here. I changed our training loop a little bit. So this was a 4i
02:22:02.400 | in max steps. I just change it to step now so that we have the notion of a step as a single
02:22:07.360 | optimization step in the for loop. And then here, I get the LR for this step of the optimization
02:22:14.960 | using a new function I call getLR. And then in PyTorch to set the learning rate, I think this
02:22:20.400 | is the way to set the learning rate. It's a little bit gnarly. Because you have to basically there's
02:22:25.120 | a notion of different parameter groups that could exist in the optimizer. And so you actually have
02:22:29.200 | to iterate over them, even though we currently have a single param group only. And you have
02:22:34.080 | to set the LR in this for loop kind of style, is my impression right now. So we have this local
02:22:40.400 | LR, we set the learning rate, and then on the bottom, I'm also printing it. So that's all the
02:22:46.240 | changes I made to this loop. And then of course, the getLR is my scheduler. Now it's worth pointing
02:22:51.200 | out that PyTorch actually has learning rate schedulers, and you can use them. And I believe
02:22:55.680 | there's a cosine learning rate schedule in PyTorch. I just don't really love using that
02:23:00.880 | code, because honestly, it's like five lines of code. And I fully understand what's happening
02:23:06.720 | inside these lines. So I don't love to use abstractions where they're kind of inscrutable,
02:23:12.160 | and then I don't know what they're doing. So personal style. So the max learning rate here
02:23:17.360 | is let's say 3e negative 4. But we're going to see that in GPT-3 here, they have a table of what the
02:23:25.280 | maximum learning rate is for every model size. So for this one, basically 12 layer 768 GPT-3.
02:23:36.560 | So the GPT-3 small is roughly like a GPT-2 124M. We see that here they use a learning rate of 6e
02:23:43.280 | negative 4. So we could actually go higher. In fact, we may want to try to follow that and just
02:23:48.000 | set the maximal R here at 6. Then that's the maximum learning rate. The min learning rate is
02:23:54.400 | 10% of that per description in the paper, some number of steps that we're going to warm up over,
02:24:01.600 | and then the maximum steps of the optimization, which I now use also in the for loop down here.
02:24:06.080 | And then you can go over this code if you like. It's not terribly inside floor interesting.
02:24:12.240 | I'm just modulating based on the iteration number which learning rate there should be.
02:24:17.600 | So this is the warmup region. This is the region after the optimization. And then this is the
02:24:24.240 | region sort of in between. And this is where I calculate the cosine learning rate schedule.
02:24:28.960 | And you can step through this in detail if you'd like. But this is basically implementing this
02:24:33.040 | curve. And I ran this already. And this is what that looks like. So when we now run, we start at
02:24:45.600 | some very low number. Now, note that we don't start exactly at zero because that would be not
02:24:49.760 | useful to update with a learning rate of zero. That's why there's an it plus one, so that on
02:24:54.320 | the zeroth iteration, we are not using exactly zero. We're using something very, very low.
02:24:58.640 | Then we linearly warm up to maximum learning rate, which in this case was 3e negative 4 when I ran
02:25:04.320 | it. But now it would be 6e negative 4. And then it starts to decay all the way down to 3e negative 5,
02:25:12.640 | which was at the time 10% of the original learning rate. Now, one thing we are not
02:25:16.880 | following exactly is that they mentioned that -- let me see if I can find it again.
02:25:22.640 | We're not exactly following what they did because
02:25:26.960 | they mentioned that their training horizon is 300 billion tokens. And they come down to 10%
02:25:34.400 | of the initial learning rate at 260 billion. And then they train after 260 with 10%. So basically,
02:25:41.440 | their decay time is less than the max steps time, whereas for us, they're exactly equal.
02:25:46.560 | So it's not exactly faithful, but it's an okay -- this is okay for us and for our purposes right
02:25:54.240 | now. And we're just going to use this ourselves. I don't think it makes too big of a difference,
02:26:00.400 | honestly. I should point out that what learning rate schedule you use is totally up to you.
02:26:05.440 | There's many different types. Cosine learning rate has been popularized a lot by GPT-2 and
02:26:10.640 | GPT-3, but people have come up with all kinds of other learning rate schedules. And this is
02:26:15.600 | kind of like an active area of research as to which one is the most effective at training
02:26:20.320 | these networks. Okay, next up, the paper talks about the gradual batch size increase. So there's
02:26:27.040 | a ramp on the batch size that is linear. And you start with very small batch size, and you ramp up
02:26:32.400 | to a big batch size over time. We're going to actually skip this, and we're not going to work
02:26:36.880 | with it. And the reason I don't love to use it is that it complicates a lot of the arithmetic,
02:26:41.760 | because you are changing the number of tokens that you're processing at every single step of
02:26:45.200 | the optimization. And I like to keep that math very, very simple. Also, my understanding is
02:26:50.080 | that this is not a major improvement. And also, my understanding is that this is not an algorithmic
02:26:56.560 | optimization improvement. It's more of a systems and speed improvement. And roughly speaking,
02:27:01.360 | this is because in the early stages of the optimization, again, the model is in a very
02:27:08.080 | atypical setting. And mostly what you're learning is that you're mostly learning to ignore the
02:27:14.000 | tokens that don't come up in your training set very often. You're learning very simple biases and
02:27:19.280 | that kind of a thing. And so every single example that you put through your network
02:27:25.520 | is basically just telling you, use these tokens and don't use these tokens. And so the gradients
02:27:30.080 | from every single example are actually extremely highly correlated. They all look roughly the same
02:27:35.040 | in the original parts of the optimization, because they're all just telling you that
02:27:39.040 | these tokens don't appear and these tokens do appear. And so because the gradients are all
02:27:44.240 | very similar, and they're highly correlated, then why are you doing batch sizes of millions,
02:27:49.280 | when if you do a batch size of 32k, you're basically getting the exact same gradient
02:27:53.760 | early on in the training. And then later in the optimization, once you've learned all the simple
02:27:59.040 | stuff, that's where the actual work starts. And that's where the gradients become more
02:28:02.800 | decorrelated per examples. And that's where they actually offer you sort of statistical power,
02:28:07.760 | in some sense. So we're going to skip this just because it kind of complicates things.
02:28:13.200 | And we're going to go to data are sampled without replacement during training.
02:28:19.520 | So until an epoch boundary is reached. So without replacement means that they're not sampling from
02:28:25.600 | some fixed pool, and then take a sequence, train on it, but then also like return to sequence the
02:28:32.480 | pool, they are exhausting a pool. So when they draw a sequence, it's it's gone until the next
02:28:37.920 | epoch of training. So we're already doing that because our data loader iterates over chunks of
02:28:44.560 | data. So there's no replacement, they don't become eligible to be drawn again until the next epoch.
02:28:50.640 | So we're basically already doing that. All models use a weight decay of point one to
02:28:57.280 | provide a small amount of regularization. So let's implement a weight decay. And you see
02:29:02.400 | here that I've already kind of made the changes. And in particular, instead of creating the
02:29:06.400 | optimizer right here, I'm creating a new configure optimizers function inside the model. And I'm
02:29:13.680 | passing in some of the hyper parameters instead. So let's look at the configure optimizers,
02:29:18.080 | which is supposed to return the optimizer object. Okay, so it looks complicated, but it's actually
02:29:27.680 | really simple. And it's just, we're just being very careful. And there's a few settings here
02:29:32.400 | to go through. The most important thing with respect to this line is that you see there's
02:29:36.640 | a weight decay parameter here. And I'm passing that into well, I'm passing that into something
02:29:44.960 | called optim_groups that eventually ends up going into the add_mw optimizer. And the weight decay
02:29:50.720 | that's by default used in add_mw here is 0.01. So it's 10 times lower than what's used in GPT-3
02:29:58.160 | paper here. So the weight decay basically ends up making its way into the add_mw3 optimizer groups.
02:30:05.200 | Now what else is going on here in this function? So the two things that are happening here that
02:30:09.600 | are important is that I'm splitting up the parameters into those that should be weight
02:30:13.920 | decayed and those that should not be weight decayed. So in particular, it is common to not
02:30:18.960 | weight decay biases and any other sort of one-dimensional tensors. So the one-dimensional
02:30:25.840 | tensors are in the node decay parameters. And these are also things like layer norm,
02:30:31.760 | scales, and biases. It doesn't really make sense to weight decay those. You mostly want to weight
02:30:36.240 | decay the weights that participate in matrix multiplications. And you want to potentially
02:30:41.920 | weight decay the embeddings. And we've covered in a previous video why it makes sense to decay
02:30:47.760 | the weights, because you can sort of think of it as a regularization. Because when you're pulling
02:30:52.000 | down on all the weights, you're forcing the optimization to use more of the weights. And
02:30:57.200 | you're not allowing any one of the weights individually to be way too large. You're
02:31:02.560 | forcing the network to kind of distribute the work across more channels, because there's sort
02:31:07.600 | of like a pull of gravity on the weights themselves. So that's why we are separating it
02:31:14.400 | in those ways here. We're only decaying the embeddings and the matmul participating weights.
02:31:20.000 | We're printing the number of parameters that we're decaying and not. Most of the parameters
02:31:25.120 | will be decayed. And then one more thing that we're doing here is I'm doing another optimization here.
02:31:31.920 | And previous AdamW did not have this option, but later parts of PyTorch introduced it.
02:31:37.840 | And that's why I'm guarding it with an inspect.signature, which is basically checking
02:31:42.320 | if this fused quark is present inside AdamW. And then if it is present, I'm going to end up using
02:31:50.560 | it and passing it in here. Because some earlier versions do not have fused equals. So here's
02:31:57.520 | AdamW fused equals. It did not used to exist and it was added later. And there's some docs here for
02:32:03.440 | what's happening. And basically they say that by default, they do not use fused because it is
02:32:09.440 | relatively new and we want to give it sufficient bake time. So by default, they don't use fused,
02:32:13.600 | but fused is a lot faster when it is available and when you're running on CUDA. And what that
02:32:18.720 | does is instead of iterating in a for loop over all the parameter tensors and updating them,
02:32:25.040 | that would launch a lot of kernels, right? And so fused just means that all those kernels are
02:32:30.640 | fused into a single kernel. You get rid of a lot of overhead and you a single time on all the
02:32:36.160 | parameters call a kernel that updates them. And so it's just basically a kernel fusion for the
02:32:43.440 | AdamW update instead of iterating over all the tensors. So that's the configure optimizers
02:32:49.680 | function that I like to use. And we can rerun and we're not going to see any major differences from
02:32:55.120 | what we saw before, but we are going to see some prints coming from here. So let's just take a look
02:33:00.400 | at what they look like. So we see that number of decay tensors is 50 and it's most of the primers
02:33:07.440 | and number of non-decay tensors is 98. And these are the biases and the layer norm parameters
02:33:11.840 | mostly. And that's, there's only a hundred thousand of those. So most of it is decayed.
02:33:17.520 | And then we are using the fused implementation of AdamW, which will be a lot faster. So if you
02:33:22.480 | have it available, I would advise you to use it. I'm not actually a hundred percent sure why they
02:33:26.720 | don't default to it. It seems fairly benign and harmless. And also because we are using the fused
02:33:31.920 | implementation, I think this is why we have dropped, notice that the running time used to
02:33:38.000 | be 93 milliseconds per step. And we're now down to 90 milliseconds per step because of using the
02:33:43.200 | fused AdamW optimizer. So in a single commit here, we are introducing fused Adam, getting
02:33:50.240 | improvements on the time, and we're adding or changing the weight decay, but we're only weight
02:33:56.080 | decaying the two-dimensional parameters, the embeddings, and the matrices that participate
02:34:00.320 | in the linear. So that is this, and we can take this out. And yeah, that is it for this line.
02:34:09.440 | One more quick note before we continue here. I just want to point out that the relationship between
02:34:14.000 | weight decay, learning rate, batch size, the Adam parameters, beta1, beta2, the epsilon, and so on,
02:34:19.920 | these are very complicated mathematical relationships in the optimization literature.
02:34:25.120 | And for the most part, in this video, I'm just trying to copy paste the settings that OpenAI
02:34:30.720 | used. But this is a complicated topic, quite deep. And yeah, in this video, I just want to copy the
02:34:36.880 | parameters because it's a whole different video to really talk about that in detail and give it
02:34:41.040 | a proper justice instead of just high-level intuitions. Now, the next thing that I want
02:34:45.600 | to move on to is that this paragraph here, by the way, we're going to turn back around to when we
02:34:51.360 | improve our data loader. For now, I want to swing back around to this table,
02:35:02.000 | where you will notice that for different models, we, of course, have different hyperparameters
02:35:07.760 | for the transformer that dictate the size of the transformer network. We also have a different
02:35:12.160 | learning rate. So we're seeing the pattern that the bigger networks are trained with slightly
02:35:15.600 | lower learning rates. And we also see this batch size, where in the small networks,
02:35:21.840 | they use a smaller batch size, and in the bigger networks, they use a bigger batch size.
02:35:26.240 | Now, the problem for us is we can't just use 0.5 million batch size,
02:35:30.960 | because if I just try to come in here and I try to set this B, where's my B?
02:35:36.800 | B equals...
02:35:41.840 | Where do I call the data loader? Okay, B equals 16. If I try to set...
02:35:48.720 | Well, we have to be careful. It's not 0.5 million, because this is the batch size
02:35:54.960 | in the number of tokens. Every single one of our rows is 1,024 tokens. So 0.5E6,
02:36:01.840 | 1 million divide 1,024. This would need about a 488 batch size. So the problem is I can't
02:36:09.440 | come in here and set this to 488, because my GPU would explode. This would not fit for sure.
02:36:16.400 | But we still want to use this batch size, because again, as I mentioned, the batch size is correlated
02:36:23.600 | with all the other optimization hyperparameters and the learning rates and so on. So we want to
02:36:28.160 | have a faithful representation of all the hyperparameters, and therefore we need to
02:36:31.760 | use a batch size of 0.5 million, roughly. But the question is, how do we use 0.5 million if
02:36:38.720 | we only have a small GPU? Well, for that, we need to use what's called gradient accumulation.
02:36:43.120 | So we're going to turn to that next, and it allows us to simulate in a serial way
02:36:48.400 | any arbitrary batch size that we set. And so we can do a batch size of 0.5 million. We just have
02:36:54.400 | to run longer, and we have to process multiple sequences and basically add up all the gradients
02:37:00.720 | from them to simulate a batch size of 0.5 million. So let's turn to that next.
02:37:05.440 | Okay, so I started the implementation right here just by adding these lines of code.
02:37:08.800 | And basically what I did is first I set the total batch size that we desire. So this is
02:37:14.880 | exactly 0.5 million, and I used a nice number, a power of 2, because 2 to the 19 is 524288,
02:37:22.560 | so it's roughly 0.5 million. It's a nice number. Now, our micro-batch size, as we call it now,
02:37:28.000 | is 16. So this is going to be -- we still have B by T indices that go into the transformer and do
02:37:35.120 | forward-backward, but we're not going to do an update, right? We're going to do many forward-
02:37:38.960 | backwards. We're going to -- and those gradients are all going to plus equals on the parameter
02:37:43.920 | gradients. They're all going to add up. So we're going to do forward-backward grad-accum-steps
02:37:48.960 | number of times, and then we're going to do a single update once all that is accumulated.
02:37:53.200 | So in particular, our micro-batch size is just now controlling how many tokens,
02:37:58.960 | how many rows we're processing in a single go of a forward-backward.
02:38:01.840 | So here we are doing 16 times 124. We're doing 16384
02:38:11.440 | tokens per forward-backward, and we are supposed to be doing
02:38:15.120 | 2 to the 19 -- whoops, what am I doing -- 2 to the 19 in total, so the grad-accum will be 32.
02:38:24.000 | So therefore, grad-accum here will work out to 32, and we have to do 32 forward-backward
02:38:32.560 | and then a single update. Now, we see that we have about 100 milliseconds for a single
02:38:38.320 | forward-backward, so doing 32 of them will be -- will make every step roughly three seconds,
02:38:43.680 | just napkin math. So that's grad-accum-steps, but now we actually have to implement that.
02:38:50.080 | So we're going to swing over to our training loop, because now this part here
02:38:55.920 | and this part here, the forward and the backward, we have to now repeat this
02:39:01.200 | 32 times before we do everything else that follows. So let's see how we can implement that.
02:39:07.520 | So let's come over here, and actually, we do have to load a new batch every single time,
02:39:11.520 | so let me move that over here, and now this is where we have the inner loop.
02:39:15.280 | So for micro-step in range grad-accum-steps, we do this. And remember that last-step-backward
02:39:24.720 | always deposits gradients, so we're doing -- inside last-step-backward, there's always a
02:39:28.320 | plus-equals on the gradients. So in every single last-step-backward, gradients will add up on the
02:39:33.840 | gradient tensors. So we last-step-backward, and then we get all the gradients over there,
02:39:41.120 | and then we normalize, and everything else should just follow. So we're very close,
02:39:47.840 | but actually, there's a subtle and deep issue here, and this is actually incorrect.
02:39:53.280 | So I invite you to think about why this is not yet sufficient, and let me fix it then.
02:39:59.520 | Okay, so I brought back the Jupyter Notebook, so we can think about this carefully
02:40:03.280 | in a simple toy setting and see what's happening. So let's create a very simple
02:40:07.840 | neural net that takes a 16 -- vector of 16 numbers and returns a single number.
02:40:12.320 | And then here, I'm creating some random examples x and some targets y, and then we are using the
02:40:21.520 | mean-squared loss here to calculate the loss. So basically, what this is, is four individual
02:40:28.640 | examples, and we're just doing simple regression with the mean-squared loss over those four
02:40:33.600 | examples. Now, when we calculate the loss and we last-step-backward and look at the gradient,
02:40:39.360 | this is the gradient that we achieve. Now, the loss objective here -- notice that in MSC loss,
02:40:45.600 | the default for the loss function is reduction is mean. So we're calculating the average mean loss
02:40:52.800 | here over the four examples. So this is the exact loss objective, and this is the average,
02:41:01.920 | the 1/4, because there are four independent examples here. And then we have the four
02:41:07.280 | examples and their mean-squared error -- the squared error, and then this makes it the
02:41:11.760 | mean-squared error. So therefore, we calculate the squared error and then we normalize it to
02:41:18.240 | make it the mean over the examples, and there's four examples here. So now, when we come to the
02:41:22.880 | gradient accumulation version of it, this here is the gradient accumulation version of it,
02:41:31.200 | where we have grad-account steps of four, and I reset the gradient, with grad-account steps of
02:41:36.320 | four, and now I'm evaluating all the examples individually instead and calling last-step-backward
02:41:41.280 | on them many times, and then we're looking at the gradient that we achieve from that.
02:41:44.640 | So basically, now we forward our function, calculate the exact same loss, do a backward,
02:41:50.640 | and we do that four times, and when we look at the gradient, you'll notice that the gradients
02:41:56.000 | don't match. So here we did a single batch of four, and here we did four gradient accumulation
02:42:03.600 | steps of batch size one, and the gradients are not the same. And basically, the reason that they're
02:42:09.600 | not the same is exactly because this mean squared error gets lost. This one quarter in this loss
02:42:15.600 | gets lost, because what happens here is the loss objective for every one of the loops
02:42:21.760 | is just a mean squared error, which in this case, because there's only a single example,
02:42:26.800 | is just this term here. So that was the loss in the zeroth iteration, the same in the first,
02:42:31.520 | third, and so on. And then when you do the last-step-backward, we're accumulating gradients,
02:42:37.200 | and what happens is that accumulation in the gradient is basically equivalent
02:42:42.160 | to doing a sum in the loss. So our loss actually here is this without the factor of one quarter
02:42:51.840 | outside of it. So we're missing the normalizer, and therefore our gradients are off. And so the
02:42:57.360 | way to fix this, or one of them, is basically we can actually come here and we can say loss equals
02:43:02.480 | loss divide four. And what happens now is that we're scaling our loss, we're introducing a one
02:43:10.720 | quarter in front of all of these places. So all the individual losses are now scaled by one quarter,
02:43:18.160 | and then when we backward, all of these accumulate with a sum, but now there's a one quarter inside
02:43:24.880 | every one of these components, and now our losses will be equivalent. So when I run this,
02:43:31.680 | you see that the gradients are now identical. So long story short, with this simple example,
02:43:37.600 | when you step through it, you can see that basically the reason that this is not correct
02:43:42.160 | is because in the same way as here in the MSC loss, the loss that we're calculating here
02:43:48.720 | in the model is using a reduction of mean as well. So where is the loss? F dot cross entropy.
02:43:59.200 | And by default, the reduction here in cross entropy is also, I don't know why they don't
02:44:04.000 | show it, but it's the mean loss at all the b by t elements, right? So there's a reduction
02:44:12.880 | by mean in there, and if we're just doing this gradient accumulation here, we're missing that.
02:44:17.040 | And so the way to fix this is to simply compensate for the number of gradient accumulation steps,
02:44:22.480 | and we can in the same way divide this loss. So in particular here, the number of steps that
02:44:26.720 | we're doing is loss equals loss divided gradient accumulation steps. So even Copilot gets the
02:44:34.800 | modification. But in the same way exactly, we are scaling down the loss so that when we do
02:44:40.160 | loss step backward, which basically corresponds to a sum in the objective, we are summing up
02:44:45.040 | the already normalized loss. And therefore, when we sum up the losses divided by grad-accum steps,
02:44:52.160 | we are recovering the additional normalizer. And so now these two will be, now this will be
02:44:58.960 | equivalent to the original sort of optimization, because the gradient will come out the same.
02:45:03.680 | Okay, so I had to do a few more touch-ups, and I launched the optimization here. So in particular,
02:45:09.840 | one thing we want to do, because we want to print things nicely, is, well, first of all, we need to
02:45:14.880 | create like an accumulator over the loss. We can't just print the loss, because we'd be printing only
02:45:18.960 | the final loss at the final microstep. So instead, we have loss-accum, which I initialized at zero,
02:45:24.800 | and then I accumulate the loss into it. And I'm using detach so that I'm detaching the tensor
02:45:32.560 | from the graph, and I'm just trying to keep track of the values. So I'm making these leaf nodes when
02:45:38.720 | I add them. So that's loss-accum, and then we're printing that here instead of loss.
02:45:43.840 | And then in addition to that, I had to account for the grad-accum steps inside the tokens processed,
02:45:49.120 | because now the tokens processed per step is b times t times gradient accumulation.
02:45:53.920 | So long story short, here we have the optimization. It looks reasonable, right? We're starting at a
02:46:00.800 | good spot. We calculated the grad-accum steps to be 32, and we're getting about three seconds here,
02:46:07.280 | right? And so this looks pretty good. Now, if you'd like to verify that your optimization and
02:46:17.440 | the implementation here is correct and you're working on a side, well, now because we have
02:46:21.120 | the total batch size and the gradient accumulation steps, our setting of b is purely a performance
02:46:26.640 | optimization kind of setting. So if you have a big GPU, you can actually increase this to 32,
02:46:32.080 | and you'll probably go a bit faster. If you have a very small GPU, you can try 8 or 4.
02:46:36.800 | But in any case, you should be getting the exact same optimization and the same answers
02:46:40.720 | up to a floating point error, because the gradient accumulation kicks in and can handle everything
02:46:47.360 | serially as necessary. So that's it for gradient accumulation, I think.
02:46:52.880 | Okay, so now is the time to bring out the heavy weapons. You've noticed that so far,
02:46:57.040 | we've only been using a single GPU for training. But actually, I am paying for eight GPUs here,
02:47:02.560 | and so we should be putting all of them to work. And in particular, they're all going to collaborate
02:47:07.360 | and optimize over tokens at the same time and communicate so that they're all kind of
02:47:15.520 | collaborating on the optimization. For this, we are going to be using the distributed data parallel
02:47:20.160 | from PyTorch. There's also a legacy data parallel, which I recommend you not use,
02:47:24.320 | and that's kind of like legacy. Distributed data parallel works in a very simple way.
02:47:29.680 | We have eight GPUs, so we're going to launch eight processes, and each process is going to be
02:47:36.240 | assigned a GPU. And for each process, the training loop and everything we've worked on so far is
02:47:41.520 | going to look pretty much the same. Each GPU, as far as it's concerned, is just working on exactly
02:47:46.480 | what we've built so far. But now secretly, there's eight of them, and they're all going to be
02:47:51.280 | processing slightly different parts of the data. And we're going to add one more part, where once
02:47:57.440 | they all calculate their gradients, there's one more part where we do an average of those gradients.
02:48:02.720 | And so that's how they're going to be collaborating on the computational workload here.
02:48:08.720 | So to use all eight of them, we're not going to be launching our script anymore with just
02:48:14.000 | PyTorch-train-gpt2.py. We're going to be running it with a special command called "torch run"
02:48:20.720 | in PyTorch. We'll see that in a bit. And torch run, when it runs our Python script,
02:48:26.720 | will actually make sure to run eight of them in parallel. And it creates these environmental
02:48:33.120 | variables where each of these processes can look up basically which one of the processes it is.
02:48:40.400 | So for example, torch run will set rank, local rank, and world size, environmental variables.
02:48:46.320 | And so this is a bad way to detect whether DDP is running. So if we're using torch run,
02:48:52.960 | if DDP is running, then we have to make sure that CUDA is available,
02:48:57.920 | because I don't know that you can run this on CPU anymore, or that that makes sense to do.
02:49:02.320 | This is some setup code here. The important part is that there's a world size, which for us will
02:49:10.800 | be eight. That's the total number of processes running. There's a rank, which is each process
02:49:16.960 | will basically run the exact same code at the exact same time, roughly. But the only difference
02:49:23.920 | between these processes is that they all have a different DDP rank. So the GPU 0 will have DDP
02:49:30.960 | rank of 0, GPU 1 will have rank of 1, etc. So otherwise, they're all running the exact same
02:49:37.440 | script. It's just that DDP rank will be a slightly different integer. And that is the way for us to
02:49:42.800 | coordinate that they don't, for example, run on the same data. We want them to run on different
02:49:47.360 | parts of the data, and so on. Now, local rank is something that is only used in a multi-node
02:49:53.600 | setting. We only have a single node with eight GPUs. And so local rank is the rank of the GPU
02:49:59.920 | on a single node. So from 0 to 7, as an example. But for us, we're mostly going to be running on
02:50:06.640 | a single box. So the things we care about are rank and world size. This is 8, and this will be
02:50:12.400 | whatever it is, depending on the GPU, that this particular instantiation of the script runs on.
02:50:18.400 | Now, here, we make sure that according to the local rank, we are setting the device
02:50:27.200 | to be CUDA colon. And colon indicates which GPU to use if there are more than one GPUs. So
02:50:34.640 | depending on the local rank of this process, it's going to use just the appropriate GPU. So there's
02:50:40.640 | no collisions on which GPU is being used by which process. And finally, there's a boolean variable
02:50:45.840 | that I like to create, which is the DDP rank equals equals zero. So the master process is
02:50:51.840 | arbitrarily process number zero, and it does a lot of the printing, logging, checkpointing, etc.
02:50:57.040 | And the other processes are thought of mostly as compute processes that are assisting.
02:51:01.280 | And so master process zero will have some additional work to do. All the other processes
02:51:05.680 | will almost just be doing forward and backwards. And if we're not using DDP, and none of these
02:51:10.640 | variables are set, we revert back to single GPU training. So that means that we only have rank
02:51:15.440 | zero, the world size is just one. And we are the master process. And we try to auto detect the
02:51:22.400 | device. And this is world as normal. So so far, all we've done is we've initialized DDP.
02:51:28.480 | And in the case where we're running with Torch run, which we'll see in a bit,
02:51:34.000 | there's going to be eight copies running in parallel, each one of them will have a different
02:51:37.520 | rank. And now we have to make sure that everything happens correctly afterwards.
02:51:42.880 | So the tricky thing with running multiple processes is you always have to imagine that there's going
02:51:48.400 | to be eight processes running in parallel. So as you read the code, now you have to imagine there's
02:51:54.240 | eight, you know, eight Python interpreters running down these lines of code. And the only difference
02:52:00.160 | between them is that they have a different DDP rank. So they all come here, they all pick the
02:52:05.120 | exact same seed, they all make all of these calculations, completely unaware of the other
02:52:10.160 | copies running, roughly speaking, right. So they all make the exact same calculations. And now we
02:52:15.680 | have to adjust these calculations to take into account that there's actually like a certain
02:52:20.960 | world size and certain ranks. So in particular, these micro batches and sequence links, these are
02:52:27.280 | all just per GPU, right. So now there's going to be num processes of them running in parallel.
02:52:33.600 | So we have to adjust this, right, because the Gradacom steps now is going to be total batch
02:52:39.040 | size divided by B times T times DDP world size, because each process will do B times T, and
02:52:49.600 | there's this many of them. And so in addition to that, we want to make sure that this fits nicely
02:52:56.800 | into total batch size, which for us it will because 16 times 124 times eight GPUs is 131K.
02:53:04.480 | And so 524288, this means that our Gradacom will be four with the current settings, right. So
02:53:13.680 | there's going to be 16 times 124 processes in each GPU, and then there's eight GPUs. So we're
02:53:18.960 | going to be doing 131,000 tokens in a single forward backward on the eight GPUs.
02:53:26.240 | So we want to make sure that this fits nicely so that we can derive a nice gradient accumulation
02:53:31.520 | steps. And yeah, let's just adjust the comments here times DDP world size. Okay. So each GPU
02:53:43.600 | calculates this. Now this is where we start to get run into issues, right. So we are each process
02:53:49.280 | going to come by a print, and they're all going to print. So we're going to have eight copies
02:53:54.240 | of these prints. So one way to deal with this is exactly this master process variable that we have.
02:53:59.360 | So if master process, then guard this. And that's just so that we just print this a single time,
02:54:05.760 | because otherwise, all the processes would have computed the exact same variables. And there's
02:54:09.600 | no need to print this eight times. Before getting into the data loader, and we're going to have to
02:54:15.520 | refactor it, obviously, maybe at this point is we should do some prints and just take it out for a
02:54:22.560 | spin and exit at this point. So import sys and sys.exit and print IMGPU DDP rank. IMGPU DDP rank
02:54:40.960 | and print by. So now let's try to run this and just see how this works. So let's take it for a
02:54:51.600 | spin just so we see what it looks like. So normally we used to launch Python train gpt2.py like this.
02:54:57.440 | Now we're going to run with Torch run, and this is what it looks like. So Torch run standalone,
02:55:02.400 | number of processes, for example, is eight for us because we have eight GPUs,
02:55:05.840 | and then train gpt2.py. So this is what the command would look like. And Torch run again,
02:55:12.480 | we'll run eight of these. So let's just see what happens. So first, it gets a little busy. So
02:55:19.600 | there's a lot going on here. So first of all, there's some warnings from distributed, and I
02:55:24.000 | don't actually know that these mean anything. I think this is just like, the code is setting up
02:55:28.640 | and the processes are coming online. And we're seeing some preliminary failure to collect while
02:55:33.440 | the processes come up. I'm not 100% sure about that. But we start to then get into actual prints.
02:55:39.920 | So all the processes went down. And then the first print actually comes from process five,
02:55:48.720 | just by chance. And then it printed. So process five basically got here first,
02:55:53.360 | it said on process on GPU five by, and then this these prints come from the master process.
02:56:00.800 | So process five just finished first, for whatever reason, it just depends on how
02:56:05.760 | the operating system scheduled the processes to run. Then GPU zero ended, then GPU three and two.
02:56:11.600 | And then probably process five or something like that has exited.
02:56:18.080 | And DDP really doesn't like that, because we didn't properly dispose of the multi GPUs setting.
02:56:26.240 | And so process group has not been destroyed before we destruct. So it really doesn't like that. And
02:56:32.640 | in an actual application, we would want to call destroy process group, so that we clean up DDP
02:56:38.320 | properly. And so it doesn't like that too much. And then the rest of the GPUs finish. And that's
02:56:44.320 | it. So basically, we can't guarantee when these processes are running is totally arbitrary,
02:56:48.640 | but they are running in parallel, we don't want that to be printing. And next up, let's erase this.
02:56:56.080 | Next up, we want to make sure that when we create data loader light, we need to now make it aware
02:57:02.160 | of this multi process setting, because we don't want all the processes to be loading the exact
02:57:08.800 | same data, we want every process to get its own chunk of data, so that they're all working on
02:57:13.360 | different parts of the data set, of course. So let's adjust that. So one particularly simple
02:57:19.040 | and a naive way to do this is we have to make sure that we pass in the rank and the size
02:57:23.840 | to the data loader. And then we come up here, we see that we now take rank and processes and we
02:57:30.160 | save them. Now, the current position will not be zero. Because what we want is we want to stride
02:57:37.200 | out all the processes. So one way to do this is we basically take self.b times self.t, and then
02:57:44.080 | multiply it by the process rank. So process rank zero will start at zero, but process rank one
02:57:51.520 | now starts at b times t process rank two is starts at two times b times t, etc. So that is the
02:57:58.080 | initialization. Now we still they still do this identically. But now when we advance, we don't
02:58:05.200 | advance by b times t, we advance by b times t times number of processes. Right? So basically,
02:58:12.400 | the total number of tokens that we're consuming is b times t times numProcesses. And they all go
02:58:20.240 | off to a different rank. And the position has to advance by the entire chunk. And then here at b
02:58:28.240 | times t times self.numProcesses plus one would be to exceed number of tokens, then we're going to
02:58:35.120 | loop. And when we loop, we want to of course, loop in the exact same way. So we sort of like reset
02:58:40.800 | back. So this is the simplest change that I can find for kind of a very simple distributed data
02:58:47.920 | loader like. And you can notice that if process rank is zero, and numProcesses is one, then the
02:58:54.320 | whole thing will be identical to what we had before. But now we can have actually multiple
02:58:58.080 | processes running and this should work fine. So that's the data loader. Okay, so next up,
02:59:06.320 | once they've all initialized the data loader, they come here and they all create a GPT model.
02:59:11.280 | So we create eight GPT models on eight processes. But because the seeds are fixed here,
02:59:17.360 | they all create the same identical model. They all move it to the device of their rank,
02:59:22.560 | and they all compile the model. And because the models are identical, there are eight identical
02:59:27.200 | compilations happening in parallel, but that's okay. Now, none of this changes because that is
02:59:33.040 | on a per step basis. And we're currently working kind of within step because we need to just all
02:59:39.520 | the all the changes we're making are kind of like a within step changes. Now, the important thing
02:59:44.400 | here is when we construct the model, we actually have a bit of work to do here. GetLogits is
02:59:49.680 | deprecated. So create model. We need to actually wrap the model into the distributed data parallel
02:59:57.600 | container. So this is how we wrap the model into the DDP container. And these are the docs for
03:00:05.440 | DDP. And they're quite extensive. And there's a lot of caveats and a lot of things to be careful
03:00:10.000 | with because everything complexifies times 10 when multiple processes are involved. But roughly
03:00:15.920 | speaking, this device IDs I believe has to be passed in. Now, unfortunately, the docs for what
03:00:20.320 | device IDs is, is extremely unclear. So when you actually like come here, this comment for what
03:00:27.360 | device IDs is, is roughly nonsensical. But I'm pretty sure it's supposed to be the DDP local
03:00:34.400 | rank. So not the DDP rank, the local rank. So this is what you pass in here. This wraps the model.
03:00:41.840 | And in particular, what DDP does for you is in a forward pass, it actually behaves identically. So
03:00:46.880 | my understanding of it is nothing should be changed in the forward pass.
03:00:51.120 | But in the backward pass, as you are doing the backward pass, in the simplest setting,
03:00:56.720 | once the backward pass is over on each independent GPU, each independent GPU has the gradient for
03:01:03.200 | all the parameters. And what DDP does for you is once the backward pass is over, it will call
03:01:09.040 | what's called all reduce. And it basically does an average across all the ranks of their gradients.
03:01:16.720 | And then it will deposit that average on every single rank. So every single rank will end up
03:01:23.040 | with the average on it. And so basically, that's the communication, it just synchronizes and
03:01:27.920 | averages the gradients. And that's what DDP offers you. Now, DDP actually is a little bit more,
03:01:32.320 | is a little bit more involved than that, because as you are doing the backward pass through the
03:01:38.240 | layers of the transformer, it actually can dispatch communications for the gradient while
03:01:43.360 | the backward pass is still happening. So there's overlap of the communication of the gradients and
03:01:48.160 | the synchronization of them and the backward pass. And this is just more efficient to do it that way.
03:01:55.680 | So that's what DDP does for you. Forward is unchanged, and backward is mostly unchanged.
03:02:02.000 | And we're tacking on this average, as we'll see in a bit. Okay, so now let's go to the optimization.
03:02:09.120 | Nothing here changes. Let's go to the optimization here, the inner loop, and think through the
03:02:13.680 | synchronization of these gradients in the DDP. So basically, by default, what happens, as I
03:02:18.800 | mentioned, is when you do loss dot backward here, it will do the backward pass, and then it will
03:02:23.760 | synchronize the gradients. The problem here is because of the gradient accumulation steps loop
03:02:30.320 | here, we don't actually want to do the synchronization after every single loss dot backward,
03:02:36.560 | because we are just depositing gradients, and we're doing that serially, and we just want them
03:02:40.800 | adding up, and we don't want to synchronize every single time. That would be extremely wasteful.
03:02:44.960 | So basically, we want to add them up, and then on the very last — it's only on the very last step,
03:02:50.880 | when microstep becomes grad-accum steps minus one, only at that last step do we want to actually do
03:02:57.840 | the all-reduce to average up the gradients. So to do that, we come here, and the official
03:03:05.440 | sanctioned way, by the way, is to do this no-sync context manager. So PyTorch says this is a context
03:03:12.560 | manager to disable gradient synchronization across DDP processes. So within this context,
03:03:17.280 | gradients will be accumulated, and basically, when you do no-sync, there will be no communication.
03:03:23.760 | So they are telling us to do, with DDP no-sync, do the gradient accumulation, accumulate grads,
03:03:30.080 | and then they are asking us to do DDP again with another input and dot backward. And I just really
03:03:35.600 | don't love this. I just really don't like it, the fact that you have to copy-paste your code here
03:03:40.480 | and use a context manager, and this is just super ugly. So when I went to the source code here,
03:03:44.720 | you can see that when you enter, you simply toggle this variable,
03:03:50.880 | this require backward grad sync, and this is being toggled around and changed. And this is the
03:03:58.560 | variable that basically, if you step through it, is being toggled to determine if the gradient is
03:04:05.200 | going to be synchronized. So I actually just kind of like to use that directly. So instead, what I
03:04:10.960 | like to do is the following. Right here, before the last dot backward, if we are using DDP, then
03:04:19.680 | we only want to synchronize, we only want this variable to be true when it is the final iteration.
03:04:29.600 | In all the other iterations inside the microsteps, we want it to be false. So I just toggle it like
03:04:35.440 | this. So require backward grad sync should only turn on when the microstep is the last step.
03:04:40.720 | And so I'm toggling this variable directly, and I hope that that impacts last dot backward,
03:04:48.080 | and this is a naughty thing to do because they could probably change the DDP and this variable
03:04:52.640 | will go away. But for now, I believe this works, and it allows me to avoid the use of context
03:04:58.160 | managers and code duplication. I'm just toggling the variable, and then last dot backward will not
03:05:02.560 | synchronize most of the steps, and it will synchronize the very last step. And so once
03:05:06.480 | this is over, and we come out, every single rank will suddenly magically have the average
03:05:17.120 | of all the gradients that were stored on all the ranks. So now we have to think through whether
03:05:22.800 | that is what we want, and also if this suffices, and how it works with the loss,
03:05:30.160 | and what is loss_accum. So let's think through that now. And the problem I'm getting at is that
03:05:35.440 | we've averaged the gradients, which is great, but the loss_accum has not been impacted yet,
03:05:41.200 | and this is outside of the DDP container, so that is not being averaged. And so here,
03:05:47.360 | when we are printing loss_accum, well, presumably we're only going to be printing on the master
03:05:51.680 | process, rank 0, and it's just going to be printing the losses that it saw on its process.
03:05:56.480 | But instead, we want it to print the loss over all the processes and the average of that loss,
03:06:02.240 | because we did average of gradients, so we want the average of loss as well.
03:06:06.240 | So simply here, after this, this is the code that I've used in the past,
03:06:11.200 | and instead of loss_eff, we want loss_accum. So if DDP, again, then dist is a PyTorch distributed.
03:06:22.320 | I import it. Where do I import it? Oh, gosh. So this file is starting to get out of control, huh?
03:06:32.080 | So import torch.dist, so dist.all_reduce, and we're doing the average on loss_accum.
03:06:41.920 | And so this loss_accum tensor exists on all the ranks. When we call all_reduce of average,
03:06:46.960 | it creates the average of those numbers, and it deposits that average on all the ranks.
03:06:52.000 | So all the ranks after this call will now contain loss_accum averaged up.
03:06:58.880 | And so when we print here on the master process,
03:07:01.120 | the loss_accum is identical in all the other ranks as well.
03:07:03.280 | So here, if master_process, oops, we want to print like this.
03:07:09.360 | Okay, and finally, we have to be careful, because we're not processing even more tokens.
03:07:14.880 | So times DDP_world_size, that's the number of tokens that we've processed up above.
03:07:20.640 | And everything else should be fine.
03:07:27.760 | The only other thing to be careful with is, as I mentioned, you want to destroy the process group
03:07:32.160 | so that we are nice to nickel, and it's not going to DDP,
03:07:36.160 | and it's not going to complain to us when we exit here.
03:07:39.600 | So that should be it. Let's try to take it for a spin.
03:07:43.840 | Okay, so I launched the script, and it should be printing here imminently.
03:07:47.840 | We're now training with eight GPUs at the same time.
03:07:50.320 | So the gradient accumulation steps is not 32, it is now divide 8, and it's just 4.
03:07:56.000 | So otherwise, this is what the optimization now looks like.
03:08:00.800 | And wow, we're going really fast.
03:08:02.320 | So we're processing 1.5 million tokens per second now.
03:08:08.640 | So these are some serious numbers.
03:08:10.720 | And the tiny Shakespeare dataset is so tiny that we're just doing like
03:08:13.520 | so many epochs over it, most likely.
03:08:16.480 | But this is roughly what it looks like.
03:08:17.840 | One thing that I had to fix, by the way, is that this was model.configure_optimizers,
03:08:24.320 | which now doesn't work because model now is a DDP model.
03:08:27.440 | So instead, this has to become raw_model.configure_optimizers,
03:08:32.080 | where raw_model is something I create here.
03:08:35.120 | So right after I wrap the model into DDP, I have to create the raw_model,
03:08:39.760 | which in the case of DDP is a model.module,
03:08:43.280 | is where it stores the raw NM module of GPT-2 as we have it,
03:08:48.320 | which contains the configure_optimizers function that we want to call.
03:08:52.080 | So that's one thing that I had to fix.
03:08:54.320 | Otherwise, this seems to run.
03:08:55.920 | Now, one thing you'll notice is that when you actually compare this run
03:08:59.200 | and the numbers in it to just running a single GPU,
03:09:01.920 | you'll notice that this is a single GPU run with 32 Gradacom.
03:09:06.640 | The numbers won't exactly match up.
03:09:08.560 | And it's kind of a boring reason for why that happens.
03:09:12.800 | The reason for that is that in the data loader,
03:09:15.360 | we're basically just iterating through batches in a slightly different way,
03:09:18.320 | because now we're looking for an entire page of data.
03:09:21.200 | And if that page for all the GPUs,
03:09:24.240 | if that chunk exceeds the number of tokens, we just loop.
03:09:27.920 | And so actually the single GPU and the GPU process
03:09:31.280 | will end up resetting in a slightly different manner.
03:09:35.200 | And so our batches are slightly different.
03:09:37.120 | And so we get slightly different numbers.
03:09:38.640 | But one way to convince yourself that this is okay
03:09:42.480 | is just make the total batch size much smaller and the B and a T.
03:09:46.240 | And then so I think I used 4 times 124 times 8.
03:09:52.000 | So I used 32768 as a total batch size.
03:09:54.640 | And then so I made sure that the single GPU
03:09:58.240 | will do eight gradient accumulation steps.
03:10:00.400 | And then I did multi GPU.
03:10:01.600 | And then you're reducing the boundary effects of the data loader.
03:10:05.120 | And you'll see that the numbers match up.
03:10:07.120 | So long story short, we're now going really, really fast.
03:10:10.560 | The optimization is mostly consistent with GPT-2 and 3 hybrid parameters.
03:10:14.960 | And we have outgrown our tiny Shakespeare file.
03:10:18.800 | And we want to upgrade it.
03:10:20.160 | So let's move to that next.
03:10:21.760 | So let's now take a look at what data sets were used by GPT-2 and GPT-3.
03:10:25.040 | So GPT-2 used this web text data set that was never released.
03:10:29.520 | There's an attempt at reproducing it called open web text.
03:10:33.840 | So basically, roughly speaking, what they say here in the paper is that
03:10:36.880 | they scraped all outbound links from Reddit.
03:10:39.040 | And then with at least three karma.
03:10:42.560 | And that was kind of like their starting point.
03:10:44.160 | And they collected all the web pages and all the text in them.
03:10:47.520 | And so this was 45 million links.
03:10:49.760 | And this ended up being 40 gigabytes of text.
03:10:51.600 | So that's roughly what GPT-2 says about its data set.
03:10:57.200 | So it's basically outbound links from Reddit.
03:10:59.040 | Now, when we go over to GPT-3, there's a training data set section.
03:11:03.360 | And that's where they start to talk about Common Crawl, which is a lot more used.
03:11:08.720 | Actually, I think even GPT-2 talked about Common Crawl.
03:11:13.600 | But basically, it's not a very high quality data set all by itself,
03:11:16.640 | because it is extremely noisy.
03:11:18.000 | This is a completely random subset of the internet.
03:11:20.560 | And it's much worse than you think.
03:11:22.000 | So people go into great lengths to filter Common Crawl,
03:11:25.040 | because there's good stuff in it.
03:11:26.480 | But most of it is just like ad spam, random tables and numbers and stock tickers.
03:11:31.440 | And it's just a total mess.
03:11:34.320 | So that's why people like to train on these data mixtures
03:11:39.840 | that they curate and are careful with.
03:11:42.880 | So a large chunk of these data mixtures typically will be Common Crawl.
03:11:46.400 | Like, for example, 50% of the tokens will be Common Crawl.
03:11:49.440 | But then here in GPT-3, they're also using WebText2 from before.
03:11:52.880 | So that's Reddit outbound.
03:11:54.320 | But they're also adding, for example, books.
03:11:56.480 | And they're adding Wikipedia.
03:11:57.840 | There's many other things you can decide to add.
03:11:59.840 | Now, this data set for GPT-3 was also never released.
03:12:03.680 | So today, some of the data sets that I'm familiar with that are quite good
03:12:06.640 | and would be representative of something along these lines
03:12:09.840 | are, number one, the Red Pajama data set.
03:12:12.400 | Or more specifically, for example, the Slim Pajama subset of the Red Pajama data set,
03:12:17.520 | which is a cleaned and deduplicated version of it.
03:12:20.160 | And just to give you a sense, again, it's a bunch of Common Crawl.
03:12:23.440 | C4, which is also, as far as I know, more Common Crawl, but processed differently.
03:12:29.200 | And then we have GitHub, Books, Archive, Wikipedia, StackExchange.
03:12:33.600 | These are the kinds of data sets that would go into these data mixtures.
03:12:36.160 | Now, specifically the one that I like that came out recently
03:12:39.680 | is called FineWebDataset.
03:12:41.520 | So this is an attempt to basically collect really high-quality Common Crawl data
03:12:47.760 | and filter it, in this case, to 15 trillion tokens.
03:12:50.320 | And then in addition to that, more recently, Hugging Face released this
03:12:54.080 | FineWebEDU subset, which is 1.3 trillion of educational
03:12:59.120 | and 5.4 trillion of high-educational content.
03:13:02.080 | So basically, they're trying to filter Common Crawl
03:13:05.200 | to very high-quality educational subsets.
03:13:08.480 | And this is the one that we will use.
03:13:10.880 | There's a long web page here on FineWeb.
03:13:14.240 | And they go into a ton of detail about how they process the data,
03:13:16.960 | which is really fascinating reading, by the way.
03:13:19.040 | And I would definitely recommend, if you're interested
03:13:20.880 | into data mixtures and so on, and how data gets processed at these scales,
03:13:24.640 | look at this page.
03:13:25.760 | And more specifically, we'll be working with the FineWebEDU, I think.
03:13:29.920 | And it's basically educational content from the internet.
03:13:35.200 | They show that training on educational content in their metrics
03:13:38.640 | works really, really well.
03:13:41.920 | And we're going to use this sample 10 billion tokens subsample of it.
03:13:48.640 | Because we're not going to be training on trillions of tokens.
03:13:51.040 | We're just going to train on a 10 billion sample of the FineWebEDU.
03:13:56.160 | Because empirically, in my previous few experiments,
03:13:58.800 | this actually suffices to really get close to GPT-2 performance.
03:14:02.400 | And it's simple enough to work with.
03:14:04.480 | And so let's work with the sample 10 BT.
03:14:07.600 | So our goal will be to download it, process it,
03:14:11.440 | and make sure that our data loader can work with it.
03:14:13.840 | So let's get to that.
03:14:15.440 | Okay, so I introduced another file here
03:14:19.440 | that will basically download FineWebEDU from Hugging Face datasets.
03:14:23.920 | It will pre-process and pre-tokenize all of the data.
03:14:27.360 | And it will save data shards to a folder on a local disk.
03:14:34.160 | And so while this is running, I just wanted to briefly mention
03:14:39.040 | that you can kind of look through the dataset viewer here
03:14:41.760 | just to get a sense of what's in here.
03:14:43.440 | And it's kind of interesting.
03:14:44.400 | I mean, it basically looks like it's working fairly well.
03:14:48.080 | Like it's talking about nuclear energy in France.
03:14:50.080 | It's talking about Mexican America, some Mac Pi Js, et cetera.
03:14:57.200 | So actually, it seems like their filters are working pretty well.
03:14:59.840 | The filters here, by the way, were applied automatically
03:15:03.280 | using LLAMA370B, I believe.
03:15:06.320 | And so basically, LLMs are judging which content is educational
03:15:10.800 | and that ends up making it through the filter.
03:15:12.480 | So that's pretty cool.
03:15:14.320 | Now, in terms of the script itself,
03:15:15.680 | I'm not going to go through the full script
03:15:17.840 | because it's not as interesting and not as LLM-centric.
03:15:21.520 | But when you run this, basically, number one,
03:15:23.840 | we're going to load the dataset,
03:15:25.520 | which this is all Hugging Face code running this.
03:15:28.160 | You're going to need to pip install datasets.
03:15:31.200 | So it's downloading the dataset.
03:15:33.200 | Then it is tokenizing all of the documents inside this dataset.
03:15:37.600 | Now, when we tokenize the documents,
03:15:39.600 | you'll notice that to tokenize a single document,
03:15:43.200 | we first start the tokens with the end of text token.
03:15:48.480 | And this is a special token in the GPT-2 tokenizer, as you know.
03:15:51.360 | So 50,256 is the ID of the end of text.
03:15:55.440 | And this is what begins a document,
03:15:57.040 | even though it's called end of text.
03:15:58.640 | But this is the first token that begins a document.
03:16:01.520 | Then we extend with all of the tokens of that document.
03:16:05.360 | Then we create a NumPy array out of that.
03:16:07.840 | We make sure that all the tokens are between...
03:16:11.200 | Okay, let me debug this.
03:16:14.720 | Okay, so apologies for that.
03:16:16.720 | It just had to do with me using a float division in Python.
03:16:19.600 | It must be integer division,
03:16:20.960 | so that this is an int and everything is nice.
03:16:25.520 | Okay, but basically, the tokenization here is relatively straightforward.
03:16:29.200 | Returns tokens in mp.un16.
03:16:31.840 | We're using un.16 to save a little bit of space
03:16:34.960 | because 2 to the 16 minus 1 is 65,000.
03:16:39.120 | So the GPT-2 max token ID is well below that.
03:16:42.080 | And then here, there's a bunch of multiprocessing code.
03:16:45.280 | And it's honestly not that exciting,
03:16:46.640 | so I'm not gonna step through it.
03:16:48.320 | But we're loading the dataset, we're tokenizing it,
03:16:51.360 | and we're saving everything to shards.
03:16:53.760 | And the shards are NumPy files.
03:16:55.600 | So just storing a NumPy array,
03:16:58.640 | which is very, very similar to Torch tensors.
03:17:02.000 | And the first shard, 000, is a validation shard.
03:17:07.120 | And all the other shards are training shards.
03:17:10.080 | And as I mentioned, they all have 100 million tokens in them exactly.
03:17:13.760 | And that just makes it easier to work with, to shard the files.
03:17:20.320 | Because if we just have a single massive file,
03:17:22.080 | sometimes they can be hard to work with on the disk.
03:17:24.720 | And so sharding it is just kind of a messier from that perspective.
03:17:28.320 | And yeah, so we'll just let this run.
03:17:31.840 | This will be probably 30-ish minutes or so.
03:17:36.400 | And then we're gonna come back to actually train on this data.
03:17:39.040 | And we're gonna be actually doing some legit pre-training in this case.
03:17:41.840 | This is a good dataset.
03:17:43.680 | We're doing lots of tokens per second.
03:17:45.840 | We have eight GPUs, the code is ready.
03:17:48.080 | And so we're actually gonna be doing a serious training run.
03:17:50.720 | So let's get back in a bit.
03:17:52.320 | Okay, so we're back.
03:17:54.000 | So if we ls edu find_web, we see that there's now 100 shards in it.
03:17:59.760 | And that makes sense because each shard is 100 million tokens.
03:18:04.400 | So 100 shards of that is 10 billion tokens in total.
03:18:07.360 | Now, swinging over to the main file,
03:18:09.760 | I made some adjustments to our data loader again.
03:18:12.480 | And that's because we're not running with Shakespeare anymore.
03:18:16.000 | We want to use the find_web shards.
03:18:18.640 | And so you'll see some code here that additionally basically can load these shards.
03:18:22.640 | We load the UN16 numpy file.
03:18:26.800 | We convert it to a torch.long tensor,
03:18:29.520 | which is what a lot of the layers up top expect by default.
03:18:33.040 | And then here, we're just enumerating all the shards.
03:18:35.280 | I also added a split to data_loader_light.
03:18:38.800 | So we can load the split train, but also the split val, the zero split.
03:18:43.120 | And then we can load the shards.
03:18:46.080 | And then here, we also have not just the current position now,
03:18:49.440 | but also the current shard.
03:18:51.520 | So we have a position inside a shard.
03:18:53.680 | And then when we run out of tokens in a single shard,
03:18:57.120 | we first advance the shard and loop if we need to.
03:19:00.560 | And then we get the tokens and readjust the position.
03:19:03.360 | So this data loader will now iterate all the shards as well.
03:19:07.520 | So I changed that.
03:19:09.600 | And then the other thing that I did while the data was processing
03:19:13.520 | is our train loader now has split train, of course.
03:19:16.560 | And down here, I set up some numbers.
03:19:20.240 | So we are doing 2 to the 19 tokens per step.
03:19:28.960 | And we want to do roughly 10 billion tokens,
03:19:33.440 | because that's how many unique tokens we have.
03:19:36.400 | So if we did 10 billion tokens, then divide that by 2 to the 19,
03:19:40.080 | we see that this is 19,073 steps.
03:19:43.600 | So that's where that's from.
03:19:44.560 | And then the GPT-3 paper says that they warm up the learning rate over 375 million tokens.
03:19:50.800 | So I came here and 375E6 tokens divide 2 to the 19 is 715 steps.
03:20:00.320 | So that's why warm up steps is set to 715.
03:20:03.040 | So this will exactly match the warm up schedule that GPT-3 used.
03:20:08.400 | And I think 715, by the way, is very mild.
03:20:11.600 | And this could be made significantly more aggressive.
03:20:13.520 | Probably even like 100 is good enough.
03:20:15.280 | But it's okay.
03:20:17.520 | Let's leave it for now so that we have the exact hyperparameters of GPT-3.
03:20:20.880 | So I fixed that.
03:20:23.040 | And then that's pretty much it.
03:20:26.160 | We can run.
03:20:27.840 | So we have our script here.
03:20:29.200 | And we can launch.
03:20:32.000 | And actually, sorry, let me do one more thing.
03:20:38.000 | [COUGHS]
03:20:38.800 | Excuse me.
03:20:39.360 | For my GPU, I can actually fit more batch size.
03:20:44.240 | And I believe I can fit 64 on my GPU as a micro-batch size.
03:20:49.840 | So let me try that.
03:20:51.040 | I could be misremembering.
03:20:56.880 | But that means 64 times 124 per GPU.
03:20:59.600 | And then we have 8 GPUs.
03:21:01.280 | So that means we would not even be doing gradient accumulation if this fits.
03:21:05.040 | Because this just multiplies out to the full total batch size.
03:21:09.840 | So no gradient accumulation.
03:21:12.080 | And that would run pretty quickly if that fits.
03:21:14.880 | Let's go.
03:21:27.200 | Let's go.
03:21:27.700 | I mean, if this works, then this is basically a serious pre-training run.
03:21:31.840 | We're not logging.
03:21:33.920 | We're not evaluating the validation split.
03:21:35.680 | We're not running any evaluations yet.
03:21:37.680 | So it's not-- we haven't crossed our Ts and dotted our Is.
03:21:41.040 | But if we let this run for a while, we're going to actually get a pretty good model.
03:21:46.400 | And the model that might even be on par with or better than GPT-124M.
03:21:51.780 | So it looks like everything is going great.
03:21:55.360 | We're processing 1.5 million tokens per second.
03:21:58.160 | Everything here looks good.
03:22:03.360 | We're doing 330 milliseconds per iteration.
03:22:06.880 | And we have to do a total of-- where are we printing that?
03:22:11.440 | 1973.
03:22:12.800 | So 19073 times 0.33 is this many seconds, this many minutes.
03:22:20.080 | So this will run for 1.7 hours.
03:22:23.360 | So one and a half hour run like this.
03:22:28.400 | And we don't even have to use gradient accumulation, which is nice.
03:22:31.760 | And you might not have that luxury in your GPU.
03:22:34.160 | In that case, just start decreasing the batch size until things fit.
03:22:37.520 | But keep it to nice numbers.
03:22:38.800 | So that's pretty exciting.
03:22:42.240 | We're currently warming up the learning rate.
03:22:43.760 | So you see that it's still very low, 1e-4.
03:22:46.640 | So this will ramp up over the next few steps all the way to 6e-4 here.
03:22:53.120 | Very cool.
03:22:55.280 | So now what I'd like to do is let's cross the Ts and dot our Is.
03:22:58.640 | Let's evaluate on the validation split.
03:23:00.880 | And let's try to figure out how we can run evals, how we can do logging,
03:23:04.640 | how we can visualize our losses, and all the good stuff.
03:23:07.840 | So let's get to that before we actually do the run.
03:23:10.960 | OK, so I've adjusted the code so that we're evaluating on the validation split.
03:23:14.720 | So creating the val loader just by passing in split equals val,
03:23:18.320 | that will basically create a data loader just for the validation shard.
03:23:21.920 | The other thing I did is in the data loader, I introduced a new function reset,
03:23:27.920 | which is called at init.
03:23:29.440 | And it basically resets the data loader.
03:23:31.600 | And that is very useful because when we come to the main training loop now--
03:23:35.360 | so this is the code that I've added.
03:23:37.920 | And basically, every 100th iteration, including the 0th iteration,
03:23:42.640 | we put the model into evaluation mode.
03:23:45.120 | We reset the val loader.
03:23:46.480 | And then no gradients involved.
03:23:50.400 | We're going to basically accumulate the gradients over, say, 20 steps.
03:23:56.720 | And then average it all up and print out the validation loss.
03:23:59.200 | And so that basically is the exact same logic as the training loop, roughly.
03:24:05.680 | But there's no loss that backward.
03:24:07.360 | It's only inference.
03:24:08.240 | We're just measuring the loss.
03:24:09.360 | We're adding it up.
03:24:10.400 | Everything else otherwise applies.
03:24:12.160 | And it's exactly as we've seen it before.
03:24:14.320 | And so this will print the validation loss every 100th iteration,
03:24:18.560 | including the very first iteration.
03:24:20.080 | So that's nice.
03:24:22.160 | That will tell us a little bit about how much we're overfitting.
03:24:26.560 | That said, we have roughly infinity data.
03:24:29.760 | So we're mostly expecting our train and val loss to be about the same.
03:24:32.880 | But the other reason I'm interested in this is because we can take the GPT-2-124M
03:24:37.760 | as OpenAI released it.
03:24:39.280 | We can initialize from it.
03:24:40.800 | And we can basically see what kind of loss it achieves on the validation loss as well.
03:24:44.480 | And that gives us an indication as to how much that model would generalize to 124M.
03:24:49.840 | But it's not-- sorry, to fine web EDU validation split.
03:24:54.080 | That said, it's not a super fair comparison to GPT-2
03:24:56.640 | because it was trained on a very different data distribution.
03:24:58.960 | But it's still kind of like an interesting data point.
03:25:00.960 | And in any case, you would always want to have a validation split in a training run like this
03:25:06.800 | so that you can make sure that you are not overfitting.
03:25:10.960 | And this is especially a concern if we were to make more epochs in our training data.
03:25:15.440 | So for example, right now, we're just doing a single epoch.
03:25:19.120 | But if we get to a point where we want to train on 10 epochs or something like that,
03:25:22.640 | we would be really careful with-- maybe we are memorizing that data too much
03:25:27.360 | if we have a big enough model.
03:25:28.560 | And our validation split would be one way to tell whether that is happening.
03:25:32.800 | OK, and in addition to that, if you remember, at the bottom of our script,
03:25:36.000 | we had all of this orphaned code for sampling from way back when.
03:25:39.440 | So I deleted that code.
03:25:40.720 | And I moved it up to here.
03:25:43.520 | So once in a while, we sample a validation.
03:25:46.640 | Once in a while, we sample-- we generate samples.
03:25:50.640 | And then we do that only every 100 steps.
03:25:53.920 | And we train on every single step.
03:25:55.920 | So that's how I have a structure right now.
03:25:57.360 | And I've been running this for 1,000 iterations.
03:26:00.000 | So here are some samples on iteration 1,000.
03:26:01.920 | Hello, I'm a language model.
03:26:06.800 | And I'm not able to get more creative.
03:26:08.320 | I'm a language model.
03:26:10.640 | And languages file you're learning about here is-- or is the beginning of a computer.
03:26:14.480 | OK, so this is all pretty-- there's still a garble.
03:26:20.640 | But we're only at iteration 1,000.
03:26:22.800 | And we've only just barely reached the maximum learning rate.
03:26:25.840 | So this is still learning.
03:26:27.280 | We're about to get some more samples coming up in 1,100.
03:26:31.280 | OK, this is-- the model is still a young baby.
03:26:39.680 | OK, so basically, all of this sampling code that I've put here,
03:26:45.280 | everything should be familiar to you and came from before.
03:26:48.080 | The only thing that I did is I created a generator object in PyTorch
03:26:51.840 | so that I have a direct control over the sampling of the random numbers.
03:26:55.920 | Because I don't want to impact the RNG state of the random number generator
03:27:00.320 | that is the global one used for training.
03:27:02.800 | I want this to be completely outside of the training loop.
03:27:05.040 | And so I'm using a special sampling RNG.
03:27:08.240 | And then I make sure to seed it, that every single rank has a different seed.
03:27:13.280 | And then I pass in here, where we sort of consume random numbers in multinomial,
03:27:18.240 | where the sampling happens.
03:27:19.760 | I make sure to pass in the generator object there.
03:27:22.240 | Otherwise, this is identical.
03:27:23.440 | Now, the other thing is you'll notice that we're running a bit slower.
03:27:28.160 | That's because I actually had to disable torch.compile to get this to sample.
03:27:32.240 | And so we're running a bit slower.
03:27:35.120 | So for some reason, it works with no torch.compile.
03:27:37.040 | But when I torch.compile my model, I get a really scary error from PyTorch.
03:27:40.800 | And I have no idea how to resolve it right now.
03:27:42.800 | So probably by the time you see this code released or something like that, maybe it's fixed.
03:27:47.120 | But for now, I'm just going to do end false.
03:27:49.280 | And I'm going to bring back torch.compile.
03:27:52.400 | And you're not going to get samples.
03:27:53.760 | And I think I'll fix this later.
03:27:56.720 | By the way, I will be releasing all this code.
03:28:00.480 | And actually, I've been very careful about making git commits every time we add something.
03:28:04.960 | And so I'm going to release the entire repo that starts completely from scratch,
03:28:08.960 | all the way to now and after this as well.
03:28:12.560 | And so everything should be exactly documented in the git commit history.
03:28:15.440 | And so I think that will be nice.
03:28:18.720 | So hopefully, by the time you go to GitHub, this is removed and it's working.
03:28:22.160 | And I will have fixed the bug.
03:28:23.840 | OK, so I have the optimization running here.
03:28:25.600 | And it's stepping and we're on step 6,000 or so.
03:28:28.800 | So we're about 30% through training.
03:28:30.960 | Now, while this is training, I would like to introduce one evaluation
03:28:34.000 | that we're going to use to supplement the validation set.
03:28:36.480 | And that is the Hellaswag eval.
03:28:39.840 | So Hellaswag comes from this paper back in 2019.
03:28:43.360 | So it's a five-year-old eval now.
03:28:44.800 | And the way Hellaswag works is there's basically a sentence completion data set.
03:28:49.680 | So it's a multiple choice.
03:28:51.520 | For every one of these questions, we have basically a shared context,
03:28:56.080 | like a woman is outside with a bucket and a dog.
03:28:58.960 | The dog is running around trying to avoid bath.
03:29:01.600 | She, A, rinses the bucket off with soap and blow dry the dog's head.
03:29:07.360 | B, uses a hose to keep it from getting soapy.
03:29:09.760 | C, gets the dog wet and it runs away again.
03:29:12.880 | Or D, gets into a bathtub with the dog.
03:29:15.920 | And so basically, the idea is that these multiple choice are constructed
03:29:21.040 | so that one of them is a natural continuation of the sentence and the others are not.
03:29:29.120 | And the others might not make sense, like uses the hose to keep it from getting soapy.
03:29:34.640 | That makes no sense.
03:29:35.840 | And so what happens is that models that are not trained very well
03:29:39.360 | are not able to tell these apart.
03:29:41.520 | But models that have a lot of world knowledge and can tell a lot about the world
03:29:48.400 | will be able to create these completions.
03:29:50.960 | And these sentences are sourced from ActivityNet and from Wikihow.
03:29:55.760 | And at the bottom of the paper, there's kind of like a cool chart
03:30:03.200 | of the kinds of domains in Wikihow.
03:30:05.280 | So there's a lot of sentences from computers and electronics and homes and garden.
03:30:09.840 | And it has kind of a broad coverage of the kinds of things
03:30:13.040 | you need to know about the world in order to find the most likely completion
03:30:17.120 | and the identity of that completion.
03:30:21.120 | One more thing that's kind of interesting about Hellaswag is the way it was constructed
03:30:25.920 | is that the incorrect options are deliberately adversarially sourced.
03:30:34.240 | So they're not just random sentences.
03:30:36.400 | They're actually sentences generated by language models.
03:30:39.120 | And they're generated in such a way that language models basically find them difficult,
03:30:43.040 | but humans find them easy.
03:30:45.040 | And so they mentioned that humans have a 95% accuracy on this set.
03:30:48.800 | But at the time, the state-of-the-art language models had only 48%.
03:30:52.000 | And so at the time, this was a good benchmark.
03:30:54.800 | Now, you can read the details of this paper to learn more.
03:30:59.200 | The thing to point out, though, is that this is five years ago.
03:31:02.880 | And since then, what happened to Hellaswag is that it's been totally just solved.
03:31:10.080 | And so now the language models here are 96%.
03:31:13.040 | So basically, the last 4% is probably errors in the data set,
03:31:17.120 | or the questions are really, really hard.
03:31:19.360 | And so basically, this data set is kind of crushed with respect to language models.
03:31:22.560 | But back then, the best language model was only at about 50%.
03:31:24.960 | But this is how far things got.
03:31:29.440 | But still, the reason people like Hellaswag, and it's not used, by the way, in GPT-2,
03:31:35.360 | but in GPT-3, there is Hellaswag eval.
03:31:38.800 | And lots of people use Hellaswag.
03:31:40.480 | And so for GPT-3, we have results here that are cited.
03:31:47.440 | So we know what percent accuracies GPT-3 attains
03:31:51.200 | at all these different model checkpoints for Hellaswag eval.
03:31:54.320 | And the reason people like it is because Hellaswag is a smooth eval.
03:31:59.360 | And it is an eval that offers, quote, unquote, early signal.
03:32:02.640 | So early signal means that even small language models
03:32:06.320 | are going to start at the random chance of 25%.
03:32:09.120 | But they're going to slowly improve.
03:32:11.120 | And you're going to see 25, 26, 27, et cetera.
03:32:13.920 | And you can see slow improvement, even when the models are very small, and it's very early.
03:32:20.160 | So it's smooth.
03:32:21.200 | It has early signal.
03:32:24.080 | And it's been around for a long time.
03:32:26.800 | So that's why people kind of like this eval.
03:32:29.440 | Now, the way that we're going to evaluate this is as follows.
03:32:34.240 | As I mentioned, we have a shared context.
03:32:38.240 | And this is kind of like a multiple choice task.
03:32:40.560 | But instead of giving the model a multiple choice question
03:32:43.280 | and asking it for A, B, C, or D, we can't do that.
03:32:46.960 | Because these models, when they are so small, as we are seeing here,
03:32:50.320 | the models can't actually do multiple choice.
03:32:52.320 | They don't understand the concept of associating a label
03:32:55.520 | to one of the options of multiple choice.
03:32:57.360 | They don't understand that.
03:32:59.120 | So we have to give it to them in a native form.
03:33:01.360 | And the native form is a token completion.
03:33:03.920 | So here's what we do.
03:33:05.520 | We construct a batch of four rows and T tokens, whatever that T happens to be.
03:33:11.440 | Then the shared context, that is basically the context for the four choices,
03:33:16.400 | the tokens of that are shared across all of the rows.
03:33:20.080 | And then we have the four options.
03:33:22.240 | So we kind of like lay them out.
03:33:24.400 | And then only one of the options is correct.
03:33:26.000 | In this case, label 3, option 3.
03:33:28.000 | And so this is the correct option.
03:33:31.360 | And option 1, 2 are incorrect.
03:33:33.040 | Now, these options might be of different lengths.
03:33:37.200 | So what we do is we sort of like take the longest length.
03:33:39.760 | And that's the size of the batch, B by T.
03:33:42.400 | And then some of these here are going to be padded dimensions.
03:33:46.800 | So they're going to be unused.
03:33:48.400 | And so we need the tokens.
03:33:51.280 | We need the correct label.
03:33:52.560 | And we need a mask that tells us which tokens are active.
03:33:56.480 | And the mask is then 0 for these padded areas.
03:33:59.920 | So that's how we construct these batches.
03:34:02.960 | And then in order to get the language model to predict A, B, C, or D,
03:34:07.280 | the way this works is basically we're just going to look at the tokens,
03:34:10.880 | their probabilities.
03:34:12.320 | And we're going to pick the option that gets the lowest
03:34:16.080 | or the highest average probability for the token.
03:34:21.120 | So for the tokens, because that is the most likely completion
03:34:25.920 | according to the language model.
03:34:27.200 | So we're just going to look at the probabilities here
03:34:31.440 | and average them up across the options
03:34:35.040 | and pick the one with the highest probability, roughly speaking.
03:34:38.000 | So this is how we're going to do Hellaswag.
03:34:40.960 | And this is, I believe, also how GPT-3 did it.
03:34:49.200 | This is how GPT-3 did it, as far as I know.
03:34:51.680 | But you should note that some of the other evals
03:34:53.920 | where you might see Hellaswag may not do it this way.
03:34:56.880 | They may do it in a multiple choice format
03:34:58.560 | where you sort of give the context a single time
03:35:02.080 | and then the four completions.
03:35:03.680 | And so the model is able to see all the four options
03:35:06.160 | before it picks the best possible option.
03:35:08.480 | And that's actually an easier task for a model
03:35:10.880 | because you get to see the other options when you're picking your choice.
03:35:14.240 | But unfortunately, models at our size can't do that.
03:35:17.760 | Only models at a bigger size are able to do that.
03:35:20.640 | And so our models are actually slightly handicapped in this way
03:35:24.240 | that they are not going to see the other options.
03:35:26.160 | They're only going to see one option at a time
03:35:28.880 | and they just have to assign probabilities
03:35:30.880 | and the correct option has to win out in this metric.
03:35:33.200 | All right, so let's now implement this very briefly
03:35:36.800 | and incorporate it into our script.
03:35:38.400 | Okay, so what I've done here is I've introduced a new file
03:35:41.600 | called Hellaswag.py that you can take a look into.
03:35:44.880 | And I'm not going to step through all of it
03:35:46.560 | because this is not exactly like deep code.
03:35:50.880 | It's kind of like a little bit tedious, honestly,
03:35:52.880 | because what's happening is I'm downloading Hellaswag from GitHub
03:35:56.400 | and I'm rendering all of its examples.
03:35:58.160 | And there are a total of 10,000 examples.
03:36:00.160 | I am rendering them into this format.
03:36:02.640 | And so here at the end of this render example function,
03:36:08.800 | you can see that I'm returning the tokens.
03:36:12.080 | The tokens of this four by T array of tokens,
03:36:17.920 | the mask, which tells us which parts are the options
03:36:20.800 | and everything else is zero,
03:36:22.480 | and the label that is the correct label.
03:36:25.360 | And so that allows us to then iterate the examples
03:36:28.240 | and render them.
03:36:29.120 | And I have an evaluate function here,
03:36:30.720 | which can load a GPT-2 from HuggingFace
03:36:34.960 | and it runs the eval here.
03:36:36.480 | And basically just calculates, just as I described,
03:36:41.840 | it predicts the option that has the lowest
03:36:44.720 | or the highest probability.
03:36:45.840 | And the way to do that actually
03:36:48.080 | is we can basically evaluate the cross entropy loss.
03:36:50.560 | So we're basically evaluating the loss
03:36:53.120 | of predicting the next token in a sequence.
03:36:55.280 | And then we're looking at the row
03:36:57.040 | that has the lowest average loss.
03:36:59.200 | And that's the option that we pick as the prediction.
03:37:04.960 | And then we do some stats and prints and stuff like that.
03:37:07.520 | So that is a way to evaluate Hellaswag.
03:37:09.840 | Now, if you go up here, I'm showing that for GPT-2-124m,
03:37:13.920 | if you run this script,
03:37:15.520 | you're going to see that Hellaswag gets 29.55%.
03:37:18.560 | So that's the performance we get here.
03:37:22.480 | Now, remember that random chance is 25%.
03:37:24.320 | So we haven't gone too far.
03:37:26.480 | And GPT-2-XL, which is the biggest, VGPT-2,
03:37:30.640 | gets all the way up to 49% roughly.
03:37:33.600 | So these are pretty low values,
03:37:35.840 | considering that today's state of the art
03:37:37.520 | is more like 95%.
03:37:39.040 | So these are definitely older models by now.
03:37:40.960 | And then there's one more thing called Eleuther Harness,
03:37:44.000 | which is a very common piece of infrastructure
03:37:46.240 | for running evals for language models.
03:37:48.080 | And they get slightly different numbers.
03:37:50.160 | And I'm not 100% sure what the discrepancy is for these.
03:37:52.720 | It could be that they actually do the multiple choice
03:37:56.320 | instead of just the completions.
03:37:58.720 | And that could be the discrepancy.
03:38:02.080 | But I'm not 100% sure about that.
03:38:03.760 | I'd have to take a look.
03:38:04.880 | But for now, our script reports 29.55.
03:38:07.920 | And so that is the number that we'd like to beat
03:38:10.000 | if we were training AGP-2124M from scratch in ourselves.
03:38:13.600 | So now I'm going to go into actually incorporating
03:38:20.480 | this eval into our main training script.
03:38:23.600 | And basically, because we want to evaluate it
03:38:27.600 | in a periodic manner
03:38:28.880 | so that we can track Hellaswag and how it evolves over time
03:38:32.000 | and see when and if we cross this 29.55 region.
03:38:39.840 | So let's now walk through some of the changes
03:38:42.000 | to train GPT-2.py.
03:38:43.280 | The first thing I did here
03:38:44.960 | is I actually made useCompile optional, kind of.
03:38:48.000 | And I disabled it by default.
03:38:49.920 | And the problem with compile
03:38:54.080 | is that unfortunately, it does make our code faster.
03:38:56.400 | But it actually breaks the evaluation code
03:38:58.320 | and the sampling code.
03:38:59.360 | It gives me a very gnarly message.
03:39:00.720 | And I don't know why.
03:39:01.840 | So hopefully, by the time you get to the code base
03:39:04.560 | when I put it up on GitHub,
03:39:05.600 | we're going to fix that by then.
03:39:07.360 | But for now, I'm running without TorchCompile,
03:39:09.280 | which is why you see this be a bit slower.
03:39:11.280 | So we're running without TorchCompile.
03:39:13.760 | I also created a log directory, log,
03:39:16.800 | where we can place our log.txt,
03:39:19.520 | which will record the train loss, validation loss,
03:39:22.720 | and the Hellaswag accuracies.
03:39:24.480 | So a very simple text file.
03:39:25.760 | And we're going to open for writing
03:39:28.240 | so that it sort of starts empty.
03:39:30.240 | And then we're going to append to it.
03:39:31.520 | I created a simple variable that helps tell us
03:39:36.640 | when we have a last step.
03:39:37.920 | And then basically, periodically inside this loop,
03:39:41.360 | every 250th iteration or at the last step,
03:39:45.280 | we're going to evaluate the validation loss.
03:39:47.040 | And then every 250th iteration,
03:39:50.320 | we are going to evaluate Hellaswag.
03:39:54.560 | But only if we are not using compile
03:39:57.360 | because compile breaks it.
03:39:59.360 | So I'm going to come back to this code
03:40:01.040 | for evaluating Hellaswag in a second.
03:40:02.640 | And then every 250th iteration as well,
03:40:06.000 | we're also going to sample from the model.
03:40:07.920 | And so you should recognize this as our ancient code
03:40:10.480 | from way back when we started the video.
03:40:12.720 | And we're just sampling from the model.
03:40:14.080 | And then finally here, these are, if we're not,
03:40:18.720 | after we validate, sample, and evaluate Hellaswag,
03:40:22.720 | we actually do a training step here.
03:40:25.040 | And so this is one step of training.
03:40:27.760 | And you should be pretty familiar with all of what this does.
03:40:30.320 | And at the end here, once we get our training loss,
03:40:33.440 | we write it to the file.
03:40:34.480 | So the only thing that changed that I really added
03:40:37.440 | is this entire section for Hellaswag eval.
03:40:39.360 | And the way this works is I'm trying to get
03:40:42.240 | all the GPUs to collaborate on the Hellaswag.
03:40:44.400 | And so we're iterating on the examples.
03:40:46.640 | And then each process only picks the examples
03:40:51.200 | that assigned to it.
03:40:53.200 | So we sort of take i and mod it by the world size.
03:40:55.680 | And we have to make it equal to rank.
03:40:57.520 | Otherwise, we continue.
03:40:58.480 | And then we render an example, put it on a GPU.
03:41:01.920 | We get the logits.
03:41:03.920 | Then I create a helper function that helps us basically
03:41:06.560 | predict the option with the lowest loss.
03:41:08.880 | So this comes here, the prediction.
03:41:11.120 | And then if it's correct, we sort of keep count.
03:41:13.840 | And then if multiple processes were collaborating
03:41:17.040 | on all of this, then we need to synchronize their stats.
03:41:19.920 | And so the one way to do that is to package up
03:41:22.480 | our statistics here into tensors.
03:41:25.600 | Which we can then call this.allReduceOn and sum.
03:41:28.640 | And then here we sort of unwrap them from tensors
03:41:33.920 | so that we just have ints.
03:41:35.040 | And then here, the master process will print
03:41:37.920 | and log the Hellaswag accuracy.
03:41:39.440 | So that's kind of it.
03:41:44.800 | And that's what I'm running right here.
03:41:47.120 | So you see this optimization here.
03:41:48.640 | And we just had a generation.
03:41:52.240 | And this is step 10,000 out of about 20,000, right?
03:41:55.200 | So we are halfway done.
03:41:56.800 | And these are the kinds of samples that we are getting
03:41:59.680 | at this stage.
03:42:00.320 | So let's take a look.
03:42:01.040 | Hello, I'm a language model.
03:42:03.840 | So I'd like to use it to generate some kinds of output.
03:42:06.160 | Hello, I'm a language model.
03:42:08.160 | And I'm a developer for a lot of companies.
03:42:09.920 | Hello, I'm a language model.
03:42:12.160 | Let's see if I can find any fun one.
03:42:16.320 | I don't know.
03:42:16.880 | You can go through this yourself.
03:42:18.080 | But certainly, the predictions are getting less and less random.
03:42:20.960 | It seems like the model is a little bit more self-aware
03:42:24.000 | and using language that is a bit more specific to it
03:42:29.120 | being a language model.
03:42:30.080 | Hello, I'm a language model.
03:42:32.560 | And like the model, I'm going to use it to generate
03:42:34.800 | some kind of output.
03:42:36.160 | So let's see if I can find any fun one.
03:42:37.760 | Hello, I'm a language model.
03:42:38.800 | And I'm a developer for a lot of companies.
03:42:40.800 | Let's see if I can find any fun one.
03:42:41.840 | Hello, I'm a language model.
03:42:42.960 | And I'm a developer for a lot of companies.
03:42:44.160 | Hello, I'm a language model.
03:42:45.200 | And like how the language is used to communicate,
03:42:47.600 | I'm a language model.
03:42:48.560 | And I'm going to be speaking English and German.
03:42:51.520 | Okay, I don't know.
03:42:52.960 | So let's just wait until this optimization finishes.
03:42:55.680 | And we'll see what kind of samples we get.
03:42:57.760 | And we're also going to look at the train, val,
03:43:01.360 | and the hellosquare accuracy
03:43:03.120 | and see how we're doing with respect to GPT-2.
03:43:05.040 | Okay, good morning.
03:43:07.840 | So focusing for a moment on the Jupyter Notebook
03:43:10.880 | here on the right,
03:43:12.000 | I created a new cell that basically allows us
03:43:14.480 | to visualize the train, val, and the hellosquare.
03:43:19.280 | And you can step through this.
03:43:20.960 | It basically like parses the log file that we are writing.
03:43:23.520 | And a lot of this is just like boring matplotlib code.
03:43:27.840 | But basically, this is what our optimization looks like.
03:43:30.480 | So we ran for 19,073 steps,
03:43:36.720 | which is roughly 10 billion tokens,
03:43:40.160 | which is, whoops, oh my gosh,
03:43:42.160 | which is one epoch of the sample 10B of FineWebEDU.
03:43:45.600 | On the left, we have the loss.
03:43:47.680 | And in the blue, we have the training loss.
03:43:51.040 | In orange, we have the validation loss.
03:43:53.120 | And in red, as a horizontal line,
03:43:55.440 | we have the opening IGPT-2 124M model checkpoint,
03:43:59.360 | when it's just evaluated on the validation set
03:44:01.520 | of this FineWebEDU.
03:44:04.640 | So you can see that we are surpassing,
03:44:07.760 | this orange is below the red.
03:44:09.200 | So we're surpassing the validation set of this dataset.
03:44:12.560 | And like I mentioned, the dataset distribution
03:44:14.640 | is very different from what GPT-2 trained on.
03:44:16.640 | So this is not an exactly fair comparison,
03:44:18.880 | but it's a good cross-check to look at.
03:44:22.640 | Now, we would ideally like something
03:44:25.200 | that is withheld and comparable and somewhat standard.
03:44:28.960 | And so for us, that is helloswag.
03:44:32.320 | And so on here, we see the helloswag progress
03:44:34.800 | we made from 25% all the way here.
03:44:38.080 | In red, we see the opening IGPT-2 124M model in red.
03:44:43.200 | So it achieves this helloswag here.
03:44:45.360 | And the GPT-3 model 124M,
03:44:49.360 | which was trained on 300 billion tokens, achieves green.
03:44:53.040 | So that's over here.
03:44:55.520 | So you see that we basically surpassed
03:44:57.280 | the GPT-2 124M model right here, which is really nice.
03:45:03.680 | Now, interestingly, we were able to do so
03:45:07.280 | with only training on 10 billion tokens,
03:45:09.280 | while GPT-2 was trained on 100 billion tokens.
03:45:12.160 | So for some reason, we were able to get away
03:45:15.360 | with significantly fewer tokens for training.
03:45:17.760 | There are many possibilities as to why we could match
03:45:21.440 | or surpass this accuracy with only 10 billion training.
03:45:25.760 | So number one, it could be that OpenAI GPT-2
03:45:29.840 | was trained on a much wider data distribution.
03:45:32.800 | So in particular, FindWebEDU is all English.
03:45:36.400 | It's not multilingual.
03:45:38.080 | And there's not that much math and code.
03:45:40.000 | And so math and code and multilingual could have been
03:45:44.240 | stealing capacity from the original GPT-2 model.
03:45:48.080 | And basically, that could be partially the reason
03:45:52.000 | why this is not working out.
03:45:54.080 | There's many other reasons.
03:45:55.120 | So for example, the helloswag eval is fairly old,
03:45:58.560 | maybe five years or so.
03:45:59.920 | It is possible that aspects of helloswag in some way,
03:46:03.200 | or even identically, have made it into the training set.
03:46:06.960 | Or FindWeb, we don't know for sure.
03:46:09.680 | But if that was the case, then we are basically looking
03:46:11.600 | at the training curve instead of the validation curve.
03:46:13.440 | So long story short, this is not a perfect eval.
03:46:16.320 | And there's some caveats here.
03:46:17.440 | But at least we have some confidence
03:46:19.920 | that we're not doing something completely wrong.
03:46:21.840 | And it's probably the case that when people try
03:46:26.960 | to create these data sets, they try to make sure
03:46:28.640 | that test sets that are very common
03:46:30.960 | are not part of the training set.
03:46:32.720 | For example, when HuggingFace created the FindWebEDU,
03:46:35.520 | they use helloswag as an eval.
03:46:37.120 | So I would hope that they make sure that they deduplicate
03:46:40.080 | and that there's no helloswag in the training set.
03:46:42.640 | But we can't be sure.
03:46:43.680 | The other thing I wanted to address briefly is,
03:46:47.040 | look at this loss curve.
03:46:48.160 | This looks really wrong here.
03:46:50.880 | I don't actually know 100% what this is.
03:46:52.960 | And I suspect it's because the 10 billion sample
03:46:56.320 | of FindWebEDU was not properly shuffled.
03:46:59.760 | And there's some issue here with the data
03:47:04.320 | that I don't fully understand yet.
03:47:05.680 | And there's some weird periodicity to it.
03:47:07.520 | And because we are in a very lazy way,
03:47:10.320 | sort of serializing all the tokens
03:47:11.840 | and just iterating on them from scratch
03:47:13.600 | without doing any permutations
03:47:15.040 | or any random sampling ourselves,
03:47:17.200 | I think we're inheriting some of the ordering
03:47:20.560 | that they have in the data set.
03:47:22.000 | So this is not ideal.
03:47:24.800 | But hopefully by the time you get to this repo,
03:47:27.840 | some of these things, by the way,
03:47:29.280 | will hopefully be fixed.
03:47:31.040 | And I will release this BuildNanoGPT repo.
03:47:34.800 | And right now it looks a little ugly and preliminary.
03:47:37.840 | So hopefully by the time you get here, it's nicer.
03:47:40.640 | But down here, I'm going to show errata.
03:47:42.960 | And I'm going to talk about some of the things
03:47:45.520 | that happened after the video.
03:47:47.200 | And I expect that we will have fixed the small issue.
03:47:50.160 | But for now, basically, this shows that our training
03:47:53.920 | is not completely wrong.
03:47:55.840 | And it shows that we're able to surpass the accuracy
03:47:59.840 | with only 10x the token budget.
03:48:01.440 | And possibly it could be also that the data set
03:48:05.840 | may have improved.
03:48:07.200 | So the original GPT-2 data set was WebText.
03:48:10.880 | It's possible that not a lot of care and attention
03:48:13.120 | went into the data set.
03:48:14.320 | This was very early in LLMs.
03:48:16.400 | Whereas now there's a lot more scrutiny
03:48:18.160 | on good practices around deduplication, filtering,
03:48:21.920 | quality filtering, and so on.
03:48:23.360 | And it's possible that the data set we're training on
03:48:25.120 | is just of higher quality per token.
03:48:27.120 | And that could be giving us a boost as well.
03:48:29.600 | So a number of caveats to think about.
03:48:31.200 | But for now, we're pretty happy with this.
03:48:33.280 | And yeah.
03:48:35.360 | Now, the next thing I was interested in is,
03:48:37.600 | as you see, it's a morning now.
03:48:39.200 | So there was an overnight.
03:48:40.720 | And I wanted to basically see how far
03:48:42.480 | I could push the result.
03:48:43.840 | So to do an overnight run, I basically did,
03:48:47.360 | instead of one epoch, which took roughly two hours,
03:48:50.000 | I just did it times four.
03:48:51.280 | So that that would take eight hours while I was sleeping.
03:48:53.680 | And so we did four epochs or roughly 40 billion
03:48:56.800 | tokens of training.
03:48:58.320 | And I was trying to see how far we could get.
03:49:00.320 | And so this was the only change.
03:49:02.560 | And I re-ran the script.
03:49:03.920 | And when I point and read the log file at the 40B,
03:49:07.680 | this is what the curve looked like.
03:49:12.000 | So to narrate this, number one,
03:49:13.520 | we are seeing this issue here with the periodicity
03:49:16.560 | through the different epochs and something really weird
03:49:18.800 | with the FindWebEDU data set.
03:49:21.120 | And that is to be determined.
03:49:22.720 | But otherwise, we are seeing that the Hellaswag
03:49:26.560 | actually went up by a lot.
03:49:28.640 | And we almost made it to the GPT-3 124M accuracy up here,
03:49:34.880 | but not quite.
03:49:36.320 | So it's too bad that I didn't sleep slightly longer.
03:49:39.120 | And I think if this was a five epoch run,
03:49:44.320 | we may have gotten here.
03:49:45.680 | Now, one thing to point out is that if you're
03:49:48.240 | doing multi-epoch runs, we're not actually
03:49:51.200 | being very careful in our data loader.
03:49:52.800 | And this data loader goes through the data
03:49:57.360 | in exactly the same format and exactly the same order.
03:50:01.600 | And this is kind of suboptimal.
03:50:03.120 | And you would want to look into extensions
03:50:04.880 | where you actually permute the data randomly.
03:50:07.920 | You permute the documents around in every single shard
03:50:10.960 | on every single new epoch and potentially even permute
03:50:15.280 | the shards.
03:50:15.840 | And that would go a long way into decreasing the periodicity.
03:50:19.520 | And it's also better for the optimization
03:50:21.600 | so that you're not seeing things in the identical format.
03:50:24.960 | And you're introducing some of the randomness
03:50:27.680 | in how the documents follow each other.
03:50:29.600 | Because you have to remember that in every single row,
03:50:32.000 | these documents follow each other.
03:50:33.440 | And then there's the end of text token
03:50:34.880 | and then the next document.
03:50:36.240 | So the documents are currently glued together
03:50:38.960 | in the exact same identical manner.
03:50:41.280 | But we actually want to break up the documents
03:50:43.680 | and shuffle them around.
03:50:45.040 | Because the order of the documents shouldn't matter.
03:50:47.280 | And they shouldn't-- basically, we
03:50:49.360 | want to break up that dependence.
03:50:50.800 | Because it's kind of a spurious correlation.
03:50:52.720 | And so our data letter is not currently doing that.
03:50:55.440 | And that's one improvement you could think of making.
03:50:57.920 | The other thing to point out is we're almost matching
03:51:02.080 | GPT-3 accuracy with only 40 billion tokens.
03:51:04.880 | GPT-3 trained on 300 billion tokens.
03:51:07.840 | So again, we're seeing about a 10x improvement here
03:51:11.360 | with respect to learning efficiency.
03:51:13.360 | The other thing I wanted to-- and I don't actually
03:51:16.880 | know exactly what to attribute this to other than some
03:51:19.120 | of the things that I already mentioned previously
03:51:21.200 | for the previous run.
03:51:21.920 | The other thing I wanted to briefly mention
03:51:24.480 | is the max LR here.
03:51:27.600 | I saw some people already play with this a little bit
03:51:30.160 | in a previous related repository.
03:51:32.000 | And it turns out that you can actually almost 3x this.
03:51:36.000 | So it's possible that the maximum learning rate
03:51:37.600 | can be a lot higher.
03:51:38.800 | And for some reason, the GPT-3 hyperparameters
03:51:40.880 | that we are inheriting are actually
03:51:42.480 | extremely conservative.
03:51:43.760 | And you can actually get away with a higher learning rate.
03:51:45.680 | And it would train faster.
03:51:47.360 | So a lot of these hyperparameters are quite tunable.
03:51:51.360 | And feel free to play with them.
03:51:52.800 | And they're probably not set precisely correctly.
03:51:55.840 | And it's possible that you can get away
03:51:59.600 | with doing this, basically.
03:52:00.800 | And if you wanted to exactly be faithful to GPT-3,
03:52:05.200 | you would also want to make the following difference.
03:52:09.760 | You'd want to come here.
03:52:11.120 | And the sequence length of GPT-3 is 2x.
03:52:13.760 | It's 2,048 instead of 1,024.
03:52:16.240 | So you would come here, change this to 2,048 for t.
03:52:19.600 | And then if you want the exact same number of tokens,
03:52:21.520 | half a million per iteration or per step,
03:52:25.440 | you want to then decrease this to 32.
03:52:27.200 | So they still multiply to half a mil.
03:52:29.440 | So that would give your model sequence length
03:52:33.120 | equal to that of GPT-3.
03:52:35.280 | And in that case, basically, the models
03:52:38.800 | would be roughly identical as far as I'm aware.
03:52:42.560 | Because again, GPT-2 and GPT-3 are very, very similar models.
03:52:46.240 | Now, we can also look at some of the samples here
03:52:48.080 | from the model that was trained overnight.
03:52:50.720 | So this is the optimization.
03:52:54.640 | And you see that here, we stepped all the way
03:52:56.320 | to 76,290 or so.
03:52:59.440 | And these are-- the Hellas spike we achieved was 33.24.
03:53:04.400 | And these are some of the samples from the model.
03:53:07.680 | And you can see that if you read through this
03:53:10.240 | and pause the video briefly, you can see
03:53:11.840 | that there are a lot more coherent.
03:53:14.160 | So-- and they're actually addressing the fact
03:53:17.440 | that it's a language model, almost.
03:53:19.760 | So hello, I'm a language model.
03:53:24.400 | And I try to be as accurate as possible.
03:53:25.920 | I'm a language model, not a programming language.
03:53:30.400 | I know how to communicate.
03:53:33.760 | I use Python.
03:53:35.360 | I don't know.
03:53:39.200 | If you pause this and look at it and then compare it
03:53:41.280 | to the one-- to the model that was only trained
03:53:43.440 | for 10 billion, you will see that these
03:53:45.600 | are a lot more coherent.
03:53:46.720 | And you can play with this yourself.
03:53:48.240 | One more thing I added to the code, by the way,
03:53:51.040 | is this chunk of code here.
03:53:52.720 | So basically, right after we evaluate the validation loss,
03:53:56.240 | if we are the master process, in addition
03:53:58.640 | to logging the validation loss, every 5,000 steps,
03:54:01.520 | we're also going to save the checkpoint,
03:54:03.440 | which is really just the state dictionary of the model.
03:54:06.640 | And so checkpointing is nice just
03:54:08.240 | because you can save the model.
03:54:10.000 | And later, you can use it in some way.
03:54:12.320 | If you wanted to resume the optimization,
03:54:14.960 | then in addition to saving the model,
03:54:16.880 | we have to also save the optimizer state dict.
03:54:20.320 | Because remember that the optimizer
03:54:21.680 | has a few additional buffers because of atom.
03:54:24.400 | So it's got the m and v.
03:54:26.400 | And you need to also resume the optimizer properly.
03:54:30.240 | You have to be careful with the RNG seeds, random number
03:54:32.800 | generators, and so on.
03:54:34.080 | So if you wanted to exactly be able to resume optimization,
03:54:37.040 | you have to think through the state of the training process.
03:54:40.800 | But if you just want to save the model,
03:54:42.080 | this is how you would do it.
03:54:43.440 | And one nice reason why you might want to do this
03:54:46.160 | is because you may want to evaluate the model a lot more
03:54:48.960 | carefully.
03:54:49.460 | So here, we are only kind of like winging the LSWG eval.
03:54:54.080 | But you may want to use something nicer,
03:54:57.040 | like, for example, the Luther evaluation hardness.
03:55:01.280 | Evaluation harness?
03:55:02.960 | Hardness.
03:55:05.600 | So this is a way to also evaluate language models.
03:55:09.360 | And so it's possible that you may
03:55:13.840 | want to use basically different infrastructure
03:55:15.760 | to more thoroughly evaluate the models
03:55:18.320 | on different evaluations and compare it
03:55:21.120 | to the OpenAI GPT-2 model on many other tasks,
03:55:25.280 | like, for example, that involve math, code,
03:55:27.040 | or different languages, and so on.
03:55:28.320 | So this is a nice functionality to have as well.
03:55:30.640 | And then the other thing I wanted to mention
03:55:35.040 | is that everything we've built here,
03:55:36.960 | this is only the pre-training step.
03:55:39.280 | So the GPT here is a--
03:55:42.080 | it dreams documents.
03:55:43.200 | It just predicts the next token.
03:55:44.720 | You can't talk to it like you can talk to chat GPT.
03:55:48.080 | If you wanted to talk to the model,
03:55:50.480 | we have to fine-tune it into the chat format.
03:55:53.440 | And it's not actually that complicated.
03:55:55.120 | If you're looking at supervised fine-tuning or SFT,
03:55:58.000 | really what that means is we're just swapping out the data set
03:56:00.720 | into a data set that is a lot more conversational.
03:56:02.960 | And there's a user-assistant, user-assistant kind
03:56:04.960 | of structure.
03:56:05.840 | And we just fine-tune on it.
03:56:07.200 | And then we basically fill in the user tokens,
03:56:10.560 | and we sample the assistant tokens.
03:56:12.880 | It's not a lot more deeper than that.
03:56:14.480 | But basically, we swap out the data set
03:56:16.560 | and continue training.
03:56:17.360 | But for now, we're going to stop at pre-training.
03:56:20.720 | One more thing that I wanted to briefly show you
03:56:23.040 | is that, of course, what we've built up today
03:56:25.760 | was building towards NanoGPT,
03:56:27.520 | which is this repository from earlier.
03:56:30.080 | But also, there's actually another NanoGPT implementation,
03:56:32.880 | and it's hiding in a more recent project
03:56:35.200 | that I've been working on called llm.c.
03:56:37.360 | And llm.c is a pure C CUDA implementation
03:56:42.160 | of GPT-2 or GPT-3 training.
03:56:44.080 | And it just directly uses CUDA and is written as C CUDA.
03:56:49.040 | Now, the NanoGPT here acts as reference code in PyTorch
03:56:53.200 | to the C implementation.
03:56:54.560 | So we're trying to exactly match up the two,
03:56:56.800 | but we're hoping that the C CUDA is faster
03:56:59.440 | and, of course, currently that seems to be the case
03:57:01.280 | because it is a direct optimized implementation.
03:57:04.400 | So train gpt2.py in llm.c is basically the NanoGPT.
03:57:09.600 | And when you scroll through this file,
03:57:12.000 | you'll find a lot of things that very much look like
03:57:15.200 | things that we've built up in this lecture.
03:57:19.360 | And then when you look at train gpt2.cu,
03:57:22.480 | this is the C CUDA implementation.
03:57:26.080 | So there's a lot of MPI, NICL, GPU, CUDA, C, C++.
03:57:31.040 | And you have to be familiar with that.
03:57:33.040 | But when this is built up,
03:57:37.120 | we can actually run the two side by side
03:57:39.440 | and they're going to produce the exact same results,
03:57:41.600 | but llm.c actually runs faster.
03:57:43.760 | So let's see that.
03:57:44.480 | So on the left, I have PyTorch, a NanoGPT looking thing.
03:57:49.680 | On the right, I have the llm.c call.
03:57:51.840 | And here I'm gonna launch the two.
03:57:55.040 | Both of these are gonna be running on a single GPU.
03:57:57.120 | And here I'm putting the llm.c on GPU one,
03:57:59.680 | and this one will grab GPU zero by default.
03:58:02.800 | And then we can see here that llm.c compiled
03:58:08.320 | and then allocate space and it's stepping.
03:58:12.240 | So basically, meanwhile, PyTorch is still compiling
03:58:18.320 | because Torch compile is a bit slower here
03:58:21.120 | than the llm.c NVCC C CUDA compile.
03:58:25.280 | And so this program has already started running
03:58:27.520 | and we're still waiting here for Torch compile.
03:58:30.400 | Now, of course, this is a very specific implementation
03:58:33.440 | to GPT-2 and 3.
03:58:35.040 | PyTorch is a very general neural network framework,
03:58:37.440 | so they're not exactly comparable.
03:58:38.960 | But if you're only interested in training GPT-2 and 3,
03:58:41.280 | llm.c is very fast, it takes less space,
03:58:45.520 | it's faster to start, and it's faster per step.
03:58:49.360 | And so PyTorch started stepping here.
03:58:53.120 | And as you can see, we're running
03:58:54.560 | at about 223,000 tokens per second here,
03:58:57.520 | and about 185,000 tokens per second here.
03:59:00.480 | So quite a bit slower, but I don't have full confidence
03:59:05.680 | that I exactly squeezed out all the juice
03:59:08.640 | from the PyTorch implementation.
03:59:10.240 | But the important thing here is notice
03:59:12.320 | that if I align up the steps,
03:59:14.800 | you will see that the losses and the norms
03:59:16.880 | that are printed between these two are identical.
03:59:19.200 | So on the left, we have the PyTorch,
03:59:21.920 | and on the right, this C code implementation,
03:59:24.640 | and they're the same, except this one runs faster.
03:59:27.600 | I wanted to show you also briefly llm.c,
03:59:32.000 | and this is a parallel implementation,
03:59:34.000 | and it's also something that you may want
03:59:35.600 | to play with or look at, and it's kind of interesting.
03:59:39.520 | Okay, so at this point, I should probably start
03:59:41.280 | wrapping up the video, because I think
03:59:42.800 | it's getting way longer than anticipated.
03:59:45.280 | But we did cover a lot of ground,
03:59:47.040 | and we built everything from scratch.
03:59:48.720 | So as a brief summary, we were looking
03:59:51.280 | at the GPT-2 and GPT-3 papers.
03:59:54.560 | We were looking at how you set up these training runs,
03:59:58.240 | and all the considerations involved.
04:00:00.640 | We wrote everything from scratch,
04:00:02.560 | and then we saw that over the duration
04:00:04.160 | of either a two-hour training run or an overnight run,
04:00:07.280 | we can actually match the 124 million parameter checkpoints
04:00:10.800 | of GPT-2 and GPT-3 to a very large extent.
04:00:13.760 | In principle, the code that we wrote
04:00:16.560 | would be able to train even bigger models
04:00:18.480 | if you have the patience or the computing resources,
04:00:21.120 | and so you could potentially think about training
04:00:23.200 | some of the bigger checkpoints as well.
04:00:24.480 | There are a few remaining issues to address.
04:00:28.480 | What's happening with the loss here,
04:00:30.080 | which I suspect has to do
04:00:31.280 | with the fine web EDU data sampling.
04:00:33.440 | Why can't we turn on Torch Compile?
04:00:36.400 | It currently breaks generation and Hellaswag.
04:00:39.040 | What's up with that?
04:00:40.240 | In the data loader, we should probably be permuting our data
04:00:42.640 | when we reach epoch boundaries.
04:00:44.800 | So there's a few more issues like that,
04:00:46.400 | and I expect to be documenting some of those over time
04:00:49.040 | in the Build NanoGPT repository here,
04:00:51.920 | which I'm going to be releasing with this video.
04:00:55.120 | If you have any questions
04:00:57.200 | or would like to talk about anything that we covered,
04:00:59.600 | please go to discussions tab so we can talk here,
04:01:02.640 | or please go to issues or pull requests
04:01:06.400 | depending on what you'd like to contribute,
04:01:08.720 | or also have a look at the Zero2Hero Discord,
04:01:13.120 | and I'm going to be hanging out here on NanoGPT.
04:01:18.000 | Otherwise, for now, I'm pretty happy about where we got,
04:01:20.960 | and I hope you enjoyed the video,
04:01:24.160 | and I will see you later.