back to indexStanford CS25: V1 I Mixture of Experts (MoE) paradigm and the Switch Transformer
Chapters
0:0
0:7 Scaling Transformers through Sparsity
0:25 Overall Motivation
1:0 Scaling Laws for Neural Language Models
5:5 Switch Transformer
7:23 Improved Training Methodology
8:27 Differentiable Load Balancing
8:55 Selected Precision
10:25 The Initialization Scale
16:23 Multi-Stage Routing Procedure
20:35 What Is the Research Question
29:50 Perplexity versus Strength Time
30:53 Spot Scaling Laws
37:10 Data Parallelism
37:36 Model Parallelism
38:51 Expert and Data Parallelism
39:31 Model Partitioning
40:37 Mesh Abstraction
46:18 Fine-Tuning Properties of Sparse Models
47:8 Multilingual Training
47:45 Distillation
00:00:00.000 |
Today, Erwin and I are going to be giving a talk on scaling transformers through sparsity. 00:00:10.460 |
And the kind of sparsity we're going to be talking about today is the kind where each 00:00:13.820 |
input can get either a different set of weights or have a different amount of computation 00:00:22.960 |
So, I guess the overall motivation for this line of work is that the community has realized 00:00:32.360 |
that scale is perhaps one of the most important axes to focus on for obtaining strong performance. 00:00:40.240 |
And there's almost like this ongoing arms race right now with different labs and different 00:00:45.720 |
institutions competing for training the largest models. 00:00:53.440 |
And so, maybe this dates back from early 2020 with a paper from OpenAI called Scaling Laws 00:01:01.360 |
for Neural Language Models, where they find that model performance follows a predictable 00:01:08.600 |
power law, scale as a power law with model size in terms of either compute or just parameters. 00:01:22.080 |
And so, this scaling law generalizes over multiple orders of magnitude, and that gives 00:01:28.620 |
us the confidence that if we are to train very large models, we can expect a certain 00:01:35.060 |
performance just by extrapolating these scaling laws. 00:01:40.120 |
So, in that paper, they also find the interesting observation that basically larger models are 00:01:52.200 |
And so, if you have a fixed compute budget, you can predict what is the size, what is 00:02:03.120 |
the optimal model size for a fixed compute budget. 00:02:06.440 |
And the overall observation is that you'd rather train very large models for less tests 00:02:15.640 |
than train smaller models for more training steps. 00:02:19.680 |
And so, these models are scaled through basically the paper focuses on dense models, where you 00:02:28.640 |
just increase the model dimensions, but they're not looking at sparsity. 00:02:34.480 |
And so, sparsity is a new dimension that you can use to scale architectures, and this is 00:02:45.240 |
And so, the sparsity we're mentioning here is basically you will have sparsely activated 00:02:56.560 |
So, every input will go to a roughly similar amount of computation, but will be applied 00:03:05.840 |
And so, this dates back to 1991 with a paper called Adaptive Mixtures of Local Experts, 00:03:14.040 |
and was recently revisited by Noam Shazier and colleagues at Google Brain with LSTMs, 00:03:21.720 |
where they replaced sort of the feed-forward networks in LSTMs with a mixture of experts. 00:03:30.800 |
And so, the way this works there roughly is that you will have multiple experts each implementing 00:03:37.880 |
a small network, or in that case, I think just a dense matrix multiplication. 00:03:46.040 |
And so, you have an additional gating network shown in green here that outputs a probability 00:03:54.160 |
distribution over experts that each token should be sent to. 00:04:00.600 |
So, this probability distribution is computed as a softmax, and once you have it, you select 00:04:11.080 |
So, there are different strategies, maybe we'll talk about it later on. 00:04:16.120 |
And the output is simply sort of the weighted mixture of all selected experts' outputs. 00:04:22.440 |
So, they've been pretty successful primarily in translation, but there were some complexities 00:04:41.880 |
And so, the Switch Transformer paper addresses some of those, and we'll be discussing how 00:04:49.880 |
to fix training instabilities or reduce communication costs and reduce model complexity. 00:05:02.920 |
So, one kind of approach that we're going to have for sparsity is the Switch Transformer, 00:05:07.520 |
which is kind of like a simplified mixture of expert variant along with some other improved 00:05:13.080 |
training and fine-tuning techniques that allow it to be stably trained and also perform better 00:05:19.560 |
when fine-tuned on a lot of downstream tasks. 00:05:22.880 |
And so, yeah, so the Switch Transformer kind of model works as the following. 00:05:27.320 |
So, you have some transformer model that has self-attention and feed-forward layers. 00:05:32.880 |
And the idea is that we replace maybe one every two or one every four feed-forward layers 00:05:39.960 |
So, you can see on the left is like one kind of layer block, which is self-attention, then 00:05:46.360 |
add normalize, then a feed-forward layer, then add normalize. 00:05:49.400 |
And in this case, we're replacing the normal feed-forward layer with the Switch layer. 00:05:53.920 |
And we can see an illustration of this on the right. 00:05:57.000 |
So, on the right, we can see that the layer has two inputs. 00:06:01.320 |
One is the token more, the other is the token parameters. 00:06:04.480 |
And we can see that these embedding representations will get sent to a router, which is exactly 00:06:10.960 |
So, the router is basically just going to be getting a distribution over all of the 00:06:16.240 |
So, in this case, we can see that the highest probability is going to the expert number 00:06:23.080 |
And then the right token is actually having the most probability on the first feed-forward 00:06:30.040 |
So, yeah, we can see here that what we're going to do is in the Switch Transformer, 00:06:34.760 |
It's just send it to the highest probability expert. 00:06:38.040 |
And so, here we can see where the adaptive computation lies, where we'll have four sets 00:06:43.480 |
There's some shared weights and computation across all the tokens. 00:06:46.560 |
For example, the self-attention layer is computed exactly the same for the more token and for 00:06:52.920 |
But in the sparse Switch layer, we can see that actually the inputs are, while having 00:06:56.720 |
the same amount of floating point operations applied to them, actually have different weight 00:07:06.160 |
So, that's the kind of high-level idea with Switch Transformer, is that instead of sending 00:07:10.920 |
a token to multiple different experts, which can also increase the communication costs, 00:07:14.940 |
as I'll go into a little bit later, it also just significantly simplifies the algorithm 00:07:22.120 |
So, for the improved training methodology, we focused on three different things to help 00:07:29.560 |
The first was selected precision, which allows these sparse models to be trained in lower 00:07:33.400 |
precision formats, which is incredibly important. 00:07:36.560 |
Most of the models we train, we really don't want to be using float 32, because it's just 00:07:41.640 |
And also, when you're communicating tensors across different processes and stuff, it's 00:07:45.680 |
twice as slow, just because there's twice as many things. 00:07:49.120 |
Also, we have some initialization tricks and some training tricks as well for allowing 00:07:53.160 |
them to be trained more stably, especially as the models grow in size, which is like 00:07:57.200 |
a new initialization method, along with a change to the learning rate schedule. 00:08:02.800 |
And third, since that our models have so many more parameters, we do notice definitely different 00:08:07.260 |
overfitting dynamics, especially once we fine tune these models that have been pre-trained 00:08:12.160 |
on all of the internet on these small tasks with maybe only 50 to 100,000 examples, that 00:08:19.080 |
So we also look at some custom regularization to help prevent some of the overfitting that 00:08:26.400 |
And finally, we also talk about this differentiable load balancing technique we make, which kind 00:08:31.520 |
of allows each expert to roughly get the same amount of tokens. 00:08:36.200 |
Because this is very important, especially given that we want the stuff to be efficient 00:08:40.480 |
We want roughly each expert to have similar amounts of tokens sent to it. 00:08:44.080 |
And so to kind of encourage this, we tack on an additional load balancing loss along 00:08:49.040 |
with our cross-entropy loss that we're training with. 00:08:54.580 |
So here, I'm going to go into selected precision. 00:08:56.520 |
So yeah, again, so when we're training large models, it's really important that we should 00:09:00.080 |
be able to train them in lower precision formats. 00:09:02.080 |
So instead of each weight being an activation, being 32 bits, we want to shrink it down to 00:09:11.220 |
And what we found out of the gate is that these models are just unstable, especially 00:09:16.240 |
the sparse models are much more unstable than the dense models in terms of you'll train 00:09:19.700 |
it for 10,000, 20,000 steps, and then the losses would just diverge. 00:09:23.040 |
This was something that we frequently encountered. 00:09:25.620 |
And so one key thing that we found is that basically, you need to be casting a part of 00:09:30.740 |
the computation in float32 for these models to be able to be trained stably. 00:09:38.180 |
And the key component that we found that you need to cast is the router computation. 00:09:42.840 |
And essentially, we can go into the technical details a little bit more later. 00:09:47.060 |
But basically, any time that there's these exponentiation functions, it's very important 00:09:51.140 |
that we are having higher and higher precision because of round off errors that can then 00:09:56.720 |
drastically change the output of some kind of exponentiation function. 00:10:01.480 |
So for example, if you have an exponentiation function and you change it by 0.1 or 0.2 or 00:10:06.580 |
0.3, this can drastically change the output of exponentiating it, especially depending 00:10:15.940 |
And it basically doesn't change the compute at all and allows the models to just be significantly 00:10:24.060 |
So the second thing we looked at is also the initialization scale. 00:10:27.220 |
So like the standard way that we were initializing these models, we found to also just make the 00:10:31.980 |
models much more prone to being unstable and/or just performing worse. 00:10:36.260 |
So one thing that we did that we found was very effective was to just simply make the 00:10:42.460 |
And when we did this, we found that the quality just drastically improved, and it was like 00:10:51.840 |
And the third thing I mentioned, where since we noticed that these models are much more 00:10:55.300 |
prone to overfitting, since they just have significantly more parameters, is that we 00:10:59.600 |
also use much more dropout for the expert layers only. 00:11:03.100 |
So here we can see we have the T5 base, which is a dense model, and then we have a bunch 00:11:10.120 |
And we found to be the most effective on these four different fine-tuning tasks was just 00:11:13.920 |
to really significantly increase the dropout rate inside the expert layers. 00:11:17.800 |
And we found that this was pretty effective for combating the overfitting. 00:11:32.160 |
It was just in reference to the previous table where you have throughput and precision. 00:11:36.880 |
It just seems surprising to me that you could match this 1390 number using selective precision. 00:11:43.320 |
It seems like I would expect it to be something in between. 00:11:47.960 |
So it essentially comes down to the fact that there's maybe a little bit of noise sampled 00:11:53.440 |
And the only part we're casting is the router, which is maybe such an insignificant portion 00:12:01.600 |
That is essentially like a free operation in the network. 00:12:04.140 |
So whether you cast it to VFLOW16 or FLOW32, it doesn't actually impact the speed at all 00:12:08.720 |
within the precision that we can actually measure the speed. 00:12:12.320 |
And also, these architectures only use sparse layer once, one every four layers. 00:12:19.240 |
And so, yeah, essentially, the FLOW32 part is kind of very negligible in the entire architecture. 00:12:27.320 |
It's like, for example, I think off the top of my head, it's like 1/40th the computation 00:12:32.360 |
that would cost for you to do the first weight matrix multiply in a dense, ReLU dense layer 00:12:40.040 |
And yeah, we're not using them very frequently, like Erwin mentioned as well. 00:12:51.480 |
I won't go into some of the technical details, but yeah, we definitely-- since we're training 00:12:54.960 |
these things on hardware, we really-- I think a big part of the mixture of experts paradigm 00:12:58.660 |
is that these things are designed such that it maps really efficiently to hardware. 00:13:03.680 |
So we want to be doing dense matrix multiplies. 00:13:06.360 |
And for this to work really well, we also want to be able to have roughly equal amount 00:13:10.280 |
of tokens going to each of the different experts. 00:13:13.740 |
And I think this isn't that sensitive to the load balancing formulation. 00:13:19.800 |
But yeah, essentially, you definitely want some kind of load balancing loss added on 00:13:31.960 |
Yeah, so the frameworks, the library we use rely on static shapes for-- OK, yeah, so XLA, 00:13:55.020 |
so the compiler for TensorFlow and MeshTensorFlow expects static shapes for tensors. 00:14:03.220 |
However, the computations in switch transformers are dynamic because of the router, right? 00:14:11.980 |
Different inputs will be routed to different experts. 00:14:15.700 |
And so we need to specify ahead of time how many tokens will be sent to each expert. 00:14:22.580 |
And so we will introduce this expert capacity hyperparameter to specify that. 00:14:29.120 |
And that's going to be a static number which says how many tokens each expert can process. 00:14:37.960 |
And so in practice, we instead parametrize this by having a quantity called the capacity 00:14:47.940 |
So the bottom row is a bunch of tokens on one device. 00:14:56.660 |
And then you need to sort of route those tokens to multiple devices or multiple experts. 00:15:02.880 |
So if too many tokens are routed to a single expert, some tokens will be dropped because, 00:15:14.000 |
So that's the example on the left where the capacity factor is one, and that basically 00:15:17.900 |
means that there's no extra buffer for routing tokens. 00:15:27.180 |
So instead of that, we can use the capacity factor that's larger than one. 00:15:30.780 |
So on the right, you have an example with 1.5. 00:15:34.740 |
So that means that now each expert has three slots that can process three tokens. 00:15:42.580 |
And so that prevents token dropping because we have more capacity. 00:15:46.920 |
But the issue is that this means more expensive communication across devices. 00:16:02.860 |
One thing that we also experimented with was this method called no token left behind. 00:16:09.000 |
So since we have to have a fixed batch size for each expert, and there can be token dropping, 00:16:15.700 |
we're thinking that, hey, yeah, having tokens dropped or having some tokens not having any 00:16:19.860 |
computation applied to it is probably hurting the model performance. 00:16:23.740 |
So what if we do a multistage routing procedure? 00:16:26.060 |
So first, you do the normal routing where it's like you send each token to its highest 00:16:30.900 |
But then any dropped tokens, you then send to their second highest probability expert, 00:16:37.460 |
Or you can basically repeat this process to guarantee that no tokens are being dropped. 00:16:41.540 |
Interestingly, actually, this approach didn't empirically improve model performance. 00:16:48.420 |
And we thought that was actually very interesting. 00:16:51.060 |
And I think the intuition is that, you know, once the model learns it wants to send a token 00:16:54.380 |
to one expert, like it really wants to have that computation applied to it. 00:16:58.060 |
And just applying some other computation doesn't, you know, have at all the same property, along 00:17:03.460 |
with it actually maybe being potentially detrimental. 00:17:06.400 |
So yeah, we thought that was pretty interesting, as we were very optimistic this would potentially, 00:17:10.240 |
you know, get improved performance, but it ended up not really making a difference. 00:17:21.220 |
I think it will actually kind of like address literally the last point that you brought 00:17:26.380 |
I think when I think about like a mixture of experts, usually like they specialize in 00:17:33.140 |
So I think it was like, just like a lot, like I was just wondering, like if you send it 00:17:41.740 |
to like the second best or whatever, like what if like all of your tokens would be particularly 00:17:49.140 |
good for like one expert, and then you only like process, let's say, like 20% of your 00:17:57.000 |
So that ends up being better than rerouting them to anything else. 00:18:04.080 |
So yeah, even if you're dropping a lot of tokens, it's not beneficial to be sending 00:18:06.760 |
them to the second, third or fourth best thing. 00:18:09.440 |
And one actually interesting property that we, you know, noticed about these models is 00:18:12.380 |
they're surprisingly robust to token dropping, especially during fine tuning. 00:18:17.980 |
So in the standard paradigm, what we'll do is we'll pre-train this thing, we'll have 00:18:19.680 |
some load balancing loss, which makes the tokens pretty balanced actually. 00:18:24.480 |
But then during fine tuning, where it's like, we really want to fine tune it on a specific 00:18:28.640 |
We actually studied this exact question and we were studying, does it help to have a load 00:18:34.360 |
And so if you have the load balancing loss, yeah, that kind of is encouraging, you know, 00:18:37.560 |
for the specific task, we want to try to have, you know, all the experts be used versus turning 00:18:42.240 |
Whereas there's definitely some, you know, prior specialization and it's actually much 00:18:47.760 |
And even if it's like, you know, 60 to 70% of the tokens are being dropped, that actually 00:18:51.760 |
performs much better than, you know, having all the tokens balanced. 00:18:55.520 |
But doesn't a load balancing loss encourage basically all the experts to learn very similar 00:19:00.920 |
weights and then just randomly assign your tokens? 00:19:05.320 |
Because then it doesn't matter to which expert stuff is being sent to. 00:19:08.980 |
So when we use the load balancing loss, like the routing mechanism is definitely learned. 00:19:12.120 |
So the model definitely is encouraged to, you know, choose an expert that it wants to 00:19:18.520 |
But like if all the experts learn the same weights, then the router learns basically, 00:19:26.720 |
So if you encourage load balancing, you encourage technically that like you want any loss to 00:19:34.760 |
I mean, that's maybe the extreme behavior if you have a very high sort of load balancing 00:19:39.320 |
loss coefficient, but in practice that coefficient is kind of tuned and we observe that for, 00:19:44.520 |
you know, small enough values, the router still learns like semantic, like meaningful 00:19:53.040 |
Because it's like a balance between this, like, you know, cross entropy loss and this 00:19:57.440 |
And so on one hand, yeah, you definitely want to encourage the model to be balanced. 00:20:00.880 |
Then on the other hand, you also want to just get good empirical performance. 00:20:04.760 |
And yeah, the model is able to definitely like on one hand, learn and specialize the 00:20:08.680 |
experts where they have different weights such that it's like, you know, definitely 00:20:11.640 |
it expects certain tokens to be sent to certain experts, but on the other hand, still be reasonably 00:20:14.960 |
balanced so that the models are efficiently run on like modern hardware. 00:20:24.080 |
So the question that I want to ask is, it seems to me like this is a very experimental 00:20:29.480 |
We're talking about floating point precision. 00:20:30.480 |
We're talking about different approaches and currently work well. 00:20:31.480 |
And whenever we're dealing with clients, there's a question of what is the research question? 00:20:39.040 |
So what are we trying to answer with all these experiments? 00:20:44.000 |
I think the, I think the high level of research question is like, you know, can we, you know, 00:20:48.800 |
create models that are, you know, like doing adaptive computation from the standpoint of 00:20:53.920 |
like, no, can we try to make models more simulate the dynamics that we think models should most 00:20:58.760 |
naturally use, which is different inputs to have different amounts of computation applied, 00:21:03.040 |
have different weights applied to them, you know, and basically all of this, basically 00:21:06.080 |
we're trying to research and like figure out how can we create like a new framework for 00:21:09.560 |
these models to be trained as opposed to their dense counterparts that, you know, for every 00:21:13.440 |
input are always having the same exact computation applied. 00:21:17.040 |
So that's interesting because when you say the same exact computation applied, one might 00:21:21.040 |
imagine that like, to me, the immediate thing is about how long to deliberate about something. 00:21:26.960 |
What I mean by that is if we want to have variable length computation, you could imagine 00:21:31.360 |
that I could have a short amount of computation or it could have much older computation, but 00:21:35.280 |
there's like, you have like, why then do we instead consider the dimension of different 00:21:40.480 |
I mean, assuming of course that these experts do indeed learn different things, which I 00:21:46.320 |
So why do we immediately jump to thinking about specialized experts as opposed to thinking 00:21:52.640 |
So, yeah, so this is actually, we actually go into some variable length computation stuff 00:21:58.240 |
And I feel like they're both actually just important axes that should both be pushed 00:22:02.720 |
I think, I guess, yeah, I guess it's kind of, you know, yeah, I'm not afraid of my question, 00:22:07.800 |
but what I'm trying to understand is you're thinking about why did you decide to attack 00:22:11.760 |
I want to understand why your team chose to go this direction first. 00:22:15.720 |
So I think that one empirically, it seems that sparsity has led to better empirical 00:22:20.880 |
results in the field of deep learning than adaptive computation so far. 00:22:23.960 |
And I think the way that we use these things maps really well to our modern hardware, which 00:22:29.820 |
And I think the way we were kind of looking at it as like sparsity is like a first step 00:22:32.840 |
towards doing more interesting and general adaptive computation where, and we're, and 00:22:37.320 |
you know, cause I think it's like, you know, this stuff is complicated and typically starting 00:22:41.000 |
from something that works well is better than necessarily like, you know, you know, trying 00:22:45.720 |
something that's not necessarily as proven out and then trying to like get it to work 00:22:50.080 |
So I think we're kind of starting from sparsity, which like, you know, Noam Shazier and others 00:22:53.360 |
got to work really well in the context of LSTMs. 00:22:55.600 |
We were kind of interested in, you know, let's port some of this to transformers. 00:22:59.920 |
And then let's slowly start expanding towards a lot of the other natural questions that 00:23:04.080 |
Whereas like, okay, whereas instead of, you know, different weights per core, let's also 00:23:07.640 |
maybe have different computation per core and all of this. 00:23:10.320 |
So that's, I guess how we were kind of building the natural, like, you know, buildup and progression 00:23:23.280 |
I mean, I guess I kind of see adaptive computation and sparsity as, you know, related, but separate 00:23:32.520 |
So, you know, sparsity is more like different parameters for each example and adaptive computation 00:23:36.960 |
might be more different amount of flops and we have some of that with the token dropping, 00:23:42.560 |
but that's kind of, you know, that's not the main motivation. 00:23:49.200 |
Definitely as Barrett mentioned, I would say, you know, no one really has figured out adaptive 00:23:59.200 |
And one reason is because we have these, you know, accelerators, right. 00:24:05.640 |
Expect like sort of, you know, we need to work with like batch, like data parallelism, 00:24:12.600 |
So, and all of our accelerators and our frameworks use this SPMD paradigm where we're kind of 00:24:19.320 |
supposed to apply the same computation to examples. 00:24:25.200 |
And so if you look at the literature, you have, you know, works like universal transformers 00:24:30.880 |
where they replace the feed forward in the transformer by just a recurrent weight. 00:24:36.240 |
And so it's kind of like an LSTM on each token and the LSTM can stop at different times based 00:24:45.440 |
But the way these things are implemented is just through masking because it needs to be 00:24:55.520 |
And so definitely sparsity was kind of like easier to get to work first. 00:24:59.800 |
And also there were some prior results with LSTM, so yeah. 00:25:06.600 |
In terms of like the first question, you know, sort of what's the research question here 00:25:09.880 |
is just like, oh, can we design more efficient models? 00:25:13.400 |
And sparsity is this new axis that hasn't been explored that much. 00:25:17.240 |
And yeah, I think that, you know, I'm happy with just that being the research question. 00:25:35.440 |
So the switch transformer layer selects an expert, like just the top expert, and then 00:25:40.160 |
incorporates a bunch of the general sparse model improvements to, you know, allow it 00:25:44.360 |
to fine tune better, allow it to, you know, be more regularized, allow it to, you know, 00:25:49.840 |
be trained with lower precision formats and a lot of like technical details to just get 00:25:58.120 |
So one thing that we also wanted to do was a comparison between like top one and top 00:26:02.200 |
two routing since top two routing was kind of the, you know, most popular technique. 00:26:07.760 |
And so here we can see we have two different dense models trained of different sizes. 00:26:10.400 |
And we're going to be looking at like the, the pre-training like negative log perplexity. 00:26:20.900 |
So, so, and what we're going to be doing is we're going to be studying them at different 00:26:26.140 |
So a capacity factor of 2.0 basically means that there's enough buffer for two tokens 00:26:32.560 |
And we're going to be comparing like top one versus top two routing and also comparing 00:26:36.280 |
their speeds along with their like time to get some like threshold quality. 00:26:42.560 |
So here we can see in the capacity factor 2.0 case that the MOE models outperform switch 00:26:48.800 |
transformer, which makes a lot of sense, like since switch transformer is only, you know, 00:26:52.800 |
sending like a top one token to each expert, the mixture of expert is sending, you know, 00:26:58.920 |
So that makes sense that this extra buffer will be like disproportionately beneficial 00:27:04.120 |
And so we noticed that and next slide or next now, when we, so the really interesting parts 00:27:11.020 |
for the top one routing becomes when we lower the capacity factors. 00:27:14.940 |
So having a high capacity factor is bad for many reasons. 00:27:17.800 |
One of which is it really incurs more of these, you know, communication costs for sending 00:27:24.280 |
It also incurs more compute costs and also incurs like a lot of memory overhead. 00:27:28.280 |
So if you can get this lower, it's, it's usually like a very, very good thing. 00:27:32.760 |
And so what we see here is that switch transformer actually outperforms mixture of experts when 00:27:40.400 |
And we can see that the time to quality threshold, we you know, yeah, we, we get there much quicker. 00:27:46.500 |
And so even across the 2.0 and the 1.25 capacity factors, like the kind of Pareto optimal thing 00:27:51.460 |
we saw in our setup is to use switch transformer at a lower capacity factor, just due to the 00:27:56.600 |
fact that while the quality is worse, a little bit worse on a step basis, it's just like 00:28:06.220 |
And we can also be seeing that like for capacity factor 1.0, again, we can see that this really 00:28:11.660 |
disproportionately benefits switch transformer and is even better on a Pareto standpoint 00:28:20.060 |
And interestingly, since, you know, MOE also does like a little bit more computation, we 00:28:24.000 |
can also just increase the amount of compute done elsewhere in the model. 00:28:28.100 |
And we can see that that's like a much more efficient allocation of compute. 00:28:31.740 |
So yeah, overall, our takeaway is that, yeah, lower capacity factors using op one routing 00:28:37.240 |
is more Pareto efficient than, you know, using like the top two routing at higher capacity 00:28:48.980 |
So next we'll look at how a switch transformer scales as a function of the number of exports 00:28:59.220 |
And so on the right side here, you see a plot that shows perplexity versus training steps 00:29:06.140 |
for different switch architectures, ranging from T5 base, which is basically no export 00:29:17.940 |
And so you see that as we increase the number of exports, which also increases the number 00:29:21.820 |
of parameters, of sparse parameters, you get sort of speed ups, you know, you get increasing 00:29:32.940 |
And they're like sort of diminishing returns to, you know, multiplying to, you know, increasing 00:29:44.580 |
So the previous figure was looking at perplexity versus training steps. 00:29:50.780 |
Here we look at perplexity versus strength time. 00:29:54.780 |
So that includes, you know, all the, you know, additional communication costs when you have 00:30:00.900 |
more exports or, you know, comparing to the dense baseline. 00:30:07.700 |
And so this is for switch base or T5 base, and we observe up to 7x speedups over T5 base. 00:30:17.580 |
And so, you know, just to maybe contextualize these numbers, like, you know, 7x speedups 00:30:26.620 |
And so I think this is one of the, you know, one of the results that, you know, can spark 00:30:34.300 |
a lot of interest in sparse models, even if it's only for pre-training for now, like just 00:30:39.260 |
having that number is like, you know, maybe there's a significant, there's something significant 00:30:55.660 |
So here we'll look at sort of loss versus sparse model parameters, which are increased 00:31:06.980 |
And so similarly to the sort of, you know, normal scaling law paper, we observed that 00:31:13.380 |
as you increase the parameters, which the sparse parameters and keep the flops fixed, 00:31:23.020 |
you get diminishing, like consistent gains, but diminishing gains. 00:31:27.900 |
Okay, so now we're going to compare export parallelism and model parallelism. 00:31:34.700 |
So we introduced sparsity or export parallelism as a new dimension to scale models. 00:31:42.340 |
But of course, that's the other one for dense model, which is simply model parallelism where, 00:31:48.660 |
you know, model weights are partitioned across cores once they are above the maximum size 00:31:57.900 |
All right, so, yeah, Bharath, I assume to the left is export parallelism here? 00:32:05.980 |
Yeah, so essentially what we're doing is, yeah, we're kind of comparing a switch-based 00:32:13.660 |
And we're also comparing against a larger dense model that has used model parallelism. 00:32:18.820 |
And we can see that, you know, because basically when we want to scale a model size, we kind 00:32:22.380 |
of have two axes that we can either go through. 00:32:24.340 |
We can either increase the number of flops by scaling through model parallelism or increase 00:32:28.820 |
the number of parameters by scaling through sparsity. 00:32:31.620 |
And so we can see that, you know, even compared to like, you know, a dense model that's been 00:32:35.500 |
scaled up through model parallelism, that sparsity is still at the scale, a more effective 00:32:39.020 |
way to scale up the model by, you know, still getting 2.5x speedups over this larger dense 00:32:50.060 |
Yeah, basically here, T5 large is the dense model that uses model parallelism. 00:32:59.860 |
Yeah, and so one thing that we also wanted to look at is like, you know, are these expert 00:33:05.740 |
models effective if you have like, you know, really small amount of computer, just a small 00:33:10.220 |
So typically when we're designing these models, like we have one expert per core. 00:33:14.580 |
But if you don't have like a large cluster to run these things on, let's say you just 00:33:17.140 |
have like a GPU with two cores or something, I guess having two experts more effective 00:33:25.220 |
So we can see even pretty good scaling properties, even with like a tiny amount of experts, which 00:33:29.420 |
is very, very promising for these models to be used even in like much lower compute regimes. 00:33:42.060 |
And so look at, you know, what things look like when we use different types of parallelism, 00:33:51.140 |
namely expert parallelism to add experts, model parallelism to shard model weights across 00:33:57.060 |
cores and also data parallelism, which is sort of the dominant paradigm in deep learning 00:34:06.300 |
And so, you know, I guess, you know, in the previous slides, we mostly talked about expert 00:34:12.100 |
parallelism, but of course, you know, dense models and large scale dense models use model 00:34:18.220 |
So GP3 and these other large models, what they do is that they will simply shard model 00:34:31.740 |
I just wanted to know, because I think there was like, I don't know if you're going to 00:34:35.420 |
address later, but I think somewhere in a paper, it said that the more experts you have, 00:34:43.380 |
And I was just like hoping, hoping that you could give us some intuition about that, because 00:34:47.700 |
I don't understand why that would be the case. 00:34:52.260 |
So I guess, yeah, maybe, so I guess like, you know, there's all of this work on larger 00:34:59.700 |
models are more sample efficient and larger in the context of the scaling law works means 00:35:06.980 |
As you increase the number of experts, there's more parameters, but not more flops. 00:35:09.900 |
But the model is still like, you know, larger and like, you know, a similar sense. 00:35:14.260 |
So I guess like building on the intuition that larger models are more sample efficient 00:35:18.100 |
in my mind, it's not necessarily that surprising that these models with more experts that have 00:35:27.500 |
I guess that's my like kind of high level intuition for it. 00:35:30.100 |
Yeah, I would say that's kind of expected that, you know, more experts leads to better 00:35:37.480 |
sample efficiency, especially if you look at training step, right, in a training time. 00:35:53.700 |
Yeah, so, yeah, so, okay, so we'll look at how model weights are split over cost for 00:36:04.580 |
So that's kind of the typical setup that deep learning uses, especially for not so large 00:36:11.840 |
networks which don't require model parallelism. 00:36:16.040 |
And so let me, yeah, let me explain how, yeah, I'll just go to the final figure and I'll 00:36:27.360 |
Okay, so we have 16 processes which are organized in the four by four mesh, right? 00:36:33.980 |
So each dotted line, each four by four dotted line here represents a different core. 00:36:41.160 |
And the first row studies how the model weights are split over cost. 00:36:45.720 |
And the second row illustrates how data, so literally examples and tokens are split over 00:36:55.360 |
And yeah, and then the final thing that's required to understand this figure is that 00:37:00.000 |
each, yeah, each color of the shaded squares here identifies the unique weight matrix. 00:37:13.140 |
So for data parallelism, the same model weights are replicated across all cores. 00:37:20.800 |
And the data is simply partitioned over cores. 00:37:23.920 |
And so that's what this corresponds to, using the description of the caption, the explanation 00:37:39.520 |
That's kind of just like a theoretical example because in practice, people always use model 00:37:44.440 |
parallelism in conjunction with data parallelism. 00:37:48.000 |
But so if you were to do only model parallelism, now you would have a single model weight that 00:37:52.440 |
is partitioned over all cores, and your data would just be replicated over all cores instead. 00:38:01.880 |
So now we have model and data parallelism, and that's kind of the typical scenario for 00:38:08.040 |
So in that case, model weights are partitioned among a subset of the cores, the subset of 00:38:14.240 |
cores that process different batches of data. 00:38:16.920 |
And so in that example here, we have sort of four, so the first sub-square here means 00:38:22.760 |
that the model weights are partitioned across four cores. 00:38:31.080 |
And this is replicated sort of four times for the data parallelism dimension. 00:38:38.320 |
On the data side, for model and data parallelism, yeah, the data here is replicated across model 00:38:46.720 |
parallel cores and partitioned across data parallel cores. 00:38:55.320 |
So in that scenario, that's kind of similar to data parallelism, but now each core will 00:38:59.520 |
hold a different model weight, which is illustrated by the different colors. 00:39:06.260 |
And for the data side, the data is simply replicated, sorry, the data is partitioned 00:39:12.400 |
across all cores, just like in the data parallelism scenario. 00:39:18.160 |
And so finally, we have the rightmost column, which is, I guess, yeah, that's the setup 00:39:25.800 |
used in the switch transformer paper for the larger models. 00:39:30.920 |
And so here for the model partitioning, each expert is partitioned across multiple cores. 00:39:37.020 |
So in that example, we have four experts, each partitioned across four cores, and the 00:39:42.820 |
data is replicated across model parallel cores and partitioned across data parallel cores. 00:39:48.500 |
So that's a little bit complex to understand, really, but the switch transformer paper has 00:39:54.780 |
a nice, the same figure with a nice caption to explain it. 00:40:00.360 |
Yeah, maybe we can, about it, we can add something quickly about how this is implemented in practice. 00:40:09.820 |
So there's this paper called Mesh Transformer, which kind of extends batch or data parallelism 00:40:18.780 |
to more general purpose SPMD style programming. 00:40:23.900 |
And so different labs have different frameworks, but this paper kind of lays the foundation 00:40:28.420 |
for general SPMD distributed computing, which is required for training large scale models. 00:40:37.080 |
And so under the mesh abstraction, basically we have a mesh of processes, and so that mesh 00:40:45.460 |
has dimensions, name dimensions, and these name dimensions specify how the tensor dimensions 00:40:53.740 |
will be partitioned or replicated across the mesh dimensions. 00:40:58.400 |
And so just that simple abstraction sort of supports data parallelism, also model parallelism, 00:41:07.980 |
And so I invite whoever is interested to also check that paper, because that kind of lays 00:41:16.060 |
the foundation for understanding these things. 00:41:23.180 |
So next we are going to kind of talk about like how we take these parallelism strategies 00:41:26.300 |
and like kind of combine them together to make like a 1.6 trillion parameter sparse 00:41:35.220 |
So what we ended up doing in this work was we trained two different very large sparse 00:41:41.820 |
models, and we compared them to the largest T5 model. 00:41:44.540 |
So we can see the T5 XXL, which is a dense model, and it was the largest one trained 00:41:48.820 |
in the T5 paper, and it has around 13 billion parameters. 00:41:52.740 |
And here we list a lot of the model dimensions like D model, DFF, which are just like the 00:41:56.140 |
various sizes and shapes of the tensors and stuff, the number of layers, the number of 00:42:01.900 |
And importantly, we also mentioned the negative log perplexity at step 250k and at 500k. 00:42:08.980 |
And so yeah, so we designed two sparse models to test like how scaling versus sparsity versus 00:42:21.660 |
So that has the same amount of flops per token as T5 XXL, but has 64 experts. 00:42:27.200 |
And this leads it to have around 400 billion parameters. 00:42:31.020 |
And we can see that on a step basis, it actually performs quite well and outperforms the T5 00:42:37.700 |
Interestingly, though, are the third model we designed switch C, which has 1.6 trillion 00:42:42.340 |
parameters, but has a significantly fewer flops, almost 10 less flops per token than 00:42:48.900 |
So it's really trading by reducing flops to have way more sparse parameters. 00:42:54.660 |
And we can see on a step basis, the switch C model works well, but not, not as well as 00:42:59.780 |
actually the higher flop model, but on a, like a kind of a Pareto axis where we are 00:43:04.860 |
looking at TPU hours on the X axis and not step the switch C model actually outperforms 00:43:12.060 |
So for pre-training performance, we're seeing that actually just like having a lot of sparsity 00:43:15.900 |
and less flops is actually, um, can be quite good. 00:43:22.060 |
And so, yeah, this, so again, those two sparse models are kind of really trying to get at 00:43:25.700 |
this hypothesis that actually Noam Shazir had, which is, you know, that, you know, parameters 00:43:30.060 |
are good for more knowledge, reasoning and compute AKA flops is good for intelligence. 00:43:37.020 |
And so we're going to kind of try to get at that by taking these different sparse models 00:43:39.900 |
and then fine tuning them on, uh, different tasks, some of which require more like knowledge 00:43:44.140 |
and then others, which require more of like reasoning, um, for whatever, like hand wavy 00:43:54.380 |
So for a fixed, Oh, can you go back to the previous slide? 00:43:59.980 |
So for a fixed quality on an upstream pre-training task, um, yeah. 00:44:06.980 |
One of which is super glue, which is kind of our like reasoning task. 00:44:09.860 |
And then another is like trivia QA, which is like some knowledge task where it's like, 00:44:12.860 |
you just give it a question, you have it output an answer. 00:44:19.300 |
And so here we're going to take a look at super glue quality. 00:44:21.860 |
So we can see on the X axis is the pre-training performance and the Y axis is the super glue 00:44:28.980 |
And interestingly, we can see definitely that the sparse models definitely are for a fixed, 00:44:34.260 |
um, pre-training perplexity do worse on fine tuning. 00:44:37.660 |
This can be especially noticed at like the upper right portion of the plot where the 00:44:41.340 |
dense models are definitely fine tuning better than the, their sparse counterpart. 00:44:48.140 |
Interestingly, when we study it on the more knowledge, heavy tasks, the sparse model for 00:44:52.940 |
a fixed, uh, pre-training perplexity does disproportionately well. 00:44:57.100 |
So, you know, for a model that roughly has the same perplexity, we're getting like really 00:45:00.260 |
large boosts for these knowledge, heavy tasks. 00:45:04.300 |
And it also really, you know, show some of the dangers of comparing only on your pre-training 00:45:10.220 |
So dense models, you know, can have the same exact pre-training metric, but very different, 00:45:13.580 |
um, you know, properties when fine tuning them on different tasks. 00:45:21.060 |
And interestingly, so yeah, all of the switch models here are the, um, are, are just like, 00:45:26.220 |
you know, various models that have still a good amount of flops, but the red model is 00:45:30.780 |
actually the 1.6 trillion parameter, uh, sparse model that has, you know, very few flops, 00:45:38.500 |
And we can see that as the red dot here, and it does actually disproportionately bad compared 00:45:42.580 |
to other sparse models that also have pretty good perplexities. 00:45:46.180 |
And so, yeah, it's, uh, it's definitely very interesting and it shows that, you know, for 00:45:49.540 |
models during pre-training that have a lot of sparsity, they definitely suffer on some 00:45:53.340 |
of these more reasoning heavy metrics, but do disproportionately well for more of these 00:46:02.060 |
And so here we can see it as just like a huge outlier for a pre-training perplexity doing 00:46:07.380 |
like just incredibly well on this, uh, downstream question answering task. 00:46:16.420 |
So also, you know, one thing that we were going to do is just look at the fine tuning 00:46:19.760 |
properties of sparse models across like a few scales and just see how they perform. 00:46:29.220 |
One is, um, T5 base, and then we make a flop match sparse counterpoint. 00:46:33.460 |
And when they say flop match, it's like, you know, each token will have the same amount 00:46:38.600 |
So we do this for both base and large, and we see that actually across almost all tasks 00:46:42.140 |
besides two arc tasks, the sparse models perform quite well, which is, which is definitely 00:46:48.460 |
So we are seeing that these models are pretty robust, they pre-train well, and then they 00:46:51.380 |
also fine tune well when scaled appropriately by scaling up both the flops and sparsity. 00:46:57.020 |
Whereas, you know, the negative results we've really seen are like, yeah, when you just 00:47:00.460 |
have a huge amount of sparsity and not too many flops. 00:47:06.180 |
And one also thing we wanted to look at was, uh, the multilingual training. 00:47:10.620 |
So we were previously studying all of this on like English only, and we also wanted to 00:47:14.300 |
see how sparsity helps in the multilingual setting because, you know, we also felt like 00:47:18.060 |
this would be a very natural place for sparsity to work well, or potentially experts could 00:47:25.500 |
So on 91% of the languages, I think of like around a hundred languages, we see over like 00:47:30.500 |
at least a 4x speedup over the MT5, um, dense model. 00:47:44.860 |
So another thing we wanted to talk about was, um, distillation. 00:47:47.300 |
So one downside of these sparse models is that they'll have a lot more parameters, which 00:47:51.940 |
means that, you know, if you're serving these things or something, you either need like 00:47:55.140 |
high throughput use cases, or you need to maybe distill it back down into like a smaller 00:48:00.700 |
So here, what we do is we look at like the T5 base and switch base, and we look at its 00:48:05.620 |
And then we go through, um, some ablations of different distillation techniques and find 00:48:09.060 |
that like with the best techniques, we can keep around 30% of the quality improvements 00:48:14.180 |
of sparsity while distilling it back down into its, uh, dense, um, counterpart. 00:48:24.420 |
And then we kind of study this across multiple scales. 00:48:26.380 |
And again, we see like around like 30 to 40% of the gains can be, um, like, you know, kept 00:48:32.340 |
when going from a dense model and going from, you know, a sparse model and distilling it 00:48:35.980 |
back down until it gets flop match dense model. 00:48:38.500 |
So you can get, you know, get rid of up to 99% of the parameters and still keep like 00:48:42.620 |
around 30% of the improvements, which is very promising. 00:48:52.940 |
You said that you can keep the benefit 30% of the teachers benefit. 00:49:01.820 |
So we're looking at like, yeah, you train a sparse model and then you just fill it back 00:49:06.260 |
down to a dense model and you're versus training a dense model from scratch. 00:49:10.700 |
And like you look at the gap between the sparse and dense model from scratch versus the, the, 00:49:15.380 |
the gap between the dense and then the distilled dense model. 00:49:22.660 |
Maybe let me just do like a quick high level summary again. 00:49:29.420 |
So what we're, what we'll do is for our comparisons is we'll train a dense model from scratch. 00:49:34.580 |
We'll train a sparse model from scratch and then we'll also run a third experiment where 00:49:38.740 |
we distill that sparse model down into a dense model. 00:49:45.740 |
Like we're basically trying to match the like the teacher's logics, like the kind of standard 00:49:50.660 |
thing of like, you know, like matching the, like either the logics or like the soft probabilities 00:50:05.420 |
So what I'm struggling with is how do I interpret the linements as percent of teacher and performance? 00:50:13.140 |
So it's, it's basically looking at the, like the gap between the dense and sparse model. 00:50:17.820 |
So we'll have the dense model gets some performance, we'll have the sparse model gets some performance 00:50:21.500 |
and then the, the dense model that's still from the sparse model would be somewhere in 00:50:28.020 |
And we're basically saying it's 30% through that range. 00:50:31.280 |
So it's like in like a zero to one interval, it's like 0.3 of the way from the dense to 00:50:37.340 |
So this is not saying that the percent of teacher performance does not mean that if 00:50:40.540 |
the teacher say gets, if we use the teacher's guesses or predictions as the ground truth, 00:50:45.420 |
this is not saying that the distilled model gets matches with the teacher, 33% of the 00:50:52.540 |
It's basically saying you get like 30% of the quality improvements. 00:50:56.740 |
And then if we can back up a slide, I had a different question, but I didn't want to 00:51:01.660 |
When we were talking about all of these different T5 bases, and then also on a few slides before 00:51:06.980 |
I'm curious, you know, when T5 is trained, is there a weight penalty in the loss function? 00:51:15.580 |
No, there's no weight decay trained with any of those sparse or dense models. 00:51:20.780 |
So out of curiosity then, how do dense models perform compared to the switch model? 00:51:25.500 |
If you add some sort of weight regularization that incentivizes getting rid of useless weights? 00:51:31.180 |
Oh, so some kind of like maybe like L1 term or something like that? 00:51:36.620 |
So I'm wondering like how much of, because here we're talking about the benefits of sparsity, 00:51:39.180 |
and I'm wondering how much of this benefit from sparsity is due to the fact that just 00:51:43.740 |
some of this, I mean, effectively what the switch model is doing, if I understand correctly, 00:51:46.900 |
maybe I don't, what I understand is that the switch model, the feed forward layer, it's 00:51:55.860 |
Well, actually, we're kind of really trying to like inject more weights. 00:51:59.220 |
So we're actually kind of trying to do, it's a little bit maybe like paradoxical, because 00:52:02.140 |
we're saying switch transformer, but our idea is to be like, hey, we actually want to just 00:52:05.700 |
have significantly more weights, not less weights. 00:52:08.580 |
It's kind of like you would zero out weights, but within a much larger weight matrix, if 00:52:16.660 |
And so to me, it seems like a relevant baseline to just ask what happens if I have the dense 00:52:18.380 |
matrix, but I incentivize it with, say, an L1 or L2 penalty on the weights. 00:52:21.860 |
And I would, I'd be curious to know how that compares. 00:52:24.740 |
Yeah, we didn't run this, but also that kind of gets rid of weights for the dense model. 00:52:35.580 |
Also, to me, it's like, if you just add like an L1 penalty loss, you're not going to have 00:52:39.460 |
structured sparsity, whereas like here we, you know, it's not random weights in your 00:52:46.580 |
giant weight matrix that are zeroed out, right? 00:52:48.660 |
It's like really like blocks depending, like blocks corresponding to each expo. 00:52:55.020 |
So that structure allows the whole like communication stuff and that's- 00:53:01.820 |
That leverages the fact that you have multiple calls and so on, right? 00:53:05.420 |
I totally agree with that block structure and that's what I'm trying to say, is that 00:53:08.860 |
the switch has this very rich, it's not just sparse, it also has this rich structure. 00:53:12.780 |
And what I'm trying to do in my mind is disentangle, is the sparsity what's offering an advantage 00:53:17.420 |
or is this additional structure that you built in, is that what is the performance gain? 00:53:23.740 |
So the block structure is what enables to leverage the fact that you have multiple calls. 00:53:32.340 |
Like if you didn't have that block structure, you'd still have to route to everything. 00:53:36.980 |
And so you have more communication costs and so on. 00:53:40.740 |
And then your first question was what, sorry? 00:53:43.020 |
I'm not actually sure if there was a question, I guess what I'm trying to say is I'm trying 00:53:48.780 |
But I agree, it's a little bit weird because sparsity kind of, there's a spectrum of meaning 00:53:54.980 |
So it's like, for example, compression and like model pruning is a form of sparsity, 00:54:00.220 |
but also a switch transformer and MOE also referred to as sparsity and that kind of related, 00:54:07.340 |
but definitely they're aiming at different things, so. 00:54:09.940 |
This is a really interesting idea of it's sparse, but you have more parameters. 00:54:16.900 |
So you have like sparse within this like giant weight matrix, which is- 00:54:28.220 |
I have a follow up question on distillation part. 00:54:34.340 |
So if you distill it back down, now you have like one technically, you're back to the dense 00:54:42.620 |
So now the entire idea of expert is that certain tokens would be sent to different experts 00:54:48.140 |
because they just like, I don't know, are more specialized in figuring something out 00:54:52.940 |
So now if you go back to this like dense layer, aren't you like basically only serving whichever 00:55:04.140 |
expert you base this dense layer on, like these tokens will probably perform well and 00:55:09.020 |
all the other tokens are kind of like left behind, right? 00:55:13.140 |
I'm actually, sorry, I don't think I'm fully understanding your question. 00:55:19.500 |
So are you kind of getting at like we're distilling this on a specific data set? 00:55:25.060 |
No, I'm thinking of how to use that, like why- 00:55:30.620 |
So maybe concretely, like let's, so like for super glue, right? 00:55:31.620 |
Like let's say you want to serve a model that does super glue well. 00:55:34.380 |
I think the idea is that like you distill the sparse model into a dense model on super 00:55:39.040 |
So then you kind of get this compressed dense model that now performs better than if you 00:55:42.740 |
were to just train it from scratch or train it from like a pre-trained dense model. 00:55:53.740 |
You can just distill all of the, again, because you're just matching the model outputs. 00:55:57.380 |
So you can just treat the sparse model as kind of like a black box thing. 00:56:00.180 |
All we're doing is just trying to have the dense model match the actual like final like 00:56:11.540 |
I was not, I was not familiar with the idea of distillation. 00:56:12.540 |
So I think that was like my current confusion. 00:56:17.540 |
Um, I guess one motivation here is that, um, having experts can make solving a little bit 00:56:24.660 |
more difficult because, um, it requires bigger topologies. 00:56:29.300 |
Let's say you have eight experts, um, you need like, well, I guess you can have multiple 00:56:35.140 |
experts on fewer calls, but, um, you know, let's just say they're a little bit harder 00:56:42.500 |
And so if we can, you know, get the benefits from sparsity at pre-training and then use 00:56:48.700 |
distillation to a dense model for solving, uh, that can be, that can be beneficial. 00:56:54.740 |
So I think that was sort of the motivation for that, uh, experiment, right, Derek? 00:57:10.100 |
I just said, I think one more string kind of question. 00:57:24.100 |
Was wondering if you think there are any interesting directions around, uh, building models that 00:57:29.220 |
are like explicitly optimized for, for parallel training. 00:57:33.660 |
Um, I guess like the, the MOE model seems like, you know, it does a really good job 00:57:39.420 |
And also like at, at inference time, it's very useful to like, you know, have fewer 00:57:43.820 |
flops per, per computation, um, or per forward pass. 00:57:49.720 |
But, um, I guess, do you think that there are any interesting directions around distributed 00:57:54.180 |
training where you might have like models that are explicitly are architected to have 00:58:00.060 |
a lot of, uh, parallel heads or, or other like features that are, you know, kind of 00:58:06.500 |
embarrassingly parallelizable or does just using like standard, you know, scale up the 00:58:12.420 |
models by adding more layers, uh, and then just, you know, get away with using model 00:58:21.940 |
So let me just make sure I'm fully understanding. 00:58:23.440 |
So yeah, I think also like, you know, right now, like even our models are definitely very 00:58:26.780 |
co-designed with the hardware and like the shapes and things, you know? 00:58:29.900 |
Um, so yeah, I, I, I think at a high level, like, yes, I think there's a ton of interesting 00:58:33.760 |
research on like co-designing the hardware, the partitioning algorithms and the models. 00:58:39.020 |
I think given, you know, that we have this kind of like SPMD mesh style partitioning, 00:58:43.540 |
we are already kind of designing our models in ways that fit it really well. 00:58:47.020 |
So for example, when we want to scale up our model, one of the first dimensions we go to 00:58:52.980 |
Because there's some really nice properties of scaling up this dimension. 00:58:55.140 |
It basically becomes like, kind of, you know, independent to some of the communication costs. 00:58:59.040 |
It's really good when looking at the compute to memory operations on these, you know, like, 00:59:06.620 |
Like I think when we're even designing these models, we're like really setting dimensions 00:59:11.620 |
Um, so it's almost like, you know, given that we have this model data parallelism, we're 00:59:18.300 |
But I also think that there's a ton of new, interesting distributed algorithms and stuff 00:59:22.100 |
like that, which makes designing models very interesting. 00:59:24.020 |
Like I think one thing that I think is really cool is like the Microsoft zero partitioning 00:59:27.740 |
too, which also adds some really new, like nice implications for like how to design and 00:59:32.900 |
So yeah, I think there's like, this is a very fruitful research direction. 00:59:36.380 |
Um, if that, if that kind of answered your question, yeah, no, that was super helpful. 00:59:44.340 |
Like I'm very optimistic on the future of us, like designing the hardware, the model, 00:59:48.020 |
the partitioning strategies altogether, because really to get it to work well, you kind of 00:59:50.820 |
have to know about all three and like kind of, you know, intertwined the development 01:00:00.380 |
So just to summarize, it's like, yeah, so switch transformer is like a nice simplification 01:00:05.740 |
And we're seeing that we get really strong speed up improvements on pre-training over 01:00:09.900 |
like a lot of the T5 models, which are very strong baselines. 01:00:13.140 |
We're seeing that we can, you know, efficiently distill the sparse models back to dense ones 01:00:17.420 |
and, you know, get improved both pre-training and fine tuning through some of these newer 01:00:23.500 |
And we're also seeing that the models are working on multilingual data and that we can, 01:00:27.260 |
you know, now easily successfully train up to, you know, 1.6 trillion parameter models, 01:00:31.700 |
which is pretty promising and, um, next slide. 01:00:35.500 |
And so we also wanted to go into two slides about some like newer work about actually 01:00:38.540 |
using these kinds of models for computer vision, and actually also a little bit of how they 01:00:42.340 |
can be used to actually do some level of like adaptive computation where not only now each 01:00:47.060 |
input gets different weights, but also sometimes different inputs will have different amounts 01:00:53.420 |
And so there was some really great work of doing this out of the Google Zurich team. 01:00:58.340 |
And yeah, there's just doing it for image classification and, you know, they're basically 01:01:01.420 |
seeing a lot of the similar types of scaling properties where, you know, scaling up the 01:01:04.780 |
number of experts and using sparsity allows them to get good performances on image classification. 01:01:14.780 |
And interestingly, one of the things they do is like, as we talk about the capacity 01:01:17.740 |
factor, so we were talking about values of like one, 1.25, 2.0, which means like at a 01:01:22.100 |
value of 2.0, there's buffer for, you know, two tokens per expert, but they actually study 01:01:28.420 |
So that means that like at 0.5, that means there's only like room for half the number 01:01:33.140 |
And the nice part is, is that they did this for image classification. 01:01:35.900 |
And also in images, there's just a lot of redundancy and they noticed that you can actually 01:01:39.700 |
get really good performance by only allowing like, you know, up to one 10th of the parts 01:01:45.320 |
of the image to be processed by a sparse layer. 01:01:47.340 |
So yeah, we think this is like a really nice direction too, in terms of combining sparsity 01:01:59.660 |
So thank you, Barrett and, sorry, Arifan, for coming here. 01:02:08.300 |
So I will just like ask a bunch of questions and then we can have like a, after the class, 01:02:20.020 |
So one thing is like, have you tried using like, like more like linear attention mechanisms 01:02:23.700 |
like reformers and like other stuff to like scale the computation? 01:02:31.020 |
I personally haven't, I haven't personally done this. 01:02:35.620 |
So, oh, you know, I guess we can maybe comment on how, you know, the attention, the cost 01:02:44.540 |
coming from the attention maps isn't the dominant cause in, in this large transformers. 01:02:52.640 |
So you know, the motivation for using linear attention, like performance is that it reduces 01:03:03.460 |
But so far, I mean, at least, you know, in like sort of typical NLP setups, like superglue, 01:03:09.060 |
C4 and so on, as you scale the models, most of the memory comes from the model weights 01:03:17.260 |
as opposed to attention, to the attention maps. 01:03:20.460 |
That's also because, you know, using very long context or sequence length doesn't prove 01:03:28.580 |
And so, you know, just, you know, working with the vanilla self-attention mechanism 01:03:39.020 |
So another question is like, do you think this like mechanism is even more scalable? 01:03:43.660 |
Like, can you go on and be like 10 trillion parameter models, stuff like that? 01:03:50.560 |
I think, honestly, the, one of the biggest constraints is that like, you know, and this 01:03:55.300 |
isn't even necessarily a constraint, it's just like, you have to fit the parameter somewhere 01:04:01.780 |
But if you get enough devices such that, you know, yeah, you can just partition the weights. 01:04:05.420 |
It's like, yeah, I don't see anything stopping it. 01:04:08.940 |
So what do you think, like, personally, is your, like, the thing, like, with the direction, 01:04:14.120 |
like, like scaling of transformers will go into, like, will there be more like works 01:04:18.480 |
that are trying to just like use this transformer, like mechanisms, like Mr. Experts, or do you 01:04:23.340 |
think there's like, you're going to be other things that the community needs? 01:04:26.860 |
I mean, I definitely think mixture of experts should find its way, or at least, you know, 01:04:29.780 |
sparse players like switch transformer and stuff will definitely, I think, find their 01:04:34.120 |
I think they really confer a lot of benefits and they're also very good in like high throughput 01:04:39.780 |
So I think the one thing, like, so the one downside is on sparsity is like, if you look 01:04:43.300 |
at the performance per model weight, they're going to always be worse than bonds models. 01:04:47.700 |
So it's like, if you really are constrained on like, I want to design the best model I 01:04:51.120 |
can to fit on as small of a device as I can, then they're probably not going to be the 01:04:55.140 |
best solution because the sparse weights just aren't as good as just the dense weight that's 01:05:01.020 |
So I think it really depends on the application, but I'm very optimistic for when we're training 01:05:04.660 |
these models during pre-training with lots of data parallelism, and then we're serving 01:05:07.740 |
them in like medium to higher throughput examples. 01:05:10.220 |
I feel like they could actually just be a pretty big win. 01:05:13.800 |
So that that's kind of my thoughts on, on how I think sparsity will be used in terms 01:05:19.300 |
There's a ton of exciting research, you know, from everything from, yeah, like a lot of 01:05:21.740 |
the linear attention stuff, adaptive computation, new pre-training objectives, you know, yeah, 01:05:27.260 |
it's hard to know what the future will look like, but yeah, a lot of exciting things to 01:05:34.940 |
So we can just now have like a round of student questions, so we'll just stop the recording.