Back to Index

Stanford CS25: V1 I Mixture of Experts (MoE) paradigm and the Switch Transformer


Chapters

0:0
0:7 Scaling Transformers through Sparsity
0:25 Overall Motivation
1:0 Scaling Laws for Neural Language Models
5:5 Switch Transformer
7:23 Improved Training Methodology
8:27 Differentiable Load Balancing
8:55 Selected Precision
10:25 The Initialization Scale
16:23 Multi-Stage Routing Procedure
20:35 What Is the Research Question
29:50 Perplexity versus Strength Time
30:53 Spot Scaling Laws
37:10 Data Parallelism
37:36 Model Parallelism
38:51 Expert and Data Parallelism
39:31 Model Partitioning
40:37 Mesh Abstraction
46:18 Fine-Tuning Properties of Sparse Models
47:8 Multilingual Training
47:45 Distillation

Transcript

Today, Erwin and I are going to be giving a talk on scaling transformers through sparsity. And the kind of sparsity we're going to be talking about today is the kind where each input can get either a different set of weights or have a different amount of computation applied to it.

Erwin, do you want to start it off? So, I guess the overall motivation for this line of work is that the community has realized that scale is perhaps one of the most important axes to focus on for obtaining strong performance. And there's almost like this ongoing arms race right now with different labs and different institutions competing for training the largest models.

And so, maybe this dates back from early 2020 with a paper from OpenAI called Scaling Laws for Neural Language Models, where they find that model performance follows a predictable power law, scale as a power law with model size in terms of either compute or just parameters. And so, this scaling law generalizes over multiple orders of magnitude, and that gives us the confidence that if we are to train very large models, we can expect a certain performance just by extrapolating these scaling laws.

So, in that paper, they also find the interesting observation that basically larger models are more sample efficient. And so, if you have a fixed compute budget, you can predict what is the size, what is the optimal model size for a fixed compute budget. And the overall observation is that you'd rather train very large models for less tests than train smaller models for more training steps.

And so, these models are scaled through basically the paper focuses on dense models, where you just increase the model dimensions, but they're not looking at sparsity. And so, sparsity is a new dimension that you can use to scale architectures, and this is sort of the focus of the talk.

And so, the sparsity we're mentioning here is basically you will have sparsely activated weights based on the network inputs. So, every input will go to a roughly similar amount of computation, but will be applied different weights. And so, this dates back to 1991 with a paper called Adaptive Mixtures of Local Experts, and was recently revisited by Noam Shazier and colleagues at Google Brain with LSTMs, where they replaced sort of the feed-forward networks in LSTMs with a mixture of experts.

And so, the way this works there roughly is that you will have multiple experts each implementing a small network, or in that case, I think just a dense matrix multiplication. And so, you have an additional gating network shown in green here that outputs a probability distribution over experts that each token should be sent to.

So, this probability distribution is computed as a softmax, and once you have it, you select a few experts. So, there are different strategies, maybe we'll talk about it later on. And the output is simply sort of the weighted mixture of all selected experts' outputs. So, they've been pretty successful primarily in translation, but there were some complexities that hindered their broader use in NLP.

And so, the Switch Transformer paper addresses some of those, and we'll be discussing how to fix training instabilities or reduce communication costs and reduce model complexity. All right, Barrett, do you want to go? Yeah. So, one kind of approach that we're going to have for sparsity is the Switch Transformer, which is kind of like a simplified mixture of expert variant along with some other improved training and fine-tuning techniques that allow it to be stably trained and also perform better when fine-tuned on a lot of downstream tasks.

And so, yeah, so the Switch Transformer kind of model works as the following. So, you have some transformer model that has self-attention and feed-forward layers. And the idea is that we replace maybe one every two or one every four feed-forward layers with a Switch Transformer layer. So, you can see on the left is like one kind of layer block, which is self-attention, then add normalize, then a feed-forward layer, then add normalize.

And in this case, we're replacing the normal feed-forward layer with the Switch layer. And we can see an illustration of this on the right. So, on the right, we can see that the layer has two inputs. One is the token more, the other is the token parameters. And we can see that these embedding representations will get sent to a router, which is exactly how it works in the mixture of expert.

So, the router is basically just going to be getting a distribution over all of the experts. So, in this case, we can see that the highest probability is going to the expert number two out of the four experts. And then the right token is actually having the most probability on the first feed-forward weight, which is like the first expert.

So, yeah, we can see here that what we're going to do is in the Switch Transformer, which is very simple. It's just send it to the highest probability expert. And so, here we can see where the adaptive computation lies, where we'll have four sets of weights. There's some shared weights and computation across all the tokens.

For example, the self-attention layer is computed exactly the same for the more token and for the parameters token. But in the sparse Switch layer, we can see that actually the inputs are, while having the same amount of floating point operations applied to them, actually have different weight matrices. Next slide.

Yeah. So, that's the kind of high-level idea with Switch Transformer, is that instead of sending a token to multiple different experts, which can also increase the communication costs, as I'll go into a little bit later, it also just significantly simplifies the algorithm by just only sending it to one expert.

So, for the improved training methodology, we focused on three different things to help improve the training of sparse models. The first was selected precision, which allows these sparse models to be trained in lower precision formats, which is incredibly important. Most of the models we train, we really don't want to be using float 32, because it's just slower to compute.

And also, when you're communicating tensors across different processes and stuff, it's twice as slow, just because there's twice as many things. Also, we have some initialization tricks and some training tricks as well for allowing them to be trained more stably, especially as the models grow in size, which is like a new initialization method, along with a change to the learning rate schedule.

And third, since that our models have so many more parameters, we do notice definitely different overfitting dynamics, especially once we fine tune these models that have been pre-trained on all of the internet on these small tasks with maybe only 50 to 100,000 examples, that they can be much more prone to overfitting.

So we also look at some custom regularization to help prevent some of the overfitting that we observe. And finally, we also talk about this differentiable load balancing technique we make, which kind of allows each expert to roughly get the same amount of tokens. Because this is very important, especially given that we want the stuff to be efficient on hardware.

