back to index

How does 4o ImageGen work? Visual Autoregressive Modeling paper - Best Paper @ NeurIPS


Whisper Transcript | Transcript Only Page

00:00:00.000 | yeah we're good all right you guys see my screen yeah okay so um so for the var paper
00:00:10.460 | just kind of a quick overview um so previously people said hey alums they work great uh let's
00:00:20.780 | try just doing auto regressive on images and um you can do this at the pixel level you can do this
00:00:28.940 | at the patch level and the most obvious first thing you might do is just raster order so i'm going to
00:00:33.980 | start the top left corner i'm going to go across the top and keep going um and uh you can imagine
00:00:41.480 | that this is this is great because there's a a natural sequencing for language but for images
00:00:47.440 | it feels a little wrong right um because you you uh if you're in the middle of the image then then
00:00:55.560 | your context is every all the pixels above you but you have no context for the pixels below you
00:01:00.320 | right um and so what the var paper is going to suggest is that if you go from low res to high res
00:01:08.600 | then this provides a natural uh a good inductive bias and so this is just my overview slide obviously
00:01:15.380 | we'll get into the details of how they do all this but but so that's the big change is rather than just
00:01:21.040 | switching from raster order to spiral to space filling curve those all still fundamentally have
00:01:25.980 | the problem that when you're halfway done half of the content half the pixels are in your context but
00:01:31.100 | half pixels are not okay whereas if you go from low res to high res then what you say is all of the
00:01:37.180 | pixels but only at a low res are in my context and then so eventually we'll get global features like
00:01:43.940 | it's a picture of a dog and so then you'll have the background is blue sky above green grass below
00:01:49.620 | no high frequency fine details you'll just have like you know that that high level stuff
00:01:56.100 | um and so in order to implement this uh what they used is they use variational auto encoders uh and
00:02:03.700 | actually they use vector quantizations so it's um bqvae um and they purposefully decided to use a gpt2 like
00:02:14.500 | llm uh which actually i think is one of the strengths of the paper is that they didn't go for the best
00:02:20.340 | possible transformer architecture they said uh we want to show that what worked was the low to high scale
00:02:30.020 | technique not what worked was we had a really awesome transformer okay they did make a few
00:02:35.460 | changes to like the normalization and stuff but they they tried not to just say like let me get the best
00:02:40.340 | um and one of the consequences when you go from the way they did low scale to high scale
00:02:46.420 | is if i'm at a medium scale and i'm going to predict the medium high scale image
00:02:52.020 | that medium high scale image is multiple tokens so even though we are now still going left to right
00:03:00.580 | linearly we're not going one token at a time we're actually going to go multiple tokens at a time and
00:03:06.100 | again this is just the overview so the details we'll see in a minute so i'll pause just to see if there's
00:03:14.100 | any any burning questions before we get into the details how they do this um all right yeah i'm i'm
00:03:22.580 | trying to explain vae for those who are newer to the concepts in the in the i i have a slide on it
00:03:27.540 | although you you can explain there because it's just one slide it's not really explaining it helps with
00:03:33.140 | images yeah yeah okay so again i i think i kind of covered this but but you can imagine so we have lms and
00:03:40.660 | the diagram from the paper says yeah here we have this nice clear ordering the cat sat by blah blah
00:03:45.780 | blah and you could take an image you could break it into these nine and then if you go in raster order
00:03:50.900 | then you go one two three four or five right so again the problem is when you're uh when you've seen
00:03:57.620 | one two three four five and you're predicting six patch six then uh patch three is in your context
00:04:04.580 | and so so you can say okay yeah yeah but you don't have patch eight in your context and that's a little
00:04:09.380 | bit unnatural right and you can change the ordering like i said but any one d sequence is still going
00:04:14.820 | to fundamentally have this problem so what they say is that let's have lower res higher res images
00:04:22.580 | and now there's sort of a logical ordering so you could see at what they call r3 here um maybe you can
00:04:30.020 | maybe you can't tell that it's a bird a parrot at this point if you knew it was a parrot and i asked you
00:04:35.620 | is the parrot facing to the right or the left at r3 you could probably say yeah i think this is the beak
00:04:41.780 | and so i think it's facing to my right okay um and so so you could imagine that at r4 you're getting this
00:04:49.700 | global context you know what color is the main body of the bird it's probably blue right there are things
00:04:55.620 | that you can learn it's not working out the fine details and so then as you keep going more and more
00:05:01.860 | local information is is available and so now if you're actually trying to predict specific pixels on
00:05:08.900 | a portion of the beak you have uh from this you have a lot of more local context in addition to just the
00:05:16.340 | general global context about this so this is a nice concept but it still begs the question
00:05:24.420 | how do we actually implement this idea of going from low scale to high res scale by the way just stop at
00:05:32.660 | any point and ask questions uh i i can try to monitor the chat but it's not always easy when i'm screen
00:05:38.420 | sharing so i apologize that either somebody else give a shout if there's a question good question in the chat
00:05:44.580 | that that i'm missing um i i have a question here um what what i'm really um i'm not sure um i i can
00:05:55.940 | understand what is a token in an image in a in a sentence but i i what what would be a token in an image is a
00:06:04.340 | some some some some some pixels because uh it you can have like lots of different
00:06:14.180 | it's uh it's uh would be a really huge it's the perfect question that's the perfect question
00:06:20.820 | okay i have this idea that i want to go from low scale to high res what the heck is the token how do
00:06:26.180 | i tokenize this what what are the pieces so that that that's that's the the the what we need to get
00:06:35.620 | into next that's the solution that we need to solve this is just a high level description but it doesn't
00:06:40.980 | actually make it obvious at all how you how you pull this off so perfect question
00:06:46.740 | okay so um a couple concepts that we're going to use in order to to explain the tokenization scheme
00:06:56.340 | okay uh so one is auto encoder right so here we have a a simple auto encoder and you take an image
00:07:05.060 | you pass it through some cnn layers you have a bottleneck in the middle and this can be considered
00:07:09.460 | your embedding uh this is a deterministic embedding so you can't use a simple auto encoder as a generator
00:07:17.780 | okay um so we have variational auto encoders which then uh take the input and you may have cnn layers and
00:07:25.220 | whatnot up front but ultimately instead of getting a single embedding uh what you're doing is you're you're
00:07:30.980 | uh uh uh you're you're creating a probability distribution okay so generally speaking what
00:07:39.780 | we're doing is we're transforming it to something simple and so here we have uh in this slide uh it's
00:07:45.700 | a multi-dimensional gaussian distribution with diagonal covariance okay and what we want is we
00:07:55.220 | want the encoder to basically transform our data distribution into this simple gaussian distribution
00:08:01.060 | which we know how to sample from and then when we do generation the real um real power of this variational
00:08:11.540 | auto encoder is that the decoder knows how to do the reverse transform uh basically to take something
00:08:18.340 | sampled from this multi-dimensional gaussian and put it into the data distribution um so i don't know
00:08:26.260 | this is all i have i i wasn't really covering vaes in this so if you want to add any color about vaes
00:08:37.780 | it looks like anton's giving um people are talking about vq vaes but um my my understanding
00:08:43.380 | stops at vae and uh the only thing i'd add is that this is basically what led to latent diffusion
00:08:49.460 | yeah yeah so so vaes were used for image generation standalone uh i don't know how many years ago right
00:08:58.500 | like four years ago or whatever um and they weren't they weren't very good um at the time they were very
00:09:03.940 | interesting but but ultimately you got some blurry images and you can do things to clean up the fact
00:09:08.900 | that the images are blurry so then yes next we're going to go to um vector quantization because vaes
00:09:16.340 | actually if i go back uh uh this uh this latent vector that you sample is is continuous okay so if we
00:09:25.700 | say the dimensionality here is 256 you have 256 floats they're continuous uh real valued all right
00:09:32.580 | so for anything where we're going to do a gpt style thing we need a vocabulary we need a discrete
00:09:39.860 | distribution so we're going to use vector quantization and um the way i describe vector quantization is it is
00:09:48.420 | just um a multi-dimensional uh uh version of rounding or quantization and so super quickly if i said you
00:09:59.220 | have floating point numbers between zero one and rounded to the nearest hundredth okay you'd create a hundred
00:10:04.740 | buckets and basically uh right we know how you just like at the hundredth place you just you look you
00:10:11.460 | truncate the rest of it and you drop things into the buckets um so that would just be very simple rounding
00:10:17.380 | into 100 equal sized buckets but if i said to you hey my data is actually not distributed uniformly
00:10:24.420 | uh let's say it's it's you know somehow whatever like heights of people and it's kind of bell curve
00:10:29.300 | shaped right but i want good use of my buckets i want my buckets to be have roughly equal numbers
00:10:35.460 | of people in them then one thing you might do is you might have narrow buckets near the middle and wider
00:10:40.980 | buckets at the ends and if you do it right then what you'll wind up with is once you see your data
00:10:47.060 | you'll get approximately one percent of the of the people falling in every bucket so that's going
00:10:53.460 | from just simple rounding to uh uneven sized buckets then the last thing you could do is you could say
00:11:00.580 | rather than predetermining those i'm actually going to get some data and i'm just going to learn i'm going
00:11:05.540 | to see based on that data uh uh how i should apportion my buckets in order to have them all be roughly
00:11:11.460 | equal samples um so good so now we just basically have a learned hundred bucket thing for one-dimensional
00:11:17.940 | data how do you do that for multi-dimensional data well if you have a vector of length uh 256 you just
00:11:24.020 | do the exact same thing that you were doing in the one-dimensional case except you use let's say l2 loss
00:11:29.140 | in order to uh figure out which bucket you're closest to and then you just shift the buckets around until
00:11:35.300 | you get roughly equal numbers in every bucket so for me i just say that uh vector quantization is just
00:11:42.980 | learned rounding on multi-dimensional vectors and and then for me i have the the mental intuition
00:11:51.140 | or it's not something super fancy or anything like that yeah i see you raised hand yeah i have a
00:11:58.900 | question so so the basic question i think it's also being as in the chat window is why do we have to go
00:12:04.820 | towards uh vector quantization because one of the reasons for actually doing a vae is is that you have a
00:12:13.140 | continuous distribution so that a subtle change in the input will lead to a subtle change in the output as
00:12:20.660 | opposed to any kind of dramatic changes whereas once you move into the quantization world you lose
00:12:26.580 | that subtlety the proportional change uh uh benefit that you actually get from vaes um i don't know if
00:12:38.020 | i can answer the the question fully uh uh i'm not the the author of the paper but what i will say is that if if we
00:12:48.740 | want to use a gpt style transformer we are going to have a fixed vocabulary um
00:12:55.620 | and and so we want to discretize this if you're familiar with like whisper or these other audio
00:13:01.780 | models what they do is they do uh residual vector quantization and so you vector quantize and there
00:13:08.660 | is this rounding error that you're losing information so then what you can do is you can take the the the
00:13:13.700 | error between the original and the um and your first quantization and then you can quantize
00:13:20.180 | that error okay yeah and then you can take the leftover from that you do that four times and now
00:13:26.020 | you've actually got a a much much more accurate approximation of your original than if you just did
00:13:31.300 | one round of quantization and so we'll see that they don't explicitly use residual vector quantization in the
00:13:38.020 | var paper but the process they use ends up imitating that as they go through the scales
00:13:44.900 | from low to high res they are doing residuals and so it ends up uh kind of being like rvq
00:13:51.780 | got it thank you and and one other question while i'm still here would you be able to make this
00:13:59.060 | presentation available to me or to us so yeah yeah i i have a pdf on github i can share the link great
00:14:08.100 | thanks um yeah so so again if you just think that this is just multi-dimensional rounding uh such that
00:14:18.260 | we get equal numbers in every bucket so that we're making good use of our buckets uh then what do we call
00:14:24.820 | the list of buckets so so right if you use equal size buckets then you just have a formula for
00:14:30.660 | calculating the buckets but if you if you are now learning them you you have to just write them
00:14:35.300 | down you have to store them somewhere and so by convention um the list of the buckets uh is
00:14:42.420 | um is called a code book okay and uh rather than super fancy buckets i think what they do is they just
00:14:50.740 | store the center of the bucket the centroid or whatever you want to call it uh and so then uh
00:14:56.260 | you're not actually like defining the upper and lower bound of the bucket where you're just saying
00:15:01.140 | here's the middle of the bucket here's the middle of that bucket and then uh when you want to quantize
00:15:05.860 | something what you do is you actually compare it to every bucket and you see which has the the lowest l2
00:15:13.060 | distance and then that's the bucket that you that you uh quantize it to that's the one that you round it to
00:15:19.540 | um so it is i don't know slightly expensive so if you have 4096 buckets it does it is like order 4096
00:15:28.580 | to quantize something you have to compare to every one of those and then you say yeah it was closest
00:15:33.700 | to bucket 17 so i'm now going to throw it in bucket 17.
00:15:43.220 | all right so now that we kind of have a little bit of information on on vaes and vqvaes um so now the
00:15:50.980 | question is how do we tokenize how do we fit images into an llm okay um and so the solution in the var paper
00:16:00.100 | is we're going to break the image into patches and each patch is each patch is going to go through
00:16:07.940 | uh our our vector quantizer our vector vqvae um but what we're going to do is if it's a low res image
00:16:16.980 | we're going to give it fewer patches and if it's a high res image we're going to give it more patches
00:16:23.940 | um specifically in the var paper they worked on image net 256 by 256 color images uh for image net
00:16:33.060 | and they said the high res in their case is 16 by 16 256 patches and the lowest res is one patch
00:16:42.660 | so low res is very very low res it would really just be like the average color
00:16:50.740 | of the entire image right just the mean so that's like extremely extremely low res and if you think
00:16:59.220 | about it the the 16 by 16 is actually not not that high but it's a small image it's only 256 by 256
00:17:05.700 | okay so uh they actually sped it up a little bit they didn't use every single possible size in between
00:17:14.020 | but so you have a one by one is your lowest res then a two by two and then a three by three and then a
00:17:20.180 | four by four and then at some point they got to like you know eight and they jumped to 10 and then they
00:17:24.420 | jumped to 12 and whatever but that's not that's not like the schedule they used is not particularly
00:17:29.540 | important if you just did it naively you would you would say i have 16 steps from one to 16. they had
00:17:35.860 | i believe 10 steps if you look at the code in that uh in that var paper all right so so basically that
00:17:45.780 | means that the lowest res image is going to get one token as it's embedding and the the token id is just
00:17:55.460 | the bucket number uh from our vector quantizer right so i was saying if if you if you compare it to all
00:18:02.740 | 4096 and 17 is the closest one then literally what you do is you just say 17 is my is my token number um and then
00:18:13.300 | gpt2 right it goes through the embedding layer that turns it into a dense specter it goes through however
00:18:19.220 | many was it 12 i don't remember if they even use gpt2 small medium whatever goes through 12 transformer
00:18:24.820 | layers and then out pops your next token prediction um you guys all know lms right so um so so basically uh
00:18:36.660 | the bottom point here is normally our llm is predicting the next token but if we're predicting
00:18:43.060 | the next higher resolution now that's multiple patches so when i start with just one patch i'm going
00:18:49.220 | to predict the next image which means i'm actually predicting four tokens at once and then when i have
00:18:56.580 | that one i have the first and the second one i have five tokens in my context i'm now going to predict the
00:19:02.340 | the next nine tokens all at once and then when i have 14 tokens in my context i'm going to predict the next
00:19:09.300 | 16 tokens all at once okay um and basically what you can do is is you if you just change the shape of your
00:19:18.260 | auto regressive mask instead of being purely diagonal you have it be kind of block uh diagonal then what you can do is you can say
00:19:24.340 | then what you can do is you can say when i have five tokens in my context and i'm predicting nine more so
00:19:31.140 | what is that six through 14 or something like that um i'm going to change my mask so that tokens seven
00:19:38.660 | through 14 can still only see the first five they cannot see token six okay so tokens tokens six through
00:19:48.180 | 14 all have equal amount of 14 all have equal amount of keys that they can attend to equal amount of
00:19:53.540 | information though the the fact that they are ordered in a particular order gives them no extra information
00:20:00.820 | because we are we are creating the mask um you know slightly block wise so that yes in fact the very last
00:20:10.100 | token the 14th token still can only attend to tokens one through five it cannot see any of the earlier tokens from
00:20:17.220 | its level
00:20:18.900 | okay um and then in practice what we do is uh um if you're familiar with like you know pre-fill versus
00:20:28.340 | decoding right in the pre-fill stage we we we are uh uh uh inferencing multiple positions at once
00:20:35.540 | here there's no reason not to do that as well you could do them one at a time but it would just be
00:20:41.220 | less efficient but that's just an implementation issue okay whether you inference the nine tokens
00:20:46.260 | one at a time or you do them all at once because of the attention mask it's it's exactly the same
00:20:51.220 | all right um so now the thing we're going to do is we're going to modify
00:21:02.660 | our vector quantizer a little bit uh and we're not going to just use a vanilla vqvae uh there's a
00:21:10.500 | special vqvae that was trained specifically for var um and uh ultimately they tried it both ways they
00:21:20.980 | said i will just take this image and when i want low res they just did linear interpolation okay there's
00:21:28.660 | nothing fancy here uh whether you're using like the the cv2 function or the pytorch built-in function
00:21:35.860 | it's literally just linear interpolation so i start with my image when i want a one by one
00:21:41.780 | you just basically say resize or linear interpolate down to one by one and it gives you you know the
00:21:48.980 | average and if you say uh you know two by two um it's just the the the simple thing so there's no
00:21:55.860 | there's no fancy learning going on here um uh so they tried it where they said i'm just going to do
00:22:03.220 | one by one two by two three by three so on and so forth and every one of those i'm going to run through
00:22:08.420 | run my patches through the vector quantizer it worked what they found worked a little bit better however
00:22:14.180 | is after you have the one by one you can project that back to your full size image 256 by 256 and of
00:22:25.620 | course you're it's a little bit more complicated but but you can imagine that you're then just going to
00:22:30.820 | get a solid color for the whole thing because you only have one data point so you so when you when you
00:22:36.340 | um upscale that you're just going to get this flat uniform thing uh what they do is they say for the two
00:22:44.340 | by two instead of also predicting the original image i'm going to subtract what the one by one predicted my
00:22:52.420 | full image was going to be i'm going to subtract so so if you have this mean color that you've
00:22:56.900 | predicted i'm going to subtract that from the image and that's what i'm actually going to predict for my
00:23:02.020 | that's what i'm going to use what i'm going to quantize for my two by two so this is where the part
00:23:07.140 | i mentioned earlier about it's a little bit like um residual vector quantization right so then after
00:23:14.020 | you have a two by two that's been vector quantized you then upscale that back to 256 subtract that from
00:23:23.460 | what you have and so you're successively doing this so when they did it in 10 steps what it means is
00:23:28.660 | that the the last step that's the 16 by 16 patches they're not predicting the original image they're
00:23:36.740 | predicting the leftover after all the previous nine uh quantizations have done their job okay and so
00:23:45.060 | you can imagine like now you're really getting into i'm just predicting fine uh details looking at
00:23:53.220 | the chat numbers going up and i apologize if if i should be answering questions or if you guys are just
00:23:58.500 | super chatty but i'm like 44 freaking messages uh exactly if it's important someone will interrupt but
00:24:06.340 | that's all okay uh coding which is always fun for images
00:24:11.940 | uh i don't actually remember what they did for positional encodings off the top of my head sorry
00:24:21.860 | i could go look at the paper real quick but yes there have to be positional codings because
00:24:28.580 | attention is position invariant
00:24:31.220 | and um i think you mentioned that uh in the last step they're predicting the deltas from the previous
00:24:38.900 | step right i wonder do you have an intuition of why um a vq vae would work better here as opposed
00:24:45.380 | to a residual vae because it seems that residual would work better with predicting residual but i i don't know
00:24:52.260 | um i'm not familiar with the what you're saying residual vae um yeah um i mean the idea here is that
00:25:02.660 | you you have the the the the vae but in this case we're quantizing it okay and so um
00:25:13.940 | so you have you have two forms of error uh so to speak you have the fact that you
00:25:21.620 | have downscaled this a lot okay and then you upscale it back and so then you've lost a lot of information
00:25:29.780 | there and then the other error you have is the fact that when you downscaled it you then quantized it so
00:25:35.620 | quantized it so you moved it a little bit and then you upscaled that sucker that you moved back
00:25:40.900 | yep so um so i don't know how to answer your question because i don't know that that other
00:25:46.580 | residual v but i can tell you that what so what they're doing is the combination of those two errors
00:25:51.460 | is what the next iteration is then trying to uh uh uh encode okay gotcha thank you and then so the two
00:26:01.300 | by two is going to downscale quantize have those both kinds of errors created uh when you upscale
00:26:07.460 | that back and then the three by three is going to just look at the leftover errors from both levels one
00:26:13.940 | and two and then the four by four is going to look at the leftover errors after those three and so on and
00:26:19.220 | so forth so um so this this vqvae one of the key things though is that the code book used
00:26:29.460 | when you do level one versus the last level the 16 by 16 it's using the exact same code book
00:26:37.700 | so uh i've seen different people ask you you potentially could have a more accurate quantizer if
00:26:46.740 | you use customized code books for the one by one level versus the 16 by 16 level but since we want to
00:26:53.540 | feed all of these into our same gpt model that's why we we're forced to use um same code book for for
00:27:01.300 | all the different levels and and so basically then uh um uh the the vqvae when it's learning its code book
00:27:13.540 | has to decide on uh uh codes that work well both at the high res and and at the low res and it has
00:27:23.060 | to come up with some sort of compromise because these these these codes are going to be shared
00:27:27.780 | everywhere in my head my intuition is that even though the the code book has to be shared and it may
00:27:34.660 | not be optimized for low and high res still what's happening is you're going to get these broad global
00:27:40.980 | features at lower resolutions you're going to get the sort of low frequency information um encoded
00:27:48.820 | and as you get to the last and that's where the high frequency features so the the details of the grass
00:27:55.140 | you know the the texture of the fur that's going to not happen until until probably quite late in the in
00:28:02.740 | the quantization process and so this this next slide i wasn't planning on spending a lot of time on it but
00:28:09.620 | what you can see just from this algorithm is that you input an image and it's going to loop in this case
00:28:16.580 | 10 times it doesn't matter if it's 10 or 16 times going from the one by one up to the 16 by 16 and and
00:28:23.060 | the key thing though is that they have this queue that they're sticking on but you get all of those those
00:28:28.420 | embeddings at the different layers at once this is not like a separate kind of a thing you you
00:28:34.020 | do the multi-resolution quantization in one fell swoop you get all the resolutions simultaneously out of
00:28:43.300 | after you you run your for loop okay um so it's a it's a package deal okay you so this is a dedicated multi-scale
00:28:53.060 | vqvae that gives you all your resolutions at once and then reconstruction uh also has to be an iterative
00:29:01.300 | process you cannot say let me do the reconstruction at the eight by eight level without the information
00:29:08.260 | about the other ones because this is a residual process you have to have all the earlier ones
00:29:13.780 | and then you sum all the predictions from all the earlier ones with your predictions in order to get your
00:29:20.100 | your final prediction it's just like in an llm you you couldn't say what is the output of just the eighth
00:29:26.980 | layer right it's a residual stream so you have to have layers one through seven in order to know what
00:29:32.500 | the output of that they are i didn't say that quite right what's the residual stream after the eighth
00:29:38.740 | layer you can say what the output of just the eighth layer is but you can't say what's the residual stream
00:29:43.540 | after the eighth layer unless you also have the first seven layers so it's the same thing here you can't
00:29:48.260 | say what's the eighth resolution unless you also have the first one all right so just to clarify what the
00:29:58.020 | training process looks like is the multi-scale vqvae is trained separately so you in this case it was
00:30:05.060 | image net so you give it a bunch of image net and it tries to um uh encode these images in a way that
00:30:13.780 | that when they're decoded the l2 loss is the look the lowest um and it's going to do this as best it
00:30:19.300 | can and then it's going to be frozen so the actual lm part did not involve any of this training at this
00:30:28.980 | point this this sucker is completely frozen then what we do is we say given a fixed vqvae i'm now going
00:30:36.660 | to train my actual lm on this idea of i'm going to give it one patch it's going to predict four i'm
00:30:44.340 | going to give it one plus four patches it's going to predict nine then it's going to predict 16 then
00:30:48.980 | it's going to predict 25 up until at the very end for the 16 by 16 it has all the context of all of the
00:30:55.620 | earlier resolutions and it's going to predict 256 tokens in one fell swoop um so in this case it
00:31:09.460 | was trained on image net and so if you're wondering for the very beginning of the process what is the
00:31:15.380 | prompt uh there's one token that encodes the class so there's a thousand classes in image net and so i
00:31:21.780 | don't remember i think it was embedded but basically so you start with a class token and it predicts
00:31:27.140 | the one by one and then you have the class plus one by one and it predicts two by two and then so on and
00:31:32.820 | so forth um and this this predict that this repeats until you get the the final the 256 tokens and then for
00:31:41.300 | your actual image you need to sum up um all of this and have the decoder turn that into uh pixel values
00:31:50.500 | because this is this is still a very special vqvae and you still need the decoder part here
00:31:57.940 | so all of this stuff that's happening with the multiple scales this is all happening in latent space
00:32:03.940 | this is not predicting pixels this is predicting latence okay so when we when we have tokens that are
00:32:10.260 | embedding from our our our vqvae vocabulary these are all vocabularies in the latent space
00:32:18.420 | um and that's that's uh pretty much it so then you can see uh some of the results here and um they had
00:32:27.540 | really good image quality um if you guys are familiar with a um fid distance inception score and then they also said
00:32:36.260 | said because we were able to do inference in chunks uh we actually had many fewer steps than if you did
00:32:47.380 | auto regression naively so if you did one patch at a time you would have 256 inference steps they had 10
00:32:55.460 | inference steps going from 1 to 16 by 16 in just 10 steps maybe those individual steps were a little bit more expensive each
00:33:03.220 | which maybe not not i haven't really done the math um but you can imagine that 10 inference steps is
00:33:09.220 | still going to be a lot cheaper than 256 inference steps
00:33:12.740 | and i think oh they also uh talked a little about scaling so um if you look at some of these other
00:33:21.620 | models uh the performance improves as they get bigger and then they sort of hit a wall
00:33:26.500 | um and in this case they said uh yeah ours didn't hit a wall they didn't really run this over a very
00:33:32.580 | large uh uh number of orders of magnitude so it's still tbd whether or not as you really get bigger i
00:33:40.660 | don't blame them for not having the resources they're not a google okay um but nevertheless the fact that it
00:33:48.020 | it didn't hit the wall is good but it's not really proved yet until you really get to probably like
00:33:55.700 | more like seven billion parameter scale or something like that uh but just if you just if you compare
00:34:01.620 | for example dit if you looked at these first three data points you'd like yeah this thing scales great
00:34:06.900 | and then you get the first fourth data point and it just fundamentally uh architecturally hits a wall
00:34:12.580 | okay so they're saying var hasn't hit a wall but we still don't know that it's not going to hit a wall
00:34:17.140 | on the next data point you know you really have to right so so the ultimate proof is in the pudding
00:34:22.020 | until you build try to build a gpt4 size thing or whatever you know you just don't um whether or not
00:34:29.300 | you're going to get a wall for images maybe it's not gp4 size but it's the same concept that um it's good
00:34:35.380 | to be asking that question but i just don't think that they quite did it over a large enough spread
00:34:42.420 | um some sample images and here uh if you look you can see um uh this is uh left is earlier in
00:34:52.900 | training right is later in training and this is scaling up the size of the lm that they used um inside
00:35:00.340 | var and so obviously the bigger lm does better and then i i don't know that early in training is really
00:35:06.740 | that important but you know it does learn better and so uh late in training it does it does better
00:35:12.260 | better quality images that's not surprising so to conclude uh the the key things that they're
00:35:19.380 | selling us on why they think this paper is good why they think this technology is good is that um it's
00:35:24.660 | fast uh it's much faster than auto regression in in uh more naive ways it generates very high quality
00:35:32.820 | images um and they say that it has the right inductive bias that going from low scale to high scale
00:35:38.740 | um there is a very clear 1d ordering nobody would argue that somehow medium res should come before low
00:35:45.860 | res everyone's going to agree on the inductive bias whether or not that inductive bias proves to
00:35:51.700 | encapsulate everything that we need you know i would not have necessarily guessed before gpt2 that next
00:35:58.740 | token prediction was sufficient to be able to generate really complex you know mathematical proofs or something
00:36:05.300 | like that right so i i don't have the ability to guess whether or not low res to high res is enough
00:36:11.220 | information for it to be able to um uh do whatever but because it has a very uh uh i don't know intuitive
00:36:21.460 | inductive bias then it does mean that you can do in painting you can do out painting you can you can
00:36:27.460 | fill in any direction um with you know a fixed auto regressive patch order you can't if you if you
00:36:34.660 | if you do raster scan and you give it just the bottom of the image and you tell it give me fill in the top
00:36:38.900 | of the image you can't do that it can only go from top to bottom so this one because of the way it works
00:36:43.620 | it can it can go in any direction in out mass whatever whatever um so byte dance did uh share the code here
00:36:52.580 | and um uh they uh also they have uh uh two follow-up papers that are um that i've seen maybe there's
00:37:04.740 | more uh they they had an infinity paper um they did video they've also done um this was just image net now
00:37:13.380 | they've done text to image like that that's a pretty obvious thing and there's a um xar instead of var paper
00:37:21.860 | now where um they generalized instead of just scale they say you can have arbitrary um uh uh precedence
00:37:33.700 | um and then they also added some stuff that i haven't quite worked my way through where they're
00:37:39.060 | um matching as a what seems like a a little carrot on top of the auto regressive uh learning to to make
00:37:47.620 | the image quality even better
00:37:49.220 | one other thing i have if you're reading the cheat sheet there's a lot of terminology and so i did
00:37:56.660 | sorry if you're reading the paper there's a lot of terminology and so i did actually make a cheat sheet
00:38:03.540 | um uh i'll share the link to the pdf but so basically um they have images and then they talk
00:38:10.500 | about their f um and you're doing this on individual patches at individual scales um and so then you have
00:38:18.500 | your tokens here um that are these cues which combine all the cues together make an image r or rather a
00:38:26.260 | latent r at a given um resolution and then when you're decoding then you go back and you get your f hats and
00:38:32.580 | eventually you get your chat so that's useful all right so i don't know if you guys want to ask me
00:38:43.540 | questions or if there's other discussion you guys want to have um i mean first off thank you i mean it's
00:38:52.260 | great uh that you volunteered in the last minute and i've charted this uh this entire session uh and also
00:39:01.060 | also thanks for the pdf file it looks uh the all the slides look amazing
00:39:04.660 | i i'm just thinking in terms of does it even make sense to think or rationalize var
00:39:12.740 | against how diffusion models work is that going to help me because some of these are rat hole
00:39:19.300 | things that i don't want to get into but do you think in your based on your experience does it help to understand or rationalize var
00:39:28.020 | var by comparing it against diffusion models because even in the diffusion models you have a sense of
00:39:33.700 | uh resolution increments that you see as as uh in practice although behind the scenes uh the the thing
00:39:44.740 | the thing that is driving is actually the uh differential equations
00:39:47.780 | um i think there's there's one thing that i think personally is a very important trend from var
00:40:02.820 | and that's the idea of predicting multiple tokens at once um
00:40:07.380 | the when when we do next token prediction we're sort of fixing the information content first step
00:40:21.300 | okay and you know for me if if you have words like the and whatever um there's not a lot of information
00:40:29.620 | content there but if you're processing code or you're processing like what is you know 12 times 15 or
00:40:37.460 | whatever and you're outputting digits or something like that there's like really high information
00:40:43.220 | density and and you cannot get it slightly wrong or you're just dead um and so whether it's scaling up
00:40:51.940 | or scaling down or whatever but this idea that we can have variable information content so i think the the
00:40:57.620 | the the byte latent paper from meta and and then and to a certain extent the large concept model uh for
00:41:04.900 | me one of the things that they all all these papers have in common is that they're addressing the idea
00:41:10.660 | that the information content may not be uniform token by token by token um and so i think that the idea that
00:41:18.500 | maybe if we're doing reasoning or doing other things that we can we can embed things with multiple tokens
00:41:27.060 | that allows us to to get more information content into one inference step uh and so i think there's
00:41:34.740 | a huge opportunity to either make transformers faster language models faster or have reasoning better if
00:41:41.940 | they have certain capabilities to dial up and down so i think the idea of predicting multiple tokens at once
00:41:47.860 | i think there will be a lot of successful research playing with that uh the xar paper already says that
00:41:56.820 | maybe scale is not the one and only way to do it um uh but they they are looking at this idea and so
00:42:06.900 | diffusion has like really strong mathematical foundations um but we know that transformers are just
00:42:17.140 | really full so uh so there's no reason why like the transformer can't learn the score function that
00:42:26.500 | diffusion models are learning the real question which i i don't know is just kind of who can do it in the
00:42:34.660 | fewest steps uh but it seems like if if if a diffusion model uh using a unit or replacing the unit with a
00:42:43.700 | transformer can learn the score function then if you gave it something similar a progression of images then
00:42:50.020 | you should be able to learn um that same score function using attention um so yeah so i think
00:42:56.500 | that this is something to pay attention to but not necessarily that it has to be literally just
00:43:01.540 | the scale technique that these guys use um but um but yeah i i would pay attention to this idea of
00:43:09.220 | of um that multiple people are approaching of of changing around our tokenization whether it's bite
00:43:18.500 | latent lcm this or something like that so there's a answer sorry uh just to answer the question um more
00:43:27.460 | directly though at least my understanding uh whether there's something to take away from this uh for
00:43:32.660 | whether understanding diffusion helps um i think to me this is an example of uh an earlier paper both both
00:43:43.380 | diffusion and this approach of scale prediction is an example of um an earlier paper i think it's by
00:43:50.180 | such cover where the idea is you actually don't need to solve an ode or something if you can train a
00:43:57.140 | model to denoise for an arbitrary um uh corruption um either adding gaussian noise or maybe in this case you
00:44:06.900 | can interpret the next scale like down sampling prediction as uh uh a corruption where you you actually
00:44:13.780 | think of it backwards you have a high resolution image and you've corrupted it by sampling it down
00:44:18.260 | and you're trying to predict backwards um like all of these are viable ways to think about the the image
00:44:25.300 | generation process so in this case it's a combination of using an llm for attention uh
00:44:31.460 | for being able to predict the next scale rather than raster scan plus the fact that you can you
00:44:37.380 | need some sort of tokenization to actually use the llm in the first place other than that it's kind of
00:44:42.660 | similar and it seems to work really well yeah thanks
00:44:46.980 | regarding that point of multi-token prediction meta put out a paper last year
00:44:57.140 | about training llms with multiple tokens they did it during training and not inference and then they
00:45:03.460 | mentioned in llama 3 that they use this technique in ted basically all your points there for sample
00:45:08.580 | efficiency or like spot on they they go into that um in terms of at inference time this is what um this
00:45:16.500 | is kind of what motivated a lot of the work behind speculative decoding so speculative decoding is kind
00:45:21.540 | of multi-token prediction with a small model and that idea came out of research that you know they
00:45:27.620 | started off with let's just predict multiple tokens and see how that goes and that led down to the path of
00:45:32.820 | some speculative decoding nice
00:45:37.540 | i have a question on the on the training time of this one so this is this approach seems clearly
00:45:46.180 | very fast at inference time because it can predict so many tokens in parallel um but did they put any
00:45:53.060 | mention in and how long it takes to train something also in comparison for example to diffusion models
00:46:00.580 | i that's a great question i don't remember them discussing that in the paper
00:46:04.820 | or maybe i just didn't pay close attention these are actually a lot bigger models than diffusion models
00:46:12.660 | and their inference is a lot slower it's like you trial try 4-0 in mid-generation it takes on the order
00:46:18.820 | of you know 15-20 seconds and it's significantly larger um the big thing with autoregressive generation
00:46:24.420 | is you kind of scale up a lot more than you do with diffusion right we have really small local
00:46:29.460 | diffusion models like uh your iphone locally can do genmoji diffusion but autoregressive models just
00:46:36.020 | at you know base layer they're larger than diffusion models now the inference tricks of like um you know
00:46:43.380 | lcm laura where you can skip a bunch of steps for diffusion those haven't been super applied to image
00:46:48.740 | generation right so uh var is for autoregressive image generation don't have speculative decoding yet
00:46:55.380 | or maybe they do but you know there's a lot of inference optimization that hasn't hit yet so
00:46:59.620 | there's a room to grow but base level we expect diffusion models to stay small uh autoregressive
00:47:07.940 | models kind of get big as they generalize and yeah there's there's a lot of inference optimization that
00:47:13.860 | hasn't hit yet yeah i think that the optimization is a key point um that that fibo shared right so
00:47:20.980 | you compare this to like the very first diffusion models they were doing i don't remember either
00:47:24.500 | hundreds or thousands of steps um and we've dramatically uh improved that and so um if this is a useful
00:47:33.060 | technique then you would expect that there's going to be a bunch of optimizations improvements upon it
00:47:38.500 | this is just day one that's a minor point i guess uh you mentioned 4o image generation being slow i'm
00:47:46.980 | actually not sure whether uh it's so they're clearly doing var but i'm not sure they're doing the same
00:47:52.740 | var because you can imagine that if you want to do text to image or sort of this kind of multimodal
00:47:59.620 | thing you want to actually um fine tune your text model to actually intake the same tokens and alert
00:48:07.780 | like teach it to do the task all with one model rather than uh doing it with separate models um so
00:48:13.540 | it could be that that is the reason like if you just had a dedicated var model you could you can imagine
00:48:18.740 | it being pretty good anyway a question about this actually and i think we had a short thread in discord
00:48:26.500 | about this it doesn't really seem practical to like completely jointly train one model like you know
00:48:35.140 | if you have the sort of 4o architecture or something like that it probably has to be just a post training
00:48:40.660 | thing right and then yeah so you need some way to integrate that with like the separate
00:48:46.580 | like var architecture that you've already trained and i think rj and in discord i'm not sure if he's here
00:48:54.500 | had a plausible explanation that like you have some reserve token that you use to like delineate the
00:49:02.180 | image tokens from the text tokens and then the only thing you need to do in like post training is like
00:49:07.780 | the chat model needs to learn to like say okay now this is um an image or something and then it gives
00:49:14.980 | you text tokens and you take the text tokens and you have a text conditional var and that generates
00:49:19.140 | image or something like that which kind of pencils out to me but i am i'm very curious practically how
00:49:24.580 | they do something like that but i don't know i think you could do it with tool use like you're describing
00:49:29.940 | but um people fine tune on tasks all the time when they for example you want to add like a tabular or
00:49:36.900 | like a numeric modality or like an image modality to a language model you do your pre-training on a big like
00:49:43.860 | like text corpus and then you have a much smaller like modality corpus that you sort of uh train your
00:49:50.900 | encoder to to do uh the quantization so that you can actually have a image token but then you sort of
00:49:59.700 | fine-tune your model to understand image tokens and text tokens all in the same sequence but then then you
00:50:06.740 | don't actually need to do uh co-train everything it is like a post-training thing um and you hope that
00:50:14.180 | you don't lose the text capability when you gain the image capability um but i think apple has a bunch
00:50:21.940 | of uh papers on basically anything to anything prediction um like i think it's called like 4m or
00:50:28.740 | something um uh i don't know how well it works but meta meta meta has a um the segment anything then
00:50:37.620 | there's another one that's six modalities in one audio video depth meta also has chameleon which is
00:50:44.580 | native training so separate than like lava they have a native image language model i guess apple might too
00:50:52.980 | but i if there's i'm thinking yeah i'm thinking of the forum uh series of papers which i think is apple
00:50:59.620 | but again i don't know how well it works um anyway i don't know how anything works in practice at these
00:51:05.860 | scales uh but yeah that's my guess regardless you need some sort of like tool use thing to like say okay
00:51:13.460 | now this is an image right and then you generate the image tokens and then you like end the image and
00:51:19.300 | keep generating text or something like that or whatever is that right but you may need that just
00:51:24.740 | for the ui essentially the the lm model may not care um and whether you use tool use to sort of
00:51:32.100 | dispatch to a new model that's clearly how it used to work because you could basically trick one model into
00:51:37.540 | the uh to uh generating some description that then you know would fail downstream but if it's all one
00:51:43.620 | model uh it's all in the same space and that allows you to do uh image understanding whereas previously
00:51:49.780 | like these models didn't understand what was in the image that they generated themselves so that's
00:51:55.220 | maybe a test a way to check whether like if you ask to generate an image of like a tech bro with uh
00:52:00.580 | and say no glasses and it generates glasses you can ask it whether it contained glasses
00:52:07.060 | yeah yes no no that's an issue and i'm sure it does know that yeah and also just just moving this
00:52:14.340 | into a different track of conversation in the case of diffusion models lots of things have evolved like
00:52:20.740 | there are control nets and lauras and image prompts and whatnot i wonder what would happen in this world
00:52:27.540 | in this universe to get that kind of control because i work in the movie industry and and one of the things
00:52:33.620 | that we uh actually care a lot about is consistency like character consistencies and so on and control
00:52:41.540 | nets are the current way of dealing with all of that so so do you know if there is any literature
00:52:47.700 | around var's along those lines i don't think there's any literature yet but but but like uh open ai uh
00:52:57.940 | they let you refine the image and there's there's clear uh character consistency when you do multi-step
00:53:05.780 | refinement of your image so it seems like they've definitely had that in mind yeah they actually call
00:53:12.740 | that out in the in the blog post they have about the 4-0 image capability and they the i think what i recall
00:53:21.460 | is that they claim that uh that's because of this auto aggressive nature of it right so you have the
00:53:27.940 | history so that the attention can pay attention to the previous image so you kind of get it for free
00:53:32.900 | that was my interpretation of what they were saying yeah then this suggests that it is one model and
00:53:39.460 | that's cool or i guess it could be one uh change model yeah no i i guess so my point in the discord was
00:53:47.220 | actually um that i i just it's unclear to me whether there are uh reserve tokens that are that it's
00:53:54.100 | trained on in post training or if it's if they're just reusing the text tokens but with some sort of
00:53:59.300 | special token to say okay now i'm generating an image and then now i'm not generating an image but
00:54:04.580 | like i don't think it matters too much but you might like lose less of the text capability if you use
00:54:10.100 | reserve tokens yeah if i had to bet that i i would bet on um them using reserve tokens um because you
00:54:18.020 | don't need that many it's a pretty small vocabulary on your code book um and exactly your point uh you
00:54:24.260 | don't want any any forgetting yeah i don't well normally you don't um reuse tokens well i i was not aware
00:54:34.100 | that people do uh reuse tokens ever uh you just uh you have your dictionary for for text tokens and then
00:54:41.780 | you have your implicit dictionary through your uh encoder that you use your vqvae that is now your image
00:54:48.980 | dictionary and they're like they don't overlap um so yeah you just use those and your model learns to
00:54:58.020 | sort of uh naturally use whatever it needs yeah i think there were some early multimodal llms where
00:55:06.900 | you you you had like a beginning you know image token and then you reuse the same space uh but yeah i think
00:55:15.780 | we're all saying the same thing that that's probably not the the conventional wisdom now it's just use
00:55:22.980 | separate tokens and then and then and then you can have uh whatever hypothetically just an orthogonal
00:55:29.940 | corner of your embedding space yeah but those tokens i so i totally agree those tokens are
00:55:34.820 | delimination uh delimiting tokens for example uh when you train a lm model for fill in the middle tasks
00:55:42.180 | even though it can only generate one way right you use like a special you add new tokens that say like
00:55:48.260 | now you're doing the fill in the middle task and here's the beginning and here's the end and now
00:55:52.740 | start generating the the whole essentially uh you create new tokens for those delimiters
00:55:58.020 | but the vocabulary that it it like learns to use the right vocabulary
00:56:10.420 | well thank you so much ted for yeah again we were just gonna casual discuss slides are always appreciated
00:56:16.980 | and uh good to see you here for next week anyone have a topic they want to discuss they want to volunteer
00:56:24.500 | they want someone else to volunteer they want to do half a paper
00:56:27.220 | i want someone to volunteer the um the stuff that just came out of anthropic on
00:56:34.900 | uh okay i might do that since i've done the other three so unless someone else wants to do it i'll do
00:56:40.740 | the fourth one awesome okay okay i i might get someone from anthropic to join us we'll see uh otherwise
00:56:48.100 | next week is anthropic if anyone hasn't seen my favorite mean uh matthew berman he said you know
00:56:54.900 | this is the most surprise he's ever been but he was just as surprised in the thumbnail as always
00:57:00.660 | it's very sad very sad we have a running discord of all his thumbnails
00:57:03.860 | but okay i'll share i'm gonna make the thumbnail for this youtube ah okay good all right we're uh
00:57:15.700 | we got the thumbnail cool thanks guys see you next week and thank you ted thank you ted thank you