back to index

The Ultra-Scale Handbook by HuggingFace Nanotron


Whisper Transcript | Transcript Only Page

00:00:00.000 | a bit on Laura's okay I'll restart we just started recording so basically Hugging Face put out this
00:00:07.020 | blog post like a week or two ago and it's a very in-depth estimated reading time is like two to
00:00:14.080 | three days really good visualizations and like little widgets on how to do pre-training of
00:00:19.480 | models so there's a quick overview at the start the plan is basically we're not covering a two-day
00:00:26.300 | reading in an hour we'll go through the main points of everything and try to do try to do some little
00:00:33.160 | you know double click into whatever people have questions on if anyone has thoughts comments you
00:00:39.200 | know open whenever a lot of it is pretty like mathematical guide on how to pre-train a language
00:00:44.960 | model so there's a lot of resources out there for if you're doing fine-tuning Laura synthetic data
00:00:50.560 | generation like there's a lot of libraries that make it easy to do little post-training stuff or
00:00:55.800 | fine-tuning but they wanted to solve the question of okay everyone says like it's hard to train across
00:01:01.920 | multiple nodes multiple GPUs what is all this data parallelism pipeline parallelism how do we know about
00:01:08.180 | like you know like who's actually writing about this so Hugging Face has a cluster of 512 GPUs
00:01:14.940 | they ran a thousand different experiments and they basically put together uh hey if you want to go
00:01:20.180 | train pre-train a model here's kind of like your zero to eighty percent of what you should know
00:01:25.560 | about distributed training so this is when you have more than one GPU and kind of more than two GPUs
00:01:32.280 | how can you efficiently scale your training so all from you can't have a model that fits on one GPU
00:01:39.920 | to how do you have like let's say a 7b model and you want to efficiently use 10 GPUs versus how do you have
00:01:47.000 | like multi-node so if you have a train run what if you have multiple nodes um what makes up the most of
00:01:53.540 | like compute when you're training so activations back prop um all this stuff so TLDR they kind of kick
00:02:01.780 | off with a basic overview of this and they're like there's pretty good open source models but they don't
00:02:08.100 | really talk too much about these little nitty-gritty stuff um at a high level something that this blog
00:02:13.520 | post also kind of didn't explicitly state but does show is you see all these gray dots so they ran over a
00:02:20.260 | thousand experiments all these gray dots are failed runs so some of them were kind of failed force runs
00:02:25.180 | like force failed where you know you overdo your bad size and you know you'll run out of memory and guess
00:02:31.540 | what you ran out of memory but the other little interesting niche that um I've noticed is when you
00:02:39.400 | do a lot of like multi-node multi-GPU training is sometimes a random GPU fails sometimes one GPU has
00:02:47.160 | like a weird loss and your train run is kind of screwed so um it doesn't account for those little niches right
00:02:53.720 | so sure there's checkpointing sure there's stuff like this but uh in practice sometimes when you're
00:02:59.520 | training on like a thousand GPUs yeah one GPU randomly dies and it causes little issues but anyway
00:03:05.080 | this is kind of just the high level like here's what to do so um kicking off let's go over the
00:03:13.000 | fundamentals and background so they have this like cool interactive widget that's basically like here
00:03:19.940 | are the presets of models that exist for the llama family so there's a llama tiny llama 8b llama 70b and
00:03:26.920 | it kind of lets you visualize how much VRAM is required to do pre-training so if you want to like
00:03:32.520 | train the full thing at different precision and a different um sequence lengths right so as we scale
00:03:40.840 | up the sequence length we can kind of see what happens to um the VRAM required right so let's say we want to
00:03:46.680 | train um there's like a few key parameters that you should know here right so you've got that size and
00:03:52.440 | sequence length and that kind of makes up how much GPU usage you're going to have so how many parameters
00:03:57.960 | like what's the weight so this is a 70b how many batches are you doing so more batches more more GPU
00:04:04.760 | needed more sequence length significantly more GPU needed and kind of like one of the takeaways is you know
00:04:11.480 | the attention blocks here quadratically increase right so as I double the batch size we quadratically
00:04:18.440 | or sorry double the sequence length we um quadratically increase attention so if we start
00:04:25.080 | with a sequence length of like 4 000 you can see how like the gradients optimizers these make up like
00:04:31.400 | you know oh I did 14 000 as we do about 4 000 the attention is sure it's some of it like to to train
00:04:39.320 | llama 7db uh 4 000 sequence length you know activation memory is about 80 gigs parameters gradients
00:04:46.920 | optimization steps are about 300 gigs parameters are obviously you know it's about one to two if
00:04:53.400 | you're doing um mixed precision you can also toggle that off so one to two for mixed precision then
00:04:59.480 | gradients and optimizations they add a lot people don't realize that this um attention mechanism it
00:05:05.400 | actually doesn't take that much but as we scale up let's say to a hundred thousand context length or a
00:05:11.640 | million this attention mechanism really eats up all the all the um memory so cool little visualization
00:05:20.760 | um the other stuff I guess yeah like hidden dimension cool bat size cool it's cool how they at least just
00:05:26.840 | have these presets for you though you know um vocab size is kind of interesting if you want to kind
00:05:32.360 | of see what this is changing here um you could change different activations but basically you guys should
00:05:39.000 | play around with this like read through this intro um pldr of this blog post they have like a high
00:05:45.400 | level overview if you want to get anything just read through this and you'll you'll kind of have
00:05:50.120 | a decent understanding of what's going on what takes memory when you train a foundation model from scratch
00:05:56.120 | and yeah um okay let's go to inefficiency so I guess they're in 4 000 experiments not not a thousand and
00:06:04.360 | then they run them across different sizes so 1b 3b 8b to 70b here's what crashed here's more tokens I
00:06:11.160 | thought this chart was kind of weird to read I don't know they wanted it like pretty instead of functional
00:06:16.760 | but I didn't find the charts too useful by the way all this context is pretty much pure gold um so high
00:06:24.680 | level overview um they're going to look at three challenges when you do training right so memory usage this
00:06:31.480 | is kind of your hard limitation right how many gpus do you have and what can you fit so one gpu
00:06:37.720 | you got to adjust your batch size your sequence length two gpus uh how do we do compute efficiency
00:06:43.720 | right how do we share the model do we put data parallelism where we have the entire model on every
00:06:48.600 | single gpu or do we do something like uh shard the model weights if we can't even fit them like that
00:06:54.360 | then communication overhead this kind of becomes when you're gpu rich rich you know most people gpu poor you
00:07:00.440 | have one two gpus you can do some fine tuning some post training gpu rich rich is when you have
00:07:06.120 | multiple nodes so there's there's a thing called like gpu interconnect right so what's the communication
00:07:12.680 | bandwidth it's a buzzword that we hear around pretty often what that basically means is like
00:07:17.400 | when you're sharding stuff between gpus they need to talk to each other right so like if i have
00:07:23.080 | layer one of a model on one gpu and layer two on another gpu well they kind of have to sync to do
00:07:29.160 | their thing right so if they've got good communication this is basically what envy link is if they have
00:07:35.240 | good communication you're kind of chilling there's stuff we can do ring attention is a way that you can
00:07:39.880 | like do really long context and you split this really like chunky sequence length you can do like a
00:07:47.720 | million context scaled across multiple gpus now that only works if you have low communication overhead
00:07:55.080 | meaning you know the gpus talk to each other very very like closely so you have one node of eight gpus
00:08:01.480 | all interconnected with high communication that's cool but then when you become gpu rich and you have
00:08:06.680 | multiple nodes how do these nodes talk together well now you have like another level of communication
00:08:12.120 | bottleneck right so you have nodes that have good interconnect but the node to node communication you
00:08:17.400 | probably want to do something different so these are kind of like the three main things um if anyone is
00:08:23.480 | really doing this you probably don't need to listen to this lap yap from me you probably know more than me
00:08:28.600 | but uh they give good formulas and like explanations as to how you want to mathematically figure out what's
00:08:35.080 | your batch size what's your sequence length based on um your available compute they have a little cheat
00:08:41.720 | sheet here we can go over pretty quick um so small models use single parallelism you know if you have eight
00:08:49.160 | gpus in a small model you can probably throw the full model on every gpu if you have a large model and
00:08:54.120 | you need more than eight gpus there's tensor parallelism if you have a bunch of gpus there's
00:08:59.800 | like moe parallelism there's stuff like that um there's a section on optimizing throughput so sensor
00:09:07.160 | parallelism data parallelism they kind of have this little cheat sheet um then they have parallelization
00:09:12.360 | strategies right so these are kind of the ones that they go over data parallelism is pretty straightforward
00:09:18.040 | they give you pros and cons then they have zero one two three um tensor parallelism pipeline and context
00:09:25.080 | we'll we'll quickly go over all of them but this cheat sheet is interesting it seems more like uh
00:09:31.720 | if you're interviewing a research intern you would expect them to be able to yap about this but if you're
00:09:39.160 | really doing something in training and if you're giving someone access to 500 or a thousand gpus
00:09:45.160 | they better not be needing this cheat sheet but i guess it's useful you know um yeah so step zero this
00:09:53.400 | is still kind of that overview right what happens when you're training on one gpu so very very high
00:09:59.720 | level we know that lms are trained to predict the next token right so what happens is you do a forward
00:10:05.960 | pass of your data you do like gradient accumulation you know you kind of do a forward pass you accumulate
00:10:13.560 | gradients you do a backward propagation where you update these gradients you run an optimizer step
00:10:18.520 | and now we have new new parameters right so the interesting little diagram but this is kind of
00:10:25.640 | what's happening when you train on a single gpu forward pass backward pass to compute gradients
00:10:31.320 | optimization step to update gradients and parameters you basically keep doing this at scale then you can start
00:10:38.120 | batching your input so the interesting thing about this attention which is taking up all of this
00:10:44.440 | compute right out of this 300 terabytes of ram required to train an 8b at that scale or let's say
00:10:51.640 | 400 gigs of vram required to train uh 70b at 256 right all of this little attention blocks they're really
00:10:59.000 | good at running in parallel so you don't necessarily sequentially do attention attention attention you run all
00:11:05.720 | these in parallel that's what gpus are good at so when you're doing stuff in parallel you want to start
00:11:11.720 | doing batching right batching is basically where you do multiple sequences at once so you have eight
00:11:17.560 | sequences you throw them in one batch and then you run the whole batch in parallel so um that's kind of
00:11:24.040 | what your batch size is they're like forget all that stuff what we really think about is batch size tokens
00:11:29.960 | batch size tokens is batch size times your sequence length so if each sequence is a hundred tokens like
00:11:35.640 | a hundred characters and you have eight batches of a hundred characters you're actually basically just
00:11:41.720 | training 8 000 tokens um a suite this is kind of a cool little line right so for people not super reading
00:11:48.680 | up on pre-training stuff lms are typically trained on the order of four to 60 million tokens per batch so the
00:11:55.320 | batch size has been kind of increasing as we get bigger and bigger clusters and we get more efficient
00:12:00.040 | training so llama one was trained on four million token batch size tokens uh each batch had whatever
00:12:07.960 | many um batches times how many ever tokens so about four million token batch size tokens and uh one point
00:12:16.360 | wait sorry form sorry it was with a batch size of four million tokens for 1.4 trillion tokens so yeah
00:12:22.040 | uh four million batch size tokens and they trained on troll total of 1.4 trillion tokens llama three was
00:12:28.200 | trained on 15 trillion tokens and i don't know the batch size tokens off the top of my head deep seek
00:12:33.560 | was trained on 60 million batch size tokens for 14 trillion tokens so numbers go up more efficient training
00:12:40.040 | is very cool so first challenge is scaling our model to have these large batch sizes and running out of
00:12:48.040 | memory issues right so as much as i want to parallelize everything i'm going to run out of
00:12:52.760 | um i'm going to run out of memory so the question here we're trying to solve is what should we do when
00:12:58.040 | one gpu doesn't have enough memory to hold a full batch of target batch sizes well um basically you know
00:13:05.000 | make the batch size smaller and if you really can't fit it then let's start adding more gpus but what
00:13:10.760 | happens when you can't fit your one batch size of tokens is you well they want to kind of explain what
00:13:17.160 | that is right so let's look at what happens in the memory so when you train you kind of have a few
00:13:23.080 | things going on right you have um model weights which is basically the parameters so parameters are like
00:13:31.080 | depending on the precision they go into all the math about this of how many bytes it takes per
00:13:35.960 | parameter then you're computing gradients for backprop then you have an optimization state that you have
00:13:41.480 | to keep keep held and activate and activations additionally you you need to leave a little bit of a buffer
00:13:48.760 | cuda kernels add one to two gigs of gpu memory so if you're kind of doing local stuff you know if you're
00:13:54.520 | trying to like do some training on a 7b model and you have like a 40 90 50 90 or if you have like a
00:14:01.320 | macbook with 16 gigs of ram well 16 or 24 gigs of ram doesn't mean that you can use it all right
00:14:08.040 | your cuda kernels take one to two gigs you're like screen in general like laptop processing takes a few
00:14:14.680 | gigs so keep that into account quantization stuff is this crazy thing you do mix precision training where you
00:14:21.240 | can kind of um keep cut the cut the um cut the memory requirement for the model weights um another
00:14:30.360 | thing is that you're not consistently using all the memory right so this is kind of four training steps
00:14:36.760 | of llama 1b you see spikes right so at one second memory being used was up here 50 gigs then it dropped
00:14:45.800 | when it was doing other stuff right so gradient accumulation autograd all this stuff it drops
00:14:50.920 | then it does um activations and drop so it's like kind of little spiky like this um they go into a
00:14:58.920 | bit of math about how to calculate all this but i think if you're interested in this you should kind of
00:15:04.200 | read it on your own time it's pretty long um this is always interesting so they kind of have
00:15:10.600 | interesting little tables here that show you how much memory you would need for fixed precision
00:15:15.720 | or half precision training with gradient accumulation gradient accumulation is a pretty cool little thing
00:15:21.560 | where you can kind of you know add every couple steps accumulate gradients and then do a loss over the
00:15:27.560 | average of them activation memory is another interesting one um it's it's another thing that you know
00:15:35.320 | activations take memory uh we are running a little slow so i'm gonna go a little quicker but if anyone
00:15:43.160 | wants to double click into any of this stuff um you know we can always come back towards the end or just
00:15:48.520 | interrupt me and we can go into it now um yeah sorry just one quick question and um that uh that graph that
00:15:55.800 | you showed is is um on the um uh for epochs right like so that's the first epoch in uh by the end of
00:16:04.120 | the first epoch right this is just training steps so all the steps okay yeah so steps could be like you
00:16:11.480 | know per batch for whatever after you do the training after you accumulate your loss and stuff after you do a
00:16:17.800 | back prop there's just a little dip so this is per batch um yeah so it kind of depends epochs are just a
00:16:25.080 | broader term for this grouping of these right but similar concept yeah a lot of stuff is one epoch you
00:16:32.520 | don't train as much on the same data anymore kind of interesting but uh same same concept okay uh gradient
00:16:41.320 | accumulation this is a very fun one so as you accumulate these gradients um yeah so epoch meta
00:16:49.960 | might be shifting to four epochs there's interesting little niches there so high quality data more epochs
00:16:56.360 | retraining data less epochs one epoch is basically training on a full trade like a full pastor of all
00:17:03.160 | your training data right so i've had thousand samples if i train on a thousand all my one thousand samples that
00:17:08.520 | is one epoch uh two epochs would be you know do the thousand samples and do it again and there's like
00:17:14.920 | there was this whole phase of one epoch on a trillion tokens is the new norm i guess new meta is now four
00:17:22.760 | epochs according to some paper okay so um gradient accumulation this is kind of interesting right so
00:17:30.360 | that gradient storing these gradients as you train is pretty important right as we do a forward pass we have
00:17:36.920 | to kind of do a backward pass and calculate the gradients right so we predict some tokens then we
00:17:42.280 | look at okay what was the actual token and how far off are we should we are we moving in the right
00:17:47.240 | direction or are we completely off for every parameter for all these gradients we kind of have
00:17:51.800 | to store them then we have multiple batches then we have to do an optimization now storing this little
00:17:58.040 | gradient um this little these gradients for every single parameter adds up it takes a lot a lot of memory so
00:18:05.400 | what we can do is this gradient accumulation right so instead of doing one entire batch we split our batch
00:18:12.840 | into micro batches we do a forward and backward pass on each micro batch complete compute the gradients
00:18:19.560 | so if you have a batch of 16 instead cut it up into small micro batches of four do forward backward passes
00:18:31.560 | and kind of only compute smaller gradients and then kind of accumulate them average them out it really helps
00:18:39.960 | you reduce your memory footprint there's a bit of a overhead right now because instead of fast parallel
00:18:45.560 | computation we now have to take time to do all this gradient accumulation stuff but this is pretty
00:18:51.480 | straightforward it's what it's what you would expect right um oh i haven't even looked at this chart but
00:18:57.240 | i guess it shows there's a a bit of idle gpu when you're doing some of this um accumulation merging so
00:19:06.360 | that's kind of your one gpu right you you take your stuff you find you how much vram you have
00:19:11.960 | you take your sequences you optimize your batch size um if you can't fit that size then you can do
00:19:18.040 | gradient accumulation with micro batches um there's other little niches like yeah when you have too many
00:19:24.920 | micro batches your loss can get a little spiky because now you're you know you're doing more little
00:19:30.280 | batches so little little niches if you're actually doing it you should start to look into but otherwise
00:19:36.600 | at a high level that's kind of what's um going on here now that's all great one gpu stuff is pretty
00:19:43.960 | easy right there's a bunch of libraries that'll help us do this one gpu fine tuning like full parameter
00:19:50.680 | stuff a lot of companies will do it for us but what happens when i want to start to like use multiple
00:19:56.120 | gpus right there was a whole gpu shortage but now now it's kind of chill like now you can rent multiple
00:20:03.400 | gpus i can go to run pod i can go to prime intellect i can go to a bunch of these companies
00:20:09.080 | and i can rent like two four eight i think run pod announced they're soon gonna have on demand 64 h100s
00:20:15.800 | so let's say i've got like eight h100s or 64 h100s what do i do um i don't want to just you know
00:20:22.840 | do a whole batch size and just run it that's kind of inefficient so that's where stuff like
00:20:27.000 | data parallelism comes into play so data parallelism is where you basically um you put all of the model
00:20:34.840 | weights on every gpu or several gpus then you split up your data and you run different batches on different
00:20:43.400 | gpus so you know in parallel you have all the weights and you do training on different gpus on different
00:20:51.080 | batches then you kind of accumulate them and you do this optimization and you update them all in
00:20:56.440 | parallel there's a bunch of ways that we do this there's zero one zero two all this stuff and we'll
00:21:01.240 | we'll kind of talk about that in a bit but the the interesting other little note here with data
00:21:06.360 | parallelism is there's this whole crypto wave of um distributed training and everyone talks about you
00:21:14.040 | know everyone has one gpu everyone has one macbook let's train god gpt5 on everyone's distributed
00:21:21.960 | training what they're actually doing is they're doing data parallelism so like news research prime
00:21:27.800 | intellect not to shine on the research they did it's still very interesting and very unique and like
00:21:32.200 | novel and hard to do um they're doing this distributed training which is yeah it's not a mean um it's basically
00:21:39.240 | where you have multiple data centers so like you have one node here one in australia one in europe
00:21:44.600 | one wherever and you put the entire weights on them and then you train batches across there it's still
00:21:51.000 | cool it still requires a lot of you know servers and different locations but it is still like the whole
00:21:57.400 | weights have to be on every gpu on you know every server so that's kind of what this one is it's it's
00:22:04.280 | basically as straightforward as you would expect um you throw all the weights and you do different
00:22:10.280 | batches on different stuff and then you have optimization sinking across right so um there's
00:22:19.240 | there's different types there's you can add gradient accumulation you can bucket gradients they talk about
00:22:24.920 | like efficiencies to do this which i think is not what we should cover in an hour you know if you're at
00:22:30.360 | a stage where you're doing this they kind of go over what are the main ways to do it but um yeah
00:22:35.400 | that's kind of the two so we've gone from single gpu training to now we have data parallel where
00:22:40.840 | let's say i have a node of four h100s i want to train a llama 70b a llama 7b what i can do is i can
00:22:47.320 | throw llama 7b on each gpu split my one trillion tokens so let's say like a hundred million tokens split
00:22:54.760 | that into batches do each gpu trains a batch and then kind of efficiently chunk them together so
00:23:01.880 | this is kind of a summary of what's happened so far so if you're at this stage here's what you should
00:23:08.920 | do first determine the best global batch size by either running experiments or just consume uh you know
00:23:16.680 | consult their literature read their blog post to figure out how you should do this then select the sequence
00:23:22.040 | length this kind of takes a lot of compute what's done in modern models is you train on a small
00:23:28.360 | sequence length for like 90 of the tokens and then the last 10 you kind of increase so if i'm not
00:23:33.800 | mistaken llama did like 14 trillion tokens at 4 000 sequence length and then the last trillion they did
00:23:40.360 | at like 100k then they you know update it's very common because it takes a lot of compute and this kind of
00:23:45.560 | works to generalize as well but two to eight thousand token length works well for the evaluations we have
00:23:51.720 | today i think this is something that people don't really look into enough right um sure you could do
00:23:57.880 | two to eight thousand tokens because that works for mmlu but if you're training something like a llm judge
00:24:03.960 | classifier you probably don't always need 8 000 tokens right so be smart and train on less tokens if you
00:24:09.400 | need less tokens um then we know the bat size we can find the minimum local bat size for a single gpu
00:24:17.720 | and then kind of increase that till you run out of memory then determine the number of gps use you
00:24:23.160 | have for data parallelism and you kind of throw your stuff across them so cool then you have gradient
00:24:30.440 | accumulation if you need it but that's kind of basic data um parallelism now we've kind of got this whole
00:24:38.760 | series of how do we improve upon data problem we've got um 0 1 0 2 0 3 um so what is 0 0 is redundant
00:24:51.000 | actually let me also check chat real quick if there's anything there's map reduce to kind of do this
00:24:57.240 | stink uh they have fancy versions of map reduce as well numbers are different for reasoning models
00:25:02.920 | are when it's putting out yeah reasoning models are a little different but that's kind of rl rl is
00:25:08.440 | at some level similar right you're still just training a transformer um what's happening in that
00:25:14.360 | transformer is somewhat irrelevant okay so zero what's going on with zero um so data parallelism is an
00:25:22.920 | efficient way to scale training the naive replication of optimizer states gradients and parameters across dp
00:25:30.920 | introduces significant memory redundancy right so what's happening here is as you in regular data
00:25:38.600 | parallelism as you split the entire model you also have to at every single gpu do all the optimizer states
00:25:47.320 | all the gradients and you know all the parameters on every gpu and that's kind of redundancy right so
00:25:53.560 | what zero one two and three look into is how can we shard more than just the weights so can we sort of
00:26:01.240 | shard this optimizer state can we start shard the gradient partitioning so and then there's uh zero three
00:26:08.120 | which is kind of interesting where you start to actually shard the parameters themselves so uh zero
00:26:14.760 | one zero two are more efficient ways of data sharding basically instead of just model like instead of
00:26:21.720 | everything on every gpu can we also shard this optimizer state a little bit and then zero two is
00:26:27.480 | can we also shard this gradient a little bit okay back into what's happening in memory so um
00:26:35.320 | in mixed precision training we have parameters we have gradients we have optimizer states which
00:26:41.000 | are stored in higher precision and then gradients in higher precision if we want gradients in higher
00:26:48.200 | position uh precision so zero one is partitioning the optimizer state so in regular data parallelism
00:26:56.440 | all ranks gather uh gradients after the backwards pass simultaneously and perform identical optimization steps
00:27:04.840 | that's a lot of duplicated of work we can avoid it using redundant memory stores at the same pad so
00:27:10.600 | they have this all gather which is kind of um community it's kind of yeah let's shard this um
00:27:17.800 | optimization state so forward passes happen at the same time backward passes at the same time
00:27:24.680 | perform a reduced scatter on gradients each replica performs an optimization step on its local optimization
00:27:31.480 | steps then you perform this all gather to kind of send these missing replications pldr um they're now just
00:27:40.920 | sharding optimization uh how they do it they go into more details here now zero two is kind of the next
00:27:48.920 | step we have a reduced shatter optimization where now not only does each replica have to shard optimization but
00:27:57.240 | they're also calculating gradients right so why don't we shard this gradient optimization state so during
00:28:04.120 | the backward instead of performing in all reduced we perform a reduced scatter reduced scatter spreads
00:28:10.040 | gradients needed in memory and it's more memory saving again um that's that's cool they have a chart
00:28:17.400 | at the end of this that kind of shows how all of this works um then we have zero three zero three is the fun one
00:28:23.240 | a fun one so if you can or if you need to if you're training something like let's say you're doing uh
00:28:30.680 | llama 70b or something post training and you want to crazily increase your sequence length right so you
00:28:37.800 | want to go up to a hundred thousand sequence length well now you're not fitting in even one node this is
00:28:43.800 | where um fsdp starts coming in where you have to shard your parameters themselves so instead of each gpu
00:28:51.160 | having every single all the model weights now you're sharding the model weights across different gpus
00:28:58.040 | sounds pretty interesting right well in zero three you still need good communication between the gpus
00:29:04.200 | there's kind of technical ways that they do this how they do these forward passes these backward passes
00:29:09.480 | later on we'll learn about sharding across different dimensions the way this works is basically
00:29:17.320 | training and all this attention and all of like math all all of this stuff is just matmos right
00:29:24.360 | it's matrix multiplication so it's row times column there's smart ways to do that and manipulate row times
00:29:31.400 | column and yeah they they kind of just make it really easy to do it all this explained what's going
00:29:37.720 | on so if interested read it if you really want to use it pretty much any training library whether that's
00:29:43.720 | pytorch or whatever you're using uh you can you know in a few lines of code just implement this and it
00:29:49.720 | handles all of this for you on the back end now um yeah that's that's kind of the three stages uh
00:29:58.360 | let's let's kind of take a break here and see if there's any questions
00:30:02.280 | what's the entity okay for weights activations gradients
00:30:08.280 | and optimizer states what's the it's intuition for which ones we want in lower and higher precision
00:30:15.320 | so um for weights you can it kind of also depends on what you're doing if you're doing training or
00:30:21.320 | inference right um what you're really doing here isn't higher or lower precision there's there's mixed
00:30:27.240 | precision for some stuff like gradients and activations you do them in higher precision
00:30:33.560 | there's more and more stuff coming out about training and lower precision and mixed precision
00:30:39.640 | but the general um you know general rule of thumb is you want to use as much high precision as you can
00:30:47.880 | afford like half precision is cool for inference it's different you can do a lot of quantization but for stuff like um
00:30:55.800 | gradients or stuff that has not a lot of compute uh not a lot of memory overhead you want to use higher
00:31:02.280 | precision right so gradients are small they don't need a lot of memory do them in high precision model
00:31:08.360 | weights uh let's try to lower the precision a little bit but yeah okay now we have um tensor parallelism
00:31:17.160 | tensor parallelism is a way of more splitting it's kind of splitting up model weights across their hidden
00:31:24.040 | dimension so uh data parallelism was can we shard the weights this is a way of splitting it up across um
00:31:33.640 | sensors so you know here's the math of what's happening you're basically doing
00:31:40.120 | math moles rows and columns here's a smart way to do it here's a smart way to reduce it here's what
00:31:47.080 | we're doing in a transformer block um a lot of this is kind of the background behind stuff like flash
00:31:53.080 | attention group query attention uh this is used in papers like llama 3 and whatnot but high level this
00:32:00.840 | is just kind of another way of charting your um your model it's done column wise row wise um there's
00:32:10.520 | there's two different ways to do it there's different communication overhead in both so whether you're
00:32:16.280 | doing column wise or row wise um there's there's different benefits to both you should look into
00:32:23.320 | them um tensor parallelism and transformer layers is another interesting one so help you make progress
00:32:30.600 | yeah that's there i think for time's sake we're gonna start going a little little quicker um the the
00:32:37.800 | interesting thing here is kind of this is where as you're all in one node all this stuff kind of works
00:32:43.400 | right you can do different layers on different gpus but the more you do this the more important the
00:32:49.560 | interconnect between gpus is right so if you're on a single node of highly interconnected gpus this stuff
00:32:57.160 | will start to be fine but the minute you're using pcie cards or multiple nodes you're kind of adding too
00:33:03.240 | much bottleneck for any of this to be worth um yeah next we kind of have sequence parallelism i believe
00:33:12.520 | context parallelism context parallelism is or we went through sequence parallelism sequence parallelism
00:33:18.360 | kind of same thing um so instead of tensor parallelism split it across the sequence this is where stuff
00:33:26.200 | like ring attention starts to come in i think a while ago we had a paper club on ring attention so
00:33:32.360 | for context and sequence parallelism if instead you know you can't scale your context length uh you want
00:33:39.480 | to take these 128 000 sequences and you want to split them across gpus you want to let's say have like you
00:33:47.000 | know first 10 000 tokens on gpu one next 10 000 on gpu two and so on um there's kind of this cool little
00:33:55.480 | way that you can do this so you can you know this goes into the math about what's happening but
00:34:01.400 | it's kind of it's kind of so conceptually you think what is attention attention is the relationship
00:34:07.560 | between all tokens right you still need to calculate this so if you have a hundred thousand tokens we need
00:34:13.880 | to know how does each token relate to every other token which you know you would kind of expect that to
00:34:19.480 | have to happen on one gpu you need to look at all tokens that wants to do attention instead you can kind of
00:34:25.560 | shard them and do this ring attention where you you have each gpu do part of the attention mechanism
00:34:31.880 | you kind of sync them you send qks between each other then you do this big query uh value calculation
00:34:38.040 | and now you've kind of done attention across gpus um there's there's ring attention there's zigzag
00:34:44.280 | attention there's little optimizations for this stuff but it's it's pretty cool it was a lot of like
00:34:49.960 | how do we get this infinite context how can we do um you know long context this is kind of what happened
00:34:56.360 | here um very very interesting stuff i think it's a little underrated so basically reading a little
00:35:03.160 | about it is pretty cool context parallelism and sequence parallelism okay next is uh pipeline parallelism
00:35:11.480 | pipeline parallelism is a pretty fun one so this is where you distribute consecutive layers so
00:35:18.600 | instead of just sharding gradients or optimizations or sequences how about we start uh sharding different
00:35:25.160 | layers so let's have like you know layer one two three on this gpu and different layers on different
00:35:31.480 | gpus this once again it's like another approach to fitting really large models it it lets you reduce gpu
00:35:38.200 | overhead for everything but once again like now they go into as you do this how do you continue training
00:35:45.480 | right so they have these like all forward all backward so as you do your forward pass you need
00:35:52.040 | to do a backward pass compute gradients to optimization so once you split up layers they've kind of got
00:35:58.600 | here's what's happening as we want to do optimization right so all forward all backward um there's there's a
00:36:05.320 | few other ways of this one forward one backward this is what llama three did um very interesting little
00:36:12.120 | optimizations to make this stuff work um interleaving stages so if you've got different slices there's
00:36:21.000 | just different recipes for all this zero bubble dual pipe um yeah expert parallelism is kind of a more
00:36:29.560 | straightforward one where you have different experts in a mixture of experts model on different gpus
00:36:35.800 | that one's pretty straightforward right so you kind of have a base router you have every gpu have its
00:36:41.720 | own expert in some cases you can have multiple gpu multiple experts on multiple gpus and you kind of
00:36:48.440 | have this efficient expert parallelism so once again now there's this kind of overview summary so
00:36:55.880 | what's data parallelism it's along the batch dimension tensor parallelism is along the hidden
00:37:00.920 | dimension so different layers sequence and context parallelism are if you need to extend context length and
00:37:07.800 | you want to let's say train your model up to 500 000 token context how can we do that on a limited gpu
00:37:15.080 | pipeline parallelism is along model layers and then expert parallelism is around mixture of experts
00:37:21.880 | um then we've also got we can combine these with the zero stages right so zero one is sharding optimizers
00:37:28.920 | zero two is sharding optimizers and gradients zero three is sharding um parameters alongside the other two
00:37:35.880 | so this is kind of the high level if you have multiple gpus here are the things that you can do for
00:37:43.480 | each one they kind of have you know here's how you double click into what you want to do so let's say you're
00:37:49.240 | doing pipeline parallelism and you're sharding everything and you then want to extend your
00:37:54.920 | sequence length right so let's go back into that section um you can come in and find out what are
00:38:00.760 | what are different interleaving stages how do you want to optimize all this um what they go into here is
00:38:07.080 | kind of still all at the you're not really doing research here you're kind of still at the applied
00:38:12.760 | researcher stage right all this stuff is still available in libraries like pytorch so
00:38:17.720 | there are cuda kernels that optimize this stuff like flash attention is made that already uses all of this
00:38:24.280 | these are things that you can just kind of consume as you're doing your training if you're pushing the
00:38:30.280 | bound like um xai now has a 200 800 cluster right or gpt 4.5 was a very chunky model that's where they have to
00:38:39.480 | start to push on these ideas and develop even more novel um training strategies but yeah that's kind of
00:38:49.160 | their high level um what's going on in different parallelization strategies after that we've kind of
00:38:57.400 | got here's a recipe for finding the best um training config right so if you kind of skip everything
00:39:05.080 | the original intro is kind of important right how do we train on one gpu how do we train on two how do
00:39:10.200 | we train on a single node then the middle is kind of okay now your gpu rich you're probably smarter than
00:39:15.400 | me how do you train across nodes um some of the reasons for some of these as well by the way that we
00:39:20.680 | skipped over was if you have multiple nodes and you don't have interconnect some of these different
00:39:26.280 | parallelization strategies will work better they they better optimize for not having that interconnect
00:39:32.920 | between all nodes um so you know there's recipes for that as well but okay high level let's let's sum it
00:39:40.840 | all up here again so if you want to find your best training config how do you do it so step one uh
00:39:47.240 | fitting a training step into memory right let's figure out how we can fit a full model instance
00:39:53.240 | on the gpu there's several use cases whether you're gpu rich or your gpu poor if your gpu poor um you know
00:40:01.320 | there's full activation re-computation there's gradient accumulation there's stuff like that you can train
00:40:06.840 | slower um basically you want to increase your gradient accumulation steps to processor larger brass sizes
00:40:13.800 | and you kind of do what you do for gpu rich stuff if you have models under 10b use single parallelization
00:40:21.400 | stuff so tensor parallel with um zero three across eight gpus this is kind of where you shard those
00:40:28.360 | model weights across eight gpus for stuff between 10 and 100 billion parameters you'll need more than
00:40:34.040 | eight gpus right they keep saying eight gpus because this is basically one node if you have stuff under 10b and
00:40:40.360 | one node use fsdp with zero three if you have 10 to 100b you're going to need multiple nodes so this is
00:40:47.960 | where you're going to start to do tensor parallelism and pipeline parallelism do tensor parallelism with data
00:40:53.240 | parallelism uh only use zero three then when you're rich with with 512 or a thousand gpus um you know you
00:41:02.920 | should not be reading this but there's there's more stuff that you can do here um tpu poor we talked
00:41:08.840 | about it basically gradient accumulation find out what batch size can you support split batches do micro
00:41:16.120 | batches but don't overdo it um yeah step two achieve the target batch size so depending on where step one
00:41:23.400 | left us uh in terms of micro batches and our data parallelism our current batch size might be too small or too big
00:41:29.640 | right so what you want to do is kind of scale up data parallelization or accumulation step so you know
00:41:35.480 | keep doing gradient accumulation until you're no longer running out of memory and if you want to
00:41:41.480 | start to do long context add in context parallelism so this kind of ring attention stuff to decrease our
00:41:47.880 | global batch size we can reduce data parallelism and we can reduce context parallelism if you have more gpus
00:41:54.920 | okay step three optimizing training throughput set up tensor parallelism using interconnect if you have it
00:42:03.080 | then there's uh increased data parallelism so you know if you have more batches if you have more gpus
00:42:10.760 | yeah use data parallelism and throw the whole weight on them keep the target batch size as big as you can
00:42:17.720 | um try scaling up different parallelism parallelisms one by one and just keep experimenting uh benchmarking
00:42:27.400 | they ran a lot of benchmarks i don't know if anyone found anything interesting in any of these that we
00:42:31.960 | want to dive into but i want to leave 10 minutes for questions so um i'm going to skip through a good
00:42:37.560 | bit of this uh diving into gpus fusing threading so what is gpu how does gpu work you know there's
00:42:45.080 | high bandwidth memory there's interconnect there's global stores there's caching there's kv caching
00:42:51.080 | there's cuda kernels and stuff that optimizes all this um i don't think this is as relevant with the high
00:42:59.080 | level of if you want to train models across tpus what do you do so if you're interested read it it's
00:43:06.360 | pretty useful stuff but at a high level i think it's a little too niche for our um here's kind of
00:43:14.360 | how mapmoles happened let's see i think there's a conclusion at the end of this flash attention
00:43:20.600 | is um you know cuda optimizations for all this ring attention is one mixed precision training is another
00:43:27.080 | interesting one um mixed precision training is one thing but mixed precision inference is another we have a
00:43:33.800 | lot more on uh mixed precision inference but for training um they kind of go into accumulation what
00:43:42.360 | to have different precision for different training stuff fpa pre-training is kind of interesting some
00:43:48.040 | people are doing it we want to see if it works um recent research including fpa lm deep seek v3
00:43:55.480 | which is an interesting one um it shows potential that's kind of huge so the the big thing with
00:44:02.040 | precision is basically as you cut your precision in half you also cut the vram needed to train in half so
00:44:09.480 | full precision versus half precision for 600b model means you know i can now cut off 600 billion 600 gigs of
00:44:16.760 | vram so pretty big stuff with precision okay conclusion uh we know we now know how to train
00:44:24.280 | we now know how stuff was efficiently trained like gamma 405b and v3 on multiple gpus uh once again they
00:44:31.000 | kind of have their their cheat sheets if you want cheat sheets i think it's good to go over this again um
00:44:38.200 | there's there's there's a lot of stuff about fpa versus intake and how they're not the same so
00:44:43.480 | quantization yeah um you should look a little bit deeper into stuff like that um llama's 40 parallelism
00:44:53.320 | what else um a lot of papers a lot of references um yeah pretty pretty solid interesting little
00:45:02.040 | high level overview of how to train it i think uh as a quick recap we should just remember that they
00:45:08.760 | have this you know here's how to fit stuff into memory pick your model size pick your gpus
00:45:14.760 | pick your batch size and optimize your training throughput until stuff doesn't crash here are the
00:45:21.000 | different parallelization strategies we talked about right so data parallel everything on one
00:45:25.400 | zero one zero two zero three is where you start sharding the weights tensor parallel across hidden
00:45:31.480 | dimensions pipeline parallel across model layers context parallel for sequence length expert parallel for
00:45:39.560 | mixture of experts um pretty nice overview i think you know for the ai engineer crowd even if you don't
00:45:46.920 | have five hundred or a thousand gpus just kind of familiarizing yourself with what's going on what
00:45:53.800 | training like what parts of model and pre-training takes how much gpu um you know like yeah it's not
00:46:00.920 | all attention there's these speed forward layers there's these active there's these activation states
00:46:05.880 | there's these gradients we have to calculate how those actually take up more than half of your
00:46:10.040 | training in front uh training memory all that stuff is pretty interesting to know um then the one thing
00:46:16.600 | that this blog i might have missed or it didn't do for me was it didn't mention that you know when you
00:46:22.840 | actually do inference and you consume this stuff um you're actually not accumulating a lot of these
00:46:28.920 | gradients let me find this little diagram again so for example if we're doing let's say um
00:46:35.240 | um llama oh that's cool this doesn't want to work
00:46:41.400 | let me refresh so if we're doing
00:46:49.240 | llama
00:46:56.040 | where'd it go there it is if we're doing llama 3 8b at a sequence length of 60 of
00:47:03.960 | a thousand um all these parameters gradients that are taking a hundred gigs of memory this optimization
00:47:15.800 | these optimizers you know 70 gigs of vram we don't do this during inference we don't have to do all this
00:47:20.680 | back prop so the attention is you know some of it but all this other stuff for training is actually a
00:47:27.080 | lot less and we can quantize better so it didn't it didn't hit on that as much but yeah high level
00:47:32.520 | good read to kind of throw this thing in chat gpt or cloud or whatever your lm of choices
00:47:40.680 | stay you're reading this section have a breakdown each section for you spend an hour on it it's useful to
00:47:46.120 | know for the ai engineer um kind of understanding the difference between training and inference is
00:47:51.800 | always useful and i think um you know a good rule of thumb is like if you're hiring a research intern
00:47:59.320 | would they be able to explain all of this to you hopefully pros and cons trade-offs different multi-node
00:48:07.800 | single node single gpu um if they could that's great and then you should be able to do it yourself
00:48:14.520 | too right it's just good general knowledge so pretend you're in an interview be able to learn all this
00:48:18.760 | stuff it's good background knowledge but yeah um that's kind of high level overview of this big big
00:48:25.640 | pre-training blog um questions thoughts comments we have five minutes i'm sure i skipped out a lot of
00:48:32.440 | stuff i'm sure people have a lot to add in uh out of curiosity what precision or state of the art model
00:48:39.720 | is trained in mostly uh higher precision for gradients so there there is mixed precision as of recently
00:48:46.600 | um you know we're cutting that down i think it said that v3 was trained in fpa which is kind of interesting
00:48:54.840 | but there's still a lot a lot of precision um yeah that's that's paper thoughts comments questions
00:49:08.120 | who wants to train 70b from scratch knowing all this are we prepared to blow thousands of dollars an hour
00:49:22.120 | i think it's a good experiment to run like uh realistically h100s you can rent at two dollars
00:49:30.760 | an hour right so rent an h100 train a 7b rent two h100s on different like you know non um non good
00:49:40.040 | interconnect try some of this stuff because it's very approachable from like common libraries to do this
00:49:45.560 | training then rent a node then rent two nodes and like realistically you only have like a few different
00:49:51.560 | parallelization strategies right um a precursor to this blog from hugging face is fine web it goes
00:49:57.800 | over like i think five trillion maybe 15 trillion pre-training tokens so yeah go go learn um very
00:50:04.200 | useful to learn and just try all this stuff and see what's different and burn burn not too much money
00:50:09.480 | right two dollars an hour for h100s eight h100s twenty cents an hour like we can afford this
00:50:21.000 | um papers that can help me do this for non-english languages um the presenter from last time works on
00:50:29.560 | a state-of-the-art german model not german some some multilingual model so go to paper club someone is
00:50:36.760 | apparently an expert in non non um english models cheapest h100 provider on the internet crypto
00:50:43.240 | bullshit the crypto companies are raising a lot of money and subsidizing gpus to be very very cheap deep
00:50:49.640 | info has them for two dollars model will give you money a lot of people will also sponsor compute but
00:50:55.640 | also um hyperbolic is a good one shout out hyperbolic i think uh it's also pretty cheap right like two
00:51:01.720 | dollars an hour three dollars four dollars sf compute is not um serverless i believe sf compute is bigger
00:51:09.640 | cluster at more time i like the guys but um as run pod and stuff adds 32 h100 64 i don't know if we need
00:51:19.000 | but yeah cool um what other questions questions are on where to get gpus
00:51:28.440 | pay money that's not that expensive very subsidized right now like to a point where somewhat more
00:51:35.000 | subsidized than um
00:51:36.360 | than compute people don't like hyperbolic apparently so never mind no hyperbolic we don't like them i guess
00:51:45.960 | um but they are slightly cheaper than than the others but we don't like them i guess okay volunteers for
00:51:51.400 | next week um who wants to cover paper next week there is a cool pre-training uh post-training survey
00:51:59.240 | uh this is uh i don't know what is um i don't know if it's good uh qw32b came out i don't know if it's
00:52:05.720 | thinking the same name oh is is the is the paper good i haven't looked i don't think there's a paper if
00:52:11.240 | the paper comes out that would be cool i don't know no paper at all blog post which was shitty blog post
00:52:17.720 | 4496 word blog post you know thought they had a very sick title on this uh embracing the power of rl
00:52:27.720 | 700 word blog post that has nothing but here's code on how to use it okay um llm post training there's a
00:52:38.760 | survey paper on post training very cool little chart here um i think it's a good one if someone wants to
00:52:46.840 | cover it okay i'm gonna volunteer this one to someone if someone's down otherwise um yeah if
00:53:03.000 | there's no more questions oh actually this is pretty short so uh 20 page on post training very useful
00:53:10.040 | graph of thoughts wait these do cover mcts
00:53:14.520 | why not mcts is great dude you know how much of a time sink this was you know how much time was wasted
00:53:22.600 | on mcts who remembers strawberry who remembers multi on saying that they are q star strawberry mcts
00:53:30.280 | waste of time but you know pretty good and i think it's good background right so mcts is kind of
00:53:39.960 | tree search throughout potential next tokens can we span out and do um inference time scaling test time
00:53:48.040 | compute instead of rl a bunch of people wasted a bunch of time on this uh some companies said that
00:53:54.200 | they are q star by by doing this but yeah okay another um another paper that someone has recommended
00:54:03.400 | that we don't know if it is credible the fft strikes back an efficient alternative to self-attention
00:54:10.200 | from one guy at usc is he plugging his own paper
00:54:14.360 | i don't know uh alternatives to attention are always interesting but okay um that's that's paper
00:54:23.960 | club this week guys super nicely done thank you so much you covered two days of reading in one hour
00:54:31.000 | always need to cover quick
00:54:34.040 | in general yeah i i feel like there needs to be more good papers like i i feel like we uh paper
00:54:42.680 | i feel like paper velocity has slowed somehow yeah i mean this isn't even a paper on anything of their own
00:54:51.800 | this is just like uh okay this counts this counts okay very very good survey okay well that's paper club
00:55:00.280 | guys all right take care thank you