We want roughly each expert to have similar amounts of tokens sent to it. And so to kind of encourage this, we tack on an additional load balancing loss along with our cross-entropy loss that we're training with. Next slide. OK. So here, I'm going to go into selected precision. So yeah, again, so when we're training large models, it's really important that we should be able to train them in lower precision formats.

So instead of each weight being an activation, being 32 bits, we want to shrink it down to 16 bits. And we use the bfloat16 representation. And what we found out of the gate is that these models are just unstable, especially the sparse models are much more unstable than the dense models in terms of you'll train it for 10,000, 20,000 steps, and then the losses would just diverge.

This was something that we frequently encountered. And so one key thing that we found is that basically, you need to be casting a part of the computation in float32 for these models to be able to be trained stably. And the key component that we found that you need to cast is the router computation.

And essentially, we can go into the technical details a little bit more later. But basically, any time that there's these exponentiation functions, it's very important that we are having higher and higher precision because of round off errors that can then drastically change the output of some kind of exponentiation function.

So for example, if you have an exponentiation function and you change it by 0.1 or 0.2 or 0.3, this can drastically change the output of exponentiating it, especially depending on how large the input is. So yeah, so this was a very important thing. And it basically doesn't change the compute at all and allows the models to just be significantly more stable.

Next slide. So the second thing we looked at is also the initialization scale. So like the standard way that we were initializing these models, we found to also just make the models much more prone to being unstable and/or just performing worse. So one thing that we did that we found was very effective was to just simply make the initialization scale much smaller.

And when we did this, we found that the quality just drastically improved, and it was like a very simple fix. Next slide. And the third thing I mentioned, where since we noticed that these models are much more prone to overfitting, since they just have significantly more parameters, is that we also use much more dropout for the expert layers only.

So here we can see we have the T5 base, which is a dense model, and then we have a bunch of different switch variants on that. And we found to be the most effective on these four different fine-tuning tasks was just to really significantly increase the dropout rate inside the expert layers.

And we found that this was pretty effective for combating the overfitting. Next slide. Better? Yeah. We have a question. Oh, awesome. OK. For one of the students. Yeah. OK, let me take a look. Do you want to go ahead? Yeah, I can ask. It was just in reference to the previous table where you have throughput and precision.

It just seems surprising to me that you could match this 1390 number using selective precision. It seems like I would expect it to be something in between. Yeah. So it essentially comes down to the fact that there's maybe a little bit of noise sampled with the speed. And the only part we're casting is the router, which is maybe such an insignificant portion of the computation.

And there's zero communication there. That is essentially like a free operation in the network. So whether you cast it to VFLOW16 or FLOW32, it doesn't actually impact the speed at all within the precision that we can actually measure the speed. And also, these architectures only use sparse layer once, one every four layers.

And so, yeah, essentially, the FLOW32 part is kind of very negligible in the entire architecture. It's like, for example, I think off the top of my head, it's like 1/40th the computation that would cost for you to do the first weight matrix multiply in a dense, ReLU dense layer or something.

So it's a very, very small part. And yeah, we're not using them very frequently, like Erwin mentioned as well. Got it, OK, thanks. Yeah, and then just a quick point on this. I won't go into some of the technical details, but yeah, we definitely-- since we're training these things on hardware, we really-- I think a big part of the mixture of experts paradigm is that these things are designed such that it maps really efficiently to hardware.

So we want to be doing dense matrix multiplies. And for this to work really well, we also want to be able to have roughly equal amount of tokens going to each of the different experts. And I think this isn't that sensitive to the load balancing formulation. Like, we tried a few things.

A lot of them worked. But yeah, essentially, you definitely want some kind of load balancing loss added on when using sparsity. Yeah, next slide. Yeah, Erwin, go ahead. Yeah, so the frameworks, the library we use rely on static shapes for-- OK, yeah, so XLA, so the compiler for TensorFlow and MeshTensorFlow expects static shapes for tensors.

However, the computations in switch transformers are dynamic because of the router, right? Different inputs will be routed to different experts. And so we need to specify ahead of time how many tokens will be sent to each expert. And so we will introduce this expert capacity hyperparameter to specify that.

And that's going to be a static number which says how many tokens each expert can process. And so in practice, we instead parametrize this by having a quantity called the capacity factor. And so we have an example here. So the bottom row is a bunch of tokens on one device.

And then you need to sort of route those tokens to multiple devices or multiple experts. So if too many tokens are routed to a single expert, some tokens will be dropped because, as we said, experts have a fixed capacity. So that's the example on the left where the capacity factor is one, and that basically means that there's no extra buffer for routing tokens.

So instead of that, we can use the capacity factor that's larger than one. So on the right, you have an example with 1.5. So that means that now each expert has three slots that can process three tokens. And so that prevents token dropping because we have more capacity. But the issue is that this means more expensive communication across devices.

One thing that we also experimented with was this method called no token left behind. And the idea was the following. So since we have to have a fixed batch size for each expert, and there can be token dropping, we're thinking that, hey, yeah, having tokens dropped or having some tokens not having any computation applied to it is probably hurting the model performance.

So what if we do a multistage routing procedure? So first, you do the normal routing where it's like you send each token to its highest probability expert. But then any dropped tokens, you then send to their second highest probability expert, and so forth and so on. Or you can basically repeat this process to guarantee that no tokens are being dropped.

Interestingly, actually, this approach didn't empirically improve model performance. If anything, it actually kind of hurt it. And we thought that was actually very interesting. And I think the intuition is that, you know, once the model learns it wants to send a token to one expert, like it really wants to have that computation applied to it.

And just applying some other computation doesn't, you know, have at all the same property, along with it actually maybe being potentially detrimental. So yeah, we thought that was pretty interesting, as we were very optimistic this would potentially, you know, get improved performance, but it ended up not really making a difference.

And we found this quite surprising. We have a question from… I think it will actually kind of like address literally the last point that you brought up. I think when I think about like a mixture of experts, usually like they specialize in like different things, right? So I think it was like, just like a lot, like I was just wondering, like if you send it to like the second best or whatever, like what if like all of your tokens would be particularly good for like one expert, and then you only like process, let's say, like 20% of your tokens.

