back to indexThe Ultra-Scale Handbook by HuggingFace Nanotron

00:00:00.000 |
a bit on Laura's okay I'll restart we just started recording so basically Hugging Face put out this 00:00:07.020 |
blog post like a week or two ago and it's a very in-depth estimated reading time is like two to 00:00:14.080 |
three days really good visualizations and like little widgets on how to do pre-training of 00:00:19.480 |
models so there's a quick overview at the start the plan is basically we're not covering a two-day 00:00:26.300 |
reading in an hour we'll go through the main points of everything and try to do try to do some little 00:00:33.160 |
you know double click into whatever people have questions on if anyone has thoughts comments you 00:00:39.200 |
know open whenever a lot of it is pretty like mathematical guide on how to pre-train a language 00:00:44.960 |
model so there's a lot of resources out there for if you're doing fine-tuning Laura synthetic data 00:00:50.560 |
generation like there's a lot of libraries that make it easy to do little post-training stuff or 00:00:55.800 |
fine-tuning but they wanted to solve the question of okay everyone says like it's hard to train across 00:01:01.920 |
multiple nodes multiple GPUs what is all this data parallelism pipeline parallelism how do we know about 00:01:08.180 |
like you know like who's actually writing about this so Hugging Face has a cluster of 512 GPUs 00:01:14.940 |
they ran a thousand different experiments and they basically put together uh hey if you want to go 00:01:20.180 |
train pre-train a model here's kind of like your zero to eighty percent of what you should know 00:01:25.560 |
about distributed training so this is when you have more than one GPU and kind of more than two GPUs 00:01:32.280 |
how can you efficiently scale your training so all from you can't have a model that fits on one GPU 00:01:39.920 |
to how do you have like let's say a 7b model and you want to efficiently use 10 GPUs versus how do you have 00:01:47.000 |
like multi-node so if you have a train run what if you have multiple nodes um what makes up the most of 00:01:53.540 |
like compute when you're training so activations back prop um all this stuff so TLDR they kind of kick 00:02:01.780 |
off with a basic overview of this and they're like there's pretty good open source models but they don't 00:02:08.100 |
really talk too much about these little nitty-gritty stuff um at a high level something that this blog 00:02:13.520 |
post also kind of didn't explicitly state but does show is you see all these gray dots so they ran over a 00:02:20.260 |
thousand experiments all these gray dots are failed runs so some of them were kind of failed force runs 00:02:25.180 |
like force failed where you know you overdo your bad size and you know you'll run out of memory and guess 00:02:31.540 |
what you ran out of memory but the other little interesting niche that um I've noticed is when you 00:02:39.400 |
do a lot of like multi-node multi-GPU training is sometimes a random GPU fails sometimes one GPU has 00:02:47.160 |
like a weird loss and your train run is kind of screwed so um it doesn't account for those little niches right 00:02:53.720 |
so sure there's checkpointing sure there's stuff like this but uh in practice sometimes when you're 00:02:59.520 |
training on like a thousand GPUs yeah one GPU randomly dies and it causes little issues but anyway 00:03:05.080 |
this is kind of just the high level like here's what to do so um kicking off let's go over the 00:03:13.000 |
fundamentals and background so they have this like cool interactive widget that's basically like here 00:03:19.940 |
are the presets of models that exist for the llama family so there's a llama tiny llama 8b llama 70b and 00:03:26.920 |
it kind of lets you visualize how much VRAM is required to do pre-training so if you want to like 00:03:32.520 |
train the full thing at different precision and a different um sequence lengths right so as we scale 00:03:40.840 |
up the sequence length we can kind of see what happens to um the VRAM required right so let's say we want to 00:03:46.680 |
train um there's like a few key parameters that you should know here right so you've got that size and 00:03:52.440 |
sequence length and that kind of makes up how much GPU usage you're going to have so how many parameters 00:03:57.960 |
like what's the weight so this is a 70b how many batches are you doing so more batches more more GPU 00:04:04.760 |
needed more sequence length significantly more GPU needed and kind of like one of the takeaways is you know 00:04:11.480 |
the attention blocks here quadratically increase right so as I double the batch size we quadratically 00:04:18.440 |
or sorry double the sequence length we um quadratically increase attention so if we start 00:04:25.080 |
with a sequence length of like 4 000 you can see how like the gradients optimizers these make up like 00:04:31.400 |
you know oh I did 14 000 as we do about 4 000 the attention is sure it's some of it like to to train 00:04:39.320 |
llama 7db uh 4 000 sequence length you know activation memory is about 80 gigs parameters gradients 00:04:46.920 |
optimization steps are about 300 gigs parameters are obviously you know it's about one to two if 00:04:53.400 |
you're doing um mixed precision you can also toggle that off so one to two for mixed precision then 00:04:59.480 |
gradients and optimizations they add a lot people don't realize that this um attention mechanism it 00:05:05.400 |
actually doesn't take that much but as we scale up let's say to a hundred thousand context length or a 00:05:11.640 |
million this attention mechanism really eats up all the all the um memory so cool little visualization 00:05:20.760 |
um the other stuff I guess yeah like hidden dimension cool bat size cool it's cool how they at least just 00:05:26.840 |
have these presets for you though you know um vocab size is kind of interesting if you want to kind 00:05:32.360 |
of see what this is changing here um you could change different activations but basically you guys should 00:05:39.000 |
play around with this like read through this intro um pldr of this blog post they have like a high 00:05:45.400 |
level overview if you want to get anything just read through this and you'll you'll kind of have 00:05:50.120 |
a decent understanding of what's going on what takes memory when you train a foundation model from scratch 00:05:56.120 |
and yeah um okay let's go to inefficiency so I guess they're in 4 000 experiments not not a thousand and 00:06:04.360 |
then they run them across different sizes so 1b 3b 8b to 70b here's what crashed here's more tokens I 00:06:11.160 |
thought this chart was kind of weird to read I don't know they wanted it like pretty instead of functional 00:06:16.760 |
but I didn't find the charts too useful by the way all this context is pretty much pure gold um so high 00:06:24.680 |
level overview um they're going to look at three challenges when you do training right so memory usage this 00:06:31.480 |
is kind of your hard limitation right how many gpus do you have and what can you fit so one gpu 00:06:37.720 |
you got to adjust your batch size your sequence length two gpus uh how do we do compute efficiency 00:06:43.720 |
right how do we share the model do we put data parallelism where we have the entire model on every 00:06:48.600 |
single gpu or do we do something like uh shard the model weights if we can't even fit them like that 00:06:54.360 |
then communication overhead this kind of becomes when you're gpu rich rich you know most people gpu poor you 00:07:00.440 |
have one two gpus you can do some fine tuning some post training gpu rich rich is when you have 00:07:06.120 |
multiple nodes so there's there's a thing called like gpu interconnect right so what's the communication 00:07:12.680 |
bandwidth it's a buzzword that we hear around pretty often what that basically means is like 00:07:17.400 |
when you're sharding stuff between gpus they need to talk to each other right so like if i have 00:07:23.080 |
layer one of a model on one gpu and layer two on another gpu well they kind of have to sync to do 00:07:29.160 |
their thing right so if they've got good communication this is basically what envy link is if they have 00:07:35.240 |
good communication you're kind of chilling there's stuff we can do ring attention is a way that you can 00:07:39.880 |
like do really long context and you split this really like chunky sequence length you can do like a 00:07:47.720 |
million context scaled across multiple gpus now that only works if you have low communication overhead 00:07:55.080 |
meaning you know the gpus talk to each other very very like closely so you have one node of eight gpus 00:08:01.480 |
all interconnected with high communication that's cool but then when you become gpu rich and you have 00:08:06.680 |
multiple nodes how do these nodes talk together well now you have like another level of communication 00:08:12.120 |
bottleneck right so you have nodes that have good interconnect but the node to node communication you 00:08:17.400 |
probably want to do something different so these are kind of like the three main things um if anyone is 00:08:23.480 |
really doing this you probably don't need to listen to this lap yap from me you probably know more than me 00:08:28.600 |
but uh they give good formulas and like explanations as to how you want to mathematically figure out what's 00:08:35.080 |
your batch size what's your sequence length based on um your available compute they have a little cheat 00:08:41.720 |
sheet here we can go over pretty quick um so small models use single parallelism you know if you have eight 00:08:49.160 |
gpus in a small model you can probably throw the full model on every gpu if you have a large model and 00:08:54.120 |
you need more than eight gpus there's tensor parallelism if you have a bunch of gpus there's 00:08:59.800 |
like moe parallelism there's stuff like that um there's a section on optimizing throughput so sensor 00:09:07.160 |
parallelism data parallelism they kind of have this little cheat sheet um then they have parallelization 00:09:12.360 |
strategies right so these are kind of the ones that they go over data parallelism is pretty straightforward 00:09:18.040 |
they give you pros and cons then they have zero one two three um tensor parallelism pipeline and context 00:09:25.080 |
we'll we'll quickly go over all of them but this cheat sheet is interesting it seems more like uh 00:09:31.720 |
if you're interviewing a research intern you would expect them to be able to yap about this but if you're 00:09:39.160 |
really doing something in training and if you're giving someone access to 500 or a thousand gpus 00:09:45.160 |
they better not be needing this cheat sheet but i guess it's useful you know um yeah so step zero this 00:09:53.400 |
is still kind of that overview right what happens when you're training on one gpu so very very high 00:09:59.720 |
level we know that lms are trained to predict the next token right so what happens is you do a forward 00:10:05.960 |
pass of your data you do like gradient accumulation you know you kind of do a forward pass you accumulate 00:10:13.560 |
gradients you do a backward propagation where you update these gradients you run an optimizer step 00:10:18.520 |
and now we have new new parameters right so the interesting little diagram but this is kind of 00:10:25.640 |
what's happening when you train on a single gpu forward pass backward pass to compute gradients 00:10:31.320 |
optimization step to update gradients and parameters you basically keep doing this at scale then you can start 00:10:38.120 |
batching your input so the interesting thing about this attention which is taking up all of this 00:10:44.440 |
compute right out of this 300 terabytes of ram required to train an 8b at that scale or let's say 00:10:51.640 |
400 gigs of vram required to train uh 70b at 256 right all of this little attention blocks they're really 00:10:59.000 |
good at running in parallel so you don't necessarily sequentially do attention attention attention you run all 00:11:05.720 |
these in parallel that's what gpus are good at so when you're doing stuff in parallel you want to start 00:11:11.720 |
doing batching right batching is basically where you do multiple sequences at once so you have eight 00:11:17.560 |
sequences you throw them in one batch and then you run the whole batch in parallel so um that's kind of 00:11:24.040 |
what your batch size is they're like forget all that stuff what we really think about is batch size tokens 00:11:29.960 |
batch size tokens is batch size times your sequence length so if each sequence is a hundred tokens like 00:11:35.640 |
a hundred characters and you have eight batches of a hundred characters you're actually basically just 00:11:41.720 |
training 8 000 tokens um a suite this is kind of a cool little line right so for people not super reading 00:11:48.680 |
up on pre-training stuff lms are typically trained on the order of four to 60 million tokens per batch so the 00:11:55.320 |
batch size has been kind of increasing as we get bigger and bigger clusters and we get more efficient 00:12:00.040 |
training so llama one was trained on four million token batch size tokens uh each batch had whatever 00:12:07.960 |
many um batches times how many ever tokens so about four million token batch size tokens and uh one point 00:12:16.360 |
wait sorry form sorry it was with a batch size of four million tokens for 1.4 trillion tokens so yeah 00:12:22.040 |
uh four million batch size tokens and they trained on troll total of 1.4 trillion tokens llama three was 00:12:28.200 |
trained on 15 trillion tokens and i don't know the batch size tokens off the top of my head deep seek 00:12:33.560 |
was trained on 60 million batch size tokens for 14 trillion tokens so numbers go up more efficient training 00:12:40.040 |
is very cool so first challenge is scaling our model to have these large batch sizes and running out of 00:12:48.040 |
memory issues right so as much as i want to parallelize everything i'm going to run out of 00:12:52.760 |
um i'm going to run out of memory so the question here we're trying to solve is what should we do when 00:12:58.040 |
one gpu doesn't have enough memory to hold a full batch of target batch sizes well um basically you know 00:13:05.000 |
make the batch size smaller and if you really can't fit it then let's start adding more gpus but what 00:13:10.760 |
happens when you can't fit your one batch size of tokens is you well they want to kind of explain what 00:13:17.160 |
that is right so let's look at what happens in the memory so when you train you kind of have a few 00:13:23.080 |
things going on right you have um model weights which is basically the parameters so parameters are like 00:13:31.080 |
depending on the precision they go into all the math about this of how many bytes it takes per 00:13:35.960 |
parameter then you're computing gradients for backprop then you have an optimization state that you have 00:13:41.480 |
to keep keep held and activate and activations additionally you you need to leave a little bit of a buffer 00:13:48.760 |
cuda kernels add one to two gigs of gpu memory so if you're kind of doing local stuff you know if you're 00:13:54.520 |
trying to like do some training on a 7b model and you have like a 40 90 50 90 or if you have like a 00:14:01.320 |
macbook with 16 gigs of ram well 16 or 24 gigs of ram doesn't mean that you can use it all right 00:14:08.040 |
your cuda kernels take one to two gigs you're like screen in general like laptop processing takes a few 00:14:14.680 |
gigs so keep that into account quantization stuff is this crazy thing you do mix precision training where you 00:14:21.240 |
can kind of um keep cut the cut the um cut the memory requirement for the model weights um another 00:14:30.360 |
thing is that you're not consistently using all the memory right so this is kind of four training steps 00:14:36.760 |
of llama 1b you see spikes right so at one second memory being used was up here 50 gigs then it dropped 00:14:45.800 |
when it was doing other stuff right so gradient accumulation autograd all this stuff it drops 00:14:50.920 |
then it does um activations and drop so it's like kind of little spiky like this um they go into a 00:14:58.920 |
bit of math about how to calculate all this but i think if you're interested in this you should kind of 00:15:04.200 |
read it on your own time it's pretty long um this is always interesting so they kind of have 00:15:10.600 |
interesting little tables here that show you how much memory you would need for fixed precision 00:15:15.720 |
or half precision training with gradient accumulation gradient accumulation is a pretty cool little thing 00:15:21.560 |
where you can kind of you know add every couple steps accumulate gradients and then do a loss over the 00:15:27.560 |
average of them activation memory is another interesting one um it's it's another thing that you know 00:15:35.320 |
activations take memory uh we are running a little slow so i'm gonna go a little quicker but if anyone 00:15:43.160 |
wants to double click into any of this stuff um you know we can always come back towards the end or just 00:15:48.520 |
interrupt me and we can go into it now um yeah sorry just one quick question and um that uh that graph that 00:15:55.800 |
you showed is is um on the um uh for epochs right like so that's the first epoch in uh by the end of 00:16:04.120 |
the first epoch right this is just training steps so all the steps okay yeah so steps could be like you 00:16:11.480 |
know per batch for whatever after you do the training after you accumulate your loss and stuff after you do a 00:16:17.800 |
back prop there's just a little dip so this is per batch um yeah so it kind of depends epochs are just a 00:16:25.080 |
broader term for this grouping of these right but similar concept yeah a lot of stuff is one epoch you 00:16:32.520 |
don't train as much on the same data anymore kind of interesting but uh same same concept okay uh gradient 00:16:41.320 |
accumulation this is a very fun one so as you accumulate these gradients um yeah so epoch meta 00:16:49.960 |
might be shifting to four epochs there's interesting little niches there so high quality data more epochs 00:16:56.360 |
retraining data less epochs one epoch is basically training on a full trade like a full pastor of all 00:17:03.160 |
your training data right so i've had thousand samples if i train on a thousand all my one thousand samples that 00:17:08.520 |
is one epoch uh two epochs would be you know do the thousand samples and do it again and there's like 00:17:14.920 |
there was this whole phase of one epoch on a trillion tokens is the new norm i guess new meta is now four 00:17:22.760 |
epochs according to some paper okay so um gradient accumulation this is kind of interesting right so 00:17:30.360 |
that gradient storing these gradients as you train is pretty important right as we do a forward pass we have 00:17:36.920 |
to kind of do a backward pass and calculate the gradients right so we predict some tokens then we 00:17:42.280 |
look at okay what was the actual token and how far off are we should we are we moving in the right 00:17:47.240 |
direction or are we completely off for every parameter for all these gradients we kind of have 00:17:51.800 |
to store them then we have multiple batches then we have to do an optimization now storing this little 00:17:58.040 |
gradient um this little these gradients for every single parameter adds up it takes a lot a lot of memory so 00:18:05.400 |
what we can do is this gradient accumulation right so instead of doing one entire batch we split our batch 00:18:12.840 |
into micro batches we do a forward and backward pass on each micro batch complete compute the gradients 00:18:19.560 |
so if you have a batch of 16 instead cut it up into small micro batches of four do forward backward passes 00:18:31.560 |
and kind of only compute smaller gradients and then kind of accumulate them average them out it really helps 00:18:39.960 |
you reduce your memory footprint there's a bit of a overhead right now because instead of fast parallel 00:18:45.560 |
computation we now have to take time to do all this gradient accumulation stuff but this is pretty 00:18:51.480 |
straightforward it's what it's what you would expect right um oh i haven't even looked at this chart but 00:18:57.240 |
i guess it shows there's a a bit of idle gpu when you're doing some of this um accumulation merging so 00:19:06.360 |
that's kind of your one gpu right you you take your stuff you find you how much vram you have 00:19:11.960 |
you take your sequences you optimize your batch size um if you can't fit that size then you can do 00:19:18.040 |
gradient accumulation with micro batches um there's other little niches like yeah when you have too many 00:19:24.920 |
micro batches your loss can get a little spiky because now you're you know you're doing more little 00:19:30.280 |
batches so little little niches if you're actually doing it you should start to look into but otherwise 00:19:36.600 |
at a high level that's kind of what's um going on here now that's all great one gpu stuff is pretty 00:19:43.960 |
easy right there's a bunch of libraries that'll help us do this one gpu fine tuning like full parameter 00:19:50.680 |
stuff a lot of companies will do it for us but what happens when i want to start to like use multiple 00:19:56.120 |
gpus right there was a whole gpu shortage but now now it's kind of chill like now you can rent multiple 00:20:03.400 |
gpus i can go to run pod i can go to prime intellect i can go to a bunch of these companies 00:20:09.080 |
and i can rent like two four eight i think run pod announced they're soon gonna have on demand 64 h100s 00:20:15.800 |
so let's say i've got like eight h100s or 64 h100s what do i do um i don't want to just you know 00:20:22.840 |
do a whole batch size and just run it that's kind of inefficient so that's where stuff like 00:20:27.000 |
data parallelism comes into play so data parallelism is where you basically um you put all of the model 00:20:34.840 |
weights on every gpu or several gpus then you split up your data and you run different batches on different 00:20:43.400 |
gpus so you know in parallel you have all the weights and you do training on different gpus on different 00:20:51.080 |
batches then you kind of accumulate them and you do this optimization and you update them all in 00:20:56.440 |
parallel there's a bunch of ways that we do this there's zero one zero two all this stuff and we'll 00:21:01.240 |
we'll kind of talk about that in a bit but the the interesting other little note here with data 00:21:06.360 |
parallelism is there's this whole crypto wave of um distributed training and everyone talks about you 00:21:14.040 |
know everyone has one gpu everyone has one macbook let's train god gpt5 on everyone's distributed 00:21:21.960 |
training what they're actually doing is they're doing data parallelism so like news research prime 00:21:27.800 |
intellect not to shine on the research they did it's still very interesting and very unique and like 00:21:32.200 |
novel and hard to do um they're doing this distributed training which is yeah it's not a mean um it's basically 00:21:39.240 |
where you have multiple data centers so like you have one node here one in australia one in europe 00:21:44.600 |
one wherever and you put the entire weights on them and then you train batches across there it's still 00:21:51.000 |
cool it still requires a lot of you know servers and different locations but it is still like the whole 00:21:57.400 |
weights have to be on every gpu on you know every server so that's kind of what this one is it's it's 00:22:04.280 |
basically as straightforward as you would expect um you throw all the weights and you do different 00:22:10.280 |
batches on different stuff and then you have optimization sinking across right so um there's 00:22:19.240 |
there's different types there's you can add gradient accumulation you can bucket gradients they talk about 00:22:24.920 |
like efficiencies to do this which i think is not what we should cover in an hour you know if you're at 00:22:30.360 |
a stage where you're doing this they kind of go over what are the main ways to do it but um yeah 00:22:35.400 |
that's kind of the two so we've gone from single gpu training to now we have data parallel where 00:22:40.840 |
let's say i have a node of four h100s i want to train a llama 70b a llama 7b what i can do is i can 00:22:47.320 |
throw llama 7b on each gpu split my one trillion tokens so let's say like a hundred million tokens split 00:22:54.760 |
that into batches do each gpu trains a batch and then kind of efficiently chunk them together so 00:23:01.880 |
this is kind of a summary of what's happened so far so if you're at this stage here's what you should 00:23:08.920 |
do first determine the best global batch size by either running experiments or just consume uh you know 00:23:16.680 |
consult their literature read their blog post to figure out how you should do this then select the sequence 00:23:22.040 |
length this kind of takes a lot of compute what's done in modern models is you train on a small 00:23:28.360 |
sequence length for like 90 of the tokens and then the last 10 you kind of increase so if i'm not 00:23:33.800 |
mistaken llama did like 14 trillion tokens at 4 000 sequence length and then the last trillion they did 00:23:40.360 |
at like 100k then they you know update it's very common because it takes a lot of compute and this kind of 00:23:45.560 |
works to generalize as well but two to eight thousand token length works well for the evaluations we have 00:23:51.720 |
today i think this is something that people don't really look into enough right um sure you could do 00:23:57.880 |
two to eight thousand tokens because that works for mmlu but if you're training something like a llm judge 00:24:03.960 |
classifier you probably don't always need 8 000 tokens right so be smart and train on less tokens if you 00:24:09.400 |
need less tokens um then we know the bat size we can find the minimum local bat size for a single gpu 00:24:17.720 |
and then kind of increase that till you run out of memory then determine the number of gps use you 00:24:23.160 |
have for data parallelism and you kind of throw your stuff across them so cool then you have gradient 00:24:30.440 |
accumulation if you need it but that's kind of basic data um parallelism now we've kind of got this whole 00:24:38.760 |
series of how do we improve upon data problem we've got um 0 1 0 2 0 3 um so what is 0 0 is redundant 00:24:51.000 |
actually let me also check chat real quick if there's anything there's map reduce to kind of do this 00:24:57.240 |
stink uh they have fancy versions of map reduce as well numbers are different for reasoning models 00:25:02.920 |
are when it's putting out yeah reasoning models are a little different but that's kind of rl rl is 00:25:08.440 |
at some level similar right you're still just training a transformer um what's happening in that 00:25:14.360 |
transformer is somewhat irrelevant okay so zero what's going on with zero um so data parallelism is an 00:25:22.920 |
efficient way to scale training the naive replication of optimizer states gradients and parameters across dp 00:25:30.920 |
introduces significant memory redundancy right so what's happening here is as you in regular data 00:25:38.600 |
parallelism as you split the entire model you also have to at every single gpu do all the optimizer states 00:25:47.320 |
all the gradients and you know all the parameters on every gpu and that's kind of redundancy right so 00:25:53.560 |
what zero one two and three look into is how can we shard more than just the weights so can we sort of 00:26:01.240 |
shard this optimizer state can we start shard the gradient partitioning so and then there's uh zero three 00:26:08.120 |
which is kind of interesting where you start to actually shard the parameters themselves so uh zero 00:26:14.760 |
one zero two are more efficient ways of data sharding basically instead of just model like instead of 00:26:21.720 |
everything on every gpu can we also shard this optimizer state a little bit and then zero two is 00:26:27.480 |
can we also shard this gradient a little bit okay back into what's happening in memory so um 00:26:35.320 |
in mixed precision training we have parameters we have gradients we have optimizer states which 00:26:41.000 |
are stored in higher precision and then gradients in higher precision if we want gradients in higher 00:26:48.200 |
position uh precision so zero one is partitioning the optimizer state so in regular data parallelism 00:26:56.440 |
all ranks gather uh gradients after the backwards pass simultaneously and perform identical optimization steps 00:27:04.840 |
that's a lot of duplicated of work we can avoid it using redundant memory stores at the same pad so 00:27:10.600 |
they have this all gather which is kind of um community it's kind of yeah let's shard this um 00:27:17.800 |
optimization state so forward passes happen at the same time backward passes at the same time 00:27:24.680 |
perform a reduced scatter on gradients each replica performs an optimization step on its local optimization 00:27:31.480 |
steps then you perform this all gather to kind of send these missing replications pldr um they're now just 00:27:40.920 |
sharding optimization uh how they do it they go into more details here now zero two is kind of the next 00:27:48.920 |
step we have a reduced shatter optimization where now not only does each replica have to shard optimization but 00:27:57.240 |
they're also calculating gradients right so why don't we shard this gradient optimization state so during 00:28:04.120 |
the backward instead of performing in all reduced we perform a reduced scatter reduced scatter spreads 00:28:10.040 |
gradients needed in memory and it's more memory saving again um that's that's cool they have a chart 00:28:17.400 |
at the end of this that kind of shows how all of this works um then we have zero three zero three is the fun one 00:28:23.240 |
a fun one so if you can or if you need to if you're training something like let's say you're doing uh 00:28:30.680 |
llama 70b or something post training and you want to crazily increase your sequence length right so you 00:28:37.800 |
want to go up to a hundred thousand sequence length well now you're not fitting in even one node this is 00:28:43.800 |
where um fsdp starts coming in where you have to shard your parameters themselves so instead of each gpu 00:28:51.160 |
having every single all the model weights now you're sharding the model weights across different gpus 00:28:58.040 |
sounds pretty interesting right well in zero three you still need good communication between the gpus 00:29:04.200 |
there's kind of technical ways that they do this how they do these forward passes these backward passes 00:29:09.480 |
later on we'll learn about sharding across different dimensions the way this works is basically 00:29:17.320 |
training and all this attention and all of like math all all of this stuff is just matmos right 00:29:24.360 |
it's matrix multiplication so it's row times column there's smart ways to do that and manipulate row times 00:29:31.400 |
column and yeah they they kind of just make it really easy to do it all this explained what's going 00:29:37.720 |
on so if interested read it if you really want to use it pretty much any training library whether that's 00:29:43.720 |
pytorch or whatever you're using uh you can you know in a few lines of code just implement this and it 00:29:49.720 |
handles all of this for you on the back end now um yeah that's that's kind of the three stages uh 00:29:58.360 |
let's let's kind of take a break here and see if there's any questions 00:30:02.280 |
what's the entity okay for weights activations gradients 00:30:08.280 |
and optimizer states what's the it's intuition for which ones we want in lower and higher precision 00:30:15.320 |
so um for weights you can it kind of also depends on what you're doing if you're doing training or 00:30:21.320 |
inference right um what you're really doing here isn't higher or lower precision there's there's mixed 00:30:27.240 |
precision for some stuff like gradients and activations you do them in higher precision 00:30:33.560 |
there's more and more stuff coming out about training and lower precision and mixed precision 00:30:39.640 |
but the general um you know general rule of thumb is you want to use as much high precision as you can 00:30:47.880 |
afford like half precision is cool for inference it's different you can do a lot of quantization but for stuff like um 00:30:55.800 |
gradients or stuff that has not a lot of compute uh not a lot of memory overhead you want to use higher 00:31:02.280 |
precision right so gradients are small they don't need a lot of memory do them in high precision model 00:31:08.360 |
weights uh let's try to lower the precision a little bit but yeah okay now we have um tensor parallelism 00:31:17.160 |
tensor parallelism is a way of more splitting it's kind of splitting up model weights across their hidden 00:31:24.040 |
dimension so uh data parallelism was can we shard the weights this is a way of splitting it up across um 00:31:33.640 |
sensors so you know here's the math of what's happening you're basically doing 00:31:40.120 |
math moles rows and columns here's a smart way to do it here's a smart way to reduce it here's what 00:31:47.080 |
we're doing in a transformer block um a lot of this is kind of the background behind stuff like flash 00:31:53.080 |
attention group query attention uh this is used in papers like llama 3 and whatnot but high level this 00:32:00.840 |
is just kind of another way of charting your um your model it's done column wise row wise um there's 00:32:10.520 |
there's two different ways to do it there's different communication overhead in both so whether you're 00:32:16.280 |
doing column wise or row wise um there's there's different benefits to both you should look into 00:32:23.320 |
them um tensor parallelism and transformer layers is another interesting one so help you make progress 00:32:30.600 |
yeah that's there i think for time's sake we're gonna start going a little little quicker um the the 00:32:37.800 |
interesting thing here is kind of this is where as you're all in one node all this stuff kind of works 00:32:43.400 |
right you can do different layers on different gpus but the more you do this the more important the 00:32:49.560 |
interconnect between gpus is right so if you're on a single node of highly interconnected gpus this stuff 00:32:57.160 |
will start to be fine but the minute you're using pcie cards or multiple nodes you're kind of adding too 00:33:03.240 |
much bottleneck for any of this to be worth um yeah next we kind of have sequence parallelism i believe 00:33:12.520 |
context parallelism context parallelism is or we went through sequence parallelism sequence parallelism 00:33:18.360 |
kind of same thing um so instead of tensor parallelism split it across the sequence this is where stuff 00:33:26.200 |
like ring attention starts to come in i think a while ago we had a paper club on ring attention so 00:33:32.360 |
for context and sequence parallelism if instead you know you can't scale your context length uh you want 00:33:39.480 |
to take these 128 000 sequences and you want to split them across gpus you want to let's say have like you 00:33:47.000 |
know first 10 000 tokens on gpu one next 10 000 on gpu two and so on um there's kind of this cool little 00:33:55.480 |
way that you can do this so you can you know this goes into the math about what's happening but 00:34:01.400 |
it's kind of it's kind of so conceptually you think what is attention attention is the relationship 00:34:07.560 |
between all tokens right you still need to calculate this so if you have a hundred thousand tokens we need 00:34:13.880 |
to know how does each token relate to every other token which you know you would kind of expect that to 00:34:19.480 |
have to happen on one gpu you need to look at all tokens that wants to do attention instead you can kind of 00:34:25.560 |
shard them and do this ring attention where you you have each gpu do part of the attention mechanism 00:34:31.880 |
you kind of sync them you send qks between each other then you do this big query uh value calculation 00:34:38.040 |
and now you've kind of done attention across gpus um there's there's ring attention there's zigzag 00:34:44.280 |
attention there's little optimizations for this stuff but it's it's pretty cool it was a lot of like 00:34:49.960 |
how do we get this infinite context how can we do um you know long context this is kind of what happened 00:34:56.360 |
here um very very interesting stuff i think it's a little underrated so basically reading a little 00:35:03.160 |
about it is pretty cool context parallelism and sequence parallelism okay next is uh pipeline parallelism 00:35:11.480 |
pipeline parallelism is a pretty fun one so this is where you distribute consecutive layers so 00:35:18.600 |
instead of just sharding gradients or optimizations or sequences how about we start uh sharding different 00:35:25.160 |
layers so let's have like you know layer one two three on this gpu and different layers on different 00:35:31.480 |
gpus this once again it's like another approach to fitting really large models it it lets you reduce gpu 00:35:38.200 |
overhead for everything but once again like now they go into as you do this how do you continue training 00:35:45.480 |
right so they have these like all forward all backward so as you do your forward pass you need 00:35:52.040 |
to do a backward pass compute gradients to optimization so once you split up layers they've kind of got 00:35:58.600 |
here's what's happening as we want to do optimization right so all forward all backward um there's there's a 00:36:05.320 |
few other ways of this one forward one backward this is what llama three did um very interesting little 00:36:12.120 |
optimizations to make this stuff work um interleaving stages so if you've got different slices there's 00:36:21.000 |
just different recipes for all this zero bubble dual pipe um yeah expert parallelism is kind of a more 00:36:29.560 |
straightforward one where you have different experts in a mixture of experts model on different gpus 00:36:35.800 |
that one's pretty straightforward right so you kind of have a base router you have every gpu have its 00:36:41.720 |
own expert in some cases you can have multiple gpu multiple experts on multiple gpus and you kind of 00:36:48.440 |
have this efficient expert parallelism so once again now there's this kind of overview summary so 00:36:55.880 |
what's data parallelism it's along the batch dimension tensor parallelism is along the hidden 00:37:00.920 |
dimension so different layers sequence and context parallelism are if you need to extend context length and 00:37:07.800 |
you want to let's say train your model up to 500 000 token context how can we do that on a limited gpu 00:37:15.080 |
pipeline parallelism is along model layers and then expert parallelism is around mixture of experts 00:37:21.880 |
um then we've also got we can combine these with the zero stages right so zero one is sharding optimizers 00:37:28.920 |
zero two is sharding optimizers and gradients zero three is sharding um parameters alongside the other two 00:37:35.880 |
so this is kind of the high level if you have multiple gpus here are the things that you can do for 00:37:43.480 |
each one they kind of have you know here's how you double click into what you want to do so let's say you're 00:37:49.240 |
doing pipeline parallelism and you're sharding everything and you then want to extend your 00:37:54.920 |
sequence length right so let's go back into that section um you can come in and find out what are 00:38:00.760 |
what are different interleaving stages how do you want to optimize all this um what they go into here is 00:38:07.080 |
kind of still all at the you're not really doing research here you're kind of still at the applied 00:38:12.760 |
researcher stage right all this stuff is still available in libraries like pytorch so 00:38:17.720 |
there are cuda kernels that optimize this stuff like flash attention is made that already uses all of this 00:38:24.280 |
these are things that you can just kind of consume as you're doing your training if you're pushing the 00:38:30.280 |
bound like um xai now has a 200 800 cluster right or gpt 4.5 was a very chunky model that's where they have to 00:38:39.480 |
start to push on these ideas and develop even more novel um training strategies but yeah that's kind of 00:38:49.160 |
their high level um what's going on in different parallelization strategies after that we've kind of 00:38:57.400 |
got here's a recipe for finding the best um training config right so if you kind of skip everything 00:39:05.080 |
the original intro is kind of important right how do we train on one gpu how do we train on two how do 00:39:10.200 |
we train on a single node then the middle is kind of okay now your gpu rich you're probably smarter than 00:39:15.400 |
me how do you train across nodes um some of the reasons for some of these as well by the way that we 00:39:20.680 |
skipped over was if you have multiple nodes and you don't have interconnect some of these different 00:39:26.280 |
parallelization strategies will work better they they better optimize for not having that interconnect 00:39:32.920 |
between all nodes um so you know there's recipes for that as well but okay high level let's let's sum it 00:39:40.840 |
all up here again so if you want to find your best training config how do you do it so step one uh 00:39:47.240 |
fitting a training step into memory right let's figure out how we can fit a full model instance 00:39:53.240 |
on the gpu there's several use cases whether you're gpu rich or your gpu poor if your gpu poor um you know 00:40:01.320 |
there's full activation re-computation there's gradient accumulation there's stuff like that you can train 00:40:06.840 |
slower um basically you want to increase your gradient accumulation steps to processor larger brass sizes 00:40:13.800 |
and you kind of do what you do for gpu rich stuff if you have models under 10b use single parallelization 00:40:21.400 |
stuff so tensor parallel with um zero three across eight gpus this is kind of where you shard those 00:40:28.360 |
model weights across eight gpus for stuff between 10 and 100 billion parameters you'll need more than 00:40:34.040 |
eight gpus right they keep saying eight gpus because this is basically one node if you have stuff under 10b and 00:40:40.360 |
one node use fsdp with zero three if you have 10 to 100b you're going to need multiple nodes so this is 00:40:47.960 |
where you're going to start to do tensor parallelism and pipeline parallelism do tensor parallelism with data 00:40:53.240 |
parallelism uh only use zero three then when you're rich with with 512 or a thousand gpus um you know you 00:41:02.920 |
should not be reading this but there's there's more stuff that you can do here um tpu poor we talked 00:41:08.840 |
about it basically gradient accumulation find out what batch size can you support split batches do micro 00:41:16.120 |
batches but don't overdo it um yeah step two achieve the target batch size so depending on where step one 00:41:23.400 |
left us uh in terms of micro batches and our data parallelism our current batch size might be too small or too big 00:41:29.640 |
right so what you want to do is kind of scale up data parallelization or accumulation step so you know 00:41:35.480 |
keep doing gradient accumulation until you're no longer running out of memory and if you want to 00:41:41.480 |
start to do long context add in context parallelism so this kind of ring attention stuff to decrease our 00:41:47.880 |
global batch size we can reduce data parallelism and we can reduce context parallelism if you have more gpus 00:41:54.920 |
okay step three optimizing training throughput set up tensor parallelism using interconnect if you have it 00:42:03.080 |
then there's uh increased data parallelism so you know if you have more batches if you have more gpus 00:42:10.760 |
yeah use data parallelism and throw the whole weight on them keep the target batch size as big as you can 00:42:17.720 |
um try scaling up different parallelism parallelisms one by one and just keep experimenting uh benchmarking 00:42:27.400 |
they ran a lot of benchmarks i don't know if anyone found anything interesting in any of these that we 00:42:31.960 |
want to dive into but i want to leave 10 minutes for questions so um i'm going to skip through a good 00:42:37.560 |
bit of this uh diving into gpus fusing threading so what is gpu how does gpu work you know there's 00:42:45.080 |
high bandwidth memory there's interconnect there's global stores there's caching there's kv caching 00:42:51.080 |
there's cuda kernels and stuff that optimizes all this um i don't think this is as relevant with the high 00:42:59.080 |
level of if you want to train models across tpus what do you do so if you're interested read it it's 00:43:06.360 |
pretty useful stuff but at a high level i think it's a little too niche for our um here's kind of 00:43:14.360 |
how mapmoles happened let's see i think there's a conclusion at the end of this flash attention 00:43:20.600 |
is um you know cuda optimizations for all this ring attention is one mixed precision training is another 00:43:27.080 |
interesting one um mixed precision training is one thing but mixed precision inference is another we have a 00:43:33.800 |
lot more on uh mixed precision inference but for training um they kind of go into accumulation what 00:43:42.360 |
to have different precision for different training stuff fpa pre-training is kind of interesting some 00:43:48.040 |
people are doing it we want to see if it works um recent research including fpa lm deep seek v3 00:43:55.480 |
which is an interesting one um it shows potential that's kind of huge so the the big thing with 00:44:02.040 |
precision is basically as you cut your precision in half you also cut the vram needed to train in half so 00:44:09.480 |
full precision versus half precision for 600b model means you know i can now cut off 600 billion 600 gigs of 00:44:16.760 |
vram so pretty big stuff with precision okay conclusion uh we know we now know how to train 00:44:24.280 |
we now know how stuff was efficiently trained like gamma 405b and v3 on multiple gpus uh once again they 00:44:31.000 |
kind of have their their cheat sheets if you want cheat sheets i think it's good to go over this again um 00:44:38.200 |
there's there's there's a lot of stuff about fpa versus intake and how they're not the same so 00:44:43.480 |
quantization yeah um you should look a little bit deeper into stuff like that um llama's 40 parallelism 00:44:53.320 |
what else um a lot of papers a lot of references um yeah pretty pretty solid interesting little 00:45:02.040 |
high level overview of how to train it i think uh as a quick recap we should just remember that they 00:45:08.760 |
have this you know here's how to fit stuff into memory pick your model size pick your gpus 00:45:14.760 |
pick your batch size and optimize your training throughput until stuff doesn't crash here are the 00:45:21.000 |
different parallelization strategies we talked about right so data parallel everything on one 00:45:25.400 |
zero one zero two zero three is where you start sharding the weights tensor parallel across hidden 00:45:31.480 |
dimensions pipeline parallel across model layers context parallel for sequence length expert parallel for 00:45:39.560 |
mixture of experts um pretty nice overview i think you know for the ai engineer crowd even if you don't 00:45:46.920 |
have five hundred or a thousand gpus just kind of familiarizing yourself with what's going on what 00:45:53.800 |
training like what parts of model and pre-training takes how much gpu um you know like yeah it's not 00:46:00.920 |
all attention there's these speed forward layers there's these active there's these activation states 00:46:05.880 |
there's these gradients we have to calculate how those actually take up more than half of your 00:46:10.040 |
training in front uh training memory all that stuff is pretty interesting to know um then the one thing 00:46:16.600 |
that this blog i might have missed or it didn't do for me was it didn't mention that you know when you 00:46:22.840 |
actually do inference and you consume this stuff um you're actually not accumulating a lot of these 00:46:28.920 |
gradients let me find this little diagram again so for example if we're doing let's say um 00:46:35.240 |
um llama oh that's cool this doesn't want to work 00:46:56.040 |
where'd it go there it is if we're doing llama 3 8b at a sequence length of 60 of 00:47:03.960 |
a thousand um all these parameters gradients that are taking a hundred gigs of memory this optimization 00:47:15.800 |
these optimizers you know 70 gigs of vram we don't do this during inference we don't have to do all this 00:47:20.680 |
back prop so the attention is you know some of it but all this other stuff for training is actually a 00:47:27.080 |
lot less and we can quantize better so it didn't it didn't hit on that as much but yeah high level 00:47:32.520 |
good read to kind of throw this thing in chat gpt or cloud or whatever your lm of choices 00:47:40.680 |
stay you're reading this section have a breakdown each section for you spend an hour on it it's useful to 00:47:46.120 |
know for the ai engineer um kind of understanding the difference between training and inference is 00:47:51.800 |
always useful and i think um you know a good rule of thumb is like if you're hiring a research intern 00:47:59.320 |
would they be able to explain all of this to you hopefully pros and cons trade-offs different multi-node 00:48:07.800 |
single node single gpu um if they could that's great and then you should be able to do it yourself 00:48:14.520 |
too right it's just good general knowledge so pretend you're in an interview be able to learn all this 00:48:18.760 |
stuff it's good background knowledge but yeah um that's kind of high level overview of this big big 00:48:25.640 |
pre-training blog um questions thoughts comments we have five minutes i'm sure i skipped out a lot of 00:48:32.440 |
stuff i'm sure people have a lot to add in uh out of curiosity what precision or state of the art model 00:48:39.720 |
is trained in mostly uh higher precision for gradients so there there is mixed precision as of recently 00:48:46.600 |
um you know we're cutting that down i think it said that v3 was trained in fpa which is kind of interesting 00:48:54.840 |
but there's still a lot a lot of precision um yeah that's that's paper thoughts comments questions 00:49:08.120 |
who wants to train 70b from scratch knowing all this are we prepared to blow thousands of dollars an hour 00:49:22.120 |
i think it's a good experiment to run like uh realistically h100s you can rent at two dollars 00:49:30.760 |
an hour right so rent an h100 train a 7b rent two h100s on different like you know non um non good 00:49:40.040 |
interconnect try some of this stuff because it's very approachable from like common libraries to do this 00:49:45.560 |
training then rent a node then rent two nodes and like realistically you only have like a few different 00:49:51.560 |
parallelization strategies right um a precursor to this blog from hugging face is fine web it goes 00:49:57.800 |
over like i think five trillion maybe 15 trillion pre-training tokens so yeah go go learn um very 00:50:04.200 |
useful to learn and just try all this stuff and see what's different and burn burn not too much money 00:50:09.480 |
right two dollars an hour for h100s eight h100s twenty cents an hour like we can afford this 00:50:21.000 |
um papers that can help me do this for non-english languages um the presenter from last time works on 00:50:29.560 |
a state-of-the-art german model not german some some multilingual model so go to paper club someone is 00:50:36.760 |
apparently an expert in non non um english models cheapest h100 provider on the internet crypto 00:50:43.240 |
bullshit the crypto companies are raising a lot of money and subsidizing gpus to be very very cheap deep 00:50:49.640 |
info has them for two dollars model will give you money a lot of people will also sponsor compute but 00:50:55.640 |
also um hyperbolic is a good one shout out hyperbolic i think uh it's also pretty cheap right like two 00:51:01.720 |
dollars an hour three dollars four dollars sf compute is not um serverless i believe sf compute is bigger 00:51:09.640 |
cluster at more time i like the guys but um as run pod and stuff adds 32 h100 64 i don't know if we need 00:51:19.000 |
but yeah cool um what other questions questions are on where to get gpus 00:51:28.440 |
pay money that's not that expensive very subsidized right now like to a point where somewhat more 00:51:36.360 |
than compute people don't like hyperbolic apparently so never mind no hyperbolic we don't like them i guess 00:51:45.960 |
um but they are slightly cheaper than than the others but we don't like them i guess okay volunteers for 00:51:51.400 |
next week um who wants to cover paper next week there is a cool pre-training uh post-training survey 00:51:59.240 |
uh this is uh i don't know what is um i don't know if it's good uh qw32b came out i don't know if it's 00:52:05.720 |
thinking the same name oh is is the is the paper good i haven't looked i don't think there's a paper if 00:52:11.240 |
the paper comes out that would be cool i don't know no paper at all blog post which was shitty blog post 00:52:17.720 |
4496 word blog post you know thought they had a very sick title on this uh embracing the power of rl 00:52:27.720 |
700 word blog post that has nothing but here's code on how to use it okay um llm post training there's a 00:52:38.760 |
survey paper on post training very cool little chart here um i think it's a good one if someone wants to 00:52:46.840 |
cover it okay i'm gonna volunteer this one to someone if someone's down otherwise um yeah if 00:53:03.000 |
there's no more questions oh actually this is pretty short so uh 20 page on post training very useful 00:53:14.520 |
why not mcts is great dude you know how much of a time sink this was you know how much time was wasted 00:53:22.600 |
on mcts who remembers strawberry who remembers multi on saying that they are q star strawberry mcts 00:53:30.280 |
waste of time but you know pretty good and i think it's good background right so mcts is kind of 00:53:39.960 |
tree search throughout potential next tokens can we span out and do um inference time scaling test time 00:53:48.040 |
compute instead of rl a bunch of people wasted a bunch of time on this uh some companies said that 00:53:54.200 |
they are q star by by doing this but yeah okay another um another paper that someone has recommended 00:54:03.400 |
that we don't know if it is credible the fft strikes back an efficient alternative to self-attention 00:54:10.200 |
from one guy at usc is he plugging his own paper 00:54:14.360 |
i don't know uh alternatives to attention are always interesting but okay um that's that's paper 00:54:23.960 |
club this week guys super nicely done thank you so much you covered two days of reading in one hour 00:54:34.040 |
in general yeah i i feel like there needs to be more good papers like i i feel like we uh paper 00:54:42.680 |
i feel like paper velocity has slowed somehow yeah i mean this isn't even a paper on anything of their own 00:54:51.800 |
this is just like uh okay this counts this counts okay very very good survey okay well that's paper club