back to index2024 in Post-Transformer Architectures: State Space Models, RWKV [Latent Space LIVE! @ NeurIPS 2024]

00:00:08.620 |
So this is gonna be a little bit of a two-part presentation. 00:00:14.240 |
and I'll be joining UCSD as faculty in about a year. 00:00:42.900 |
and then afterwards, Eugene will tell us a little bit 00:00:47.440 |
and the latest frontier models in this space. 00:00:54.280 |
So this is probably a figure or something like this 00:01:00.920 |
we've seen models really scale up in parameter size, 00:01:03.640 |
and that's brought with it a bunch of new capabilities, 00:01:05.640 |
like the ability to talk to you and tell you sometimes 00:01:15.480 |
especially recently, is scaling in context length. 00:01:18.680 |
So this can mean just having more text inputs 00:01:28.000 |
image inputs to your models, or generating lots of outputs. 00:01:34.080 |
over the last few months or so is that we're seeing scaling, 00:01:37.680 |
not only during training time, but also during test time. 00:01:39.920 |
So this is one of the, this is the iconic image 00:01:45.280 |
Not only are we starting to scale train time compute, 00:01:47.840 |
but we're also starting to scale test time compute. 00:01:55.820 |
this graph on the right might look a little bit scary. 00:01:58.640 |
And one of the reasons is that the implications 00:02:09.960 |
bigger, bigger data centers, spending more flops? 00:02:13.320 |
Is this, this little Dolly 3, we need more flops guy, 00:02:20.720 |
Or is there a better way, another path forward? 00:02:27.900 |
but for a lot less compute, a lot less flops. 00:02:30.160 |
And one of the things that we're gonna talk about today 00:02:33.300 |
is specifically looking at that core attention operator 00:02:44.100 |
but attention has compute that scales quadratically 00:02:50.020 |
like test time compute, and you want to spend a bunch 00:03:02.380 |
One of the questions that we're interested in is, 00:03:10.260 |
Can we scale and let's say N to the three halves 00:03:22.860 |
and ideas that have shown over the past few years 00:03:29.140 |
that shown promise that this might actually be possible, 00:03:32.080 |
that you can actually get potentially the same quality 00:03:42.480 |
that we're gonna look is we're gonna start to see how, 00:03:45.020 |
so this is a basic graph of just the past couple of years 00:03:48.220 |
of progress of perplexity where that blue line, 00:03:52.580 |
it's your basic transformer, full dense attention. 00:03:55.100 |
And then the dots coming down are some of the methods 00:04:00.640 |
We're gonna turn the clock back all the way to 2020. 00:04:04.460 |
So this question of, can we make attention sub-quadratic? 00:04:09.180 |
Basically, as soon as we said, attention is all you need, 00:04:13.420 |
So we have this quadratic attention operator, 00:04:17.460 |
I'll briefly talk about why attention is quadratic. 00:04:19.860 |
And the basic thing that happens if you're not familiar 00:04:23.020 |
is that you have these inputs, these keys and queries, 00:04:27.740 |
this S matrix over here is that you're using, 00:04:33.860 |
So when I try to do something like upload a whole book 00:04:36.340 |
to Gemini, what happens beyond the, or maybe not Gemini, 00:04:39.620 |
'cause we don't necessarily know what architecture is, 00:04:43.580 |
what happens behind the scenes is that it's gonna take 00:04:59.620 |
And what attention does in particular is the, 00:05:03.020 |
and then what attention, sorry, don't wanna, okay. 00:05:10.180 |
instead of always operating in this quadratic thing, 00:05:15.700 |
and then multiplies it by this values matrix. 00:05:17.660 |
So one of the key points to notice is that the output size 00:05:26.340 |
So one of the first things that folks tried to do 00:05:28.340 |
around 2020 is this thing called linear attention, 00:05:30.500 |
which is just noticing that if we take out this softmax 00:05:41.160 |
you actually never hit this quadratic bottleneck. 00:05:54.060 |
or try to approximate this overall attention computation. 00:05:57.620 |
But some of this work sort of started to hit a wall in 2020 00:06:05.600 |
Back then it was kind of hard to get good quality 00:06:11.620 |
The other one was actually hardware efficiency. 00:06:13.460 |
So this feature map that was just shown by Simplify here 00:06:18.260 |
actually ends up being quite computationally expensive 00:06:28.820 |
but also they're actually just wall clock slower. 00:06:30.740 |
So you kind of end up getting the worst of both worlds. 00:06:36.620 |
So that kind of sets the SAGE for four years ago. 00:06:46.260 |
But one of the works that started kicking off 00:06:48.340 |
this mini revolution in post-transformer architectures 00:06:54.680 |
So here the seminal work is one about our work in 2022. 00:06:59.500 |
And this piece of work really brought together a few ideas 00:07:03.420 |
from some long running research lines of work. 00:07:10.460 |
The first one was, and this is really one of the keys 00:07:28.360 |
with how we model dynamical systems in signal processing 00:07:32.660 |
and then using those ideas to model the inputs, 00:07:39.060 |
a transformer-like next token prediction architecture. 00:07:42.300 |
So some of those early state-space model papers 00:07:44.780 |
were looking at this relatively simple recurrent update 00:07:55.680 |
about how you should do that recurrent update 00:08:01.560 |
out of your hidden state, out of your sequence. 00:08:19.980 |
there were stuff in time series, time series analysis. 00:08:24.420 |
You started to see the quality tick up in meaningful ways. 00:08:29.300 |
But the other key thing that was so influential 00:08:36.780 |
about how you can compute these things efficiently. 00:08:41.020 |
So if you go back to your machine learning 101 class, 00:08:46.020 |
is that they don't paralyze as well as detention, 00:08:51.060 |
you have to do this kind of sequential update 00:08:56.840 |
you can process all the tokens in parallel at one time. 00:09:04.220 |
you could take them and you could also formulate them 00:09:09.780 |
you could, instead of using a PyTorch conv1d operation, 00:09:20.820 |
that was relatively well optimized for modern hardware. 00:09:24.460 |
So those are really, I'd say the two key ideas in 2022 00:09:28.380 |
that started allowing these breakthroughs to happen 00:09:33.700 |
So these ideas about how to principally model, 00:09:53.580 |
so afterwards, we started putting out some work 00:09:58.700 |
So just like we have flash attention for transformers, 00:10:05.620 |
oftentimes whenever you see a new architecture, 00:10:12.740 |
so that you can actually get wall clock speed up? 00:10:14.780 |
So by 2022, 2023, we were starting to have these models 00:10:24.980 |
where they were better than transformers in meaningful ways. 00:10:27.980 |
That being said, there were still sometimes a quality gap, 00:10:33.580 |
And because language is so core to what we do 00:10:45.940 |
so you have this recurrent state that you're keeping around 00:10:48.600 |
that just summarizes everything that came before, 00:10:53.620 |
one of the things that you really need to be able to do 00:11:04.800 |
in a line of work called H3, Hungry, Hungry Hippos, 00:11:15.580 |
So versions of these ideas have been around for decades. 00:11:29.420 |
and then you can see quality start to pick up. 00:11:35.940 |
this also takes the selection to the next level 00:11:47.620 |
but also you can actually make the ABCD matrices 00:11:54.860 |
which will allow you to even better select out 00:12:02.420 |
if you look at the bottom right of this figure, 00:12:03.980 |
there's this little triangle with a GPU SRAM, GPU HBM, 00:12:16.940 |
that it can be hardware efficient on modern hardware. 00:12:34.320 |
linear attention actually started to come back. 00:12:38.120 |
there's a model called BASED from Simran Arora 00:12:44.600 |
a more principled version of linear attention 00:12:54.600 |
combined that with a simple sliding window attention 00:12:57.200 |
and was starting to be able to expand the Pareto frontier 00:13:01.540 |
of how much data can you recall from your sequence 00:13:04.820 |
versus how small is your recurrent state size. 00:13:35.020 |
So this was a really cool paper called Just Read Twice 00:13:45.700 |
that they can sometimes have unfair advantages 00:14:00.060 |
and then you're gonna ask some question about it. 00:14:03.060 |
One problem you might imagine for a recurrent model 00:14:11.580 |
and you're trying to ask about some really niche thing. 00:14:14.900 |
You can imagine it might be hard for the model 00:14:17.540 |
what information to put into the hidden state. 00:14:26.940 |
write down the document, write down the question, 00:14:47.140 |
of the more efficient architectures that we're having here. 00:14:50.680 |
So one of the other, I think, influential ideas 00:14:54.580 |
if you change the fundamental compute capabilities 00:15:00.260 |
you can actually start to query it at test time differently. 00:15:04.260 |
goes back to those slides on test time compute. 00:15:09.020 |
test time compute for big transformer models, 00:15:12.340 |
I think potentially a really interesting research question 00:15:35.800 |
instead, take ideas that we know from other fields, 00:15:44.760 |
Another key idea throughout all these lines of work 00:15:47.240 |
is you really want hardware and kernel support from day one. 00:15:51.160 |
So even if your model is theoretically more efficient, 00:15:54.960 |
if somebody goes and runs it and it's two times slower, 00:16:03.520 |
So you want to be designing your architectures 00:16:13.840 |
is just making sure that you encode different ways 00:16:18.720 |
and really focus on that as a key decider of quality. 00:16:22.200 |
And finally, I think one of the emerging new things 00:16:29.560 |
is what are the right test time paradigms for these models? 00:16:32.960 |
How do they change relative to what you might do 00:16:41.880 |
So I've labeled this slide where we are yesterday 00:16:45.440 |
because Eugene is gonna talk about some new models 00:16:49.840 |
But as of yesterday, some of the really cool results 00:16:52.080 |
out of these efficient alternative models were, 00:17:08.720 |
put out this new diffusion model called SANA recently 00:17:21.800 |
and then that lets you scale to much larger images, 00:17:30.720 |
And one thing that I don't think anybody would have called 00:17:36.320 |
is that one of those gated SSM, gated states-based models 00:17:56.920 |
where these non-transformer, post-transformer architectures 00:18:26.920 |
is what's the difference between RWKV and states-based? 00:18:30.200 |
So I think one of the key things to really understand, 00:18:33.560 |
right, the difference between the two groups, right, 00:18:38.680 |
an open-source rental internet meets academia 00:18:45.040 |
but we basically look at RNNs and linear intention 00:19:02.600 |
And we do all this actively in Discord, GitHub, et cetera. 00:19:17.360 |
Great, now our H-index is now three, apparently. 00:19:35.000 |
how does RWKB handle its own attention mechanic 00:19:38.520 |
and achieve the same goals of like O(n) compute, 00:19:41.600 |
respectively, and in focus of our overall goal 00:19:56.120 |
And our goal is to train to even 200 languages 00:20:00.040 |
But at the same time, we work on this architecture 00:20:08.600 |
So how did RWKB break the dependency of LSTM token flow? 00:20:13.600 |
Because I think to understand architecture, right, 00:20:16.120 |
it's probably easier to understand it from the RNN lens, 00:20:21.680 |
We all state space kind of like try to start anew 00:20:28.200 |
And AKA, this is our version of linear intention. 00:20:31.320 |
So to take a step back, all foundation models, 00:20:37.440 |
at a very high level, right, comes in a token, 00:20:45.800 |
whether QKB cache or RNN states or RWKB states, 00:20:50.360 |
and outputs an embedding layer norm in something, 00:20:52.680 |
and we just take more layers and more embeddings, 00:21:07.000 |
the general idea is that you have the embedding information 00:21:09.360 |
from all the way up, and you take that information 00:21:13.920 |
and then you process it as part of your LSTM layers. 00:21:34.360 |
So you can have a H100, and you can't even use 1% of it. 00:21:38.280 |
So that's kind of why RNNs didn't really take off 00:21:42.640 |
like billions of parameters when it comes to training. 00:21:56.360 |
It trained, it sucked, but it kind of worked. 00:22:02.800 |
because the loss was crap, but how do we improve that? 00:22:12.080 |
you can actually get your GPU saturated quickly 00:22:24.200 |
you start to cascade your compute all the way 00:22:28.760 |
So we worked on it and we started going along 00:22:34.960 |
this general architecture where we can cascade 00:22:38.040 |
and be highly efficient with our architecture, 00:22:45.680 |
In fact, if you ask me to explain some things 00:22:48.920 |
in the paper, right, officially in the paper, 00:22:51.160 |
I'll say we had this idea and we wrote it this way. 00:22:55.760 |
we tested it, it worked, and then we rationalized it. 00:23:03.200 |
we generally have two major blocks that we do. 00:23:08.080 |
And TimeMix generally handles long-term memory states 00:23:12.520 |
where essentially where we apply the matrix multiplication 00:23:17.520 |
and SILU activation functions into processing 00:23:22.200 |
I'm oversimplifying it because this calculation 00:23:25.120 |
changed every version and we have version seven right now. 00:23:36.680 |
or the token before it, 'cause there's a shift 00:23:41.480 |
I don't really want to go too much into the papers itself 00:23:46.240 |
Basically, RWKV, RNN for the transformer era, 00:23:52.040 |
This is the updated version five, version six. 00:23:54.680 |
And GoFinch is our hybrid model, respectively. 00:24:08.480 |
all our architectures are codenamed by a bird. 00:24:21.760 |
and to be clear, most of this research is done 00:24:28.000 |
that was his experiment budget for a single researcher. 00:24:40.120 |
was how do we convert transformer models instead? 00:24:43.440 |
Because someone already paid that million dollars 00:24:46.200 |
onto training, so why don't we take advantage 00:24:49.560 |
And I believe, together, AI worked on the locus 00:24:59.920 |
And that led to Q-RWKV6, which we just dropped today, 00:25:07.400 |
where we took the current 32-bit instruct model, 00:25:24.440 |
But once we do that, we train the RWKV layer. 00:25:28.600 |
Important is that the feedforward layer needs to be frozen, 00:25:41.040 |
The end result, surprisingly, and to be honest, 00:25:46.760 |
which ended up releasing the model on the same day, 00:25:49.240 |
was that with just a few hours of training on two nodes, 00:26:06.640 |
who kind of leads most of our research coordination, 00:26:28.680 |
we were essentially like Frankensteining this thing, 00:26:31.440 |
and we did brain damage to the feedforward network layer 00:26:40.760 |
We didn't even spend three days training this, 00:27:01.000 |
We are already planning to do our version seven 00:27:08.720 |
And the other thing that is uncomfortable to say 00:27:12.080 |
is that, because we are doing right now the SMPB, 00:27:14.920 |
is that if this scales correctly to 128k context length, 00:27:30.360 |
That means if this works and the benchmark matches it, 00:27:39.240 |
And then, sorry, can someone give us more GPUs, 00:27:41.560 |
because we don't need the VRAM for super long context, sadly. 00:27:54.320 |
I don't think it's going to be exclusive to RWKV, 00:28:18.520 |
with a state-based model, be it RWKV state space, 00:28:40.160 |
plus four teams, that a lot more needs to be done. 00:28:42.760 |
But these are things that excite me, essentially, 00:29:12.800 |
is continued hardware model co-design for these models. 00:29:17.800 |
So one of the things that we've put out recently 00:29:25.320 |
And one of the things that we found frustrating 00:29:27.760 |
is every time that we built one of these new architectures, 00:29:30.280 |
and I'm sure you had the exact same experience, 00:29:32.680 |
we'd have to go and spend two months in CUDA land, 00:29:37.640 |
And if we decided to change one thing in PyTorch, 00:29:45.000 |
So one of our goals with a library like Thunder Kittens, 00:29:48.440 |
so we just broke down what are the key principles, 00:30:05.840 |
So you really want your operation to be able to split 00:30:08.760 |
into a relatively small matrix-matrix multiply operation. 00:30:13.560 |
So, like, multiplying two 64 by 64 matrices, for example. 00:30:25.880 |
how you set the update, how you set the update function. 00:30:36.280 |
should not be a float, but it should be a matrix, 00:30:38.800 |
and everything should just be matrix compute. 00:30:41.280 |
And we've been using that to try to both re-implement 00:30:44.160 |
some existing architectures and also start to design 00:30:48.880 |
with this core, with a tensor core primitive in mind. 00:30:52.720 |
Another thing that we're, at least I'm excited about, 00:31:03.640 |
But if you've been paying attention to Twitter, 00:31:06.000 |
there's been a bunch of new next generation models 00:31:16.080 |
that are supported by your mouse and your keyboard, 00:31:41.320 |
or some of these new video generation models that came out. 00:31:43.680 |
So Sora came out, I don't know, two days ago now, 00:31:51.040 |
So that's probably a quadratic attention operation 00:31:55.120 |
What if we could remove that and get the same quality, 00:32:00.320 |
Or some of the demos that we saw from Paige earlier today. 00:32:04.040 |
If I have a super long conversation with my Gemini bot, 00:32:14.120 |
I mean, maybe you don't for personal reasons, 00:32:26.040 |
I think we were supposed to have some hot takes, 00:32:28.480 |
but I honestly don't remember what our hot takes were. 00:32:35.480 |
- I think the big one on Twitter that we saw, 00:32:56.960 |
I'll say I found it was a little bit challenging 00:33:02.480 |
because we had this experience over and over again 00:33:06.240 |
where you could have an embedding model of any quality. 00:33:10.760 |
So you could have a really, really bad embedding model 00:33:25.360 |
I know it doesn't actually answer the question, but. 00:33:29.600 |
So I think a lot of folks are like extremely excited 00:33:41.760 |
we just mean a different kind of infinite context 00:33:48.480 |
So think of it more along the lines of the human. 00:33:51.160 |
Like, I don't remember what I eat for breakfast 00:33:57.440 |
And we humans are not quadratic transformers. 00:34:01.600 |
If we did, if let's say we increase our brain size 00:34:06.360 |
we would have exploded by the time we are five years old 00:34:09.440 |
And I think basically fundamentally for us, right, 00:34:13.160 |
be it whether we, regardless of whether RWKB, 00:34:18.560 |
our general idea is that instead of that expanding state, 00:34:34.120 |
Like, RWKB is running at 40 megabytes for a state. 00:34:39.120 |
Its future version might run into 400 megabytes. 00:34:49.280 |
It's just that I guess we are all more inefficient about it. 00:34:53.560 |
and that's kind of like the work we are doing 00:34:57.760 |
And that's where the models will start deferring 00:35:06.280 |
some element of right, but it may not be the same right. 00:35:09.920 |
And it's like, hmm, I can't remember that article. 00:35:16.360 |
when we can't remember the article in a company, 00:35:19.800 |
- Yeah, I think something that would be really interesting 00:35:25.680 |
so right now the one intuition about language models 00:35:33.640 |
And this intuition comes from the observation 00:35:35.840 |
that if you take a really small language model, 00:35:39.800 |
or it kind of has like the style of conversation 00:35:49.640 |
about things that it knows or that it can do. 00:35:52.960 |
But that points to all those weights that we're spending, 00:35:57.360 |
all that SGD that we're spending to train these models 00:36:04.720 |
So I think one thing that would be really interesting 00:36:06.560 |
is if we could actually have some sort of outside data store 00:36:13.600 |
that maybe has some sort of gradient descent in it, 00:36:21.600 |
And then maybe you could edit it, delete facts, 00:36:23.680 |
change who's president so that it doesn't get lost. 00:36:28.440 |
- Can we open up Q&A and hot takes to the audience? 00:36:43.320 |
who's throwing in 2 million token questions, what takes? 00:36:48.120 |
- The who's throwing in 2 million token question 00:36:52.400 |
So I actually, I was gonna offer that as a hot take. 00:37:06.680 |
But I think one of the, so I think for both of us, 00:37:12.960 |
was just from the first principle of questions 00:37:18.920 |
Clearly intelligence doesn't need to be quadratic. 00:37:23.440 |
You know, since then it's kind of turned into a race, 00:37:32.560 |
Nobody is actually putting in a 2 million context prompt 00:37:37.120 |
And, you know, if they are, maybe we can go, you know, 00:37:41.400 |
design a better model to do that particular thing. 00:37:51.840 |
How many of you remember the news of Google Gemini 00:38:24.560 |
because I think the big labs may have a bigger role in this 00:38:50.880 |
reuse the VRAM consumption in the training time space. 00:39:01.000 |
But then putting it back to another paradigm, right, 00:39:08.760 |
might be actually pushing that direction downwards. 00:39:14.520 |
is that if, let's say you have a super big 400B model, 00:39:28.200 |
and this is even for transformer or non-transformer, right, 00:39:31.080 |
will take less resources than that 400B model, 00:39:35.920 |
even if it did double the amount of thinking. 00:39:39.320 |
and we're still all trying to figure this out, 00:39:50.520 |
to just reason it out over larger and larger context length. 00:40:08.560 |
where you run on a much longer context length 00:40:20.120 |
I think you guys probably had tweets along these lines too. 00:40:31.560 |
And at the very least it won't like air out on your crash. 00:40:35.200 |
There's another question of whether it can actually use 00:40:44.600 |
and architectures ran faster than other research 00:40:56.080 |
Can we actually build some benchmarks for that, 00:40:59.720 |
and then ask the question, can the models do it? 00:41:05.000 |
Yeah, I think that if I were to turn back the clock to 2022, 00:41:11.200 |
which would have been actually get some long context 00:41:16.920 |
as we started pushing context length on all these models. 00:41:25.640 |
and the model needs to be able to learn inside. 00:41:28.880 |
I think this also fits the state space model, 00:41:36.240 |
is that the model don't suddenly become crazy 00:41:38.280 |
when you go past the 8K training context length 00:41:45.520 |
It's still able to run, it's still be able to rationalize. 00:41:50.000 |
But some of these things are still there in latent memory. 00:41:53.120 |
Some of these things are still somewhat there. 00:41:54.520 |
That's the whole point of why reading twice works, 00:41:58.680 |
And one of the biggest push in this direction 00:42:05.920 |
where they use this architecture for time series data, 00:42:09.640 |
So you're not asking what was the weather five days ago. 00:42:18.600 |
as on this earth and the computer will keep running. 00:42:21.320 |
So, and they found that it is better than existing, 00:42:32.320 |
I'm quite sure there are people with larger models. 00:42:33.920 |
So there are things that in this case, right,