So that ends up being better than rerouting them to anything else. Exactly. Yeah. So yeah, even if you're dropping a lot of tokens, it's not beneficial to be sending them to the second, third or fourth best thing. And one actually interesting property that we, you know, noticed about these models is they're surprisingly robust to token dropping, especially during fine tuning.

So yeah. So in the standard paradigm, what we'll do is we'll pre-train this thing, we'll have some load balancing loss, which makes the tokens pretty balanced actually. But then during fine tuning, where it's like, we really want to fine tune it on a specific task. We actually studied this exact question and we were studying, does it help to have a load balancing loss during fine tuning or not?

And so if you have the load balancing loss, yeah, that kind of is encouraging, you know, for the specific task, we want to try to have, you know, all the experts be used versus turning it off. Whereas there's definitely some, you know, prior specialization and it's actually much better to just turn the auxiliary loss off.

And even if it's like, you know, 60 to 70% of the tokens are being dropped, that actually performs much better than, you know, having all the tokens balanced. But doesn't a load balancing loss encourage basically all the experts to learn very similar weights and then just randomly assign your tokens?

Because then it doesn't matter to which expert stuff is being sent to. So when we use the load balancing loss, like the routing mechanism is definitely learned. So the model definitely is encouraged to, you know, choose an expert that it wants to send it to for good. Right. But like if all the experts learn the same weights, then the router learns basically, oh, it doesn't matter where I send it to.

So if you encourage load balancing, you encourage technically that like you want any loss to fit with any expert. Right. I mean, that's maybe the extreme behavior if you have a very high sort of load balancing loss coefficient, but in practice that coefficient is kind of tuned and we observe that for, you know, small enough values, the router still learns like semantic, like meaningful routing.

Yeah. Because it's like a balance between this, like, you know, cross entropy loss and this load balancing loss. And so on one hand, yeah, you definitely want to encourage the model to be balanced. Then on the other hand, you also want to just get good empirical performance. And yeah, the model is able to definitely like on one hand, learn and specialize the experts where they have different weights such that it's like, you know, definitely it expects certain tokens to be sent to certain experts, but on the other hand, still be reasonably balanced so that the models are efficiently run on like modern hardware.

Exactly. We also have a question from the classroom. So the question that I want to ask is, it seems to me like this is a very experimental talk. We're talking about floating point precision. We're talking about different approaches and currently work well. And whenever we're dealing with clients, there's a question of what is the research question?

And I feel like I missed that. So what are we trying to answer with all these experiments? Yeah. I think the, I think the high level of research question is like, you know, can we, you know, create models that are, you know, like doing adaptive computation from the standpoint of like, no, can we try to make models more simulate the dynamics that we think models should most naturally use, which is different inputs to have different amounts of computation applied, have different weights applied to them, you know, and basically all of this, basically we're trying to research and like figure out how can we create like a new framework for these models to be trained as opposed to their dense counterparts that, you know, for every input are always having the same exact computation applied.

So that's interesting because when you say the same exact computation applied, one might imagine that like, to me, the immediate thing is about how long to deliberate about something. What I mean by that is if we want to have variable length computation, you could imagine that I could have a short amount of computation or it could have much older computation, but there's like, you have like, why then do we instead consider the dimension of different computation?

I mean, assuming of course that these experts do indeed learn different things, which I think you'll get to in a minute. Yeah. So why do we immediately jump to thinking about specialized experts as opposed to thinking about variable length computation? So, yeah, so this is actually, we actually go into some variable length computation stuff later in the talk.

And I feel like they're both actually just important axes that should both be pushed on. I think, I guess, yeah, I guess it's kind of, you know, yeah, I'm not afraid of my question, but what I'm trying to understand is you're thinking about why did you decide to attack this one first?

I want to understand why your team chose to go this direction first. Yeah, absolutely. So I think that one empirically, it seems that sparsity has led to better empirical results in the field of deep learning than adaptive computation so far. And I think the way that we use these things maps really well to our modern hardware, which is also very promising.

And I think the way we were kind of looking at it as like sparsity is like a first step towards doing more interesting and general adaptive computation where, and we're, and you know, cause I think it's like, you know, this stuff is complicated and typically starting from something that works well is better than necessarily like, you know, you know, trying something that's not necessarily as proven out and then trying to like get it to work really well.

So I think we're kind of starting from sparsity, which like, you know, Noam Shazier and others got to work really well in the context of LSTMs. We were kind of interested in, you know, let's port some of this to transformers. Let's get it working really well. And then let's slowly start expanding towards a lot of the other natural questions that you mentioned.

Whereas like, okay, whereas instead of, you know, different weights per core, let's also maybe have different computation per core and all of this. So that's, I guess how we were kind of building the natural, like, you know, buildup and progression of our research. Got it. Cool. Thank you. Yeah.

What do you think Erwin? Anything else to add? Yeah. I mean, I guess I kind of see adaptive computation and sparsity as, you know, related, but separate things. So, you know, sparsity is more like different parameters for each example and adaptive computation might be more different amount of flops and we have some of that with the token dropping, but that's kind of, you know, that's not the main motivation.

Definitely as Barrett mentioned, I would say, you know, no one really has figured out adaptive computation yet for deep learning. And one reason is because we have these, you know, accelerators, right. Expect like sort of, you know, we need to work with like batch, like data parallelism, right. So, and all of our accelerators and our frameworks use this SPMD paradigm where we're kind of supposed to apply the same computation to examples.

And so if you look at the literature, you have, you know, works like universal transformers where they replace the feed forward in the transformer by just a recurrent weight. And so it's kind of like an LSTM on each token and the LSTM can stop at different times based on some criteria.

But the way these things are implemented is just through masking because it needs to be implemented in the SPMD programming style. And so definitely sparsity was kind of like easier to get to work first. And also there were some prior results with LSTM, so yeah. In terms of like the first question, you know, sort of what's the research question here is just like, oh, can we design more efficient models?

