back to index

Stanford CS25: V1 I Mixture of Experts (MoE) paradigm and the Switch Transformer


Chapters

0:0
0:7 Scaling Transformers through Sparsity
0:25 Overall Motivation
1:0 Scaling Laws for Neural Language Models
5:5 Switch Transformer
7:23 Improved Training Methodology
8:27 Differentiable Load Balancing
8:55 Selected Precision
10:25 The Initialization Scale
16:23 Multi-Stage Routing Procedure
20:35 What Is the Research Question
29:50 Perplexity versus Strength Time
30:53 Spot Scaling Laws
37:10 Data Parallelism
37:36 Model Parallelism
38:51 Expert and Data Parallelism
39:31 Model Partitioning
40:37 Mesh Abstraction
46:18 Fine-Tuning Properties of Sparse Models
47:8 Multilingual Training
47:45 Distillation

Whisper Transcript | Transcript Only Page

00:00:00.000 | Today, Erwin and I are going to be giving a talk on scaling transformers through sparsity.
00:00:10.460 | And the kind of sparsity we're going to be talking about today is the kind where each
00:00:13.820 | input can get either a different set of weights or have a different amount of computation
00:00:18.720 | applied to it.
00:00:19.720 | Erwin, do you want to start it off?
00:00:22.960 | So, I guess the overall motivation for this line of work is that the community has realized
00:00:32.360 | that scale is perhaps one of the most important axes to focus on for obtaining strong performance.
00:00:40.240 | And there's almost like this ongoing arms race right now with different labs and different
00:00:45.720 | institutions competing for training the largest models.
00:00:53.440 | And so, maybe this dates back from early 2020 with a paper from OpenAI called Scaling Laws
00:01:01.360 | for Neural Language Models, where they find that model performance follows a predictable
00:01:08.600 | power law, scale as a power law with model size in terms of either compute or just parameters.
00:01:22.080 | And so, this scaling law generalizes over multiple orders of magnitude, and that gives
00:01:28.620 | us the confidence that if we are to train very large models, we can expect a certain
00:01:35.060 | performance just by extrapolating these scaling laws.
00:01:40.120 | So, in that paper, they also find the interesting observation that basically larger models are
00:01:50.240 | more sample efficient.
00:01:52.200 | And so, if you have a fixed compute budget, you can predict what is the size, what is
00:02:03.120 | the optimal model size for a fixed compute budget.
00:02:06.440 | And the overall observation is that you'd rather train very large models for less tests
00:02:15.640 | than train smaller models for more training steps.
00:02:19.680 | And so, these models are scaled through basically the paper focuses on dense models, where you
00:02:28.640 | just increase the model dimensions, but they're not looking at sparsity.
00:02:34.480 | And so, sparsity is a new dimension that you can use to scale architectures, and this is
00:02:40.540 | sort of the focus of the talk.
00:02:45.240 | And so, the sparsity we're mentioning here is basically you will have sparsely activated
00:02:52.960 | weights based on the network inputs.
00:02:56.560 | So, every input will go to a roughly similar amount of computation, but will be applied
00:03:02.640 | different weights.
00:03:05.840 | And so, this dates back to 1991 with a paper called Adaptive Mixtures of Local Experts,
00:03:14.040 | and was recently revisited by Noam Shazier and colleagues at Google Brain with LSTMs,
00:03:21.720 | where they replaced sort of the feed-forward networks in LSTMs with a mixture of experts.
00:03:30.800 | And so, the way this works there roughly is that you will have multiple experts each implementing
00:03:37.880 | a small network, or in that case, I think just a dense matrix multiplication.
00:03:46.040 | And so, you have an additional gating network shown in green here that outputs a probability
00:03:54.160 | distribution over experts that each token should be sent to.
00:04:00.600 | So, this probability distribution is computed as a softmax, and once you have it, you select
00:04:09.080 | a few experts.
00:04:11.080 | So, there are different strategies, maybe we'll talk about it later on.
00:04:16.120 | And the output is simply sort of the weighted mixture of all selected experts' outputs.
00:04:22.440 | So, they've been pretty successful primarily in translation, but there were some complexities
00:04:38.160 | that hindered their broader use in NLP.
00:04:41.880 | And so, the Switch Transformer paper addresses some of those, and we'll be discussing how
00:04:49.880 | to fix training instabilities or reduce communication costs and reduce model complexity.
00:04:56.960 | All right, Barrett, do you want to go?
00:05:01.920 | Yeah.
00:05:02.920 | So, one kind of approach that we're going to have for sparsity is the Switch Transformer,
00:05:07.520 | which is kind of like a simplified mixture of expert variant along with some other improved
00:05:13.080 | training and fine-tuning techniques that allow it to be stably trained and also perform better
00:05:19.560 | when fine-tuned on a lot of downstream tasks.
00:05:22.880 | And so, yeah, so the Switch Transformer kind of model works as the following.
00:05:27.320 | So, you have some transformer model that has self-attention and feed-forward layers.
00:05:32.880 | And the idea is that we replace maybe one every two or one every four feed-forward layers
00:05:37.240 | with a Switch Transformer layer.
00:05:39.960 | So, you can see on the left is like one kind of layer block, which is self-attention, then
00:05:46.360 | add normalize, then a feed-forward layer, then add normalize.
00:05:49.400 | And in this case, we're replacing the normal feed-forward layer with the Switch layer.
00:05:53.920 | And we can see an illustration of this on the right.
00:05:57.000 | So, on the right, we can see that the layer has two inputs.
00:06:01.320 | One is the token more, the other is the token parameters.
00:06:04.480 | And we can see that these embedding representations will get sent to a router, which is exactly
00:06:09.400 | how it works in the mixture of expert.
00:06:10.960 | So, the router is basically just going to be getting a distribution over all of the
00:06:15.240 | experts.
00:06:16.240 | So, in this case, we can see that the highest probability is going to the expert number
00:06:20.760 | two out of the four experts.
00:06:23.080 | And then the right token is actually having the most probability on the first feed-forward
00:06:27.960 | weight, which is like the first expert.
00:06:30.040 | So, yeah, we can see here that what we're going to do is in the Switch Transformer,
00:06:33.760 | which is very simple.
00:06:34.760 | It's just send it to the highest probability expert.
00:06:38.040 | And so, here we can see where the adaptive computation lies, where we'll have four sets
00:06:42.120 | of weights.
00:06:43.480 | There's some shared weights and computation across all the tokens.
00:06:46.560 | For example, the self-attention layer is computed exactly the same for the more token and for
00:06:50.800 | the parameters token.
00:06:52.920 | But in the sparse Switch layer, we can see that actually the inputs are, while having
00:06:56.720 | the same amount of floating point operations applied to them, actually have different weight
00:07:00.360 | matrices.
00:07:03.680 | Next slide.
00:07:05.160 | Yeah.
00:07:06.160 | So, that's the kind of high-level idea with Switch Transformer, is that instead of sending
00:07:10.920 | a token to multiple different experts, which can also increase the communication costs,
00:07:14.940 | as I'll go into a little bit later, it also just significantly simplifies the algorithm
00:07:18.920 | by just only sending it to one expert.
00:07:22.120 | So, for the improved training methodology, we focused on three different things to help
00:07:26.940 | improve the training of sparse models.
00:07:29.560 | The first was selected precision, which allows these sparse models to be trained in lower
00:07:33.400 | precision formats, which is incredibly important.
00:07:36.560 | Most of the models we train, we really don't want to be using float 32, because it's just
00:07:40.640 | slower to compute.
00:07:41.640 | And also, when you're communicating tensors across different processes and stuff, it's
00:07:45.680 | twice as slow, just because there's twice as many things.
00:07:49.120 | Also, we have some initialization tricks and some training tricks as well for allowing
00:07:53.160 | them to be trained more stably, especially as the models grow in size, which is like
00:07:57.200 | a new initialization method, along with a change to the learning rate schedule.
00:08:02.800 | And third, since that our models have so many more parameters, we do notice definitely different
00:08:07.260 | overfitting dynamics, especially once we fine tune these models that have been pre-trained
00:08:12.160 | on all of the internet on these small tasks with maybe only 50 to 100,000 examples, that
00:08:16.880 | they can be much more prone to overfitting.
00:08:19.080 | So we also look at some custom regularization to help prevent some of the overfitting that
00:08:24.720 | we observe.
00:08:26.400 | And finally, we also talk about this differentiable load balancing technique we make, which kind
00:08:31.520 | of allows each expert to roughly get the same amount of tokens.
00:08:36.200 | Because this is very important, especially given that we want the stuff to be efficient
00:08:39.480 | on hardware.
00:08:40.480 | We want roughly each expert to have similar amounts of tokens sent to it.
00:08:44.080 | And so to kind of encourage this, we tack on an additional load balancing loss along
00:08:49.040 | with our cross-entropy loss that we're training with.
00:08:52.580 | Next slide.
00:08:54.580 | So here, I'm going to go into selected precision.
00:08:56.520 | So yeah, again, so when we're training large models, it's really important that we should
00:09:00.080 | be able to train them in lower precision formats.
00:09:02.080 | So instead of each weight being an activation, being 32 bits, we want to shrink it down to
00:09:07.400 | 16 bits.
00:09:08.400 | And we use the bfloat16 representation.
00:09:11.220 | And what we found out of the gate is that these models are just unstable, especially
00:09:16.240 | the sparse models are much more unstable than the dense models in terms of you'll train
00:09:19.700 | it for 10,000, 20,000 steps, and then the losses would just diverge.
00:09:23.040 | This was something that we frequently encountered.
00:09:25.620 | And so one key thing that we found is that basically, you need to be casting a part of
00:09:30.740 | the computation in float32 for these models to be able to be trained stably.
00:09:38.180 | And the key component that we found that you need to cast is the router computation.
00:09:42.840 | And essentially, we can go into the technical details a little bit more later.
00:09:47.060 | But basically, any time that there's these exponentiation functions, it's very important
00:09:51.140 | that we are having higher and higher precision because of round off errors that can then
00:09:56.720 | drastically change the output of some kind of exponentiation function.
00:10:01.480 | So for example, if you have an exponentiation function and you change it by 0.1 or 0.2 or
00:10:06.580 | 0.3, this can drastically change the output of exponentiating it, especially depending
00:10:11.860 | on how large the input is.
00:10:14.080 | So yeah, so this was a very important thing.
00:10:15.940 | And it basically doesn't change the compute at all and allows the models to just be significantly
00:10:19.440 | more stable.
00:10:20.440 | Next slide.
00:10:24.060 | So the second thing we looked at is also the initialization scale.
00:10:27.220 | So like the standard way that we were initializing these models, we found to also just make the
00:10:31.980 | models much more prone to being unstable and/or just performing worse.
00:10:36.260 | So one thing that we did that we found was very effective was to just simply make the
00:10:40.380 | initialization scale much smaller.
00:10:42.460 | And when we did this, we found that the quality just drastically improved, and it was like
00:10:46.300 | a very simple fix.
00:10:48.940 | Next slide.
00:10:51.840 | And the third thing I mentioned, where since we noticed that these models are much more
00:10:55.300 | prone to overfitting, since they just have significantly more parameters, is that we
00:10:59.600 | also use much more dropout for the expert layers only.
00:11:03.100 | So here we can see we have the T5 base, which is a dense model, and then we have a bunch
00:11:08.320 | of different switch variants on that.
00:11:10.120 | And we found to be the most effective on these four different fine-tuning tasks was just
00:11:13.920 | to really significantly increase the dropout rate inside the expert layers.
00:11:17.800 | And we found that this was pretty effective for combating the overfitting.
00:11:21.160 | Next slide.
00:11:22.160 | Better?
00:11:23.160 | Yeah.
00:11:24.160 | We have a question.
00:11:25.160 | Oh, awesome.
00:11:27.160 | For one of the students.
00:11:28.160 | Yeah.
00:11:29.160 | OK, let me take a look.
00:11:30.160 | Do you want to go ahead?
00:11:31.160 | Yeah, I can ask.
00:11:32.160 | It was just in reference to the previous table where you have throughput and precision.
00:11:36.880 | It just seems surprising to me that you could match this 1390 number using selective precision.
00:11:43.320 | It seems like I would expect it to be something in between.
00:11:46.960 | Yeah.
00:11:47.960 | So it essentially comes down to the fact that there's maybe a little bit of noise sampled
00:11:52.320 | with the speed.
00:11:53.440 | And the only part we're casting is the router, which is maybe such an insignificant portion
00:11:58.720 | of the computation.
00:12:00.240 | And there's zero communication there.
00:12:01.600 | That is essentially like a free operation in the network.
00:12:04.140 | So whether you cast it to VFLOW16 or FLOW32, it doesn't actually impact the speed at all
00:12:08.720 | within the precision that we can actually measure the speed.
00:12:12.320 | And also, these architectures only use sparse layer once, one every four layers.
00:12:19.240 | And so, yeah, essentially, the FLOW32 part is kind of very negligible in the entire architecture.
00:12:27.320 | It's like, for example, I think off the top of my head, it's like 1/40th the computation
00:12:32.360 | that would cost for you to do the first weight matrix multiply in a dense, ReLU dense layer
00:12:38.040 | or something.
00:12:39.040 | So it's a very, very small part.
00:12:40.040 | And yeah, we're not using them very frequently, like Erwin mentioned as well.
00:12:44.000 | Got it, OK, thanks.
00:12:48.800 | Yeah, and then just a quick point on this.
00:12:51.480 | I won't go into some of the technical details, but yeah, we definitely-- since we're training
00:12:54.960 | these things on hardware, we really-- I think a big part of the mixture of experts paradigm
00:12:58.660 | is that these things are designed such that it maps really efficiently to hardware.
00:13:03.680 | So we want to be doing dense matrix multiplies.
00:13:06.360 | And for this to work really well, we also want to be able to have roughly equal amount
00:13:10.280 | of tokens going to each of the different experts.
00:13:13.740 | And I think this isn't that sensitive to the load balancing formulation.
00:13:17.800 | Like, we tried a few things.
00:13:18.800 | A lot of them worked.
00:13:19.800 | But yeah, essentially, you definitely want some kind of load balancing loss added on
00:13:23.640 | when using sparsity.
00:13:25.040 | Yeah, next slide.
00:13:27.640 | Yeah, Erwin, go ahead.
00:13:31.960 | Yeah, so the frameworks, the library we use rely on static shapes for-- OK, yeah, so XLA,
00:13:55.020 | so the compiler for TensorFlow and MeshTensorFlow expects static shapes for tensors.
00:14:03.220 | However, the computations in switch transformers are dynamic because of the router, right?
00:14:11.980 | Different inputs will be routed to different experts.
00:14:15.700 | And so we need to specify ahead of time how many tokens will be sent to each expert.
00:14:22.580 | And so we will introduce this expert capacity hyperparameter to specify that.
00:14:29.120 | And that's going to be a static number which says how many tokens each expert can process.
00:14:37.960 | And so in practice, we instead parametrize this by having a quantity called the capacity
00:14:43.300 | factor.
00:14:44.300 | And so we have an example here.
00:14:47.940 | So the bottom row is a bunch of tokens on one device.
00:14:56.660 | And then you need to sort of route those tokens to multiple devices or multiple experts.
00:15:02.880 | So if too many tokens are routed to a single expert, some tokens will be dropped because,
00:15:08.740 | as we said, experts have a fixed capacity.
00:15:14.000 | So that's the example on the left where the capacity factor is one, and that basically
00:15:17.900 | means that there's no extra buffer for routing tokens.
00:15:27.180 | So instead of that, we can use the capacity factor that's larger than one.
00:15:30.780 | So on the right, you have an example with 1.5.
00:15:34.740 | So that means that now each expert has three slots that can process three tokens.
00:15:42.580 | And so that prevents token dropping because we have more capacity.
00:15:46.920 | But the issue is that this means more expensive communication across devices.
00:16:02.860 | One thing that we also experimented with was this method called no token left behind.
00:16:07.820 | And the idea was the following.
00:16:09.000 | So since we have to have a fixed batch size for each expert, and there can be token dropping,
00:16:15.700 | we're thinking that, hey, yeah, having tokens dropped or having some tokens not having any
00:16:19.860 | computation applied to it is probably hurting the model performance.
00:16:23.740 | So what if we do a multistage routing procedure?
00:16:26.060 | So first, you do the normal routing where it's like you send each token to its highest
00:16:29.500 | probability expert.
00:16:30.900 | But then any dropped tokens, you then send to their second highest probability expert,
00:16:36.460 | and so forth and so on.
00:16:37.460 | Or you can basically repeat this process to guarantee that no tokens are being dropped.
00:16:41.540 | Interestingly, actually, this approach didn't empirically improve model performance.
00:16:46.020 | If anything, it actually kind of hurt it.
00:16:48.420 | And we thought that was actually very interesting.
00:16:51.060 | And I think the intuition is that, you know, once the model learns it wants to send a token
00:16:54.380 | to one expert, like it really wants to have that computation applied to it.
00:16:58.060 | And just applying some other computation doesn't, you know, have at all the same property, along
00:17:03.460 | with it actually maybe being potentially detrimental.
00:17:06.400 | So yeah, we thought that was pretty interesting, as we were very optimistic this would potentially,
00:17:10.240 | you know, get improved performance, but it ended up not really making a difference.
00:17:13.340 | And we found this quite surprising.
00:17:15.380 | We have a question from…
00:17:21.220 | I think it will actually kind of like address literally the last point that you brought
00:17:26.380 | I think when I think about like a mixture of experts, usually like they specialize in
00:17:31.340 | like different things, right?
00:17:33.140 | So I think it was like, just like a lot, like I was just wondering, like if you send it
00:17:41.740 | to like the second best or whatever, like what if like all of your tokens would be particularly
00:17:49.140 | good for like one expert, and then you only like process, let's say, like 20% of your
00:17:55.000 | tokens.
00:17:57.000 | So that ends up being better than rerouting them to anything else.
00:18:02.080 | Exactly.
00:18:03.080 | Yeah.
00:18:04.080 | So yeah, even if you're dropping a lot of tokens, it's not beneficial to be sending
00:18:06.760 | them to the second, third or fourth best thing.
00:18:09.440 | And one actually interesting property that we, you know, noticed about these models is
00:18:12.380 | they're surprisingly robust to token dropping, especially during fine tuning.
00:18:16.980 | So yeah.
00:18:17.980 | So in the standard paradigm, what we'll do is we'll pre-train this thing, we'll have
00:18:19.680 | some load balancing loss, which makes the tokens pretty balanced actually.
00:18:24.480 | But then during fine tuning, where it's like, we really want to fine tune it on a specific
00:18:27.640 | task.
00:18:28.640 | We actually studied this exact question and we were studying, does it help to have a load
00:18:31.780 | balancing loss during fine tuning or not?
00:18:34.360 | And so if you have the load balancing loss, yeah, that kind of is encouraging, you know,
00:18:37.560 | for the specific task, we want to try to have, you know, all the experts be used versus turning
00:18:41.240 | it off.
00:18:42.240 | Whereas there's definitely some, you know, prior specialization and it's actually much
00:18:45.600 | better to just turn the auxiliary loss off.
00:18:47.760 | And even if it's like, you know, 60 to 70% of the tokens are being dropped, that actually
00:18:51.760 | performs much better than, you know, having all the tokens balanced.
00:18:55.520 | But doesn't a load balancing loss encourage basically all the experts to learn very similar
00:19:00.920 | weights and then just randomly assign your tokens?
00:19:05.320 | Because then it doesn't matter to which expert stuff is being sent to.
00:19:08.980 | So when we use the load balancing loss, like the routing mechanism is definitely learned.
00:19:12.120 | So the model definitely is encouraged to, you know, choose an expert that it wants to
00:19:16.520 | send it to for good.
00:19:17.520 | Right.
00:19:18.520 | But like if all the experts learn the same weights, then the router learns basically,
00:19:23.760 | oh, it doesn't matter where I send it to.
00:19:26.720 | So if you encourage load balancing, you encourage technically that like you want any loss to
00:19:32.760 | fit with any expert.
00:19:33.760 | Right.
00:19:34.760 | I mean, that's maybe the extreme behavior if you have a very high sort of load balancing
00:19:39.320 | loss coefficient, but in practice that coefficient is kind of tuned and we observe that for,
00:19:44.520 | you know, small enough values, the router still learns like semantic, like meaningful
00:19:51.040 | routing.
00:19:52.040 | Yeah.
00:19:53.040 | Because it's like a balance between this, like, you know, cross entropy loss and this
00:19:55.920 | load balancing loss.
00:19:57.440 | And so on one hand, yeah, you definitely want to encourage the model to be balanced.
00:20:00.880 | Then on the other hand, you also want to just get good empirical performance.
00:20:04.760 | And yeah, the model is able to definitely like on one hand, learn and specialize the
00:20:08.680 | experts where they have different weights such that it's like, you know, definitely
00:20:11.640 | it expects certain tokens to be sent to certain experts, but on the other hand, still be reasonably
00:20:14.960 | balanced so that the models are efficiently run on like modern hardware.
00:20:18.800 | Exactly.
00:20:19.800 | We also have a question from the classroom.
00:20:24.080 | So the question that I want to ask is, it seems to me like this is a very experimental
00:20:28.480 | talk.
00:20:29.480 | We're talking about floating point precision.
00:20:30.480 | We're talking about different approaches and currently work well.
00:20:31.480 | And whenever we're dealing with clients, there's a question of what is the research question?
00:20:37.440 | And I feel like I missed that.
00:20:39.040 | So what are we trying to answer with all these experiments?
00:20:43.000 | Yeah.
00:20:44.000 | I think the, I think the high level of research question is like, you know, can we, you know,
00:20:48.800 | create models that are, you know, like doing adaptive computation from the standpoint of
00:20:53.920 | like, no, can we try to make models more simulate the dynamics that we think models should most
00:20:58.760 | naturally use, which is different inputs to have different amounts of computation applied,
00:21:03.040 | have different weights applied to them, you know, and basically all of this, basically
00:21:06.080 | we're trying to research and like figure out how can we create like a new framework for
00:21:09.560 | these models to be trained as opposed to their dense counterparts that, you know, for every
00:21:13.440 | input are always having the same exact computation applied.
00:21:17.040 | So that's interesting because when you say the same exact computation applied, one might
00:21:21.040 | imagine that like, to me, the immediate thing is about how long to deliberate about something.
00:21:26.960 | What I mean by that is if we want to have variable length computation, you could imagine
00:21:31.360 | that I could have a short amount of computation or it could have much older computation, but
00:21:35.280 | there's like, you have like, why then do we instead consider the dimension of different
00:21:39.480 | computation?
00:21:40.480 | I mean, assuming of course that these experts do indeed learn different things, which I
00:21:43.480 | think you'll get to in a minute.
00:21:45.320 | Yeah.
00:21:46.320 | So why do we immediately jump to thinking about specialized experts as opposed to thinking
00:21:50.880 | about variable length computation?
00:21:52.640 | So, yeah, so this is actually, we actually go into some variable length computation stuff
00:21:57.240 | later in the talk.
00:21:58.240 | And I feel like they're both actually just important axes that should both be pushed
00:22:02.720 | I think, I guess, yeah, I guess it's kind of, you know, yeah, I'm not afraid of my question,
00:22:07.800 | but what I'm trying to understand is you're thinking about why did you decide to attack
00:22:10.760 | this one first?
00:22:11.760 | I want to understand why your team chose to go this direction first.
00:22:14.560 | Yeah, absolutely.
00:22:15.720 | So I think that one empirically, it seems that sparsity has led to better empirical
00:22:20.880 | results in the field of deep learning than adaptive computation so far.
00:22:23.960 | And I think the way that we use these things maps really well to our modern hardware, which
00:22:28.320 | is also very promising.
00:22:29.820 | And I think the way we were kind of looking at it as like sparsity is like a first step
00:22:32.840 | towards doing more interesting and general adaptive computation where, and we're, and
00:22:37.320 | you know, cause I think it's like, you know, this stuff is complicated and typically starting
00:22:41.000 | from something that works well is better than necessarily like, you know, you know, trying
00:22:45.720 | something that's not necessarily as proven out and then trying to like get it to work
00:22:49.080 | really well.
00:22:50.080 | So I think we're kind of starting from sparsity, which like, you know, Noam Shazier and others
00:22:53.360 | got to work really well in the context of LSTMs.
00:22:55.600 | We were kind of interested in, you know, let's port some of this to transformers.
00:22:58.920 | Let's get it working really well.
00:22:59.920 | And then let's slowly start expanding towards a lot of the other natural questions that
00:23:03.080 | you mentioned.
00:23:04.080 | Whereas like, okay, whereas instead of, you know, different weights per core, let's also
00:23:07.640 | maybe have different computation per core and all of this.
00:23:10.320 | So that's, I guess how we were kind of building the natural, like, you know, buildup and progression
00:23:14.240 | of our research.
00:23:15.240 | Got it.
00:23:16.240 | Cool.
00:23:17.240 | Thank you.
00:23:18.240 | Yeah.
00:23:19.240 | What do you think Erwin?
00:23:21.280 | Anything else to add?
00:23:22.280 | Yeah.
00:23:23.280 | I mean, I guess I kind of see adaptive computation and sparsity as, you know, related, but separate
00:23:31.520 | things.
00:23:32.520 | So, you know, sparsity is more like different parameters for each example and adaptive computation
00:23:36.960 | might be more different amount of flops and we have some of that with the token dropping,
00:23:42.560 | but that's kind of, you know, that's not the main motivation.
00:23:49.200 | Definitely as Barrett mentioned, I would say, you know, no one really has figured out adaptive
00:23:56.440 | computation yet for deep learning.
00:23:59.200 | And one reason is because we have these, you know, accelerators, right.
00:24:05.640 | Expect like sort of, you know, we need to work with like batch, like data parallelism,
00:24:11.600 | right.
00:24:12.600 | So, and all of our accelerators and our frameworks use this SPMD paradigm where we're kind of
00:24:19.320 | supposed to apply the same computation to examples.
00:24:25.200 | And so if you look at the literature, you have, you know, works like universal transformers
00:24:30.880 | where they replace the feed forward in the transformer by just a recurrent weight.
00:24:36.240 | And so it's kind of like an LSTM on each token and the LSTM can stop at different times based
00:24:42.760 | on some criteria.
00:24:45.440 | But the way these things are implemented is just through masking because it needs to be
00:24:51.320 | implemented in the SPMD programming style.
00:24:55.520 | And so definitely sparsity was kind of like easier to get to work first.
00:24:59.800 | And also there were some prior results with LSTM, so yeah.
00:25:06.600 | In terms of like the first question, you know, sort of what's the research question here
00:25:09.880 | is just like, oh, can we design more efficient models?
00:25:13.400 | And sparsity is this new axis that hasn't been explored that much.
00:25:17.240 | And yeah, I think that, you know, I'm happy with just that being the research question.
00:25:23.640 | Yeah.
00:25:24.640 | Great.
00:25:25.640 | Okay.
00:25:26.640 | Yeah.
00:25:27.640 | So next slide.
00:25:29.640 | Oops.
00:25:30.640 | Yeah.
00:25:31.640 | Again, so kind of putting it all together.
00:25:35.440 | So the switch transformer layer selects an expert, like just the top expert, and then
00:25:40.160 | incorporates a bunch of the general sparse model improvements to, you know, allow it
00:25:44.360 | to fine tune better, allow it to, you know, be more regularized, allow it to, you know,
00:25:49.840 | be trained with lower precision formats and a lot of like technical details to just get
00:25:52.520 | them training and working well.
00:25:55.680 | Yeah.
00:25:58.120 | So one thing that we also wanted to do was a comparison between like top one and top
00:26:02.200 | two routing since top two routing was kind of the, you know, most popular technique.
00:26:07.760 | And so here we can see we have two different dense models trained of different sizes.
00:26:10.400 | And we're going to be looking at like the, the pre-training like negative log perplexity.
00:26:14.640 | So yeah, the bigger the number, the better.
00:26:19.300 | So next slide.
00:26:20.900 | So, so, and what we're going to be doing is we're going to be studying them at different
00:26:24.560 | capacity factors.
00:26:26.140 | So a capacity factor of 2.0 basically means that there's enough buffer for two tokens
00:26:30.000 | to be sent to every single expert.
00:26:32.560 | And we're going to be comparing like top one versus top two routing and also comparing
00:26:36.280 | their speeds along with their like time to get some like threshold quality.
00:26:40.560 | Okay.
00:26:41.560 | Yeah.
00:26:42.560 | So here we can see in the capacity factor 2.0 case that the MOE models outperform switch
00:26:48.800 | transformer, which makes a lot of sense, like since switch transformer is only, you know,
00:26:52.800 | sending like a top one token to each expert, the mixture of expert is sending, you know,
00:26:57.920 | two tokens.
00:26:58.920 | So that makes sense that this extra buffer will be like disproportionately beneficial
00:27:01.960 | for the mixture of expert models.
00:27:04.120 | And so we noticed that and next slide or next now, when we, so the really interesting parts
00:27:11.020 | for the top one routing becomes when we lower the capacity factors.
00:27:14.940 | So having a high capacity factor is bad for many reasons.
00:27:17.800 | One of which is it really incurs more of these, you know, communication costs for sending
00:27:22.320 | tokens to the correct experts.
00:27:24.280 | It also incurs more compute costs and also incurs like a lot of memory overhead.
00:27:28.280 | So if you can get this lower, it's, it's usually like a very, very good thing.
00:27:32.760 | And so what we see here is that switch transformer actually outperforms mixture of experts when
00:27:37.460 | you have like a lower capacity factor.
00:27:40.400 | And we can see that the time to quality threshold, we you know, yeah, we, we get there much quicker.
00:27:46.500 | And so even across the 2.0 and the 1.25 capacity factors, like the kind of Pareto optimal thing
00:27:51.460 | we saw in our setup is to use switch transformer at a lower capacity factor, just due to the
00:27:56.600 | fact that while the quality is worse, a little bit worse on a step basis, it's just like
00:28:00.440 | much faster to run.
00:28:01.520 | So it's kind of the Pareto optimal decision.
00:28:05.220 | Next slide.
00:28:06.220 | And we can also be seeing that like for capacity factor 1.0, again, we can see that this really
00:28:11.660 | disproportionately benefits switch transformer and is even better on a Pareto standpoint
00:28:17.020 | than the 1.25 capacity factors.
00:28:20.060 | And interestingly, since, you know, MOE also does like a little bit more computation, we
00:28:24.000 | can also just increase the amount of compute done elsewhere in the model.
00:28:28.100 | And we can see that that's like a much more efficient allocation of compute.
00:28:31.740 | So yeah, overall, our takeaway is that, yeah, lower capacity factors using op one routing
00:28:37.240 | is more Pareto efficient than, you know, using like the top two routing at higher capacity
00:28:42.220 | factors.
00:28:43.220 | Next slide.
00:28:44.220 | Erwin, you can take it over.
00:28:47.980 | Okay.
00:28:48.980 | So next we'll look at how a switch transformer scales as a function of the number of exports
00:28:55.700 | in the switch layers.
00:28:59.220 | And so on the right side here, you see a plot that shows perplexity versus training steps
00:29:06.140 | for different switch architectures, ranging from T5 base, which is basically no export
00:29:12.860 | or a single export up to 128 exports.
00:29:17.940 | And so you see that as we increase the number of exports, which also increases the number
00:29:21.820 | of parameters, of sparse parameters, you get sort of speed ups, you know, you get increasing
00:29:30.840 | speed ups over the dense baseline.
00:29:32.940 | And they're like sort of diminishing returns to, you know, multiplying to, you know, increasing
00:29:39.000 | the number of exports as well.
00:29:44.580 | So the previous figure was looking at perplexity versus training steps.
00:29:50.780 | Here we look at perplexity versus strength time.
00:29:54.780 | So that includes, you know, all the, you know, additional communication costs when you have
00:30:00.900 | more exports or, you know, comparing to the dense baseline.
00:30:07.700 | And so this is for switch base or T5 base, and we observe up to 7x speedups over T5 base.
00:30:17.580 | And so, you know, just to maybe contextualize these numbers, like, you know, 7x speedups
00:30:24.740 | in deep learning are pretty hard to obtain.
00:30:26.620 | And so I think this is one of the, you know, one of the results that, you know, can spark
00:30:34.300 | a lot of interest in sparse models, even if it's only for pre-training for now, like just
00:30:39.260 | having that number is like, you know, maybe there's a significant, there's something significant
00:30:47.660 | that can be obtained here.
00:30:50.660 | Okay, so sparse scaling loss.
00:30:55.660 | So here we'll look at sort of loss versus sparse model parameters, which are increased
00:31:03.580 | by increasing the number of exports.
00:31:06.980 | And so similarly to the sort of, you know, normal scaling law paper, we observed that
00:31:13.380 | as you increase the parameters, which the sparse parameters and keep the flops fixed,
00:31:23.020 | you get diminishing, like consistent gains, but diminishing gains.
00:31:27.900 | Okay, so now we're going to compare export parallelism and model parallelism.
00:31:34.700 | So we introduced sparsity or export parallelism as a new dimension to scale models.
00:31:42.340 | But of course, that's the other one for dense model, which is simply model parallelism where,
00:31:48.660 | you know, model weights are partitioned across cores once they are above the maximum size
00:31:54.900 | that you can fit on a single core.
00:31:57.900 | All right, so, yeah, Bharath, I assume to the left is export parallelism here?
00:32:05.980 | Yeah, so essentially what we're doing is, yeah, we're kind of comparing a switch-based
00:32:11.020 | model versus the dense base.
00:32:13.660 | And we're also comparing against a larger dense model that has used model parallelism.
00:32:18.820 | And we can see that, you know, because basically when we want to scale a model size, we kind
00:32:22.380 | of have two axes that we can either go through.
00:32:24.340 | We can either increase the number of flops by scaling through model parallelism or increase
00:32:28.820 | the number of parameters by scaling through sparsity.
00:32:31.620 | And so we can see that, you know, even compared to like, you know, a dense model that's been
00:32:35.500 | scaled up through model parallelism, that sparsity is still at the scale, a more effective
00:32:39.020 | way to scale up the model by, you know, still getting 2.5x speedups over this larger dense
00:32:45.180 | model that was using model parallelism.
00:32:47.460 | Cool, so next slide.
00:32:50.060 | Yeah, basically here, T5 large is the dense model that uses model parallelism.
00:32:55.740 | Yeah, right, go ahead.
00:32:58.860 | Okay.
00:32:59.860 | Yeah, and so one thing that we also wanted to look at is like, you know, are these expert
00:33:05.740 | models effective if you have like, you know, really small amount of computer, just a small
00:33:09.220 | amount of experts?
00:33:10.220 | So typically when we're designing these models, like we have one expert per core.
00:33:14.580 | But if you don't have like a large cluster to run these things on, let's say you just
00:33:17.140 | have like a GPU with two cores or something, I guess having two experts more effective
00:33:22.140 | than just like a dense model.
00:33:24.060 | And the answer is yes.
00:33:25.220 | So we can see even pretty good scaling properties, even with like a tiny amount of experts, which
00:33:29.420 | is very, very promising for these models to be used even in like much lower compute regimes.
00:33:35.420 | Next slide.
00:33:36.420 | Or when you want to go ahead.
00:33:40.060 | Okay.
00:33:41.060 | Yeah.
00:33:42.060 | And so look at, you know, what things look like when we use different types of parallelism,
00:33:51.140 | namely expert parallelism to add experts, model parallelism to shard model weights across
00:33:57.060 | cores and also data parallelism, which is sort of the dominant paradigm in deep learning
00:34:02.700 | at the moment.
00:34:06.300 | And so, you know, I guess, you know, in the previous slides, we mostly talked about expert
00:34:12.100 | parallelism, but of course, you know, dense models and large scale dense models use model
00:34:17.220 | parallelism.
00:34:18.220 | So GP3 and these other large models, what they do is that they will simply shard model
00:34:23.220 | weights across different cores.
00:34:25.300 | Yeah.
00:34:26.300 | We have a question.
00:34:29.500 | Oh yeah.
00:34:31.740 | I just wanted to know, because I think there was like, I don't know if you're going to
00:34:35.420 | address later, but I think somewhere in a paper, it said that the more experts you have,
00:34:40.580 | the more sample efficient it gets.
00:34:43.380 | And I was just like hoping, hoping that you could give us some intuition about that, because
00:34:47.700 | I don't understand why that would be the case.
00:34:52.260 | So I guess, yeah, maybe, so I guess like, you know, there's all of this work on larger
00:34:59.700 | models are more sample efficient and larger in the context of the scaling law works means
00:35:04.760 | like more parameters and more flops.
00:35:06.980 | As you increase the number of experts, there's more parameters, but not more flops.
00:35:09.900 | But the model is still like, you know, larger and like, you know, a similar sense.
00:35:14.260 | So I guess like building on the intuition that larger models are more sample efficient
00:35:18.100 | in my mind, it's not necessarily that surprising that these models with more experts that have
00:35:24.020 | more parameters are more sample efficient.
00:35:27.500 | I guess that's my like kind of high level intuition for it.
00:35:30.100 | Yeah, I would say that's kind of expected that, you know, more experts leads to better
00:35:37.480 | sample efficiency, especially if you look at training step, right, in a training time.
00:35:46.760 | Okay, cool.
00:35:52.040 | So where were we?
00:35:53.700 | Yeah, so, yeah, so, okay, so we'll look at how model weights are split over cost for
00:35:59.360 | different scenarios.
00:36:01.560 | So data parallelism is the first one.
00:36:04.580 | So that's kind of the typical setup that deep learning uses, especially for not so large
00:36:11.840 | networks which don't require model parallelism.
00:36:16.040 | And so let me, yeah, let me explain how, yeah, I'll just go to the final figure and I'll
00:36:22.920 | explain how to look at this figure.
00:36:27.360 | Okay, so we have 16 processes which are organized in the four by four mesh, right?
00:36:33.980 | So each dotted line, each four by four dotted line here represents a different core.
00:36:41.160 | And the first row studies how the model weights are split over cost.
00:36:45.720 | And the second row illustrates how data, so literally examples and tokens are split over
00:36:52.080 | cost.
00:36:55.360 | And yeah, and then the final thing that's required to understand this figure is that
00:37:00.000 | each, yeah, each color of the shaded squares here identifies the unique weight matrix.
00:37:09.200 | Okay, so let's start with data parallelism.
00:37:13.140 | So for data parallelism, the same model weights are replicated across all cores.
00:37:20.800 | And the data is simply partitioned over cores.
00:37:23.920 | And so that's what this corresponds to, using the description of the caption, the explanation
00:37:34.800 | of the caption I just gave.
00:37:36.120 | So next we have model parallelism.
00:37:39.520 | That's kind of just like a theoretical example because in practice, people always use model
00:37:44.440 | parallelism in conjunction with data parallelism.
00:37:48.000 | But so if you were to do only model parallelism, now you would have a single model weight that
00:37:52.440 | is partitioned over all cores, and your data would just be replicated over all cores instead.
00:38:01.880 | So now we have model and data parallelism, and that's kind of the typical scenario for
00:38:05.920 | large dense networks.
00:38:08.040 | So in that case, model weights are partitioned among a subset of the cores, the subset of
00:38:14.240 | cores that process different batches of data.
00:38:16.920 | And so in that example here, we have sort of four, so the first sub-square here means
00:38:22.760 | that the model weights are partitioned across four cores.
00:38:31.080 | And this is replicated sort of four times for the data parallelism dimension.
00:38:38.320 | On the data side, for model and data parallelism, yeah, the data here is replicated across model
00:38:46.720 | parallel cores and partitioned across data parallel cores.
00:38:51.240 | So next we have expert and data parallelism.
00:38:55.320 | So in that scenario, that's kind of similar to data parallelism, but now each core will
00:38:59.520 | hold a different model weight, which is illustrated by the different colors.
00:39:06.260 | And for the data side, the data is simply replicated, sorry, the data is partitioned
00:39:12.400 | across all cores, just like in the data parallelism scenario.
00:39:18.160 | And so finally, we have the rightmost column, which is, I guess, yeah, that's the setup
00:39:25.800 | used in the switch transformer paper for the larger models.
00:39:30.920 | And so here for the model partitioning, each expert is partitioned across multiple cores.
00:39:37.020 | So in that example, we have four experts, each partitioned across four cores, and the
00:39:42.820 | data is replicated across model parallel cores and partitioned across data parallel cores.
00:39:48.500 | So that's a little bit complex to understand, really, but the switch transformer paper has
00:39:54.780 | a nice, the same figure with a nice caption to explain it.
00:40:00.360 | Yeah, maybe we can, about it, we can add something quickly about how this is implemented in practice.
00:40:09.820 | So there's this paper called Mesh Transformer, which kind of extends batch or data parallelism
00:40:18.780 | to more general purpose SPMD style programming.
00:40:23.900 | And so different labs have different frameworks, but this paper kind of lays the foundation
00:40:28.420 | for general SPMD distributed computing, which is required for training large scale models.
00:40:37.080 | And so under the mesh abstraction, basically we have a mesh of processes, and so that mesh
00:40:45.460 | has dimensions, name dimensions, and these name dimensions specify how the tensor dimensions
00:40:53.740 | will be partitioned or replicated across the mesh dimensions.
00:40:58.400 | And so just that simple abstraction sort of supports data parallelism, also model parallelism,
00:41:04.700 | and especially expert parallelism at once.
00:41:07.980 | And so I invite whoever is interested to also check that paper, because that kind of lays
00:41:16.060 | the foundation for understanding these things.
00:41:18.300 | All right, Barry, do you want to go?
00:41:22.180 | Cool.
00:41:23.180 | So next we are going to kind of talk about like how we take these parallelism strategies
00:41:26.300 | and like kind of combine them together to make like a 1.6 trillion parameter sparse
00:41:31.140 | model.
00:41:32.140 | So next slide.
00:41:35.220 | So what we ended up doing in this work was we trained two different very large sparse
00:41:41.820 | models, and we compared them to the largest T5 model.
00:41:44.540 | So we can see the T5 XXL, which is a dense model, and it was the largest one trained
00:41:48.820 | in the T5 paper, and it has around 13 billion parameters.
00:41:52.740 | And here we list a lot of the model dimensions like D model, DFF, which are just like the
00:41:56.140 | various sizes and shapes of the tensors and stuff, the number of layers, the number of
00:42:00.900 | heads.
00:42:01.900 | And importantly, we also mentioned the negative log perplexity at step 250k and at 500k.
00:42:08.980 | And so yeah, so we designed two sparse models to test like how scaling versus sparsity versus
00:42:16.740 | scaling versus sparsity and flops work.
00:42:19.600 | So first, let me talk about switch XXL.
00:42:21.660 | So that has the same amount of flops per token as T5 XXL, but has 64 experts.
00:42:27.200 | And this leads it to have around 400 billion parameters.
00:42:31.020 | And we can see that on a step basis, it actually performs quite well and outperforms the T5
00:42:35.620 | XXL by like quite a good margin.
00:42:37.700 | Interestingly, though, are the third model we designed switch C, which has 1.6 trillion
00:42:42.340 | parameters, but has a significantly fewer flops, almost 10 less flops per token than
00:42:47.700 | either of the above two models.
00:42:48.900 | So it's really trading by reducing flops to have way more sparse parameters.
00:42:54.660 | And we can see on a step basis, the switch C model works well, but not, not as well as
00:42:59.780 | actually the higher flop model, but on a, like a kind of a Pareto axis where we are
00:43:04.860 | looking at TPU hours on the X axis and not step the switch C model actually outperforms
00:43:09.500 | them both by like a pretty large margin.
00:43:12.060 | So for pre-training performance, we're seeing that actually just like having a lot of sparsity
00:43:15.900 | and less flops is actually, um, can be quite good.
00:43:20.060 | Next slide.
00:43:21.060 | Yeah.
00:43:22.060 | And so, yeah, this, so again, those two sparse models are kind of really trying to get at
00:43:25.700 | this hypothesis that actually Noam Shazir had, which is, you know, that, you know, parameters
00:43:30.060 | are good for more knowledge, reasoning and compute AKA flops is good for intelligence.
00:43:37.020 | And so we're going to kind of try to get at that by taking these different sparse models
00:43:39.900 | and then fine tuning them on, uh, different tasks, some of which require more like knowledge
00:43:44.140 | and then others, which require more of like reasoning, um, for whatever, like hand wavy
00:43:48.100 | definition we want to give that.
00:43:51.060 | So yeah.
00:43:52.060 | So for a fixed, Oh, go back.
00:43:53.380 | So yeah.
00:43:54.380 | So for a fixed, Oh, can you go back to the previous slide?
00:43:56.980 | Oh yeah.
00:43:57.980 | Sorry.
00:43:58.980 | Okay.
00:43:59.980 | So for a fixed quality on an upstream pre-training task, um, yeah.
00:44:03.060 | Do parameters independently matter?
00:44:05.360 | So we're going to look at two tasks here.
00:44:06.980 | One of which is super glue, which is kind of our like reasoning task.
00:44:09.860 | And then another is like trivia QA, which is like some knowledge task where it's like,
00:44:12.860 | you just give it a question, you have it output an answer.
00:44:16.860 | Okay.
00:44:19.300 | And so here we're going to take a look at super glue quality.
00:44:21.860 | So we can see on the X axis is the pre-training performance and the Y axis is the super glue
00:44:26.620 | score after fine tuning.
00:44:28.980 | And interestingly, we can see definitely that the sparse models definitely are for a fixed,
00:44:34.260 | um, pre-training perplexity do worse on fine tuning.
00:44:37.660 | This can be especially noticed at like the upper right portion of the plot where the
00:44:41.340 | dense models are definitely fine tuning better than the, their sparse counterpart.
00:44:47.140 | Next slide.
00:44:48.140 | Interestingly, when we study it on the more knowledge, heavy tasks, the sparse model for
00:44:52.940 | a fixed, uh, pre-training perplexity does disproportionately well.
00:44:57.100 | So, you know, for a model that roughly has the same perplexity, we're getting like really
00:45:00.260 | large boosts for these knowledge, heavy tasks.
00:45:03.300 | So this is pretty interesting.
00:45:04.300 | And it also really, you know, show some of the dangers of comparing only on your pre-training
00:45:09.220 | metrics.
00:45:10.220 | So dense models, you know, can have the same exact pre-training metric, but very different,
00:45:13.580 | um, you know, properties when fine tuning them on different tasks.
00:45:18.620 | Next slide.
00:45:21.060 | And interestingly, so yeah, all of the switch models here are the, um, are, are just like,
00:45:26.220 | you know, various models that have still a good amount of flops, but the red model is
00:45:30.780 | actually the 1.6 trillion parameter, uh, sparse model that has, you know, very few flops,
00:45:37.500 | but a lot, a lot of parameters.
00:45:38.500 | And we can see that as the red dot here, and it does actually disproportionately bad compared
00:45:42.580 | to other sparse models that also have pretty good perplexities.
00:45:46.180 | And so, yeah, it's, uh, it's definitely very interesting and it shows that, you know, for
00:45:49.540 | models during pre-training that have a lot of sparsity, they definitely suffer on some
00:45:53.340 | of these more reasoning heavy metrics, but do disproportionately well for more of these
00:45:57.960 | knowledge, heavy tasks.
00:46:00.060 | Next slide.
00:46:01.060 | Yeah.
00:46:02.060 | And so here we can see it as just like a huge outlier for a pre-training perplexity doing
00:46:07.380 | like just incredibly well on this, uh, downstream question answering task.
00:46:13.420 | Next slide.
00:46:14.420 | Yeah.
00:46:15.420 | Okay.
00:46:16.420 | So also, you know, one thing that we were going to do is just look at the fine tuning
00:46:19.760 | properties of sparse models across like a few scales and just see how they perform.
00:46:24.900 | Next slide.
00:46:25.900 | Yeah.
00:46:26.900 | And so here we try two different models.
00:46:29.220 | One is, um, T5 base, and then we make a flop match sparse counterpoint.
00:46:33.460 | And when they say flop match, it's like, you know, each token will have the same amount
00:46:36.420 | of flops, but now we just have experts.
00:46:38.600 | So we do this for both base and large, and we see that actually across almost all tasks
00:46:42.140 | besides two arc tasks, the sparse models perform quite well, which is, which is definitely
00:46:47.460 | promising.
00:46:48.460 | So we are seeing that these models are pretty robust, they pre-train well, and then they
00:46:51.380 | also fine tune well when scaled appropriately by scaling up both the flops and sparsity.
00:46:57.020 | Whereas, you know, the negative results we've really seen are like, yeah, when you just
00:47:00.460 | have a huge amount of sparsity and not too many flops.
00:47:04.180 | Next slide.
00:47:05.180 | Yeah.
00:47:06.180 | And one also thing we wanted to look at was, uh, the multilingual training.
00:47:10.620 | So we were previously studying all of this on like English only, and we also wanted to
00:47:14.300 | see how sparsity helps in the multilingual setting because, you know, we also felt like
00:47:18.060 | this would be a very natural place for sparsity to work well, or potentially experts could
00:47:21.540 | specialize across languages.
00:47:23.060 | Um, and we do see strong results.
00:47:25.500 | So on 91% of the languages, I think of like around a hundred languages, we see over like
00:47:30.500 | at least a 4x speedup over the MT5, um, dense model.
00:47:36.180 | Next slide.
00:47:37.180 | Erwin, you want to go ahead?
00:47:40.100 | Uh, no, go ahead.
00:47:42.860 | Okay.
00:47:43.860 | Yeah.
00:47:44.860 | So another thing we wanted to talk about was, um, distillation.
00:47:47.300 | So one downside of these sparse models is that they'll have a lot more parameters, which
00:47:51.940 | means that, you know, if you're serving these things or something, you either need like
00:47:55.140 | high throughput use cases, or you need to maybe distill it back down into like a smaller
00:47:59.300 | dense model.
00:48:00.700 | So here, what we do is we look at like the T5 base and switch base, and we look at its
00:48:03.980 | pre-training performance.
00:48:05.620 | And then we go through, um, some ablations of different distillation techniques and find
00:48:09.060 | that like with the best techniques, we can keep around 30% of the quality improvements
00:48:14.180 | of sparsity while distilling it back down into its, uh, dense, um, counterpart.
00:48:20.460 | So next slide.
00:48:22.420 | Yeah.
00:48:24.420 | And then we kind of study this across multiple scales.
00:48:26.380 | And again, we see like around like 30 to 40% of the gains can be, um, like, you know, kept
00:48:32.340 | when going from a dense model and going from, you know, a sparse model and distilling it
00:48:35.980 | back down until it gets flop match dense model.
00:48:38.500 | So you can get, you know, get rid of up to 99% of the parameters and still keep like
00:48:42.620 | around 30% of the improvements, which is very promising.
00:48:46.060 | Next slide.
00:48:47.060 | Wait, I'm sorry.
00:48:48.060 | Yeah.
00:48:49.060 | All right.
00:48:50.060 | Sorry about that.
00:48:51.060 | Can you say that last sentence again?
00:48:52.940 | You said that you can keep the benefit 30% of the teachers benefit.
00:48:56.900 | Yeah.
00:48:57.900 | Basically.
00:48:58.900 | So yeah, you, you, you, yeah, exactly.
00:49:00.820 | So yeah.
00:49:01.820 | So we're looking at like, yeah, you train a sparse model and then you just fill it back
00:49:06.260 | down to a dense model and you're versus training a dense model from scratch.
00:49:10.700 | And like you look at the gap between the sparse and dense model from scratch versus the, the,
00:49:15.380 | the gap between the dense and then the distilled dense model.
00:49:18.660 | Yeah.
00:49:19.660 | Yeah.
00:49:20.660 | Oh yeah.
00:49:21.660 | Oh yeah.
00:49:22.660 | Maybe let me just do like a quick high level summary again.
00:49:29.420 | So what we're, what we'll do is for our comparisons is we'll train a dense model from scratch.
00:49:34.580 | We'll train a sparse model from scratch and then we'll also run a third experiment where
00:49:38.740 | we distill that sparse model down into a dense model.
00:49:43.940 | What does distilling mean?
00:49:45.740 | Like we're basically trying to match the like the teacher's logics, like the kind of standard
00:49:50.660 | thing of like, you know, like matching the, like either the logics or like the soft probabilities
00:49:54.940 | for each token or something like that.
00:49:56.980 | Okay.
00:49:57.980 | If I can jump in with my question.
00:50:05.420 | So what I'm struggling with is how do I interpret the linements as percent of teacher and performance?
00:50:11.140 | Yeah.
00:50:12.140 | Okay.
00:50:13.140 | So it's, it's basically looking at the, like the gap between the dense and sparse model.
00:50:17.820 | So we'll have the dense model gets some performance, we'll have the sparse model gets some performance
00:50:21.500 | and then the, the dense model that's still from the sparse model would be somewhere in
00:50:25.920 | between that, that range.
00:50:28.020 | And we're basically saying it's 30% through that range.
00:50:31.280 | So it's like in like a zero to one interval, it's like 0.3 of the way from the dense to
00:50:35.340 | the sparse model.
00:50:36.340 | I see.
00:50:37.340 | So this is not saying that the percent of teacher performance does not mean that if
00:50:40.540 | the teacher say gets, if we use the teacher's guesses or predictions as the ground truth,
00:50:45.420 | this is not saying that the distilled model gets matches with the teacher, 33% of the
00:50:50.540 | time.
00:50:51.540 | No, no, exactly.
00:50:52.540 | It's basically saying you get like 30% of the quality improvements.
00:50:54.740 | Yeah, exactly.
00:50:55.740 | Okay, cool.
00:50:56.740 | And then if we can back up a slide, I had a different question, but I didn't want to
00:51:00.660 | interrupt.
00:51:01.660 | When we were talking about all of these different T5 bases, and then also on a few slides before
00:51:04.900 | this, I don't know that much about T5.
00:51:06.980 | I'm curious, you know, when T5 is trained, is there a weight penalty in the loss function?
00:51:13.860 | Is there a weight decay term?
00:51:15.580 | No, there's no weight decay trained with any of those sparse or dense models.
00:51:19.780 | I see.
00:51:20.780 | So out of curiosity then, how do dense models perform compared to the switch model?
00:51:25.500 | If you add some sort of weight regularization that incentivizes getting rid of useless weights?
00:51:31.180 | Oh, so some kind of like maybe like L1 term or something like that?
00:51:35.620 | Yeah.
00:51:36.620 | So I'm wondering like how much of, because here we're talking about the benefits of sparsity,
00:51:39.180 | and I'm wondering how much of this benefit from sparsity is due to the fact that just
00:51:43.740 | some of this, I mean, effectively what the switch model is doing, if I understand correctly,
00:51:46.900 | maybe I don't, what I understand is that the switch model, the feed forward layer, it's
00:51:50.500 | just like you fixing the weights to be zero.
00:51:53.460 | That's what it means to be sparse.
00:51:55.860 | Well, actually, we're kind of really trying to like inject more weights.
00:51:59.220 | So we're actually kind of trying to do, it's a little bit maybe like paradoxical, because
00:52:02.140 | we're saying switch transformer, but our idea is to be like, hey, we actually want to just
00:52:05.700 | have significantly more weights, not less weights.
00:52:08.580 | It's kind of like you would zero out weights, but within a much larger weight matrix, if
00:52:13.660 | that makes sense.
00:52:14.660 | I see.
00:52:16.660 | And so to me, it seems like a relevant baseline to just ask what happens if I have the dense
00:52:18.380 | matrix, but I incentivize it with, say, an L1 or L2 penalty on the weights.
00:52:21.860 | And I would, I'd be curious to know how that compares.
00:52:24.740 | Yeah, we didn't run this, but also that kind of gets rid of weights for the dense model.
00:52:28.580 | So if any-
00:52:29.580 | Sure, sure.
00:52:30.580 | Yeah.
00:52:31.580 | So, yeah.
00:52:32.580 | Also-
00:52:33.580 | Yeah.
00:52:34.580 | Yeah.
00:52:35.580 | Also, to me, it's like, if you just add like an L1 penalty loss, you're not going to have
00:52:39.460 | structured sparsity, whereas like here we, you know, it's not random weights in your
00:52:46.580 | giant weight matrix that are zeroed out, right?
00:52:48.660 | It's like really like blocks depending, like blocks corresponding to each expo.
00:52:53.020 | Right.
00:52:54.020 | And so-
00:52:55.020 | So that structure allows the whole like communication stuff and that's-
00:53:01.820 | That leverages the fact that you have multiple calls and so on, right?
00:53:05.420 | I totally agree with that block structure and that's what I'm trying to say, is that
00:53:08.860 | the switch has this very rich, it's not just sparse, it also has this rich structure.
00:53:12.780 | And what I'm trying to do in my mind is disentangle, is the sparsity what's offering an advantage
00:53:17.420 | or is this additional structure that you built in, is that what is the performance gain?
00:53:21.860 | So that's why I'm asking.
00:53:23.740 | So the block structure is what enables to leverage the fact that you have multiple calls.
00:53:32.340 | Like if you didn't have that block structure, you'd still have to route to everything.
00:53:36.980 | And so you have more communication costs and so on.
00:53:40.740 | And then your first question was what, sorry?
00:53:43.020 | I'm not actually sure if there was a question, I guess what I'm trying to say is I'm trying
00:53:46.780 | Yeah.
00:53:47.780 | Yeah, anyways.
00:53:48.780 | But I agree, it's a little bit weird because sparsity kind of, there's a spectrum of meaning
00:53:53.980 | for sparsity, right?
00:53:54.980 | So it's like, for example, compression and like model pruning is a form of sparsity,
00:54:00.220 | but also a switch transformer and MOE also referred to as sparsity and that kind of related,
00:54:07.340 | but definitely they're aiming at different things, so.
00:54:09.940 | This is a really interesting idea of it's sparse, but you have more parameters.
00:54:13.260 | I'll have to think about it more.
00:54:14.900 | Thank you.
00:54:15.900 | Yeah.
00:54:16.900 | So you have like sparse within this like giant weight matrix, which is-
00:54:20.620 | Exactly.
00:54:21.620 | Yeah.
00:54:22.620 | Yeah, yeah, yeah.
00:54:23.620 | I hadn't appreciated that.
00:54:24.620 | So I appreciate you pointing that out.
00:54:27.220 | Thank you.
00:54:28.220 | I have a follow up question on distillation part.
00:54:31.980 | Yeah, of course.
00:54:33.340 | Okay.
00:54:34.340 | So if you distill it back down, now you have like one technically, you're back to the dense
00:54:38.700 | layer architecture, right?
00:54:42.620 | So now the entire idea of expert is that certain tokens would be sent to different experts
00:54:48.140 | because they just like, I don't know, are more specialized in figuring something out
00:54:51.940 | about this token.
00:54:52.940 | So now if you go back to this like dense layer, aren't you like basically only serving whichever
00:55:04.140 | expert you base this dense layer on, like these tokens will probably perform well and
00:55:09.020 | all the other tokens are kind of like left behind, right?
00:55:13.140 | I'm actually, sorry, I don't think I'm fully understanding your question.
00:55:19.500 | So are you kind of getting at like we're distilling this on a specific data set?
00:55:24.060 | So that-
00:55:25.060 | No, I'm thinking of how to use that, like why-
00:55:27.620 | Yeah.
00:55:28.620 | Yeah.
00:55:29.620 | Yeah.
00:55:30.620 | So maybe concretely, like let's, so like for super glue, right?
00:55:31.620 | Like let's say you want to serve a model that does super glue well.
00:55:34.380 | I think the idea is that like you distill the sparse model into a dense model on super
00:55:38.040 | glue.
00:55:39.040 | So then you kind of get this compressed dense model that now performs better than if you
00:55:42.740 | were to just train it from scratch or train it from like a pre-trained dense model.
00:55:46.900 | So then it's like-
00:55:47.900 | Which model did you use though?
00:55:48.900 | Say that again?
00:55:50.900 | You have to pick one expert, right?
00:55:52.740 | No, no, no.
00:55:53.740 | You can just distill all of the, again, because you're just matching the model outputs.
00:55:57.380 | So you can just treat the sparse model as kind of like a black box thing.
00:56:00.180 | All we're doing is just trying to have the dense model match the actual like final like
00:56:03.540 | token predictions.
00:56:04.540 | Oh God.
00:56:05.540 | Okay.
00:56:06.540 | Got it.
00:56:07.540 | Okay.
00:56:08.540 | Yeah.
00:56:09.540 | Okay.
00:56:10.540 | Sorry.
00:56:11.540 | I was not, I was not familiar with the idea of distillation.
00:56:12.540 | So I think that was like my current confusion.
00:56:13.540 | Okay.
00:56:14.540 | Thanks.
00:56:15.540 | Yeah, of course.
00:56:16.540 | Yeah.
00:56:17.540 | Um, I guess one motivation here is that, um, having experts can make solving a little bit
00:56:24.660 | more difficult because, um, it requires bigger topologies.
00:56:29.300 | Let's say you have eight experts, um, you need like, well, I guess you can have multiple
00:56:35.140 | experts on fewer calls, but, um, you know, let's just say they're a little bit harder
00:56:41.500 | to solve.
00:56:42.500 | And so if we can, you know, get the benefits from sparsity at pre-training and then use
00:56:48.700 | distillation to a dense model for solving, uh, that can be, that can be beneficial.
00:56:54.740 | So I think that was sort of the motivation for that, uh, experiment, right, Derek?
00:57:01.340 | Yeah, exactly.
00:57:03.100 | Yeah.
00:57:04.100 | Okay.
00:57:05.100 | Well, are we, yeah.
00:57:06.100 | Yeah.
00:57:07.100 | So kind of just wrapping up.
00:57:08.100 | Yeah, go ahead.
00:57:09.100 | No, go ahead.
00:57:10.100 | I just said, I think one more string kind of question.
00:57:11.100 | So yeah.
00:57:12.100 | Oh yeah.
00:57:13.100 | Go ahead.
00:57:14.100 | I feel free to ask it now.
00:57:15.100 | Oh yeah.
00:57:16.100 | Yeah.
00:57:17.100 | Sounds good.
00:57:18.100 | Um, yeah.
00:57:19.100 | Thanks guys for the talk so far.
00:57:20.100 | Uh, just a quick question.
00:57:24.100 | Was wondering if you think there are any interesting directions around, uh, building models that
00:57:29.220 | are like explicitly optimized for, for parallel training.
00:57:33.660 | Um, I guess like the, the MOE model seems like, you know, it does a really good job
00:57:38.420 | here.
00:57:39.420 | And also like at, at inference time, it's very useful to like, you know, have fewer
00:57:43.820 | flops per, per computation, um, or per forward pass.
00:57:49.720 | But, um, I guess, do you think that there are any interesting directions around distributed
00:57:54.180 | training where you might have like models that are explicitly are architected to have
00:58:00.060 | a lot of, uh, parallel heads or, or other like features that are, you know, kind of
00:58:06.500 | embarrassingly parallelizable or does just using like standard, you know, scale up the
00:58:12.420 | models by adding more layers, uh, and then just, you know, get away with using model
00:58:17.300 | and data parallelism work well enough.
00:58:19.940 | Yeah.
00:58:20.940 | So I think, so yeah.
00:58:21.940 | So let me just make sure I'm fully understanding.
00:58:23.440 | So yeah, I think also like, you know, right now, like even our models are definitely very
00:58:26.780 | co-designed with the hardware and like the shapes and things, you know?
00:58:29.900 | Um, so yeah, I, I, I think at a high level, like, yes, I think there's a ton of interesting
00:58:33.760 | research on like co-designing the hardware, the partitioning algorithms and the models.
00:58:39.020 | I think given, you know, that we have this kind of like SPMD mesh style partitioning,
00:58:43.540 | we are already kind of designing our models in ways that fit it really well.
00:58:47.020 | So for example, when we want to scale up our model, one of the first dimensions we go to
00:58:50.380 | scale up is the internal hidden dimension.
00:58:52.980 | Because there's some really nice properties of scaling up this dimension.
00:58:55.140 | It basically becomes like, kind of, you know, independent to some of the communication costs.
00:58:59.040 | It's really good when looking at the compute to memory operations on these, you know, like,
00:59:02.820 | uh, compute devices and stuff.
00:59:04.940 | So yeah, exactly.
00:59:06.620 | Like I think when we're even designing these models, we're like really setting dimensions
00:59:09.620 | such that it maps well into hardware.
00:59:11.620 | Um, so it's almost like, you know, given that we have this model data parallelism, we're
00:59:15.340 | like actually designing models more for it.
00:59:18.300 | But I also think that there's a ton of new, interesting distributed algorithms and stuff
00:59:22.100 | like that, which makes designing models very interesting.
00:59:24.020 | Like I think one thing that I think is really cool is like the Microsoft zero partitioning
00:59:27.740 | too, which also adds some really new, like nice implications for like how to design and
00:59:31.780 | scale models and stuff.
00:59:32.900 | So yeah, I think there's like, this is a very fruitful research direction.
00:59:36.380 | Um, if that, if that kind of answered your question, yeah, no, that was super helpful.
00:59:41.340 | Interesting.
00:59:42.340 | Yeah.
00:59:43.340 | Yeah, definitely.
00:59:44.340 | Like I'm very optimistic on the future of us, like designing the hardware, the model,
00:59:48.020 | the partitioning strategies altogether, because really to get it to work well, you kind of
00:59:50.820 | have to know about all three and like kind of, you know, intertwined the development
00:59:54.380 | of them.
00:59:55.380 | Yeah.
00:59:56.380 | Yeah.
00:59:57.380 | That sounds awesome.
00:59:58.380 | Cool.
00:59:59.380 | Yeah.
01:00:00.380 | So just to summarize, it's like, yeah, so switch transformer is like a nice simplification
01:00:04.460 | over a mixture of experts.
01:00:05.740 | And we're seeing that we get really strong speed up improvements on pre-training over
01:00:09.900 | like a lot of the T5 models, which are very strong baselines.
01:00:13.140 | We're seeing that we can, you know, efficiently distill the sparse models back to dense ones
01:00:17.420 | and, you know, get improved both pre-training and fine tuning through some of these newer
01:00:21.580 | techniques we talked about.
01:00:23.500 | And we're also seeing that the models are working on multilingual data and that we can,
01:00:27.260 | you know, now easily successfully train up to, you know, 1.6 trillion parameter models,
01:00:31.700 | which is pretty promising and, um, next slide.
01:00:35.500 | And so we also wanted to go into two slides about some like newer work about actually
01:00:38.540 | using these kinds of models for computer vision, and actually also a little bit of how they
01:00:42.340 | can be used to actually do some level of like adaptive computation where not only now each
01:00:47.060 | input gets different weights, but also sometimes different inputs will have different amounts
01:00:51.240 | of compute applied to it.
01:00:53.420 | And so there was some really great work of doing this out of the Google Zurich team.
01:00:58.340 | And yeah, there's just doing it for image classification and, you know, they're basically
01:01:01.420 | seeing a lot of the similar types of scaling properties where, you know, scaling up the
01:01:04.780 | number of experts and using sparsity allows them to get good performances on image classification.
01:01:11.420 | Next slide.
01:01:14.780 | And interestingly, one of the things they do is like, as we talk about the capacity
01:01:17.740 | factor, so we were talking about values of like one, 1.25, 2.0, which means like at a
01:01:22.100 | value of 2.0, there's buffer for, you know, two tokens per expert, but they actually study
01:01:27.260 | it going less than one.
01:01:28.420 | So that means that like at 0.5, that means there's only like room for half the number
01:01:31.660 | of tokens.
01:01:33.140 | And the nice part is, is that they did this for image classification.
01:01:35.900 | And also in images, there's just a lot of redundancy and they noticed that you can actually
01:01:39.700 | get really good performance by only allowing like, you know, up to one 10th of the parts
01:01:45.320 | of the image to be processed by a sparse layer.
01:01:47.340 | So yeah, we think this is like a really nice direction too, in terms of combining sparsity
01:01:51.540 | along with like adaptive computation.
01:01:54.660 | And yeah.
01:01:55.660 | And yeah.
01:01:56.660 | Thanks so much for having us.
01:01:57.660 | That's the, that's the talk.
01:01:59.660 | So thank you, Barrett and, sorry, Arifan, for coming here.
01:02:08.300 | So I will just like ask a bunch of questions and then we can have like a, after the class,
01:02:17.180 | open question panel for the students.
01:02:20.020 | So one thing is like, have you tried using like, like more like linear attention mechanisms
01:02:23.700 | like reformers and like other stuff to like scale the computation?
01:02:31.020 | I personally haven't, I haven't personally done this.
01:02:34.620 | Yeah.
01:02:35.620 | So, oh, you know, I guess we can maybe comment on how, you know, the attention, the cost
01:02:44.540 | coming from the attention maps isn't the dominant cause in, in this large transformers.
01:02:52.640 | So you know, the motivation for using linear attention, like performance is that it reduces
01:02:58.260 | the quadratic cost of attention maps, right.
01:03:03.460 | But so far, I mean, at least, you know, in like sort of typical NLP setups, like superglue,
01:03:09.060 | C4 and so on, as you scale the models, most of the memory comes from the model weights
01:03:17.260 | as opposed to attention, to the attention maps.
01:03:20.460 | That's also because, you know, using very long context or sequence length doesn't prove
01:03:27.180 | that fruitful.
01:03:28.580 | And so, you know, just, you know, working with the vanilla self-attention mechanism
01:03:34.220 | is, is a very strong baseline already.
01:03:37.020 | Got it.
01:03:38.020 | Okay.
01:03:39.020 | So another question is like, do you think this like mechanism is even more scalable?
01:03:43.660 | Like, can you go on and be like 10 trillion parameter models, stuff like that?
01:03:47.560 | Like what do you think?
01:03:48.560 | Yeah, definitely.
01:03:49.560 | I think, yeah, totally.
01:03:50.560 | I think, honestly, the, one of the biggest constraints is that like, you know, and this
01:03:55.300 | isn't even necessarily a constraint, it's just like, you have to fit the parameter somewhere
01:03:59.860 | and there's just limited storage on devices.
01:04:01.780 | But if you get enough devices such that, you know, yeah, you can just partition the weights.
01:04:05.420 | It's like, yeah, I don't see anything stopping it.
01:04:07.940 | Got it.
01:04:08.940 | So what do you think, like, personally, is your, like, the thing, like, with the direction,
01:04:14.120 | like, like scaling of transformers will go into, like, will there be more like works
01:04:18.480 | that are trying to just like use this transformer, like mechanisms, like Mr. Experts, or do you
01:04:23.340 | think there's like, you're going to be other things that the community needs?
01:04:25.860 | Yeah.
01:04:26.860 | I mean, I definitely think mixture of experts should find its way, or at least, you know,
01:04:29.780 | sparse players like switch transformer and stuff will definitely, I think, find their
01:04:32.580 | way into like the future of large models.
01:04:34.120 | I think they really confer a lot of benefits and they're also very good in like high throughput
01:04:38.780 | applications.
01:04:39.780 | So I think the one thing, like, so the one downside is on sparsity is like, if you look
01:04:43.300 | at the performance per model weight, they're going to always be worse than bonds models.
01:04:47.700 | So it's like, if you really are constrained on like, I want to design the best model I
01:04:51.120 | can to fit on as small of a device as I can, then they're probably not going to be the
01:04:55.140 | best solution because the sparse weights just aren't as good as just the dense weight that's
01:04:59.320 | being used for everything.
01:05:01.020 | So I think it really depends on the application, but I'm very optimistic for when we're training
01:05:04.660 | these models during pre-training with lots of data parallelism, and then we're serving
01:05:07.740 | them in like medium to higher throughput examples.
01:05:10.220 | I feel like they could actually just be a pretty big win.
01:05:13.800 | So that that's kind of my thoughts on, on how I think sparsity will be used in terms
01:05:17.300 | of other things.
01:05:18.300 | Yeah, I think, I don't know.
01:05:19.300 | There's a ton of exciting research, you know, from everything from, yeah, like a lot of
01:05:21.740 | the linear attention stuff, adaptive computation, new pre-training objectives, you know, yeah,
01:05:27.260 | it's hard to know what the future will look like, but yeah, a lot of exciting things to
01:05:30.940 | look forward to.
01:05:31.940 | Great.
01:05:32.940 | Sounds good.
01:05:33.940 | Okay.
01:05:34.940 | So we can just now have like a round of student questions, so we'll just stop the recording.
01:05:37.780 | Okay.
01:05:38.780 | Okay.
01:05:39.780 | Great.
01:05:40.780 | Thank you.
01:05:40.780 | [BLANK_AUDIO]