back to indexHow does 4o ImageGen work? Visual Autoregressive Modeling paper - Best Paper @ NeurIPS

00:00:00.000 |
yeah we're good all right you guys see my screen yeah okay so um so for the var paper 00:00:10.460 |
just kind of a quick overview um so previously people said hey alums they work great uh let's 00:00:20.780 |
try just doing auto regressive on images and um you can do this at the pixel level you can do this 00:00:28.940 |
at the patch level and the most obvious first thing you might do is just raster order so i'm going to 00:00:33.980 |
start the top left corner i'm going to go across the top and keep going um and uh you can imagine 00:00:41.480 |
that this is this is great because there's a a natural sequencing for language but for images 00:00:47.440 |
it feels a little wrong right um because you you uh if you're in the middle of the image then then 00:00:55.560 |
your context is every all the pixels above you but you have no context for the pixels below you 00:01:00.320 |
right um and so what the var paper is going to suggest is that if you go from low res to high res 00:01:08.600 |
then this provides a natural uh a good inductive bias and so this is just my overview slide obviously 00:01:15.380 |
we'll get into the details of how they do all this but but so that's the big change is rather than just 00:01:21.040 |
switching from raster order to spiral to space filling curve those all still fundamentally have 00:01:25.980 |
the problem that when you're halfway done half of the content half the pixels are in your context but 00:01:31.100 |
half pixels are not okay whereas if you go from low res to high res then what you say is all of the 00:01:37.180 |
pixels but only at a low res are in my context and then so eventually we'll get global features like 00:01:43.940 |
it's a picture of a dog and so then you'll have the background is blue sky above green grass below 00:01:49.620 |
no high frequency fine details you'll just have like you know that that high level stuff 00:01:56.100 |
um and so in order to implement this uh what they used is they use variational auto encoders uh and 00:02:03.700 |
actually they use vector quantizations so it's um bqvae um and they purposefully decided to use a gpt2 like 00:02:14.500 |
llm uh which actually i think is one of the strengths of the paper is that they didn't go for the best 00:02:20.340 |
possible transformer architecture they said uh we want to show that what worked was the low to high scale 00:02:30.020 |
technique not what worked was we had a really awesome transformer okay they did make a few 00:02:35.460 |
changes to like the normalization and stuff but they they tried not to just say like let me get the best 00:02:40.340 |
um and one of the consequences when you go from the way they did low scale to high scale 00:02:46.420 |
is if i'm at a medium scale and i'm going to predict the medium high scale image 00:02:52.020 |
that medium high scale image is multiple tokens so even though we are now still going left to right 00:03:00.580 |
linearly we're not going one token at a time we're actually going to go multiple tokens at a time and 00:03:06.100 |
again this is just the overview so the details we'll see in a minute so i'll pause just to see if there's 00:03:14.100 |
any any burning questions before we get into the details how they do this um all right yeah i'm i'm 00:03:22.580 |
trying to explain vae for those who are newer to the concepts in the in the i i have a slide on it 00:03:27.540 |
although you you can explain there because it's just one slide it's not really explaining it helps with 00:03:33.140 |
images yeah yeah okay so again i i think i kind of covered this but but you can imagine so we have lms and 00:03:40.660 |
the diagram from the paper says yeah here we have this nice clear ordering the cat sat by blah blah 00:03:45.780 |
blah and you could take an image you could break it into these nine and then if you go in raster order 00:03:50.900 |
then you go one two three four or five right so again the problem is when you're uh when you've seen 00:03:57.620 |
one two three four five and you're predicting six patch six then uh patch three is in your context 00:04:04.580 |
and so so you can say okay yeah yeah but you don't have patch eight in your context and that's a little 00:04:09.380 |
bit unnatural right and you can change the ordering like i said but any one d sequence is still going 00:04:14.820 |
to fundamentally have this problem so what they say is that let's have lower res higher res images 00:04:22.580 |
and now there's sort of a logical ordering so you could see at what they call r3 here um maybe you can 00:04:30.020 |
maybe you can't tell that it's a bird a parrot at this point if you knew it was a parrot and i asked you 00:04:35.620 |
is the parrot facing to the right or the left at r3 you could probably say yeah i think this is the beak 00:04:41.780 |
and so i think it's facing to my right okay um and so so you could imagine that at r4 you're getting this 00:04:49.700 |
global context you know what color is the main body of the bird it's probably blue right there are things 00:04:55.620 |
that you can learn it's not working out the fine details and so then as you keep going more and more 00:05:01.860 |
local information is is available and so now if you're actually trying to predict specific pixels on 00:05:08.900 |
a portion of the beak you have uh from this you have a lot of more local context in addition to just the 00:05:16.340 |
general global context about this so this is a nice concept but it still begs the question 00:05:24.420 |
how do we actually implement this idea of going from low scale to high res scale by the way just stop at 00:05:32.660 |
any point and ask questions uh i i can try to monitor the chat but it's not always easy when i'm screen 00:05:38.420 |
sharing so i apologize that either somebody else give a shout if there's a question good question in the chat 00:05:44.580 |
that that i'm missing um i i have a question here um what what i'm really um i'm not sure um i i can 00:05:55.940 |
understand what is a token in an image in a in a sentence but i i what what would be a token in an image is a 00:06:04.340 |
some some some some some pixels because uh it you can have like lots of different 00:06:14.180 |
it's uh it's uh would be a really huge it's the perfect question that's the perfect question 00:06:20.820 |
okay i have this idea that i want to go from low scale to high res what the heck is the token how do 00:06:26.180 |
i tokenize this what what are the pieces so that that that's that's the the the what we need to get 00:06:35.620 |
into next that's the solution that we need to solve this is just a high level description but it doesn't 00:06:40.980 |
actually make it obvious at all how you how you pull this off so perfect question 00:06:46.740 |
okay so um a couple concepts that we're going to use in order to to explain the tokenization scheme 00:06:56.340 |
okay uh so one is auto encoder right so here we have a a simple auto encoder and you take an image 00:07:05.060 |
you pass it through some cnn layers you have a bottleneck in the middle and this can be considered 00:07:09.460 |
your embedding uh this is a deterministic embedding so you can't use a simple auto encoder as a generator 00:07:17.780 |
okay um so we have variational auto encoders which then uh take the input and you may have cnn layers and 00:07:25.220 |
whatnot up front but ultimately instead of getting a single embedding uh what you're doing is you're you're 00:07:30.980 |
uh uh uh you're you're creating a probability distribution okay so generally speaking what 00:07:39.780 |
we're doing is we're transforming it to something simple and so here we have uh in this slide uh it's 00:07:45.700 |
a multi-dimensional gaussian distribution with diagonal covariance okay and what we want is we 00:07:55.220 |
want the encoder to basically transform our data distribution into this simple gaussian distribution 00:08:01.060 |
which we know how to sample from and then when we do generation the real um real power of this variational 00:08:11.540 |
auto encoder is that the decoder knows how to do the reverse transform uh basically to take something 00:08:18.340 |
sampled from this multi-dimensional gaussian and put it into the data distribution um so i don't know 00:08:26.260 |
this is all i have i i wasn't really covering vaes in this so if you want to add any color about vaes 00:08:37.780 |
it looks like anton's giving um people are talking about vq vaes but um my my understanding 00:08:43.380 |
stops at vae and uh the only thing i'd add is that this is basically what led to latent diffusion 00:08:49.460 |
yeah yeah so so vaes were used for image generation standalone uh i don't know how many years ago right 00:08:58.500 |
like four years ago or whatever um and they weren't they weren't very good um at the time they were very 00:09:03.940 |
interesting but but ultimately you got some blurry images and you can do things to clean up the fact 00:09:08.900 |
that the images are blurry so then yes next we're going to go to um vector quantization because vaes 00:09:16.340 |
actually if i go back uh uh this uh this latent vector that you sample is is continuous okay so if we 00:09:25.700 |
say the dimensionality here is 256 you have 256 floats they're continuous uh real valued all right 00:09:32.580 |
so for anything where we're going to do a gpt style thing we need a vocabulary we need a discrete 00:09:39.860 |
distribution so we're going to use vector quantization and um the way i describe vector quantization is it is 00:09:48.420 |
just um a multi-dimensional uh uh version of rounding or quantization and so super quickly if i said you 00:09:59.220 |
have floating point numbers between zero one and rounded to the nearest hundredth okay you'd create a hundred 00:10:04.740 |
buckets and basically uh right we know how you just like at the hundredth place you just you look you 00:10:11.460 |
truncate the rest of it and you drop things into the buckets um so that would just be very simple rounding 00:10:17.380 |
into 100 equal sized buckets but if i said to you hey my data is actually not distributed uniformly 00:10:24.420 |
uh let's say it's it's you know somehow whatever like heights of people and it's kind of bell curve 00:10:29.300 |
shaped right but i want good use of my buckets i want my buckets to be have roughly equal numbers 00:10:35.460 |
of people in them then one thing you might do is you might have narrow buckets near the middle and wider 00:10:40.980 |
buckets at the ends and if you do it right then what you'll wind up with is once you see your data 00:10:47.060 |
you'll get approximately one percent of the of the people falling in every bucket so that's going 00:10:53.460 |
from just simple rounding to uh uneven sized buckets then the last thing you could do is you could say 00:11:00.580 |
rather than predetermining those i'm actually going to get some data and i'm just going to learn i'm going 00:11:05.540 |
to see based on that data uh uh how i should apportion my buckets in order to have them all be roughly 00:11:11.460 |
equal samples um so good so now we just basically have a learned hundred bucket thing for one-dimensional 00:11:17.940 |
data how do you do that for multi-dimensional data well if you have a vector of length uh 256 you just 00:11:24.020 |
do the exact same thing that you were doing in the one-dimensional case except you use let's say l2 loss 00:11:29.140 |
in order to uh figure out which bucket you're closest to and then you just shift the buckets around until 00:11:35.300 |
you get roughly equal numbers in every bucket so for me i just say that uh vector quantization is just 00:11:42.980 |
learned rounding on multi-dimensional vectors and and then for me i have the the mental intuition 00:11:51.140 |
or it's not something super fancy or anything like that yeah i see you raised hand yeah i have a 00:11:58.900 |
question so so the basic question i think it's also being as in the chat window is why do we have to go 00:12:04.820 |
towards uh vector quantization because one of the reasons for actually doing a vae is is that you have a 00:12:13.140 |
continuous distribution so that a subtle change in the input will lead to a subtle change in the output as 00:12:20.660 |
opposed to any kind of dramatic changes whereas once you move into the quantization world you lose 00:12:26.580 |
that subtlety the proportional change uh uh benefit that you actually get from vaes um i don't know if 00:12:38.020 |
i can answer the the question fully uh uh i'm not the the author of the paper but what i will say is that if if we 00:12:48.740 |
want to use a gpt style transformer we are going to have a fixed vocabulary um 00:12:55.620 |
and and so we want to discretize this if you're familiar with like whisper or these other audio 00:13:01.780 |
models what they do is they do uh residual vector quantization and so you vector quantize and there 00:13:08.660 |
is this rounding error that you're losing information so then what you can do is you can take the the the 00:13:13.700 |
error between the original and the um and your first quantization and then you can quantize 00:13:20.180 |
that error okay yeah and then you can take the leftover from that you do that four times and now 00:13:26.020 |
you've actually got a a much much more accurate approximation of your original than if you just did 00:13:31.300 |
one round of quantization and so we'll see that they don't explicitly use residual vector quantization in the 00:13:38.020 |
var paper but the process they use ends up imitating that as they go through the scales 00:13:44.900 |
from low to high res they are doing residuals and so it ends up uh kind of being like rvq 00:13:51.780 |
got it thank you and and one other question while i'm still here would you be able to make this 00:13:59.060 |
presentation available to me or to us so yeah yeah i i have a pdf on github i can share the link great 00:14:08.100 |
thanks um yeah so so again if you just think that this is just multi-dimensional rounding uh such that 00:14:18.260 |
we get equal numbers in every bucket so that we're making good use of our buckets uh then what do we call 00:14:24.820 |
the list of buckets so so right if you use equal size buckets then you just have a formula for 00:14:30.660 |
calculating the buckets but if you if you are now learning them you you have to just write them 00:14:35.300 |
down you have to store them somewhere and so by convention um the list of the buckets uh is 00:14:42.420 |
um is called a code book okay and uh rather than super fancy buckets i think what they do is they just 00:14:50.740 |
store the center of the bucket the centroid or whatever you want to call it uh and so then uh 00:14:56.260 |
you're not actually like defining the upper and lower bound of the bucket where you're just saying 00:15:01.140 |
here's the middle of the bucket here's the middle of that bucket and then uh when you want to quantize 00:15:05.860 |
something what you do is you actually compare it to every bucket and you see which has the the lowest l2 00:15:13.060 |
distance and then that's the bucket that you that you uh quantize it to that's the one that you round it to 00:15:19.540 |
um so it is i don't know slightly expensive so if you have 4096 buckets it does it is like order 4096 00:15:28.580 |
to quantize something you have to compare to every one of those and then you say yeah it was closest 00:15:33.700 |
to bucket 17 so i'm now going to throw it in bucket 17. 00:15:43.220 |
all right so now that we kind of have a little bit of information on on vaes and vqvaes um so now the 00:15:50.980 |
question is how do we tokenize how do we fit images into an llm okay um and so the solution in the var paper 00:16:00.100 |
is we're going to break the image into patches and each patch is each patch is going to go through 00:16:07.940 |
uh our our vector quantizer our vector vqvae um but what we're going to do is if it's a low res image 00:16:16.980 |
we're going to give it fewer patches and if it's a high res image we're going to give it more patches 00:16:23.940 |
um specifically in the var paper they worked on image net 256 by 256 color images uh for image net 00:16:33.060 |
and they said the high res in their case is 16 by 16 256 patches and the lowest res is one patch 00:16:42.660 |
so low res is very very low res it would really just be like the average color 00:16:50.740 |
of the entire image right just the mean so that's like extremely extremely low res and if you think 00:16:59.220 |
about it the the 16 by 16 is actually not not that high but it's a small image it's only 256 by 256 00:17:05.700 |
okay so uh they actually sped it up a little bit they didn't use every single possible size in between 00:17:14.020 |
but so you have a one by one is your lowest res then a two by two and then a three by three and then a 00:17:20.180 |
four by four and then at some point they got to like you know eight and they jumped to 10 and then they 00:17:24.420 |
jumped to 12 and whatever but that's not that's not like the schedule they used is not particularly 00:17:29.540 |
important if you just did it naively you would you would say i have 16 steps from one to 16. they had 00:17:35.860 |
i believe 10 steps if you look at the code in that uh in that var paper all right so so basically that 00:17:45.780 |
means that the lowest res image is going to get one token as it's embedding and the the token id is just 00:17:55.460 |
the bucket number uh from our vector quantizer right so i was saying if if you if you compare it to all 00:18:02.740 |
4096 and 17 is the closest one then literally what you do is you just say 17 is my is my token number um and then 00:18:13.300 |
gpt2 right it goes through the embedding layer that turns it into a dense specter it goes through however 00:18:19.220 |
many was it 12 i don't remember if they even use gpt2 small medium whatever goes through 12 transformer 00:18:24.820 |
layers and then out pops your next token prediction um you guys all know lms right so um so so basically uh 00:18:36.660 |
the bottom point here is normally our llm is predicting the next token but if we're predicting 00:18:43.060 |
the next higher resolution now that's multiple patches so when i start with just one patch i'm going 00:18:49.220 |
to predict the next image which means i'm actually predicting four tokens at once and then when i have 00:18:56.580 |
that one i have the first and the second one i have five tokens in my context i'm now going to predict the 00:19:02.340 |
the next nine tokens all at once and then when i have 14 tokens in my context i'm going to predict the next 00:19:09.300 |
16 tokens all at once okay um and basically what you can do is is you if you just change the shape of your 00:19:18.260 |
auto regressive mask instead of being purely diagonal you have it be kind of block uh diagonal then what you can do is you can say 00:19:24.340 |
then what you can do is you can say when i have five tokens in my context and i'm predicting nine more so 00:19:31.140 |
what is that six through 14 or something like that um i'm going to change my mask so that tokens seven 00:19:38.660 |
through 14 can still only see the first five they cannot see token six okay so tokens tokens six through 00:19:48.180 |
14 all have equal amount of 14 all have equal amount of keys that they can attend to equal amount of 00:19:53.540 |
information though the the fact that they are ordered in a particular order gives them no extra information 00:20:00.820 |
because we are we are creating the mask um you know slightly block wise so that yes in fact the very last 00:20:10.100 |
token the 14th token still can only attend to tokens one through five it cannot see any of the earlier tokens from 00:20:18.900 |
okay um and then in practice what we do is uh um if you're familiar with like you know pre-fill versus 00:20:28.340 |
decoding right in the pre-fill stage we we we are uh uh uh inferencing multiple positions at once 00:20:35.540 |
here there's no reason not to do that as well you could do them one at a time but it would just be 00:20:41.220 |
less efficient but that's just an implementation issue okay whether you inference the nine tokens 00:20:46.260 |
one at a time or you do them all at once because of the attention mask it's it's exactly the same 00:20:51.220 |
all right um so now the thing we're going to do is we're going to modify 00:21:02.660 |
our vector quantizer a little bit uh and we're not going to just use a vanilla vqvae uh there's a 00:21:10.500 |
special vqvae that was trained specifically for var um and uh ultimately they tried it both ways they 00:21:20.980 |
said i will just take this image and when i want low res they just did linear interpolation okay there's 00:21:28.660 |
nothing fancy here uh whether you're using like the the cv2 function or the pytorch built-in function 00:21:35.860 |
it's literally just linear interpolation so i start with my image when i want a one by one 00:21:41.780 |
you just basically say resize or linear interpolate down to one by one and it gives you you know the 00:21:48.980 |
average and if you say uh you know two by two um it's just the the the simple thing so there's no 00:21:55.860 |
there's no fancy learning going on here um uh so they tried it where they said i'm just going to do 00:22:03.220 |
one by one two by two three by three so on and so forth and every one of those i'm going to run through 00:22:08.420 |
run my patches through the vector quantizer it worked what they found worked a little bit better however 00:22:14.180 |
is after you have the one by one you can project that back to your full size image 256 by 256 and of 00:22:25.620 |
course you're it's a little bit more complicated but but you can imagine that you're then just going to 00:22:30.820 |
get a solid color for the whole thing because you only have one data point so you so when you when you 00:22:36.340 |
um upscale that you're just going to get this flat uniform thing uh what they do is they say for the two 00:22:44.340 |
by two instead of also predicting the original image i'm going to subtract what the one by one predicted my 00:22:52.420 |
full image was going to be i'm going to subtract so so if you have this mean color that you've 00:22:56.900 |
predicted i'm going to subtract that from the image and that's what i'm actually going to predict for my 00:23:02.020 |
that's what i'm going to use what i'm going to quantize for my two by two so this is where the part 00:23:07.140 |
i mentioned earlier about it's a little bit like um residual vector quantization right so then after 00:23:14.020 |
you have a two by two that's been vector quantized you then upscale that back to 256 subtract that from 00:23:23.460 |
what you have and so you're successively doing this so when they did it in 10 steps what it means is 00:23:28.660 |
that the the last step that's the 16 by 16 patches they're not predicting the original image they're 00:23:36.740 |
predicting the leftover after all the previous nine uh quantizations have done their job okay and so 00:23:45.060 |
you can imagine like now you're really getting into i'm just predicting fine uh details looking at 00:23:53.220 |
the chat numbers going up and i apologize if if i should be answering questions or if you guys are just 00:23:58.500 |
super chatty but i'm like 44 freaking messages uh exactly if it's important someone will interrupt but 00:24:06.340 |
that's all okay uh coding which is always fun for images 00:24:11.940 |
uh i don't actually remember what they did for positional encodings off the top of my head sorry 00:24:21.860 |
i could go look at the paper real quick but yes there have to be positional codings because 00:24:31.220 |
and um i think you mentioned that uh in the last step they're predicting the deltas from the previous 00:24:38.900 |
step right i wonder do you have an intuition of why um a vq vae would work better here as opposed 00:24:45.380 |
to a residual vae because it seems that residual would work better with predicting residual but i i don't know 00:24:52.260 |
um i'm not familiar with the what you're saying residual vae um yeah um i mean the idea here is that 00:25:02.660 |
you you have the the the the vae but in this case we're quantizing it okay and so um 00:25:13.940 |
so you have you have two forms of error uh so to speak you have the fact that you 00:25:21.620 |
have downscaled this a lot okay and then you upscale it back and so then you've lost a lot of information 00:25:29.780 |
there and then the other error you have is the fact that when you downscaled it you then quantized it so 00:25:35.620 |
quantized it so you moved it a little bit and then you upscaled that sucker that you moved back 00:25:40.900 |
yep so um so i don't know how to answer your question because i don't know that that other 00:25:46.580 |
residual v but i can tell you that what so what they're doing is the combination of those two errors 00:25:51.460 |
is what the next iteration is then trying to uh uh uh encode okay gotcha thank you and then so the two 00:26:01.300 |
by two is going to downscale quantize have those both kinds of errors created uh when you upscale 00:26:07.460 |
that back and then the three by three is going to just look at the leftover errors from both levels one 00:26:13.940 |
and two and then the four by four is going to look at the leftover errors after those three and so on and 00:26:19.220 |
so forth so um so this this vqvae one of the key things though is that the code book used 00:26:29.460 |
when you do level one versus the last level the 16 by 16 it's using the exact same code book 00:26:37.700 |
so uh i've seen different people ask you you potentially could have a more accurate quantizer if 00:26:46.740 |
you use customized code books for the one by one level versus the 16 by 16 level but since we want to 00:26:53.540 |
feed all of these into our same gpt model that's why we we're forced to use um same code book for for 00:27:01.300 |
all the different levels and and so basically then uh um uh the the vqvae when it's learning its code book 00:27:13.540 |
has to decide on uh uh codes that work well both at the high res and and at the low res and it has 00:27:23.060 |
to come up with some sort of compromise because these these these codes are going to be shared 00:27:27.780 |
everywhere in my head my intuition is that even though the the code book has to be shared and it may 00:27:34.660 |
not be optimized for low and high res still what's happening is you're going to get these broad global 00:27:40.980 |
features at lower resolutions you're going to get the sort of low frequency information um encoded 00:27:48.820 |
and as you get to the last and that's where the high frequency features so the the details of the grass 00:27:55.140 |
you know the the texture of the fur that's going to not happen until until probably quite late in the in 00:28:02.740 |
the quantization process and so this this next slide i wasn't planning on spending a lot of time on it but 00:28:09.620 |
what you can see just from this algorithm is that you input an image and it's going to loop in this case 00:28:16.580 |
10 times it doesn't matter if it's 10 or 16 times going from the one by one up to the 16 by 16 and and 00:28:23.060 |
the key thing though is that they have this queue that they're sticking on but you get all of those those 00:28:28.420 |
embeddings at the different layers at once this is not like a separate kind of a thing you you 00:28:34.020 |
do the multi-resolution quantization in one fell swoop you get all the resolutions simultaneously out of 00:28:43.300 |
after you you run your for loop okay um so it's a it's a package deal okay you so this is a dedicated multi-scale 00:28:53.060 |
vqvae that gives you all your resolutions at once and then reconstruction uh also has to be an iterative 00:29:01.300 |
process you cannot say let me do the reconstruction at the eight by eight level without the information 00:29:08.260 |
about the other ones because this is a residual process you have to have all the earlier ones 00:29:13.780 |
and then you sum all the predictions from all the earlier ones with your predictions in order to get your 00:29:20.100 |
your final prediction it's just like in an llm you you couldn't say what is the output of just the eighth 00:29:26.980 |
layer right it's a residual stream so you have to have layers one through seven in order to know what 00:29:32.500 |
the output of that they are i didn't say that quite right what's the residual stream after the eighth 00:29:38.740 |
layer you can say what the output of just the eighth layer is but you can't say what's the residual stream 00:29:43.540 |
after the eighth layer unless you also have the first seven layers so it's the same thing here you can't 00:29:48.260 |
say what's the eighth resolution unless you also have the first one all right so just to clarify what the 00:29:58.020 |
training process looks like is the multi-scale vqvae is trained separately so you in this case it was 00:30:05.060 |
image net so you give it a bunch of image net and it tries to um uh encode these images in a way that 00:30:13.780 |
that when they're decoded the l2 loss is the look the lowest um and it's going to do this as best it 00:30:19.300 |
can and then it's going to be frozen so the actual lm part did not involve any of this training at this 00:30:28.980 |
point this this sucker is completely frozen then what we do is we say given a fixed vqvae i'm now going 00:30:36.660 |
to train my actual lm on this idea of i'm going to give it one patch it's going to predict four i'm 00:30:44.340 |
going to give it one plus four patches it's going to predict nine then it's going to predict 16 then 00:30:48.980 |
it's going to predict 25 up until at the very end for the 16 by 16 it has all the context of all of the 00:30:55.620 |
earlier resolutions and it's going to predict 256 tokens in one fell swoop um so in this case it 00:31:09.460 |
was trained on image net and so if you're wondering for the very beginning of the process what is the 00:31:15.380 |
prompt uh there's one token that encodes the class so there's a thousand classes in image net and so i 00:31:21.780 |
don't remember i think it was embedded but basically so you start with a class token and it predicts 00:31:27.140 |
the one by one and then you have the class plus one by one and it predicts two by two and then so on and 00:31:32.820 |
so forth um and this this predict that this repeats until you get the the final the 256 tokens and then for 00:31:41.300 |
your actual image you need to sum up um all of this and have the decoder turn that into uh pixel values 00:31:50.500 |
because this is this is still a very special vqvae and you still need the decoder part here 00:31:57.940 |
so all of this stuff that's happening with the multiple scales this is all happening in latent space 00:32:03.940 |
this is not predicting pixels this is predicting latence okay so when we when we have tokens that are 00:32:10.260 |
embedding from our our our vqvae vocabulary these are all vocabularies in the latent space 00:32:18.420 |
um and that's that's uh pretty much it so then you can see uh some of the results here and um they had 00:32:27.540 |
really good image quality um if you guys are familiar with a um fid distance inception score and then they also said 00:32:36.260 |
said because we were able to do inference in chunks uh we actually had many fewer steps than if you did 00:32:47.380 |
auto regression naively so if you did one patch at a time you would have 256 inference steps they had 10 00:32:55.460 |
inference steps going from 1 to 16 by 16 in just 10 steps maybe those individual steps were a little bit more expensive each 00:33:03.220 |
which maybe not not i haven't really done the math um but you can imagine that 10 inference steps is 00:33:09.220 |
still going to be a lot cheaper than 256 inference steps 00:33:12.740 |
and i think oh they also uh talked a little about scaling so um if you look at some of these other 00:33:21.620 |
models uh the performance improves as they get bigger and then they sort of hit a wall 00:33:26.500 |
um and in this case they said uh yeah ours didn't hit a wall they didn't really run this over a very 00:33:32.580 |
large uh uh number of orders of magnitude so it's still tbd whether or not as you really get bigger i 00:33:40.660 |
don't blame them for not having the resources they're not a google okay um but nevertheless the fact that it 00:33:48.020 |
it didn't hit the wall is good but it's not really proved yet until you really get to probably like 00:33:55.700 |
more like seven billion parameter scale or something like that uh but just if you just if you compare 00:34:01.620 |
for example dit if you looked at these first three data points you'd like yeah this thing scales great 00:34:06.900 |
and then you get the first fourth data point and it just fundamentally uh architecturally hits a wall 00:34:12.580 |
okay so they're saying var hasn't hit a wall but we still don't know that it's not going to hit a wall 00:34:17.140 |
on the next data point you know you really have to right so so the ultimate proof is in the pudding 00:34:22.020 |
until you build try to build a gpt4 size thing or whatever you know you just don't um whether or not 00:34:29.300 |
you're going to get a wall for images maybe it's not gp4 size but it's the same concept that um it's good 00:34:35.380 |
to be asking that question but i just don't think that they quite did it over a large enough spread 00:34:42.420 |
um some sample images and here uh if you look you can see um uh this is uh left is earlier in 00:34:52.900 |
training right is later in training and this is scaling up the size of the lm that they used um inside 00:35:00.340 |
var and so obviously the bigger lm does better and then i i don't know that early in training is really 00:35:06.740 |
that important but you know it does learn better and so uh late in training it does it does better 00:35:12.260 |
better quality images that's not surprising so to conclude uh the the key things that they're 00:35:19.380 |
selling us on why they think this paper is good why they think this technology is good is that um it's 00:35:24.660 |
fast uh it's much faster than auto regression in in uh more naive ways it generates very high quality 00:35:32.820 |
images um and they say that it has the right inductive bias that going from low scale to high scale 00:35:38.740 |
um there is a very clear 1d ordering nobody would argue that somehow medium res should come before low 00:35:45.860 |
res everyone's going to agree on the inductive bias whether or not that inductive bias proves to 00:35:51.700 |
encapsulate everything that we need you know i would not have necessarily guessed before gpt2 that next 00:35:58.740 |
token prediction was sufficient to be able to generate really complex you know mathematical proofs or something 00:36:05.300 |
like that right so i i don't have the ability to guess whether or not low res to high res is enough 00:36:11.220 |
information for it to be able to um uh do whatever but because it has a very uh uh i don't know intuitive 00:36:21.460 |
inductive bias then it does mean that you can do in painting you can do out painting you can you can 00:36:27.460 |
fill in any direction um with you know a fixed auto regressive patch order you can't if you if you 00:36:34.660 |
if you do raster scan and you give it just the bottom of the image and you tell it give me fill in the top 00:36:38.900 |
of the image you can't do that it can only go from top to bottom so this one because of the way it works 00:36:43.620 |
it can it can go in any direction in out mass whatever whatever um so byte dance did uh share the code here 00:36:52.580 |
and um uh they uh also they have uh uh two follow-up papers that are um that i've seen maybe there's 00:37:04.740 |
more uh they they had an infinity paper um they did video they've also done um this was just image net now 00:37:13.380 |
they've done text to image like that that's a pretty obvious thing and there's a um xar instead of var paper 00:37:21.860 |
now where um they generalized instead of just scale they say you can have arbitrary um uh uh precedence 00:37:33.700 |
um and then they also added some stuff that i haven't quite worked my way through where they're 00:37:39.060 |
um matching as a what seems like a a little carrot on top of the auto regressive uh learning to to make 00:37:49.220 |
one other thing i have if you're reading the cheat sheet there's a lot of terminology and so i did 00:37:56.660 |
sorry if you're reading the paper there's a lot of terminology and so i did actually make a cheat sheet 00:38:03.540 |
um uh i'll share the link to the pdf but so basically um they have images and then they talk 00:38:10.500 |
about their f um and you're doing this on individual patches at individual scales um and so then you have 00:38:18.500 |
your tokens here um that are these cues which combine all the cues together make an image r or rather a 00:38:26.260 |
latent r at a given um resolution and then when you're decoding then you go back and you get your f hats and 00:38:32.580 |
eventually you get your chat so that's useful all right so i don't know if you guys want to ask me 00:38:43.540 |
questions or if there's other discussion you guys want to have um i mean first off thank you i mean it's 00:38:52.260 |
great uh that you volunteered in the last minute and i've charted this uh this entire session uh and also 00:39:01.060 |
also thanks for the pdf file it looks uh the all the slides look amazing 00:39:04.660 |
i i'm just thinking in terms of does it even make sense to think or rationalize var 00:39:12.740 |
against how diffusion models work is that going to help me because some of these are rat hole 00:39:19.300 |
things that i don't want to get into but do you think in your based on your experience does it help to understand or rationalize var 00:39:28.020 |
var by comparing it against diffusion models because even in the diffusion models you have a sense of 00:39:33.700 |
uh resolution increments that you see as as uh in practice although behind the scenes uh the the thing 00:39:44.740 |
the thing that is driving is actually the uh differential equations 00:39:47.780 |
um i think there's there's one thing that i think personally is a very important trend from var 00:40:02.820 |
and that's the idea of predicting multiple tokens at once um 00:40:07.380 |
the when when we do next token prediction we're sort of fixing the information content first step 00:40:21.300 |
okay and you know for me if if you have words like the and whatever um there's not a lot of information 00:40:29.620 |
content there but if you're processing code or you're processing like what is you know 12 times 15 or 00:40:37.460 |
whatever and you're outputting digits or something like that there's like really high information 00:40:43.220 |
density and and you cannot get it slightly wrong or you're just dead um and so whether it's scaling up 00:40:51.940 |
or scaling down or whatever but this idea that we can have variable information content so i think the the 00:40:57.620 |
the the byte latent paper from meta and and then and to a certain extent the large concept model uh for 00:41:04.900 |
me one of the things that they all all these papers have in common is that they're addressing the idea 00:41:10.660 |
that the information content may not be uniform token by token by token um and so i think that the idea that 00:41:18.500 |
maybe if we're doing reasoning or doing other things that we can we can embed things with multiple tokens 00:41:27.060 |
that allows us to to get more information content into one inference step uh and so i think there's 00:41:34.740 |
a huge opportunity to either make transformers faster language models faster or have reasoning better if 00:41:41.940 |
they have certain capabilities to dial up and down so i think the idea of predicting multiple tokens at once 00:41:47.860 |
i think there will be a lot of successful research playing with that uh the xar paper already says that 00:41:56.820 |
maybe scale is not the one and only way to do it um uh but they they are looking at this idea and so 00:42:06.900 |
diffusion has like really strong mathematical foundations um but we know that transformers are just 00:42:17.140 |
really full so uh so there's no reason why like the transformer can't learn the score function that 00:42:26.500 |
diffusion models are learning the real question which i i don't know is just kind of who can do it in the 00:42:34.660 |
fewest steps uh but it seems like if if if a diffusion model uh using a unit or replacing the unit with a 00:42:43.700 |
transformer can learn the score function then if you gave it something similar a progression of images then 00:42:50.020 |
you should be able to learn um that same score function using attention um so yeah so i think 00:42:56.500 |
that this is something to pay attention to but not necessarily that it has to be literally just 00:43:01.540 |
the scale technique that these guys use um but um but yeah i i would pay attention to this idea of 00:43:09.220 |
of um that multiple people are approaching of of changing around our tokenization whether it's bite 00:43:18.500 |
latent lcm this or something like that so there's a answer sorry uh just to answer the question um more 00:43:27.460 |
directly though at least my understanding uh whether there's something to take away from this uh for 00:43:32.660 |
whether understanding diffusion helps um i think to me this is an example of uh an earlier paper both both 00:43:43.380 |
diffusion and this approach of scale prediction is an example of um an earlier paper i think it's by 00:43:50.180 |
such cover where the idea is you actually don't need to solve an ode or something if you can train a 00:43:57.140 |
model to denoise for an arbitrary um uh corruption um either adding gaussian noise or maybe in this case you 00:44:06.900 |
can interpret the next scale like down sampling prediction as uh uh a corruption where you you actually 00:44:13.780 |
think of it backwards you have a high resolution image and you've corrupted it by sampling it down 00:44:18.260 |
and you're trying to predict backwards um like all of these are viable ways to think about the the image 00:44:25.300 |
generation process so in this case it's a combination of using an llm for attention uh 00:44:31.460 |
for being able to predict the next scale rather than raster scan plus the fact that you can you 00:44:37.380 |
need some sort of tokenization to actually use the llm in the first place other than that it's kind of 00:44:42.660 |
similar and it seems to work really well yeah thanks 00:44:46.980 |
regarding that point of multi-token prediction meta put out a paper last year 00:44:57.140 |
about training llms with multiple tokens they did it during training and not inference and then they 00:45:03.460 |
mentioned in llama 3 that they use this technique in ted basically all your points there for sample 00:45:08.580 |
efficiency or like spot on they they go into that um in terms of at inference time this is what um this 00:45:16.500 |
is kind of what motivated a lot of the work behind speculative decoding so speculative decoding is kind 00:45:21.540 |
of multi-token prediction with a small model and that idea came out of research that you know they 00:45:27.620 |
started off with let's just predict multiple tokens and see how that goes and that led down to the path of 00:45:37.540 |
i have a question on the on the training time of this one so this is this approach seems clearly 00:45:46.180 |
very fast at inference time because it can predict so many tokens in parallel um but did they put any 00:45:53.060 |
mention in and how long it takes to train something also in comparison for example to diffusion models 00:46:00.580 |
i that's a great question i don't remember them discussing that in the paper 00:46:04.820 |
or maybe i just didn't pay close attention these are actually a lot bigger models than diffusion models 00:46:12.660 |
and their inference is a lot slower it's like you trial try 4-0 in mid-generation it takes on the order 00:46:18.820 |
of you know 15-20 seconds and it's significantly larger um the big thing with autoregressive generation 00:46:24.420 |
is you kind of scale up a lot more than you do with diffusion right we have really small local 00:46:29.460 |
diffusion models like uh your iphone locally can do genmoji diffusion but autoregressive models just 00:46:36.020 |
at you know base layer they're larger than diffusion models now the inference tricks of like um you know 00:46:43.380 |
lcm laura where you can skip a bunch of steps for diffusion those haven't been super applied to image 00:46:48.740 |
generation right so uh var is for autoregressive image generation don't have speculative decoding yet 00:46:55.380 |
or maybe they do but you know there's a lot of inference optimization that hasn't hit yet so 00:46:59.620 |
there's a room to grow but base level we expect diffusion models to stay small uh autoregressive 00:47:07.940 |
models kind of get big as they generalize and yeah there's there's a lot of inference optimization that 00:47:13.860 |
hasn't hit yet yeah i think that the optimization is a key point um that that fibo shared right so 00:47:20.980 |
you compare this to like the very first diffusion models they were doing i don't remember either 00:47:24.500 |
hundreds or thousands of steps um and we've dramatically uh improved that and so um if this is a useful 00:47:33.060 |
technique then you would expect that there's going to be a bunch of optimizations improvements upon it 00:47:38.500 |
this is just day one that's a minor point i guess uh you mentioned 4o image generation being slow i'm 00:47:46.980 |
actually not sure whether uh it's so they're clearly doing var but i'm not sure they're doing the same 00:47:52.740 |
var because you can imagine that if you want to do text to image or sort of this kind of multimodal 00:47:59.620 |
thing you want to actually um fine tune your text model to actually intake the same tokens and alert 00:48:07.780 |
like teach it to do the task all with one model rather than uh doing it with separate models um so 00:48:13.540 |
it could be that that is the reason like if you just had a dedicated var model you could you can imagine 00:48:18.740 |
it being pretty good anyway a question about this actually and i think we had a short thread in discord 00:48:26.500 |
about this it doesn't really seem practical to like completely jointly train one model like you know 00:48:35.140 |
if you have the sort of 4o architecture or something like that it probably has to be just a post training 00:48:40.660 |
thing right and then yeah so you need some way to integrate that with like the separate 00:48:46.580 |
like var architecture that you've already trained and i think rj and in discord i'm not sure if he's here 00:48:54.500 |
had a plausible explanation that like you have some reserve token that you use to like delineate the 00:49:02.180 |
image tokens from the text tokens and then the only thing you need to do in like post training is like 00:49:07.780 |
the chat model needs to learn to like say okay now this is um an image or something and then it gives 00:49:14.980 |
you text tokens and you take the text tokens and you have a text conditional var and that generates 00:49:19.140 |
image or something like that which kind of pencils out to me but i am i'm very curious practically how 00:49:24.580 |
they do something like that but i don't know i think you could do it with tool use like you're describing 00:49:29.940 |
but um people fine tune on tasks all the time when they for example you want to add like a tabular or 00:49:36.900 |
like a numeric modality or like an image modality to a language model you do your pre-training on a big like 00:49:43.860 |
like text corpus and then you have a much smaller like modality corpus that you sort of uh train your 00:49:50.900 |
encoder to to do uh the quantization so that you can actually have a image token but then you sort of 00:49:59.700 |
fine-tune your model to understand image tokens and text tokens all in the same sequence but then then you 00:50:06.740 |
don't actually need to do uh co-train everything it is like a post-training thing um and you hope that 00:50:14.180 |
you don't lose the text capability when you gain the image capability um but i think apple has a bunch 00:50:21.940 |
of uh papers on basically anything to anything prediction um like i think it's called like 4m or 00:50:28.740 |
something um uh i don't know how well it works but meta meta meta has a um the segment anything then 00:50:37.620 |
there's another one that's six modalities in one audio video depth meta also has chameleon which is 00:50:44.580 |
native training so separate than like lava they have a native image language model i guess apple might too 00:50:52.980 |
but i if there's i'm thinking yeah i'm thinking of the forum uh series of papers which i think is apple 00:50:59.620 |
but again i don't know how well it works um anyway i don't know how anything works in practice at these 00:51:05.860 |
scales uh but yeah that's my guess regardless you need some sort of like tool use thing to like say okay 00:51:13.460 |
now this is an image right and then you generate the image tokens and then you like end the image and 00:51:19.300 |
keep generating text or something like that or whatever is that right but you may need that just 00:51:24.740 |
for the ui essentially the the lm model may not care um and whether you use tool use to sort of 00:51:32.100 |
dispatch to a new model that's clearly how it used to work because you could basically trick one model into 00:51:37.540 |
the uh to uh generating some description that then you know would fail downstream but if it's all one 00:51:43.620 |
model uh it's all in the same space and that allows you to do uh image understanding whereas previously 00:51:49.780 |
like these models didn't understand what was in the image that they generated themselves so that's 00:51:55.220 |
maybe a test a way to check whether like if you ask to generate an image of like a tech bro with uh 00:52:00.580 |
and say no glasses and it generates glasses you can ask it whether it contained glasses 00:52:07.060 |
yeah yes no no that's an issue and i'm sure it does know that yeah and also just just moving this 00:52:14.340 |
into a different track of conversation in the case of diffusion models lots of things have evolved like 00:52:20.740 |
there are control nets and lauras and image prompts and whatnot i wonder what would happen in this world 00:52:27.540 |
in this universe to get that kind of control because i work in the movie industry and and one of the things 00:52:33.620 |
that we uh actually care a lot about is consistency like character consistencies and so on and control 00:52:41.540 |
nets are the current way of dealing with all of that so so do you know if there is any literature 00:52:47.700 |
around var's along those lines i don't think there's any literature yet but but but like uh open ai uh 00:52:57.940 |
they let you refine the image and there's there's clear uh character consistency when you do multi-step 00:53:05.780 |
refinement of your image so it seems like they've definitely had that in mind yeah they actually call 00:53:12.740 |
that out in the in the blog post they have about the 4-0 image capability and they the i think what i recall 00:53:21.460 |
is that they claim that uh that's because of this auto aggressive nature of it right so you have the 00:53:27.940 |
history so that the attention can pay attention to the previous image so you kind of get it for free 00:53:32.900 |
that was my interpretation of what they were saying yeah then this suggests that it is one model and 00:53:39.460 |
that's cool or i guess it could be one uh change model yeah no i i guess so my point in the discord was 00:53:47.220 |
actually um that i i just it's unclear to me whether there are uh reserve tokens that are that it's 00:53:54.100 |
trained on in post training or if it's if they're just reusing the text tokens but with some sort of 00:53:59.300 |
special token to say okay now i'm generating an image and then now i'm not generating an image but 00:54:04.580 |
like i don't think it matters too much but you might like lose less of the text capability if you use 00:54:10.100 |
reserve tokens yeah if i had to bet that i i would bet on um them using reserve tokens um because you 00:54:18.020 |
don't need that many it's a pretty small vocabulary on your code book um and exactly your point uh you 00:54:24.260 |
don't want any any forgetting yeah i don't well normally you don't um reuse tokens well i i was not aware 00:54:34.100 |
that people do uh reuse tokens ever uh you just uh you have your dictionary for for text tokens and then 00:54:41.780 |
you have your implicit dictionary through your uh encoder that you use your vqvae that is now your image 00:54:48.980 |
dictionary and they're like they don't overlap um so yeah you just use those and your model learns to 00:54:58.020 |
sort of uh naturally use whatever it needs yeah i think there were some early multimodal llms where 00:55:06.900 |
you you you had like a beginning you know image token and then you reuse the same space uh but yeah i think 00:55:15.780 |
we're all saying the same thing that that's probably not the the conventional wisdom now it's just use 00:55:22.980 |
separate tokens and then and then and then you can have uh whatever hypothetically just an orthogonal 00:55:29.940 |
corner of your embedding space yeah but those tokens i so i totally agree those tokens are 00:55:34.820 |
delimination uh delimiting tokens for example uh when you train a lm model for fill in the middle tasks 00:55:42.180 |
even though it can only generate one way right you use like a special you add new tokens that say like 00:55:48.260 |
now you're doing the fill in the middle task and here's the beginning and here's the end and now 00:55:52.740 |
start generating the the whole essentially uh you create new tokens for those delimiters 00:55:58.020 |
but the vocabulary that it it like learns to use the right vocabulary 00:56:10.420 |
well thank you so much ted for yeah again we were just gonna casual discuss slides are always appreciated 00:56:16.980 |
and uh good to see you here for next week anyone have a topic they want to discuss they want to volunteer 00:56:24.500 |
they want someone else to volunteer they want to do half a paper 00:56:27.220 |
i want someone to volunteer the um the stuff that just came out of anthropic on 00:56:34.900 |
uh okay i might do that since i've done the other three so unless someone else wants to do it i'll do 00:56:40.740 |
the fourth one awesome okay okay i i might get someone from anthropic to join us we'll see uh otherwise 00:56:48.100 |
next week is anthropic if anyone hasn't seen my favorite mean uh matthew berman he said you know 00:56:54.900 |
this is the most surprise he's ever been but he was just as surprised in the thumbnail as always 00:57:00.660 |
it's very sad very sad we have a running discord of all his thumbnails 00:57:03.860 |
but okay i'll share i'm gonna make the thumbnail for this youtube ah okay good all right we're uh 00:57:15.700 |
we got the thumbnail cool thanks guys see you next week and thank you ted thank you ted thank you