And sparsity is this new axis that hasn't been explored that much. And yeah, I think that, you know, I'm happy with just that being the research question. Yeah. Great. Okay. Yeah. So next slide. Yep. Oops. Yeah. Again, so kind of putting it all together. So the switch transformer layer selects an expert, like just the top expert, and then incorporates a bunch of the general sparse model improvements to, you know, allow it to fine tune better, allow it to, you know, be more regularized, allow it to, you know, be trained with lower precision formats and a lot of like technical details to just get them training and working well.

Yeah. So one thing that we also wanted to do was a comparison between like top one and top two routing since top two routing was kind of the, you know, most popular technique. And so here we can see we have two different dense models trained of different sizes. And we're going to be looking at like the, the pre-training like negative log perplexity.

So yeah, the bigger the number, the better. So next slide. So, so, and what we're going to be doing is we're going to be studying them at different capacity factors. So a capacity factor of 2.0 basically means that there's enough buffer for two tokens to be sent to every single expert.

And we're going to be comparing like top one versus top two routing and also comparing their speeds along with their like time to get some like threshold quality. Okay. Yeah. So here we can see in the capacity factor 2.0 case that the MOE models outperform switch transformer, which makes a lot of sense, like since switch transformer is only, you know, sending like a top one token to each expert, the mixture of expert is sending, you know, two tokens.

So that makes sense that this extra buffer will be like disproportionately beneficial for the mixture of expert models. And so we noticed that and next slide or next now, when we, so the really interesting parts for the top one routing becomes when we lower the capacity factors. So having a high capacity factor is bad for many reasons.

One of which is it really incurs more of these, you know, communication costs for sending tokens to the correct experts. It also incurs more compute costs and also incurs like a lot of memory overhead. So if you can get this lower, it's, it's usually like a very, very good thing.

And so what we see here is that switch transformer actually outperforms mixture of experts when you have like a lower capacity factor. And we can see that the time to quality threshold, we you know, yeah, we, we get there much quicker. And so even across the 2.0 and the 1.25 capacity factors, like the kind of Pareto optimal thing we saw in our setup is to use switch transformer at a lower capacity factor, just due to the fact that while the quality is worse, a little bit worse on a step basis, it's just like much faster to run.

So it's kind of the Pareto optimal decision. Next slide. And we can also be seeing that like for capacity factor 1.0, again, we can see that this really disproportionately benefits switch transformer and is even better on a Pareto standpoint than the 1.25 capacity factors. And interestingly, since, you know, MOE also does like a little bit more computation, we can also just increase the amount of compute done elsewhere in the model.

And we can see that that's like a much more efficient allocation of compute. So yeah, overall, our takeaway is that, yeah, lower capacity factors using op one routing is more Pareto efficient than, you know, using like the top two routing at higher capacity factors. Next slide. Erwin, you can take it over.

Okay. So next we'll look at how a switch transformer scales as a function of the number of exports in the switch layers. And so on the right side here, you see a plot that shows perplexity versus training steps for different switch architectures, ranging from T5 base, which is basically no export or a single export up to 128 exports.

And so you see that as we increase the number of exports, which also increases the number of parameters, of sparse parameters, you get sort of speed ups, you know, you get increasing speed ups over the dense baseline. And they're like sort of diminishing returns to, you know, multiplying to, you know, increasing the number of exports as well.

So the previous figure was looking at perplexity versus training steps. Here we look at perplexity versus strength time. So that includes, you know, all the, you know, additional communication costs when you have more exports or, you know, comparing to the dense baseline. And so this is for switch base or T5 base, and we observe up to 7x speedups over T5 base.

And so, you know, just to maybe contextualize these numbers, like, you know, 7x speedups in deep learning are pretty hard to obtain. And so I think this is one of the, you know, one of the results that, you know, can spark a lot of interest in sparse models, even if it's only for pre-training for now, like just having that number is like, you know, maybe there's a significant, there's something significant that can be obtained here.

Okay, so sparse scaling loss. So here we'll look at sort of loss versus sparse model parameters, which are increased by increasing the number of exports. And so similarly to the sort of, you know, normal scaling law paper, we observed that as you increase the parameters, which the sparse parameters and keep the flops fixed, you get diminishing, like consistent gains, but diminishing gains.

Okay, so now we're going to compare export parallelism and model parallelism. So we introduced sparsity or export parallelism as a new dimension to scale models. But of course, that's the other one for dense model, which is simply model parallelism where, you know, model weights are partitioned across cores once they are above the maximum size that you can fit on a single core.

All right, so, yeah, Bharath, I assume to the left is export parallelism here? Yeah, so essentially what we're doing is, yeah, we're kind of comparing a switch-based model versus the dense base. And we're also comparing against a larger dense model that has used model parallelism. And we can see that, you know, because basically when we want to scale a model size, we kind of have two axes that we can either go through.

We can either increase the number of flops by scaling through model parallelism or increase the number of parameters by scaling through sparsity. And so we can see that, you know, even compared to like, you know, a dense model that's been scaled up through model parallelism, that sparsity is still at the scale, a more effective way to scale up the model by, you know, still getting 2.5x speedups over this larger dense model that was using model parallelism.

Cool, so next slide. Yeah, basically here, T5 large is the dense model that uses model parallelism. Yeah, right, go ahead. Okay. Yeah, and so one thing that we also wanted to look at is like, you know, are these expert models effective if you have like, you know, really small amount of computer, just a small amount of experts?

So typically when we're designing these models, like we have one expert per core. But if you don't have like a large cluster to run these things on, let's say you just have like a GPU with two cores or something, I guess having two experts more effective than just like a dense model.

And the answer is yes. So we can see even pretty good scaling properties, even with like a tiny amount of experts, which is very, very promising for these models to be used even in like much lower compute regimes. Next slide. Or when you want to go ahead. Okay. Yeah.

And so look at, you know, what things look like when we use different types of parallelism, namely expert parallelism to add experts, model parallelism to shard model weights across cores and also data parallelism, which is sort of the dominant paradigm in deep learning at the moment. And so, you know, I guess, you know, in the previous slides, we mostly talked about expert parallelism, but of course, you know, dense models and large scale dense models use model parallelism.

