back to index[Paper Club] GMT20251112 200413 Recording gallery 3440x1440

00:00:00.560 |
well yes i would but i was waiting for a screen share to start oh sorry it's okay slight editing 00:00:08.640 |
um okay screen is sharing thing is recording uh so we have a few things that came out 00:00:18.240 |
basically um i think what a week ago kimmy dropped kimmy uh linear it's it's just small different 00:00:27.520 |
hybrid attention stuff ted will go over that in a bit and then this week came out kimmy k2 thinking 00:00:33.520 |
it's basically an extension of kimmy k2 but it now has reasoning and thinking and it's pretty cool 00:00:39.440 |
because it basically popped the benchmarks in in every stat so big big jump um it's open weight like 00:00:46.960 |
before the the notable ones were on like humanity's last exam i'll give a quick high level overview before 00:00:54.560 |
we kind of dive deep honestly this um this twitter post by artificial analysis is pretty good so open 00:01:01.840 |
source uh they use int4 this time instead of fp8 so it's better quantized doubled the context length it's 00:01:09.280 |
a lot more sparse um really good on humanities last exam similar to gpt oss it's now trained to do 00:01:17.040 |
tool calls and reasoning and yeah it's kind of like very state-of-the-art it's still trillion parameters it's 00:01:22.400 |
still huge um it's very very verbose though so there's two endpoints of fast and a regular but 00:01:29.360 |
you know one thing that they note here that you wouldn't see is i think it's like the most expensive 00:01:33.840 |
um most verbose tokens that communicate to thinking use the highest number of tokens ever used in their 00:01:40.800 |
eval harness so their evil harness they basically test the model across all benchmarks uh it's nice that 00:01:46.800 |
they actually shipped in for and they tested in for but you know it's very very verbose it thinks a lot 00:01:54.080 |
but that being said it's still very good there's a heavy version of it that uses even more tokens 00:02:00.400 |
and um yes even better uh on benchmarks you know it's it's very very state-of-the-art so like humanity's 00:02:07.280 |
last exam beats gpt5 high beats growth for on other stuff it's it's pretty up there and on par um i of 00:02:15.200 |
course use gpt5 to give a bit of a summary and i think it honestly does pretty well since there's no 00:02:21.680 |
paper uh what's more interesting is actually there's some commentary like nathan from interconnects did a 00:02:28.480 |
pretty good what are the five things that matter about it um someone posted something about their 00:02:34.640 |
inference on reddit there's a little q a session about quantization so let's just dig in pretty fast so 00:02:42.000 |
k2 thinking i don't know if anyone has thoughts they actually use it or whatnot uh the other aspect of 00:02:48.720 |
this is it's their servers are super on fire so um they talk about the speeds that they get here um 00:02:58.800 |
performance is good here's how it performs but basically um they reported it gets uh the the base 00:03:05.520 |
endpoint is very slow getting about eight output tokens per second while turbo is slightly faster 00:03:10.080 |
uh that's that's not the worst case right it's still open weights other providers will support it 00:03:16.320 |
um so you know you it's slow but you know base 10 is doing it pretty fast and still pretty cheap 00:03:23.040 |
um but you know will will many people use it who knows it's it's still mit license there's two clauses 00:03:28.960 |
in there i think one is if you make more than 20 mil a month or you have like a bunch of users you want 00:03:34.240 |
to prominently say you're using kimi k2 so yeah now it reasons it can plan tool call and stuff between its 00:03:43.280 |
thinking it's trained with quantization aware training in int 4 and plus training we'll talk 00:03:49.680 |
about that a bit later that really helps speed and rl they doubled the context length which is pretty 00:03:55.280 |
exciting yeah it's it's quite state-of-the-art stuff uh some stuff is uh slightly behind but you know in 00:04:02.000 |
general it's quite good it's like a little midpoint where open source was able to like take quite a win 00:04:07.520 |
again um with it being int 4 compared to k2 we'll talk about quantization aware training and stuff uh 00:04:16.720 |
someone's asking in the chat with it being um int 4 it it's you know the last k2 was fpa i believe 00:04:26.240 |
so it's quantized basically to take half the memory so there's a deployment guide with vllm 00:04:33.280 |
somewhere here deployment um basically now you can run it on a node it's optimized for older hardware 00:04:40.880 |
it's it's not like fp4 um so you know cool stuff like that you it's basically a lot faster 00:04:47.280 |
and better um here's the clause 100 million monthly active users or 20 million monthly revenue 00:04:52.720 |
what else there's the official blog post which you can go through it's mostly just evals and hype 00:04:58.400 |
there's not much here agentic reasoning oh my mouse is frozen okay we're good we're good um what else 00:05:07.360 |
api is there it's slow on kimmy but you know other people have it uh it builds a lot on the old k2 00:05:14.320 |
it's more sparse deployment guide let's see what else what else yeah uh it can reason and tool call a lot 00:05:21.440 |
between 200 to 300 sequential tool calls int 4 versus quantize uh through quantize aware training 00:05:29.760 |
we'll talk about that through the reddit post in a bit longer context um with this similar to how 00:05:35.760 |
openai had harmony they open source for gpt oss uh you want to make sure that you follow their um 00:05:42.480 |
their format parsing so you can actually properly parse these tool calls there's like a whole section 00:05:48.080 |
on this i don't know if anyone here is actually renting a node to deploy it most of us will just 00:05:52.800 |
consume and use apis and you know that's that's pretty much the norm right now uh architecture wise 00:05:58.800 |
still a trillion parameters it's 32p active uh 384 experts top eight one shared expert uh context length 00:06:06.800 |
has doubled so you know 32 be active this thing can get fast it's just that i'm sure kimmy servers are 00:06:12.400 |
like quite cooked and on fire right now a sparse in this context so there's a sparsity ratio so the 00:06:18.880 |
number of active versus total experts right um this is something that kimmy has really been pushing 00:06:25.440 |
is the sparsity ratio going down or is it going up i thought it's more sparse than before am i mistaken 00:06:32.000 |
in this uh we'll we'll double check that in a sec but i thought it's i thought it's more sparse than 00:06:38.400 |
before but we'll we'll double see maybe someone else can check uh benchmarks i don't know if there's 00:06:43.680 |
much of a point in looking in benchmarks the interesting thing to note is that there is a heavy 00:06:47.760 |
equivalent um benchmarks on hugging face are laid out quite well honestly like depending on what you 00:06:55.280 |
care about one i don't know if many people are actually going to end up using it it is quite a bit 00:07:00.080 |
cheaper but um you know if you care about humanity's last exams it's pretty state-of-the-art artificial 00:07:06.720 |
analysis has their aggregated benchmark which still shows that like you know right there after open ai 00:07:12.400 |
the only things to be aware of that don't show up in benchmarks is that it is very very verbose so 00:07:19.840 |
uses a lot a lot of tokens to think and that's that's not optimal how to use it we don't need to know 00:07:26.320 |
any of this uh design and training notes so moe with mla very sparse moe keep stop active 32 it 00:07:35.600 |
still uses me on clip they didn't they didn't share that much more in for uh quantized aware training 00:07:40.800 |
they talk about this a bit in the system report and read it we'll go over there they have caching they 00:07:46.560 |
have standard they have a turbo okay i think that's enough of a high level of what came out uh nathan 00:07:52.960 |
from interconnects put out you know a pretty good post about how their releases come out from um 00:08:00.800 |
chinese labs benchmarks first user behavior second china on the rise so basically release stuff uh after 00:08:09.040 |
they see people in the us doing it uh this is new it's like you know the interleave thinking with tool 00:08:14.640 |
calls it's really good for agents this is kind of like not many open models did this except for gpt oss so 00:08:20.960 |
it's the first one after that i think uh pressure on american labs of course i don't think there's 00:08:26.880 |
too much more to go over through artificial analysis um okay so what is this whole quantization is a comp 00:08:33.760 |
is not a compromise it's the next paradigm so an infra engineer shares insights on why this choice matters 00:08:41.280 |
uh key idea quantization is no longer a trade-off it's it's actually more beneficial so there's there's in 00:08:48.720 |
traditional inference there's two reasons why there's two goals of optimization right high throughput so 00:08:54.240 |
as you quantize you can get better utilization of your gpu right this mattered a lot for early like 00:09:00.000 |
dense models on local hardware whether you have a mac or like a 30 90 40 90 50 90 you could do uh you 00:09:07.520 |
know four bit quant on a 32b fit it in ram and use it and then of course lower latency um if you know 00:09:16.240 |
the more quantized it is the lower latency you can get so for their thing um with such a sparse moe 00:09:24.320 |
their decoding is actually memory bound the smaller the model weights uh the faster the compute is so in 00:09:31.360 |
previous their fp8 where they had about it took about a terabyte of memory right they were at the 00:09:36.960 |
limit of what single high speed interconnect per gpu node can can handle now they have um weights at four 00:09:45.040 |
bit activations at 16 bit it's a interesting type of quantization it slightly uh latency drops quite a bit 00:09:53.760 |
and it maintains quality and it's in a sense lossless so why uh quantize aware training over post training 00:10:00.720 |
quantization uh post training worked fine for shorter generations but it didn't work in long chain so 00:10:06.880 |
another big thing another big thing that they did here is one they doubled the context two they do um 00:10:12.160 |
tool call reasoning right so this thing is very verbose it's more verbose in thinking than any other 00:10:18.800 |
model it does tool calls interleaved in its thinking and yeah uh post training quantization didn't work 00:10:26.960 |
errors basically accumulated and then it degraded precision over long contexts but um quantize aware 00:10:34.560 |
training didn't do it so how does it work uh dependence on calibration expertise this was these 00:10:40.480 |
are issues with post trading quantization k2 k2 thinking adopted quantize aware training for minimal 00:10:46.320 |
loss and more stable long context reasoning how it works it uses a weight only quantize aware training 00:10:52.400 |
and a fake quantization plus straight through estimator um little details here it four is the hidden 00:10:59.280 |
advantage in rl a few people mentioned this this is kind of what all the blog posts and tweets about it 00:11:04.880 |
show uh it for doesn't just speed up inference it accelerates rl training itself right if you think 00:11:11.280 |
about rl and all this verification you have a lot of rollouts being done and you know if you can speed 00:11:17.840 |
those up you're essentially accelerating all your rl in general too right rl rollouts suffer from long-term 00:11:24.800 |
inefficiency but this low latency profile makes it much faster and you probably have uh less overhead 00:11:31.840 |
where you're off policy per se as rollouts are sinking and new weights are being updated uh this 00:11:37.600 |
is pretty interesting right so uh once again the title of this is quantization is not a compromise 00:11:42.640 |
it's the next paradigm um they they basically say in practice each rl iteration runs 10 to 20 percent 00:11:51.200 |
faster end to end with this in for um with this in for quantize aware training uh moreover quantized 00:11:59.760 |
rl brings stability smaller representational space reduces accumulation error improving learning robustness 00:12:05.600 |
so why in for versus uh mx fp4 any of these other fancy fancier quantization stuff uh basically all 00:12:14.640 |
this new fancy stuff like fp4 it it's new to blackwell but it won't work with older hardware so uh more 00:12:24.080 |
generalization is what it for gives you it for can work across older hardware and you know that's a trade 00:12:29.680 |
off they're they're willing to make so at uh quant scale of 1 to 32 it for matches fp4 formats it's 00:12:37.760 |
like you know it's more hardware adaptable even though it's not the latest fancy stuff because it works 00:12:43.200 |
not non uh blackwell okay that's like quick 10 to 15 minute overview of k2 thinking i'll pause here for 00:12:53.840 |
questions questions comments i see chat is quite active i honestly haven't had a chance to look through 00:12:59.520 |
it if anyone wants to you know interrupt or go over any parts in specific detail feel free you know 00:13:06.400 |
i'm gonna go through chat but feel free to unmute or whatever or we can always pass to ted and then do 00:13:12.240 |
do q a towards the end and keep it in in zoom 00:13:20.720 |
import restrictions force for creativity yeah a lot of people say that um the chinese labs do really 00:13:26.000 |
good stuff like uh muon clip a lot of deep seeks um innovations were basically because they don't 00:13:32.240 |
have the best gpus so they're forced to get creative and yeah um how many 50 90s do you need to run 00:13:40.320 |
this you need a lot of 50 90s you need you need the pewdiepie rig you need chinese 48 gig 50 90s or you 00:13:47.680 |
know there's always the a6000 pro the a6000 rtx it's it's where you pay twice as much for more vram you 00:13:55.760 |
get a 50 90 level performance with 96 gigs of arab you put four or six of those in a box and you can run 00:14:02.880 |
this thing locally and it'll be slow or you could always just you know use the cloud where this thing is 00:14:09.440 |
like two dollars on the millions of tokens but the uh the kimmy lin here that's what 48 48 billion 00:14:18.640 |
parameter if i'm not mistaken so there's a version of that you can run yeah but it's it's more of an 00:14:23.120 |
experimental it's not really a stated yard model yet yeah 00:14:30.000 |
quirks is if you ask it for code it will generate code and thinking see if it's correct and produce 00:14:36.000 |
the same main output i think ankit is uh making an interesting point here i think you can't look at 00:14:43.280 |
thinking tokens as uh quirks thinking tokens are all basically like random random latent space noise like 00:14:53.840 |
models even deep sea kimmy um all of them will basically do random stuff in their thinking 00:14:59.840 |
those switch languages they'll write code for poetry they'll do poetry for code and then they'll output a 00:15:05.440 |
really reasonable answer so um you know they make it a pretty clear point that there's a reason why you 00:15:10.960 |
don't output your um chain of thought reasoning to the end user like people were salty at opening 00:15:18.960 |
eyes first 01 model for not giving chain of thought tokens but in reality what we see now is you know 00:15:25.360 |
there's so much noise in there like summarizing it is probably the better ux right you don't want to 00:15:30.240 |
ask it for code and get a bunch of um you don't want to get like poetry in your code thinking it doesn't 00:15:36.240 |
really help um let's see any other questions otherwise i will probably pass over to ted and 00:15:42.240 |
i'll keep i'll keep answering in the zoom chat but yeah it's better than minimax um trade-offs are there 00:15:50.400 |
they don't really talk about anything locally it will be 10 tokens per second no it'll probably be 00:15:55.280 |
worse than 10 tokens per second right now the kimmy servers are serving it at eight tokens per second so 00:16:02.640 |
10 tokens per second would be great on your 5090 but no these things are things are pretty raw 00:16:07.840 |
okay i think um a quick question so you mentioned a 10 to 20 percent uh increase end-to-end on ril 00:16:18.720 |
training right uh that doesn't seem to be a very large number um and i was wondering like is this speed up 00:16:25.840 |
any other thing else on this article is about the uh bandwidth limit right so can we kind of correlate 00:16:33.760 |
those two things like um you know that you're going when you went to int for and uh how much 00:16:39.680 |
more bandwidth are you allowed now uh for gpu transfer and then how does that correlate with the speed 00:16:46.080 |
yeah i think there's a high level map that you can work through right so quantization will be faster 00:16:52.240 |
uh you're also taking half the ram so you know if you have x amount of gpus there's actually a 00:16:57.920 |
a full breakdown of this somewhere i can share a link at some point in the zoom chat um but yeah you 00:17:04.800 |
break down the factors that you have right so you have faster you have half the vram required so they 00:17:09.280 |
they did a scaling compared to chinchilla somewhere um but in my opinion 10 to 20 percent is uh it's quite a 00:17:15.760 |
bit right if you compound it out for your experiments plus your final train run like the final train run 00:17:22.080 |
to train a million token a trillion parameter model is in the millions of dollars right so if you're 20 00:17:31.040 |
faster uh you can use 20 less gpu hours per se right so that's pretty substantial uh i don't i don't know 00:17:41.360 |
if that's exactly a fair way to put it but 10 to 20 speedups are pretty uncommon right um just yeah i 00:17:50.160 |
think that alone is pretty impactful but then when you think about it you know um it's it's very rollout 00:17:55.600 |
heavy and if you look at what the actual bottlenecks were from um bandwidth across uh their parallel is uh 00:18:04.640 |
their pipeline parallelism and gpu sharding um having more on less gpus it really helps break 00:18:12.720 |
that down i i feel like i saw a good post about this i'll try to find it and share it oh yeah you could 00:18:18.560 |
spend more i think uh it's not just about spending more though right this this uh this optimization is 00:18:24.400 |
also at inference time so you'll you should see speed ups during brain and plane inference and forget about 00:18:33.360 |
other providers like the whole one of the big benefits about open models and open weights is 00:18:38.560 |
you can serve them yourself right you want to use k2 you need to find a terabyte of your app you 00:18:43.760 |
want to use k2 thinking you can do it on half the half the memory so that alone is a lot better than 00:18:49.840 |
a 10 10 to 20 speed up right it's it's also after gpu is required and then the other thing that they 00:18:56.560 |
mention is that you went to int 4 making uh rl training more robust um and just that end sentence 00:19:03.760 |
on that section i don't quite understand why that would be says oh smaller representative space reduce 00:19:09.040 |
accumulation error or why is that uh it seems like you're more quantized so it doesn't that you know 00:19:15.360 |
accumulate even worse so what is this reasoning there i think someone talked about it in the comments here 00:19:23.440 |
magic just having a robust quantization now that's during compute there's there's an answer that 00:19:29.280 |
someone gave about it i'll share it in the chat a bit but um i think we should as a separate thing i i 00:19:36.320 |
remember reading somewhere an analysis that said that rl updates tend to be like rank one 00:19:41.840 |
so int 4 shouldn't be a problem at all for rl if it works for for regular training it'll definitely 00:19:50.080 |
work for rl yeah so i don't know if you answered your question frankie maybe it's just it's faster 00:19:56.560 |
and it's just as good okay it brings stability smaller representational spit uh space reduces accumulation 00:20:06.160 |
error improving learning robustness i'm sure there's more we can dig into here 00:20:10.240 |
but um it's a fair question i feel like we should yeah the quantization aware maybe is what they're 00:20:15.760 |
saying makes it better yeah versus um you know post that strange motivation okay i hope i didn't take 00:20:25.040 |
up too much of your time ted there wasn't that much in this one no that was great so uh let me talk about 00:20:31.840 |
the kimmy linear paper uh share my screen uh i'm always impressed with how much information viibu can 00:20:43.200 |
cover um so fast so uh i don't know that i can i can i can match that but so we have this kimmy linear 00:20:51.040 |
paper let me just do the quick highlights so it's a hybrid linear attention architecture so it's going to be 00:20:57.040 |
one quarter of the layers are regular attention and and three quarters of the layers are going to be 00:21:01.840 |
their new kimmy linear thing okay and there can be a linear thing they call kimmy delta attention kda 00:21:07.760 |
so we'll get into it but they say it's an expressive linear attention module that extends gated delta net 00:21:14.480 |
with a finer finer grain gating mechanism and then they have um a bespoke chunk wise algorithm that achieves 00:21:22.560 |
high hardware efficiency so so um talk about that too remind me if i don't but that's an important 00:21:28.320 |
element for any linear model um so they pre-trained this 3 billion active 48 billion total so this is 00:21:36.480 |
this is not quite like state-of-the-art size and i think some of the experiments in the paper were done 00:21:41.920 |
even smaller so um just know that this is this is like a proof of concept maybe more than it is the 00:21:49.040 |
model you're going to start using daily but the key thing is because the linear has so much less memory 00:21:55.200 |
requirements if only a quarter of the layers are regular attention then you're reducing your kv cache 00:22:01.040 |
size by 75 percent and that's like you know if vbo's saying 10 20 is a pretty good increase 75 is is a 00:22:09.360 |
very notable decrease in your memory utilization they have some decoding throughput numbers i i didn't 00:22:16.080 |
understand they have some graphs that differ but it's like 2.9x 3x 6x um but basically um because decoding 00:22:25.920 |
is memory bound um bandwidth bound then when you reduce the kv cache by by three quarters that of 00:22:34.960 |
course speeds up your decoding and that's where that speed up comes from whatever the exact number is 00:22:39.200 |
and the cool thing is they open source the kda kernel and the vlm implementation 00:22:47.360 |
okay so um let me jump ahead to what this uh what this model looks like okay so they have uh layers 00:23:02.720 |
where they do kda instead of regular attention and then they have layers where they do regular attention 00:23:07.760 |
and you know tldr in the end they're going to say we did relations and we found that three of these 00:23:13.040 |
for every one normal one gave us the best bang for our buck um the the attention layers are using 00:23:21.520 |
latent attention like deep seek uh but that's not particularly noteworthy um it's a mixture of 00:23:27.200 |
experts model and the moe is just completely standard um so i feel like the really interesting thing to talk 00:23:34.640 |
about can be linear is their new um kda um uh layer that that's replacing the attention and it's doing 00:23:43.120 |
sort of a uh uh local attention it's not it's not there to be able to do um you know in context uh 00:23:54.320 |
retrieval of a token that's you know 10 000 100 000 tokens ago that's what the one quarter of the layers 00:24:01.600 |
that are full attention are there for but if you just need something like the previous token 00:24:06.320 |
two tokens ago um then this guy is is going to take care of it for you and it's actually really strong 00:24:13.040 |
and it can do more than that um all right so i'm trying to follow the chat but you guys are typing 00:24:20.160 |
faster than every time i look at it so shout out if there's anything uh um that i'm going too fast 00:24:27.520 |
you want me to to answer or whatever all right so i thought that if what we're going to focus on is 00:24:32.960 |
just this kda thing then it actually makes sense to start with their related work section um down in 00:24:39.680 |
section seven and then we'll get into the details and i think we'll have just enough time to talk a 00:24:44.160 |
little bit about the the math okay so it starts with linear attention and um the idea is that attention i 00:24:52.720 |
think you guys know quadratic so it reformulates the quadratic attention map into kernelized feature 00:24:57.760 |
interactions replacing the soft max with a positive feature map so that tension can be computed through 00:25:03.280 |
two associative matrix products uh this is a little bit like the kernel trick that you see in svms which 00:25:10.000 |
can take something non-linear and turn it into something linearly separable so ultimately you don't 00:25:15.680 |
necessarily really need to know that but um so you get rid of this explicit um um quadratic uh attention 00:25:24.720 |
similarity matrix because you you you hope for it learning some pattern that that this kernel makes 00:25:31.440 |
linearly separable and so then we just have uh uh linear uh calculations uh and then let's see here 00:25:38.320 |
subsequent work strengthens the vanilla linear attention significantly through uh more refined memory 00:25:43.680 |
control shifting from data independent decay it's more adaptive data dependent mechanisms so for example 00:25:50.400 |
mamba just uh had this um decay thing and then mamba 2 they introduced where the decay was a function 00:25:59.680 |
of what the current token is so that's what they mean when they say data dependent decay mechanisms 00:26:06.000 |
and i'll talk more about how this decay stuff works um and refining the decay granularity from course 00:26:12.000 |
headwise to precise channel wise decay so that's one of the big things here is that they're actually 00:26:17.440 |
going to say my model dimension is 2048 instead of just decaying all 2040 of those floats by the same 00:26:25.520 |
factor 0.9 i'm actually going to choose channel wise which ones i'm going to keep which ones i'm going to 00:26:31.680 |
decay um so gated linear attention generalizes these approaches with diagonal channel wise gates so that's the decay 00:26:39.680 |
thing table seven summaries collectively these men methods cast attention as compact recurrent memory 00:26:46.960 |
updated with parallel prefix scan operators infused matrix multiplies aligning with modern accelerators 00:26:53.600 |
that's a whole mouthful basically um what linear attention and kda and mamba and rwkb and gated delta net all of these 00:27:05.200 |
things what they have in common is that at decode time they act like an rnn so you push one thing in 00:27:12.320 |
it updates the state you get one thing out that's not enough in order to have a good linear attention 00:27:19.280 |
you need three things you need the the easy one which is that decoding goes um uh really easily and really 00:27:27.600 |
fast the second thing you need is that training and or pre-fill is done really fast and that's where 00:27:36.400 |
you see this business of parallel prefix scans fused matrix multiply what they have in here this chunked 00:27:43.200 |
operator that's the thing that does the prefix the pre-fill operation really fast or the training operation 00:27:50.880 |
really fast otherwise if you just had a vanilla rnn and you give it a big amount of code and you have 00:27:58.240 |
a hundred thousand tokens it would need to process those hundred thousand sequentially and it'd be really 00:28:03.120 |
slow the decoding tokens per second would be really fast but the pre-fill would suck so you need those two 00:28:09.120 |
things you need fast pre-fill you need fast decode and then the third thing you need is you need the 00:28:14.080 |
attention to actually um uh be really smart be able to to um learn very complex patterns so what i'm mostly 00:28:22.720 |
going to talk about is this complementary view of the same linear tension connects um linear tension to 00:28:29.600 |
this idea of fast weight memory um and the state is a low capacity associated table and so that's that's 00:28:36.880 |
the view that i find most intuitive for understanding these things um so then they also talk about linear 00:28:42.560 |
tension with the gating mechanism and it's really the same idea here and they say a primary distinction 00:28:49.120 |
yeah sorry someone had a question that i think is uh in the right time they asked if you can go into the 00:28:54.960 |
pre-fill a little bit more but why make mixes what makes pre-fill so fast and in general what pre-filling 00:29:02.560 |
is so um so when we train models uh typically what we do is let's talk about early pre-training um so you 00:29:12.400 |
start with a context length of say four thousand you present all four thousand you have teacher forcing 00:29:17.360 |
which means that um no matter what the model predicts you're always just putting in the four thousand 00:29:22.320 |
tokens from your original ground truth corpus and then you calculate cross entropy loss uh in parallel 00:29:29.040 |
across all four thousand token positions you're really doing the same thing in pre-fill so if 00:29:34.480 |
somebody has a very short question like what is the capital of france this is no big deal it's like eight 00:29:39.680 |
tokens whether you did them in parallel you give them sequentially no biggie but when i uh am working 00:29:45.040 |
on code and i and i give you um uh the files and let's say you need all of these files this 100 000 tokens 00:29:53.040 |
then in order to start decoding and whether you're thinking whether you do tools whatever you're doing 00:29:59.440 |
if you're going to start decoding you first need to process all 100 000 of these um and we've come to name 00:30:07.280 |
this portion of the process pre-fill because its characteristics are very different from the one 00:30:13.040 |
token at a time decoding that happens after that so under pre-fill we have all the data we have all 00:30:19.120 |
the hundred thousand we don't care about any of the predictions so we can we can do this in a highly 00:30:24.240 |
parallelist fashion and for regular attention um you can you can do this super parallel all at once 00:30:31.680 |
and your gpus will be compute bound because they're doing so many matrix multiplies calculating your 00:30:37.920 |
your keys queries multiplying them together all that stuff okay um as well as the mlps uh during decoding 00:30:45.840 |
you only get one new token one new query um and you're multiplying that against all your node keys 00:30:52.240 |
and then calculating attention and so it ends up that once your context gets long the the keys and values 00:31:00.320 |
that you're loading into the gpu are more memory bandwidth limited than the amount of compute you 00:31:08.000 |
actually need to do the matrix multiplies actually relatively fast and tiny and so um linear if you 00:31:16.320 |
imagine just a regular old-fashioned rnn um it's really great because you just have this small state 00:31:21.360 |
you put in one token and you do a little bit of operations and you get out the next token so they're 00:31:26.000 |
really great for decoding but there is no in an old-fashioned rnn there is no way to do pre-fill 00:31:33.520 |
in parallel you would have to just do them one at a time and if i give you a code base it's 100 000 tokens 00:31:39.440 |
you're kind of screwed so everybody has to if you want to have a competitive rnn style linear mechanism 00:31:47.360 |
which will save on memory you also need to have a second mode where you can in parallel process 00:31:55.040 |
these hundred thousand tokens not one at a time there has to be some trick and so this chunked method 00:32:00.880 |
um in the kda paper and then similar in other papers uh mamba whatever they they have some some way of 00:32:08.160 |
transforming the math to say there's an equivalence if i do it this way i get the same answer in parallel 00:32:14.000 |
versus having to process every token one at a time did that for whoever had that it already scrolled off 00:32:21.600 |
my screen that answers the question all right cool thank you um all right so let me just read this 00:32:32.080 |
real quick uh the primary distinction amongst various gated linear attention mechanism lies in the 00:32:38.160 |
parameterization of the forget gate so we haven't talked about this yet so maybe reading this first 00:32:43.760 |
it's a little bit out of order uh for instance retinet uses a data independent scalar decay alpha 00:32:50.240 |
and mama2 employs data dependence scalar um specifically gated linear tension uses a diagonalized fine grade matrix 00:32:58.880 |
um offering an effective trade-off between uh efficiency and performance and that's what we're going to get into 00:33:05.120 |
here note also that aside from linear attention another way you can speed things up is you can do sparse 00:33:11.200 |
attention so if we talk about the new um deep seek i don't even know what it's called but the new deep seek 00:33:17.440 |
attention thing they're kind of doing a sparse kind of a thing where um uh they're not looking at all of the 00:33:26.640 |
the tokens at once to calculate uh the attention but this paper is a linear form so i'm just gonna kind of focus on that 00:33:33.760 |
all right any questions if not then i was thinking i would jump into um uh the the actual linear attention stuff 00:33:54.480 |
all right so section 2.2 is the key section for understanding what they did in kda and um they 00:34:02.960 |
didn't actually say this um because they have these different sections here what the heck are they 00:34:07.760 |
doing what they're actually doing is they're they're doing a history of linear attention and they're 00:34:13.040 |
building up from the simplest to the most complex and their latest thing which is kidney delta tension okay 00:34:21.440 |
so section 2.2 none of this stuff literally is used in kda this is all the the history that builds up to 00:34:29.440 |
we built our our kda by making a a slight tweak on on the last thing here 00:34:35.120 |
so um to start with we're going to look at this first equation but this is just a basic linear attention 00:34:42.400 |
and what i was thinking i would do is i would just like scribble a little bit so you guys can sort of 00:34:47.840 |
understand the idea of the the associative memory so let me see if um do we have whiteboards turned on 00:34:58.320 |
here is there something i can enable uh it might not be enabled during the meeting so no i don't see it 00:35:13.680 |
okay okay this is this is new this is fancy this is everyone drawing all right people don't draw on 00:35:30.080 |
this because i need to draw on this if you want if you want me to show you something you need to stop 00:35:33.680 |
messing around okay um all right how do i erase it all right thanks guys 00:35:44.320 |
um okay okay awesome so so let's start with just a regular kv cache okay and and it looks something 00:35:59.040 |
all right so oh i just lost the whiteboard i think that was me i i thought i xed out of it for myself 00:36:10.640 |
but i i did it for everybody apparently sorry that's okay i can never 00:36:14.880 |
uh i can do it i think that you should be able to start a new whiteboard yeah yeah yeah yeah sorry 00:36:21.440 |
all right nobody nobody mess with it all right all right do you guys see or do i need to share my 00:36:25.760 |
screen we see it we see it okay all right so so let's start with a um a regular kv cache 00:36:36.080 |
and if i can do this really quickly okay it looks something like this and basically um this here sorry 00:36:46.480 |
i'm trying to do this fast this here is um is our sequence length 00:36:53.680 |
okay and over here we're going to have our um our key dimension i'll call it d sub k 00:37:06.560 |
and we're going to have our value dimension d sub v all right so so hopefully this looks familiar 00:37:13.920 |
to you guys you have keys and values and it's as long as a single place and so basically um 00:37:19.200 |
what you do in here is for every token that you've ever seen you store a copy of the the key and value 00:37:26.960 |
after you've taken that token and run it through the embedding matrix w sub k w sub v all right and 00:37:32.800 |
then you don't have to store these you could just recompute them and that's what the the sort of 00:37:37.200 |
original attention thinking is but it's actually better if you think of it in terms of the kv cash 00:37:44.000 |
all right so now what what what happens when we train the model is we learn different patterns that we care 00:37:51.280 |
about and we have the ability to recall them on demand so let's say the llm wants to know if the current 00:37:58.480 |
token is a noun okay then what it could do is it could have one of the head keys put a one in a 00:38:05.920 |
certain place if the current token is a noun then if the query also has a one in that place then when 00:38:14.640 |
you dot the query in the key your dot product is a is a large value you get a one and you're basically 00:38:20.080 |
going to recall that so imagine that uh for simplicity now this is well okay let me just start 00:38:28.480 |
this way so imagine that you just basically divided up your keys okay and you basically said hey i'm 00:38:34.800 |
gonna put um a one here anytime it's a noun and i'm gonna put uh a one here anytime it's a verb 00:38:42.560 |
okay so that's a way in which you could build this this key value memory and then anytime 00:38:50.400 |
the head wants to match on prior words that are nouns you just pass in a query that looks like one zero zero 00:38:57.760 |
zero zero whatever okay and then every word that's not a noun is going to have zeros here 00:39:02.640 |
and so any word that's not a noun is going to have a very low value is going to have a zero value 00:39:07.760 |
and so you're only going to get matches on nouns and then if everything is a zero except for verbs 00:39:13.040 |
then you put in zero one zero zero zero as your query and you're going to get all the verbs now 00:39:17.920 |
obviously you can design whatever you can learn whatever pattern you want okay but this is the 00:39:21.920 |
idea so what this means is now we've built an associative memory okay so a regular array in 00:39:30.080 |
python you you access it by index you say i want index 23 but here you can say i want nouns and you 00:39:37.200 |
don't need to know what position they're in and you don't need to know how many there are you just 00:39:41.360 |
simply say if my query is one zero zero zero zero i will get all nouns and if there's only one noun 00:39:48.400 |
then if you say give me one zero zero you will get that exact uh right entry 00:39:53.440 |
okay so the problem we have where it gets expensive is that in this formulation every time you decode and 00:40:04.240 |
you get a new token you need to lengthen this sucker and so then you're basically adding another another 00:40:10.320 |
column here and then you do decoding and you're adding another column here and this thing this thing 00:40:15.680 |
just keeps growing uh in size and i and i it's not fancy my my columns lines aren't lining up but i but 00:40:23.920 |
you get the gist so then by the time you get to hundred thousand tokens this thing is very large it's 00:40:28.320 |
it's hundred thousand so the the promise of linear attention is let's not have this thing grow like crazy 00:40:36.240 |
let's assign a fixed amount of memory to it and so if you said i'm going to assign um a fixed size of 00:40:46.480 |
1024 to this thing for the first 1024 tokens you're great right um or the first 1024 concepts that you want 00:40:58.000 |
to store in here um you can basically just have a one you know um in the first row for the first thing 00:41:05.280 |
and a one in the second row and you can just one hot and code it okay but what happens when you get to 00:41:11.200 |
the next one the 1025th if your update rule um 00:41:19.840 |
it is simply sort of something along the lines of uh of cycle through these one hot encodings 00:41:25.840 |
and then you get to 1024 you're going to be back to a one in the first position and you're going to end 00:41:30.480 |
up uh um uh writing something in here that's going to collide with this thing that you already had in 00:41:37.440 |
there and so so this is where in the literature you'll see some stuff about collisions okay now the 00:41:43.120 |
next thing we're going to do is we're going to so so we have to find a way to solve that um and so 00:41:48.080 |
what you could do is you could say i'm going to erase this entry um with the with the one in the first 00:41:54.480 |
row and then that'll create room for my 1025th thing so ultimately what you have to do is you have 00:41:59.920 |
to have a rule that says what am i going to erase when i want to put something in and so that's where 00:42:05.680 |
as you get to these more complicated linear models people are are having these decay rules these 00:42:12.800 |
are the the the erase rules for for when do i get rid of stuff in my now very limited fixed sized kb 00:42:21.280 |
cache okay so then the next thing we're going to do is we're going to do a little bit of linear algebra 00:42:26.880 |
okay and so um if you if you recall in linear algebra um yes i could put a one in the first row one 00:42:36.320 |
in the second row and this will give me this nice orthogonal basis in this case they'll also be length one 00:42:41.600 |
so it'll be orthonormal and that's the the principal basis the one that you know is easy on the eyes 00:42:47.280 |
but if this thing is d dimensional um i can have any d um orthogonal vectors and they will work perfectly 00:42:57.440 |
well they do not have to be all one and all zeros um this first vector could be one one one one one and 00:43:04.240 |
then the next one as long as it it's orthogonal to that first one um it could it could be anything and 00:43:10.720 |
so that's where if you look at the keys in an lm you're not going to see one zero zero zero zero one 00:43:15.680 |
zero zero but in fact they very well may be orthogonal to each other okay and we benefit from this concept 00:43:23.760 |
that in high dimensionality the cursive dimensionality all random vectors are very close to orthogonal to 00:43:31.120 |
each other so the dot product of two random vectors is likely to have a value close to zero all right and 00:43:38.480 |
so um so in fact we don't use one zero zero zero we just use learn uh vectors okay so um so then the next 00:43:48.720 |
first thing to understand is if we go back to i i think i can toggle back and forth but uh hold on a 00:43:57.280 |
second where's my zoom menu here uh sharing will close the whiteboard uh okay i didn't want to do that 00:44:18.160 |
okay so this is the equation that's at the top of section 2.2 okay it says that here's my hidden state 00:44:28.800 |
so so this is going to be my my new rectangle here okay and we initialize it with all zeros and um 00:44:37.120 |
what we do is we when we process the first token we add in this thing k times v transpose so this is 00:44:45.840 |
an outer product so if you're used to seeing dot products k transpose times v would be a scalar 00:44:52.240 |
quantity okay but k times v transpose is going to give us a rank one and if k and v are our same 00:45:00.080 |
dimensionality a rank one square matrix all right um and then it says 00:45:07.040 |
the way that you you get your output so equivalent to your your soft max tension is um you you take 00:45:14.800 |
the transpose of this matrix and you multiply it by your query so i want to just show you super quickly 00:45:19.840 |
how this um ends up being a really clean associative memory so i'm going to drop the the t subs subscripts 00:45:29.760 |
and and just do um uh uh uh one time step okay so we start with s is all zeros and then we say the first s 00:45:52.400 |
all right all right um then um how do we read the output from that so the output is we say it's s 00:46:03.360 |
transpose times the query so what i'm going to do is i'm going to rewrite s using the equation up above 00:46:14.080 |
and i'm going to transpose it so when i transpose it i'm going to get v times one second um uh k transpose 00:46:24.080 |
if you know you know how your transpose matrix multiply works um times the query 00:46:33.840 |
but but we can change the parentheses uh because matrix multiplies associative and so i can actually say 00:46:47.040 |
times the query okay and so in if if you are familiar with the softmax attention uh formula you 00:47:06.960 |
you have this this k transpose times query thing i think it's written the other way around q query 00:47:13.040 |
transpose times key whatever but you're multiplying your queries and your keys and then ultimately you're 00:47:17.760 |
multiplying that by your values and then it's got a scaling thing square root of d blah blah blah whatever 00:47:22.720 |
but so in this formulation we've basically got the same functionality as um we've got this query times 00:47:30.960 |
key business just like in regular softmax attention and so uh i'm gonna try and go a little bit fast 00:47:37.600 |
here just for the sake of time but hopefully this gives you at least a little bit of intuition that um 00:47:42.800 |
that we can store these things and and we can recall them and it's performing something very similar to 00:47:51.840 |
regular attention okay so the difference is that in our in our uh in our linear attention we don't have 00:47:59.200 |
these two boxes okay we now just have one box and this is this is initialized to zeros and then when 00:48:08.000 |
we get the first key in value we do this outer product it gives us a rank one square matrix and we 00:48:13.280 |
sum it in here and so it's just going to be zero plus that it's going to be our original one the key 00:48:18.160 |
thing from linear algebra is that if the second key is orthogonal to the first key and you and you do this 00:48:26.240 |
thing again regardless of what the value is and you take that rank one matrix and you add it in here 00:48:32.480 |
you will then get a rank two matrix and when you multiply by either um key in this in this this output 00:48:42.080 |
formulation you will get if the two are perfectly orthogonal zero dot product you will get an exact 00:48:48.160 |
recall of the values with no loss whatsoever now if they're if they're like mostly orthogonal you'll get 00:48:54.320 |
99 98 recall whatever that's that's good enough to the extent that you put in a query that's kind of a mix 00:49:01.920 |
between different keys then you're going to get a weighted blended average of the different keys which is 00:49:08.880 |
exactly what we see in softmax tension um you you're getting a blend of the different things of the 00:49:14.880 |
different values that are getting a weighted sum of values based on your tension scores okay so the shape 00:49:21.760 |
of this is a little different in that this one is two times d assuming that that dk equals dv the height of 00:49:29.840 |
this is 2d and the height of this is only 1d so this is actually even a more efficient representation 00:49:36.560 |
because we're taking advantage of this linear algebra stuff we don't actually have to store 00:49:40.960 |
the the the the keys and the values in separate sections we by multiplying them together and taking 00:49:48.080 |
advantage of orthogonality we just slam them all in here all right any questions before i move on to 00:49:55.600 |
the rest of section 2.2 all right cool so let me share my screen it says the whiteboard's going to go 00:50:04.880 |
away i don't know if it can be saved brought back whatever but one sec one sec let me uh let me save it or 00:50:13.760 |
add an option to save it as a pdf or something yeah there is a way to save them yeah i got it 00:50:20.400 |
i got it whiteboard i can save it as a template and i can also just export it as a p yeah just export it 00:50:27.840 |
to pdf if you would be so kind yep i got one png i'll also do a pdf and then we can figure out how to 00:50:35.920 |
get it all right it's not that beautiful so people all right it's still there still there okay go ahead 00:50:40.960 |
and get rid of it i'll bring it back if you need it all right um 00:50:45.520 |
yeah and i think i can bring it back but anyway all right so 00:50:52.400 |
hopefully you see my screen so we just talked about this first equation 00:50:58.720 |
okay this is the super simple uh way of storing things in an array and so s is simply a sum of rank 00:51:07.600 |
one outer products of approximately orthogonal keys all right this thing will die as soon as you 00:51:16.240 |
have one more thing than whatever the dimensionality is so if it's if it's you know 1024 it'll work 00:51:22.320 |
great for the first 1024 things and then it's not going to work um once you try to put something else in 00:51:29.280 |
so uh one of the things you can do is um you can have this this delta rule which basically says that 00:51:40.800 |
i'm going to subtract a little bit out in order to make room for the new thing that i'm adding um that's 00:51:48.240 |
what this business that gets you to this equation that's what this business is about i'm not going to 00:51:53.760 |
go into like super super detail but ultimately um what what then people evolve to is this idea which for 00:52:05.840 |
people here will probably be more familiar is this idea of weight decay so if you imagine that that 00:52:11.920 |
square matrix we have that's our memory store if you just apply a little bit of weight decay every time 00:52:16.960 |
you get a new token you multiply that whole matrix by 0.95 okay then what's going to happen is something 00:52:23.040 |
that's very old will have been multiplied by 0.9 over and over and over and over and over again and it will 00:52:28.240 |
eventually get close to zeroed out without you having to explicitly do anything okay and so that's where you 00:52:35.680 |
basically have this scalar alpha that's like your weight decay in in our in our weight matrices that we use 00:52:42.640 |
when we're doing you know gradient descent when we're doing like atom w right um so it every time you get 00:52:49.200 |
a new token you just multiply it by this um and and the old ones uh will decay and so the idea here is that 00:52:58.160 |
uh i is the identity matrix so our update after we get a new token is the identity times the old 00:53:06.080 |
matrix so that's just the old matrix minus um this particular thing which is the which is the making 00:53:14.720 |
room forgetting component here okay plus this is our brand new key and our brand new value so this is the 00:53:22.960 |
thing that we're adding into the memory all right and then we're going to take this whole old thing 00:53:28.320 |
and we're going to decay it by some you know whatever 0.95 some some you know value like this 00:53:33.840 |
and this thing actually worked pretty well it was pretty stable so that if you got past a thousand 00:53:40.320 |
tokens and you went to fifteen hundred two thousand three thousand it kind of mostly remembered the last 00:53:46.800 |
one thousand tokens so it works a lot like sliding window attention okay so i i saw the comment early 00:53:54.240 |
in the chat why not just do sliding window okay it works like that however it is a little bit smarter 00:54:00.560 |
than that where you can selectively say if a concept is super important i will keep it even though it's 00:54:06.560 |
more than a thousand tokens ago all right um so hypothetically if if you train a model on stories 00:54:13.840 |
and the first sentence says our main character is named bob it could decide that that based on 00:54:19.760 |
context it's gonna like never forget that thing and it's gonna really really make sure that bob stays 00:54:25.680 |
in the in the memory because it's super important if you use a sliding window it it's it's a hard-coded 00:54:31.040 |
rule it's only the last thousand tokens and so you you can't have that selectivity whether or not 00:54:36.640 |
your model is smart enough to learn that rule to know that it should remember bob is a whole separate 00:54:41.440 |
thing but that's why they they train these models to try to get them to work so finally um if you have 00:54:48.960 |
this idea of weight decay we're decaying all of the entries in there simultaneously and equally um now 00:54:57.440 |
you can say i'm going to decay them uh uh uh based on what the current token is and then the final thing that they do 00:55:06.880 |
in kimmy in this in this kimmy linear kda they say this old entry is um a float of length 2048 i'm only 00:55:20.560 |
going to decay of the 2048 floats only decay the ones that i really want to um and i can selectively choose 00:55:28.560 |
and so instead of having up here it's a scalar a 0.95 that multiplies by all the weights i'm now 00:55:35.200 |
actually going to say this skip this is now not a scalar it's a vector um and they make it a diagonal 00:55:42.000 |
matrix but it's effectively just saying each component i will choose so maybe the first 10 floats 0.95 and 00:55:48.080 |
the rest of the floats is one i'm not going to decay them at all so at the end of the day that's the that's 00:55:53.680 |
the trick and so then they say we have this efficient algorithm for doing parallel when we have pre-fill 00:55:59.760 |
or training when we have all the tokens at once it's an rnn so it's always really fast when you're 00:56:04.720 |
doing one token at a time and then what we can do is we can compare it to other linear uh things that came 00:56:14.480 |
before so for example you may remember uh mamba uh you may remember a gated delta net which is a little 00:56:23.040 |
more recent and so this thing actually if you look at these curves mamba 2 is the orange line and on 00:56:30.080 |
accuracy um it smokes it and then gated delta nets newer and it's it's similar but like a little bit 00:56:36.720 |
better um in performance so then they get into uh i know we're out of time so then they get into um 00:56:44.400 |
some simple experiments to show that their linear attention alone in a toy model with two layers 00:56:49.680 |
outperforms older things outperforms linformer mamba mamba 2 rwkv uh uh gated delta net we we outperform all 00:56:58.880 |
those on a few things that linear models tend to have difficulty on like reversing a strain 00:57:04.960 |
um and doing like multiple recall um so now that we're confident that we have a better linear 00:57:11.680 |
component let's actually stick it in a model and then we're going to do some ablations with um how 00:57:18.160 |
much full attention do we need they come up with one quarter full attention is good if you do more than 00:57:23.520 |
a quarter full attention you get almost identical performance at more cost if you do less than a 00:57:28.720 |
quarter if you do like one eighth full attention they found that the performance drops so that's why they 00:57:34.320 |
they don't have to worry about what they're doing but they don't have to worry about what they're going 00:57:38.240 |
to do and they don't have to worry about it they don't have to worry about it they don't have to worry 00:57:41.040 |
that they don't have to worry about it they don't have to worry about it they don't have to worry about 00:57:43.040 |
it they don't have to worry about it and they don't have to worry about it and they don't have to worry about 00:57:43.600 |
it they don't have to worry about it and they don't have to worry about it and they don't have to worry about it 00:57:45.040 |
and i'm not sure if they did their ablations completely right i'm not sure how they handled rope for the the baselines 00:57:53.040 |
for the the baselines but basically what they found was that if they took away positional embeddings for 00:58:00.400 |
the quarter of the layers that were full attention it actually performed better than if they gave it 00:58:05.840 |
positional information and what they argued is that attention is so regular attention is so good at 00:58:12.000 |
positional information that it it trumps the kda and says i will take care of short-term dependencies 00:58:20.400 |
when they took attention when they took a rope away from the full attention layers it was like 00:58:25.760 |
well i can't do any positional information because i don't have anything and so it fully delegated all 00:58:31.680 |
the responsibility for previous token two tokens ago to the kda layers that allowed the regular attention 00:58:40.560 |
layers to 100 percent focus on long context semantic tasks give me the nouns give me the last time i saw 00:58:48.000 |
the name bob give me all of those things and not focus any of its energy on short-term things there 00:58:55.200 |
may be some other dynamics going on here but i'm particularly uh um on event that regular rope is 00:59:01.840 |
kind of stupid we should be using at least truncated rope mla they separate the rope so that it's not 00:59:08.480 |
actually adding noise uh to the normally rope you know you you multiply it and so you're you're noising 00:59:15.120 |
your your token signal uh for mla you concatenate it so you're not noising your signal so i think there's 00:59:21.120 |
some things that that will my personal opinion is we're going to learn some things about how rope is 00:59:27.840 |
awesome it beats all the other positional embeddings but there is a cost to it and that we can be a little 00:59:32.400 |
bit smarter about it and that's why they found that by by taking it away from the the regular attention 00:59:39.680 |
layers so that for me was like the most um interesting thing uh um the rest of it is like hey we ran some 00:59:47.120 |
experiments and did really well they make these strong claims that they outperformed a model with 00:59:54.320 |
the same training recipe same number of tokens but just all full attention layers instead of a quarter full and 01:00:01.520 |
three quarters kda i would believe it but i i think it's possible that that they may have messed up 01:00:10.080 |
their relations but if they do in fact beat it i think my personal take is it's not because the kda is so 01:00:18.320 |
awesome it's because rope has a cost and they actually managed to find a way to measure that cost 01:00:26.080 |
and by removing rope from the from the full attention layers they actually let the attention do what it does 01:00:32.080 |
really well without any noise but that's just ted's hot take all right i can i can stay longer if if 01:00:40.720 |
people have questions or want to talk about it but with respect for the people who may need to go i think 01:00:45.440 |
that's the most important thing to know about it so there are other models out there that are doing the 01:00:49.520 |
hybrid i saw some stuff in the chat so so yeah like three quarters something faster one quarter full attention 01:00:56.960 |
means you get three quarters smaller uh kv cache much much faster decoding because decoding is memory 01:01:03.440 |
bound and as long as it doesn't totally kill you on on in context learning then then it's like sort of an 01:01:15.040 |
thank you um i understand a lot more about kda um after your step-by-step walkthrough just wanted to say that 01:01:35.520 |
all right so i don't know if that's a good sign that it was clear or i just completely lost 01:01:44.320 |
everybody so if you go back to this diagram they didn't label it but on the left this is the queries 01:01:49.680 |
in the keys and so like usual you take your your your um your your token you run it through a weight 01:01:57.440 |
matrix w sub q w sub k uh if you remember mamba they have some uh 1d convolutions that they run it through 01:02:04.880 |
um and those are helpful and like you know 1d convolution can very easily just find simple patterns 01:02:10.320 |
um this is a normal there's normalization layers everywhere this is a normalization layer uh qk norm like 01:02:17.600 |
quen has this is the values totally uh again normal except they run it through convolution 01:02:23.600 |
one trick they do is on their alpha um they use something akin to uh laura where they take two 01:02:32.560 |
rectangular low rank matrices and multiply them together to make a square matrix uh just because 01:02:37.680 |
they don't really need that whatever the heck it is d squared numbers of parameters um and they do a 01:02:44.800 |
similar trick uh on the on the final output gate and then this is um there uh uh beta the the forget 01:02:54.160 |
factor so everything on here is channel uh um aware so there's nothing here that applies equally to all 2048 01:03:04.320 |
floats anything they do whether they're adding subtracting multiplying it's always going to be at least 01:03:11.120 |
dimensionality d um and this is the uh this is the gate that you see in gated attention uh i don't know 01:03:21.200 |
i don't remember which other models have gated attention now uh uh uh the latest quen have it um but 01:03:27.840 |
anyway so normally what you do is you calculate all your attention you just add that to the residual stream 01:03:33.600 |
okay so this sucker has a sigmoid on it that says hey if you're layer three and for some reason you're 01:03:40.080 |
sort of less important i'm going to actually multiply your output by 0.5 and so you're going to have a 01:03:45.760 |
quieter addition to the residual stream and if you're a super important layer then you're going to get the full 01:03:51.600 |
level 1.0 um and if you guys know about like um attention seeks attention sinks and massive 01:04:01.520 |
activations um that's one of the things that researchers seem to be trying to fight and so 01:04:09.200 |
they think that if they add some of these normalizations in they don't get as large spiking activations 01:04:15.680 |
which i think helps them when they're trying to go to smaller and smaller quantization 01:04:20.560 |
because now they don't have such extreme values that they need to support and so i think the holy 01:04:25.520 |
grail is you figure out some kind of normalizations and gatings such that all your activations stay in a 01:04:33.120 |
nice little gaussian distribution and then you can totally nail it with something like you know fp4 in 01:04:40.560 |
four or whatever but right now if you have these activations and then occasionally you have an 01:04:45.040 |
activation that's a hundred times larger it's just it's just very difficult in low precision to support 01:04:51.600 |
an activation so that's why in these things you still see right like the the weights are small but 01:04:56.720 |
the activations are still 16-bit it's because you have these hundred thousand times larger activations 01:05:02.400 |
if you if you if you really quash those you go down to eight bit activations and maybe eventually 01:05:08.480 |
four bit activations so that's like a little more detail if you actually go through sort of like 01:05:16.000 |
this is what the the layers look like and you see there's just a few extra normalizations in here 01:05:20.480 |
and i had a question about you had mentioned at the beginning you drew a parallel between linear 01:05:33.360 |
and colonel methods and kernel methods and i i have i think i understand kernel kernel methods reasonably 01:05:40.400 |
well i just wanted to understand your parallel you're drawing there um i i don't know if i can do a great 01:05:47.040 |
job explaining it so i gave the the associative memory analogy for how to understand this 01:05:54.480 |
okay um the other way of thinking about it is that um if you if you have this n by n uh uh attention 01:06:07.360 |
pattern well it may be lower triangular because of of of of of of of of of of of causal attention or whatever 01:06:14.720 |
but still if you just think about it the pattern can look like anything okay and the high values 01:06:20.880 |
and the low values are not necessarily linearly separable um and because so if the high values 01:06:29.280 |
and the low values look like some sort of xor pattern where it's like high low high then if instead of 01:06:36.000 |
using a full n squared thing and a soft max if you try to squash this to just a one-dimensional 01:06:44.240 |
representation a linear representation you can only do linearly separable in a way that the highs have 01:06:52.080 |
to all be on one side and lows have to be on the other side and that is the the somewhat mathematical 01:06:58.480 |
uh uh uh interpretation for why linear attention you know whatever lin form or per form or mamba one 01:07:06.480 |
cannot be as expressive why they hit a glass ceiling that soft max full attention doesn't hit when you 01:07:12.880 |
start training these really large models um so if you imagine that you then take those those attention 01:07:22.080 |
patterns and you run them through some other function so the classic one uh uh if you want 01:07:28.960 |
i can pull it up real quick is but so you you have sort of like high low values uh centered on zero and 01:07:35.680 |
then if you do x squared and you turn it into a parabola okay then the high values are all up here and 01:07:41.920 |
the low values are all down here in in the in the bottom of the parabola and now they're linearly separable 01:07:48.480 |
right in the higher dimension okay so that's that's one of the interpretations for these linear uh 01:07:56.640 |
attentions is that they do some sort of transform to get them into a different dimension in which they 01:08:03.920 |
are linearly separable uh and if you could do that perfectly theoretically you you would get the same 01:08:13.120 |
performance as as full n squared softmax attention um i i i don't know the math well enough to know if the 01:08:22.880 |
problem is that you can't do that or i think the problem might just be that we're not learning super 01:08:31.200 |
complex transforms they're just using these like very fast to compute transforms and therefore that fast to 01:08:39.520 |
compute thing might not cover every possible scenario and in some scenarios uh even though it gets to 01:08:45.600 |
learn the queries and the keys to you know fit its purposes it's still not finding something that's 01:08:51.280 |
linearly separable uh probably just because there's so many millions of different concepts and you have 01:08:55.920 |
so many heads and at the end of the day it wants to do something and it can't it can't make everything it 01:09:01.440 |
wants to do linearly separable right it's a glass ceiling and the non-linearity that it's using 01:09:07.520 |
is basically a polynomial one through the recurrence relation is that the idea uh the math is getting 01:09:15.440 |
beyond me pretty quickly the original like rnns uh uh uh uh uh uh transformers are rnns paper that's like 01:09:24.240 |
four years old i think whatever yeah if you look at that i think that's the best reference for you'll see 01:09:29.440 |
they have this fee function that's there so like like support vector machines what you what they do 01:09:36.400 |
is they say we're going to have this kernel that transforms it but we want something that's super 01:09:41.360 |
fast to compute and so the kernel trick allows you to compute these dot products super fast for svms 01:09:47.920 |
without actually explicitly doing the transform so they they make a similar argument there where it's like 01:09:54.640 |
yes there may be all sorts of different transforms you could use we're just going to pick this one 01:09:59.360 |
because it's fast and i think ultimately that's where you get bottlenecked because you're gonna lose some 01:10:06.000 |
expressivity and also you're not training that function bespoke for whatever the the the the lm's trying to do so 01:10:18.800 |
if somehow you had a way of like picking the function per head per layer then maybe you could 01:10:26.320 |
actually get these these linear models to perform even better but the good news is that since that paper 01:10:32.240 |
since mamba one mamba two rwkb it looks like the kimmy people are saying that their linear attention 01:10:40.400 |
which has slightly different update rules and slightly different forget rules outperforms all 01:10:45.840 |
prior linear attention uh it's still not good enough you notice they didn't build it 100 kda 01:10:52.960 |
they still say one quarter regular attention but it's still better than if you did three quarters 01:10:58.320 |
mamba layers better than if you did three quarters gated delta net layers yeah okay got it i'm gonna 01:11:05.280 |
i'm gonna check out that paper you said it was transformers are rns yeah yeah okay yep something 01:11:12.640 |
like that let me see if i can find the exact thing 01:11:37.600 |
yeah i think i found it it was uh let's see icml20 01:11:43.920 |
okay yeah and there's multiple so i'm getting them confused i just found one that's the 01:11:50.720 |
fast autoregressive transformers with linear attention 01:11:54.480 |
oh is that's what i found too yeah yeah yeah so maybe that is the one that i was thinking of um 01:12:04.640 |
yep that's the one you see these these fee functions that they talk about for these transforms 01:12:11.520 |
yep so you can see they they wrap a a function fee around around all of their query and their key 01:12:22.800 |
all right awesome thank you the silence is scaring me i think i think i lost people but 01:12:32.560 |
hopefully hopefully this gives you a little better insight into kda 01:12:41.360 |
thanks ted yeah this is amazing thank you really appreciate it really thanks do you like i have a 01:12:48.960 |
i have a question actually um this is my last this is like the question um i was wondering um if 01:12:55.920 |
this means that inference for linear intention might be well better um like it does this mean that we're 01:13:07.280 |
going to move to distillation for linear models and then training for full attention during full attention 01:13:17.120 |
i was just reading some papers on that and it seemed to make sense given what you're saying like 01:13:22.720 |
i i don't know if that if if what i'm saying makes sense but yeah yeah so the kimmy people are saying 01:13:30.480 |
this can be a full drop-in replacement so it will train faster it will pre-fill faster uh i don't know 01:13:38.880 |
if it's pre-fill faster but it'll decode faster um than regular full attention so you can just completely 01:13:48.160 |
according to the paper's claims scrap your regular attention models and and replace it with one quarter 01:13:55.040 |
retention models with kda i think that claim is a bit strong and and you know proof time may tell but 01:14:03.600 |
but yes so you don't necessarily need to do anything fancy in terms of distillation or there's that you 01:14:07.680 |
can just train this instead of a regular full attention models what they're claiming 01:14:10.880 |
all right i just realized that my chat was not auto scrolling and so i thought there were no 01:14:19.040 |
messages there's like 80 bazillion messages i don't read all right on the other side um if anyone wants 01:14:34.800 |
to volunteer for next weeks i think we have a few that were posted in discord otherwise you know we'll 01:14:40.640 |
continue continue discussion and everything down there in discord next week aie new york if any of you guys are 01:14:48.160 |
in new york come to come to conference and you know we'll do something latent space in person where's 01:14:54.320 |
the discord uh you can find it on the luma or yeah it's a search latent space discord and there's a 01:15:00.400 |
paper club channel you'll find it in there cool cool awesome yeah luma slash ls and then you can you can 01:15:09.440 |
find it latent spaces all right thanks ted um and same uh ai engineer new york stuff there's a channel in 01:15:18.000 |
and disport people talking about it cool take care guys