So GP3 and these other large models, what they do is that they will simply shard model weights across different cores. Yeah. We have a question. Oh yeah. I just wanted to know, because I think there was like, I don't know if you're going to address later, but I think somewhere in a paper, it said that the more experts you have, the more sample efficient it gets.

And I was just like hoping, hoping that you could give us some intuition about that, because I don't understand why that would be the case. So I guess, yeah, maybe, so I guess like, you know, there's all of this work on larger models are more sample efficient and larger in the context of the scaling law works means like more parameters and more flops.

As you increase the number of experts, there's more parameters, but not more flops. But the model is still like, you know, larger and like, you know, a similar sense. So I guess like building on the intuition that larger models are more sample efficient in my mind, it's not necessarily that surprising that these models with more experts that have more parameters are more sample efficient.

I guess that's my like kind of high level intuition for it. Yeah, I would say that's kind of expected that, you know, more experts leads to better sample efficiency, especially if you look at training step, right, in a training time. Okay, cool. So where were we? Yeah, so, yeah, so, okay, so we'll look at how model weights are split over cost for different scenarios.

So data parallelism is the first one. So that's kind of the typical setup that deep learning uses, especially for not so large networks which don't require model parallelism. And so let me, yeah, let me explain how, yeah, I'll just go to the final figure and I'll explain how to look at this figure.

Okay, so we have 16 processes which are organized in the four by four mesh, right? So each dotted line, each four by four dotted line here represents a different core. And the first row studies how the model weights are split over cost. And the second row illustrates how data, so literally examples and tokens are split over cost.

And yeah, and then the final thing that's required to understand this figure is that each, yeah, each color of the shaded squares here identifies the unique weight matrix. Okay, so let's start with data parallelism. So for data parallelism, the same model weights are replicated across all cores. And the data is simply partitioned over cores.

And so that's what this corresponds to, using the description of the caption, the explanation of the caption I just gave. So next we have model parallelism. That's kind of just like a theoretical example because in practice, people always use model parallelism in conjunction with data parallelism. But so if you were to do only model parallelism, now you would have a single model weight that is partitioned over all cores, and your data would just be replicated over all cores instead.

So now we have model and data parallelism, and that's kind of the typical scenario for large dense networks. So in that case, model weights are partitioned among a subset of the cores, the subset of cores that process different batches of data. And so in that example here, we have sort of four, so the first sub-square here means that the model weights are partitioned across four cores.

And this is replicated sort of four times for the data parallelism dimension. On the data side, for model and data parallelism, yeah, the data here is replicated across model parallel cores and partitioned across data parallel cores. So next we have expert and data parallelism. So in that scenario, that's kind of similar to data parallelism, but now each core will hold a different model weight, which is illustrated by the different colors.

And for the data side, the data is simply replicated, sorry, the data is partitioned across all cores, just like in the data parallelism scenario. And so finally, we have the rightmost column, which is, I guess, yeah, that's the setup used in the switch transformer paper for the larger models.

And so here for the model partitioning, each expert is partitioned across multiple cores. So in that example, we have four experts, each partitioned across four cores, and the data is replicated across model parallel cores and partitioned across data parallel cores. So that's a little bit complex to understand, really, but the switch transformer paper has a nice, the same figure with a nice caption to explain it.

Yeah, maybe we can, about it, we can add something quickly about how this is implemented in practice. So there's this paper called Mesh Transformer, which kind of extends batch or data parallelism to more general purpose SPMD style programming. And so different labs have different frameworks, but this paper kind of lays the foundation for general SPMD distributed computing, which is required for training large scale models.

And so under the mesh abstraction, basically we have a mesh of processes, and so that mesh has dimensions, name dimensions, and these name dimensions specify how the tensor dimensions will be partitioned or replicated across the mesh dimensions. And so just that simple abstraction sort of supports data parallelism, also model parallelism, and especially expert parallelism at once.

And so I invite whoever is interested to also check that paper, because that kind of lays the foundation for understanding these things. All right, Barry, do you want to go? Cool. So next we are going to kind of talk about like how we take these parallelism strategies and like kind of combine them together to make like a 1.6 trillion parameter sparse model.

So next slide. So what we ended up doing in this work was we trained two different very large sparse models, and we compared them to the largest T5 model. So we can see the T5 XXL, which is a dense model, and it was the largest one trained in the T5 paper, and it has around 13 billion parameters.

And here we list a lot of the model dimensions like D model, DFF, which are just like the various sizes and shapes of the tensors and stuff, the number of layers, the number of heads. And importantly, we also mentioned the negative log perplexity at step 250k and at 500k.

And so yeah, so we designed two sparse models to test like how scaling versus sparsity versus scaling versus sparsity and flops work. So first, let me talk about switch XXL. So that has the same amount of flops per token as T5 XXL, but has 64 experts. And this leads it to have around 400 billion parameters.

And we can see that on a step basis, it actually performs quite well and outperforms the T5 XXL by like quite a good margin. Interestingly, though, are the third model we designed switch C, which has 1.6 trillion parameters, but has a significantly fewer flops, almost 10 less flops per token than either of the above two models.

So it's really trading by reducing flops to have way more sparse parameters. And we can see on a step basis, the switch C model works well, but not, not as well as actually the higher flop model, but on a, like a kind of a Pareto axis where we are looking at TPU hours on the X axis and not step the switch C model actually outperforms them both by like a pretty large margin.

So for pre-training performance, we're seeing that actually just like having a lot of sparsity and less flops is actually, um, can be quite good. Next slide. Yeah. And so, yeah, this, so again, those two sparse models are kind of really trying to get at this hypothesis that actually Noam Shazir had, which is, you know, that, you know, parameters are good for more knowledge, reasoning and compute AKA flops is good for intelligence.

And so we're going to kind of try to get at that by taking these different sparse models and then fine tuning them on, uh, different tasks, some of which require more like knowledge and then others, which require more of like reasoning, um, for whatever, like hand wavy definition we want to give that.

So yeah. So for a fixed, Oh, go back. So yeah. So for a fixed, Oh, can you go back to the previous slide? Oh yeah. Sorry. Okay. So for a fixed quality on an upstream pre-training task, um, yeah. Do parameters independently matter? So we're going to look at two tasks here.

One of which is super glue, which is kind of our like reasoning task. And then another is like trivia QA, which is like some knowledge task where it's like, you just give it a question, you have it output an answer. Okay. And so here we're going to take a look at super glue quality.

So we can see on the X axis is the pre-training performance and the Y axis is the super glue score after fine tuning. And interestingly, we can see definitely that the sparse models definitely are for a fixed, um, pre-training perplexity do worse on fine tuning. This can be especially noticed at like the upper right portion of the plot where the dense models are definitely fine tuning better than the, their sparse counterpart.

Next slide. Interestingly, when we study it on the more knowledge, heavy tasks, the sparse model for a fixed, uh, pre-training perplexity does disproportionately well. So, you know, for a model that roughly has the same perplexity, we're getting like really large boosts for these knowledge, heavy tasks. So this is pretty interesting.

And it also really, you know, show some of the dangers of comparing only on your pre-training metrics. So dense models, you know, can have the same exact pre-training metric, but very different, um, you know, properties when fine tuning them on different tasks. Next slide. And interestingly, so yeah, all of the switch models here are the, um, are, are just like, you know, various models that have still a good amount of flops, but the red model is actually the 1.6 trillion parameter, uh, sparse model that has, you know, very few flops, but a lot, a lot of parameters.

And we can see that as the red dot here, and it does actually disproportionately bad compared to other sparse models that also have pretty good perplexities. And so, yeah, it's, uh, it's definitely very interesting and it shows that, you know, for models during pre-training that have a lot of sparsity, they definitely suffer on some of these more reasoning heavy metrics, but do disproportionately well for more of these knowledge, heavy tasks.

Next slide. Yeah. And so here we can see it as just like a huge outlier for a pre-training perplexity doing like just incredibly well on this, uh, downstream question answering task. Next slide. Yeah. Okay. So also, you know, one thing that we were going to do is just look at the fine tuning properties of sparse models across like a few scales and just see how they perform.

Next slide. Yeah. And so here we try two different models. One is, um, T5 base, and then we make a flop match sparse counterpoint. And when they say flop match, it's like, you know, each token will have the same amount of flops, but now we just have experts. So we do this for both base and large, and we see that actually across almost all tasks besides two arc tasks, the sparse models perform quite well, which is, which is definitely promising.

So we are seeing that these models are pretty robust, they pre-train well, and then they also fine tune well when scaled appropriately by scaling up both the flops and sparsity. Whereas, you know, the negative results we've really seen are like, yeah, when you just have a huge amount of sparsity and not too many flops.

Next slide. Yeah. And one also thing we wanted to look at was, uh, the multilingual training. So we were previously studying all of this on like English only, and we also wanted to see how sparsity helps in the multilingual setting because, you know, we also felt like this would be a very natural place for sparsity to work well, or potentially experts could specialize across languages.

Um, and we do see strong results. So on 91% of the languages, I think of like around a hundred languages, we see over like at least a 4x speedup over the MT5, um, dense model. Next slide. Erwin, you want to go ahead? Uh, no, go ahead. Okay. Yeah. So another thing we wanted to talk about was, um, distillation.

So one downside of these sparse models is that they'll have a lot more parameters, which means that, you know, if you're serving these things or something, you either need like high throughput use cases, or you need to maybe distill it back down into like a smaller dense model. So here, what we do is we look at like the T5 base and switch base, and we look at its pre-training performance.

And then we go through, um, some ablations of different distillation techniques and find that like with the best techniques, we can keep around 30% of the quality improvements of sparsity while distilling it back down into its, uh, dense, um, counterpart. So next slide. Yeah. And then we kind of study this across multiple scales.

And again, we see like around like 30 to 40% of the gains can be, um, like, you know, kept when going from a dense model and going from, you know, a sparse model and distilling it back down until it gets flop match dense model. So you can get, you know, get rid of up to 99% of the parameters and still keep like around 30% of the improvements, which is very promising.

Next slide. Wait, I'm sorry. Yeah. All right. Sorry about that. Can you say that last sentence again? You said that you can keep the benefit 30% of the teachers benefit. Yeah. Basically. So yeah, you, you, you, yeah, exactly. So yeah. So we're looking at like, yeah, you train a sparse model and then you just fill it back down to a dense model and you're versus training a dense model from scratch.

And like you look at the gap between the sparse and dense model from scratch versus the, the, the gap between the dense and then the distilled dense model. Yeah. Yeah. Oh yeah. Oh yeah. Maybe let me just do like a quick high level summary again. So what we're, what we'll do is for our comparisons is we'll train a dense model from scratch.

We'll train a sparse model from scratch and then we'll also run a third experiment where we distill that sparse model down into a dense model. What does distilling mean? Like we're basically trying to match the like the teacher's logics, like the kind of standard thing of like, you know, like matching the, like either the logics or like the soft probabilities for each token or something like that.

Okay. If I can jump in with my question. So what I'm struggling with is how do I interpret the linements as percent of teacher and performance? Yeah. Okay. So it's, it's basically looking at the, like the gap between the dense and sparse model. So we'll have the dense model gets some performance, we'll have the sparse model gets some performance and then the, the dense model that's still from the sparse model would be somewhere in between that, that range.

And we're basically saying it's 30% through that range. So it's like in like a zero to one interval, it's like 0.3 of the way from the dense to the sparse model. I see. So this is not saying that the percent of teacher performance does not mean that if the teacher say gets, if we use the teacher's guesses or predictions as the ground truth, this is not saying that the distilled model gets matches with the teacher, 33% of the time.

No, no, exactly. It's basically saying you get like 30% of the quality improvements. Yeah, exactly. Okay, cool. And then if we can back up a slide, I had a different question, but I didn't want to interrupt. When we were talking about all of these different T5 bases, and then also on a few slides before this, I don't know that much about T5.

I'm curious, you know, when T5 is trained, is there a weight penalty in the loss function? Is there a weight decay term? No, there's no weight decay trained with any of those sparse or dense models. I see. So out of curiosity then, how do dense models perform compared to the switch model?

If you add some sort of weight regularization that incentivizes getting rid of useless weights? Oh, so some kind of like maybe like L1 term or something like that? Yeah. So I'm wondering like how much of, because here we're talking about the benefits of sparsity, and I'm wondering how much of this benefit from sparsity is due to the fact that just some of this, I mean, effectively what the switch model is doing, if I understand correctly, maybe I don't, what I understand is that the switch model, the feed forward layer, it's just like you fixing the weights to be zero.

That's what it means to be sparse. Well, actually, we're kind of really trying to like inject more weights. So we're actually kind of trying to do, it's a little bit maybe like paradoxical, because we're saying switch transformer, but our idea is to be like, hey, we actually want to just have significantly more weights, not less weights.

It's kind of like you would zero out weights, but within a much larger weight matrix, if that makes sense. I see. Yes. And so to me, it seems like a relevant baseline to just ask what happens if I have the dense matrix, but I incentivize it with, say, an L1 or L2 penalty on the weights.

And I would, I'd be curious to know how that compares. Yeah, we didn't run this, but also that kind of gets rid of weights for the dense model. So if any- Sure, sure. Yeah. So, yeah. Also- Yeah. Yeah. Also, to me, it's like, if you just add like an L1 penalty loss, you're not going to have structured sparsity, whereas like here we, you know, it's not random weights in your giant weight matrix that are zeroed out, right?

It's like really like blocks depending, like blocks corresponding to each expo. Right. And so- So that structure allows the whole like communication stuff and that's- Yes. That leverages the fact that you have multiple calls and so on, right? I totally agree with that block structure and that's what I'm trying to say, is that the switch has this very rich, it's not just sparse, it also has this rich structure.

And what I'm trying to do in my mind is disentangle, is the sparsity what's offering an advantage or is this additional structure that you built in, is that what is the performance gain? So that's why I'm asking. So the block structure is what enables to leverage the fact that you have multiple calls.

Yes. Like if you didn't have that block structure, you'd still have to route to everything. And so you have more communication costs and so on. And then your first question was what, sorry? I'm not actually sure if there was a question, I guess what I'm trying to say is I'm trying to- Yeah.

Yeah, anyways. But I agree, it's a little bit weird because sparsity kind of, there's a spectrum of meaning for sparsity, right? So it's like, for example, compression and like model pruning is a form of sparsity, but also a switch transformer and MOE also referred to as sparsity and that kind of related, but definitely they're aiming at different things, so.

This is a really interesting idea of it's sparse, but you have more parameters. I'll have to think about it more. Thank you. Yeah. So you have like sparse within this like giant weight matrix, which is- Exactly. Yeah. Yeah, yeah, yeah. I hadn't appreciated that. So I appreciate you pointing that out.

Thank you. I have a follow up question on distillation part. Yeah, of course. Okay. So if you distill it back down, now you have like one technically, you're back to the dense layer architecture, right? So now the entire idea of expert is that certain tokens would be sent to different experts because they just like, I don't know, are more specialized in figuring something out about this token.

So now if you go back to this like dense layer, aren't you like basically only serving whichever expert you base this dense layer on, like these tokens will probably perform well and all the other tokens are kind of like left behind, right? I'm actually, sorry, I don't think I'm fully understanding your question.

So are you kind of getting at like we're distilling this on a specific data set? So that- No, I'm thinking of how to use that, like why- Yeah. Yeah. Yeah. So maybe concretely, like let's, so like for super glue, right? Like let's say you want to serve a model that does super glue well.

I think the idea is that like you distill the sparse model into a dense model on super glue. So then you kind of get this compressed dense model that now performs better than if you were to just train it from scratch or train it from like a pre-trained dense model.

So then it's like- Which model did you use though? Say that again? You have to pick one expert, right? No, no, no. You can just distill all of the, again, because you're just matching the model outputs. So you can just treat the sparse model as kind of like a black box thing.

All we're doing is just trying to have the dense model match the actual like final like token predictions. Oh God. Okay. Got it. Okay. Yeah. Okay. Sorry. I was not, I was not familiar with the idea of distillation. So I think that was like my current confusion. Okay. Thanks.

Yeah, of course. Yeah. Um, I guess one motivation here is that, um, having experts can make solving a little bit more difficult because, um, it requires bigger topologies. Let's say you have eight experts, um, you need like, well, I guess you can have multiple experts on fewer calls, but, um, you know, let's just say they're a little bit harder to solve.

And so if we can, you know, get the benefits from sparsity at pre-training and then use distillation to a dense model for solving, uh, that can be, that can be beneficial. So I think that was sort of the motivation for that, uh, experiment, right, Derek? Yeah, exactly. Yeah. Okay.

Well, are we, yeah. Yeah. So kind of just wrapping up. Yeah, go ahead. No, go ahead. I just said, I think one more string kind of question. So yeah. Oh yeah. Go ahead. I feel free to ask it now. Oh yeah. Yeah. Sounds good. Um, yeah. Thanks guys for the talk so far.

Uh, just a quick question. Was wondering if you think there are any interesting directions around, uh, building models that are like explicitly optimized for, for parallel training. Um, I guess like the, the MOE model seems like, you know, it does a really good job here. And also like at, at inference time, it's very useful to like, you know, have fewer flops per, per computation, um, or per forward pass.

But, um, I guess, do you think that there are any interesting directions around distributed training where you might have like models that are explicitly are architected to have a lot of, uh, parallel heads or, or other like features that are, you know, kind of embarrassingly parallelizable or does just using like standard, you know, scale up the models by adding more layers, uh, and then just, you know, get away with using model and data parallelism work well enough.

Yeah. So I think, so yeah. So let me just make sure I'm fully understanding. So yeah, I think also like, you know, right now, like even our models are definitely very co-designed with the hardware and like the shapes and things, you know? Um, so yeah, I, I, I think at a high level, like, yes, I think there's a ton of interesting research on like co-designing the hardware, the partitioning algorithms and the models.

I think given, you know, that we have this kind of like SPMD mesh style partitioning, we are already kind of designing our models in ways that fit it really well. So for example, when we want to scale up our model, one of the first dimensions we go to scale up is the internal hidden dimension.

Because there's some really nice properties of scaling up this dimension. It basically becomes like, kind of, you know, independent to some of the communication costs. It's really good when looking at the compute to memory operations on these, you know, like, uh, compute devices and stuff. So yeah, exactly. Like I think when we're even designing these models, we're like really setting dimensions such that it maps well into hardware.

Um, so it's almost like, you know, given that we have this model data parallelism, we're like actually designing models more for it. But I also think that there's a ton of new, interesting distributed algorithms and stuff like that, which makes designing models very interesting. Like I think one thing that I think is really cool is like the Microsoft zero partitioning too, which also adds some really new, like nice implications for like how to design and scale models and stuff.

So yeah, I think there's like, this is a very fruitful research direction. Um, if that, if that kind of answered your question, yeah, no, that was super helpful. Interesting. Yeah. Yeah, definitely. Like I'm very optimistic on the future of us, like designing the hardware, the model, the partitioning strategies altogether, because really to get it to work well, you kind of have to know about all three and like kind of, you know, intertwined the development of them.

Yeah. Yeah. That sounds awesome. Cool. Yeah. So just to summarize, it's like, yeah, so switch transformer is like a nice simplification over a mixture of experts. And we're seeing that we get really strong speed up improvements on pre-training over like a lot of the T5 models, which are very strong baselines.

We're seeing that we can, you know, efficiently distill the sparse models back to dense ones and, you know, get improved both pre-training and fine tuning through some of these newer techniques we talked about. And we're also seeing that the models are working on multilingual data and that we can, you know, now easily successfully train up to, you know, 1.6 trillion parameter models, which is pretty promising and, um, next slide.

And so we also wanted to go into two slides about some like newer work about actually using these kinds of models for computer vision, and actually also a little bit of how they can be used to actually do some level of like adaptive computation where not only now each input gets different weights, but also sometimes different inputs will have different amounts of compute applied to it.

And so there was some really great work of doing this out of the Google Zurich team. And yeah, there's just doing it for image classification and, you know, they're basically seeing a lot of the similar types of scaling properties where, you know, scaling up the number of experts and using sparsity allows them to get good performances on image classification.

Next slide. And interestingly, one of the things they do is like, as we talk about the capacity factor, so we were talking about values of like one, 1.25, 2.0, which means like at a value of 2.0, there's buffer for, you know, two tokens per expert, but they actually study it going less than one.

So that means that like at 0.5, that means there's only like room for half the number of tokens. And the nice part is, is that they did this for image classification. And also in images, there's just a lot of redundancy and they noticed that you can actually get really good performance by only allowing like, you know, up to one 10th of the parts of the image to be processed by a sparse layer.

So yeah, we think this is like a really nice direction too, in terms of combining sparsity along with like adaptive computation. And yeah. And yeah. Thanks so much for having us. That's the, that's the talk. So thank you, Barrett and, sorry, Arifan, for coming here. So I will just like ask a bunch of questions and then we can have like a, after the class, open question panel for the students.

So one thing is like, have you tried using like, like more like linear attention mechanisms like reformers and like other stuff to like scale the computation? I personally haven't, I haven't personally done this. Yeah. So, oh, you know, I guess we can maybe comment on how, you know, the attention, the cost coming from the attention maps isn't the dominant cause in, in this large transformers.

So you know, the motivation for using linear attention, like performance is that it reduces the quadratic cost of attention maps, right. But so far, I mean, at least, you know, in like sort of typical NLP setups, like superglue, C4 and so on, as you scale the models, most of the memory comes from the model weights as opposed to attention, to the attention maps.

That's also because, you know, using very long context or sequence length doesn't prove that fruitful. And so, you know, just, you know, working with the vanilla self-attention mechanism is, is a very strong baseline already. Got it. Okay. So another question is like, do you think this like mechanism is even more scalable?

Like, can you go on and be like 10 trillion parameter models, stuff like that? Like what do you think? Yeah, definitely. I think, yeah, totally. I think, honestly, the, one of the biggest constraints is that like, you know, and this isn't even necessarily a constraint, it's just like, you have to fit the parameter somewhere and there's just limited storage on devices.

But if you get enough devices such that, you know, yeah, you can just partition the weights. It's like, yeah, I don't see anything stopping it. Got it. So what do you think, like, personally, is your, like, the thing, like, with the direction, like, like scaling of transformers will go into, like, will there be more like works that are trying to just like use this transformer, like mechanisms, like Mr.

Experts, or do you think there's like, you're going to be other things that the community needs? Yeah. I mean, I definitely think mixture of experts should find its way, or at least, you know, sparse players like switch transformer and stuff will definitely, I think, find their way into like the future of large models.

I think they really confer a lot of benefits and they're also very good in like high throughput applications. So I think the one thing, like, so the one downside is on sparsity is like, if you look at the performance per model weight, they're going to always be worse than bonds models.

So it's like, if you really are constrained on like, I want to design the best model I can to fit on as small of a device as I can, then they're probably not going to be the best solution because the sparse weights just aren't as good as just the dense weight that's being used for everything.

So I think it really depends on the application, but I'm very optimistic for when we're training these models during pre-training with lots of data parallelism, and then we're serving them in like medium to higher throughput examples. I feel like they could actually just be a pretty big win. So that that's kind of my thoughts on, on how I think sparsity will be used in terms of other things.

Yeah, I think, I don't know. There's a ton of exciting research, you know, from everything from, yeah, like a lot of the linear attention stuff, adaptive computation, new pre-training objectives, you know, yeah, it's hard to know what the future will look like, but yeah, a lot of exciting things to look forward to.

Great. Sounds good. Okay. So we can just now have like a round of student questions, so we'll just stop the recording. Okay. Okay. Great. Thank you.