back to index

Stanford CS25: V1 I DeepMind's Perceiver and Perceiver IO: new data family architecture


Chapters

0:0
3:1 Improving Transformers
6:5 Why non-locality?
8:40 Scalability vs. generality?
10:42 Cross-attention: attention with linear scaling
20:30 In contrast: VII
26:23 ImageNet classification
30:41 Featurizing multimodality
48:3 What is optical flow?
51:16 Real-world qualitative results

Whisper Transcript | Transcript Only Page

00:00:00.000 | So today, I'm going to be talking about some recent work that we've been doing at DeepMind,
00:00:09.600 | developing this line of architectures that we're calling perceivers, and I'll be motivating
00:00:14.880 | this in terms of a goal that we have, which is to develop general purpose architectures.
00:00:20.280 | And so just right off the bat, I want to motivate why we care about general purpose architectures.
00:00:25.680 | And so both of the reasons are fairly pragmatic.
00:00:30.200 | But basically, the idea is, if we're thinking about all of the data that we could possibly
00:00:34.700 | imagine collecting in the world, a lot of it basically involves what we think of as
00:00:39.800 | sort of traditional sense modalities, and these things range from touch and proprioception
00:00:45.160 | to echolocation to the kind of perception you need to ingest text, however you want
00:00:50.320 | to format that, to more exotic things like event-based cameras, whisker, touching with
00:00:56.680 | whistler senses, things like smell and depth, and all the way up to the kinds of sense modalities
00:01:02.680 | that we really think about when we're thinking about scientific perception.
00:01:06.840 | And so basically, if we think about the full set of data and what it would take to actually
00:01:11.560 | model each of these different modalities, it basically looks effectively intractable
00:01:17.000 | to try to engineer inductive biases that will work for every single one of these.
00:01:21.120 | So we don't want to engineer them one by one.
00:01:24.320 | This is an approach that's worked, and in some ways, it's maybe a reasonable description
00:01:27.600 | of how we think about developing new architectures for different problems, but it's just not
00:01:32.480 | going to scale.
00:01:33.480 | We can't afford, as a community, to hand-design inductive biases that will work for each and
00:01:38.440 | every one of these.
00:01:39.440 | And so rather than doing that, we want to sort of build architectures that, at least
00:01:42.440 | at first pass, will allow us to handle everything.
00:01:46.080 | There's another practical argument for why we should work on general-purpose architectures,
00:01:50.600 | and that's because it will allow us to build simpler, more unified systems.
00:01:55.380 | So if you look at how, in particular, complex multimodal data streams are typically approached
00:02:01.920 | in the sensory, computer vision, or pattern recognition literatures, effectively, the
00:02:09.080 | typical way this is done is by using inductive biases that we know hold for the individual
00:02:13.440 | modalities and then engineer ways of combining those different subsystems.
00:02:18.040 | So this can mean building specific heads, specific input modules for each of these things,
00:02:24.480 | and then trying out the various different ways of combining them.
00:02:27.520 | So this can work, but it gives us systems that, in principle, really will only work
00:02:31.760 | on one or a small number of domains, and it gives us systems that are very hard to maintain,
00:02:36.200 | tend to be fragile, tend to depend on specific processing assumptions about the input modalities.
00:02:41.680 | So rather than do that, we sort of want to move in the direction of having unified blackbox
00:02:46.680 | architectures that kind of just work.
00:02:49.120 | And the idea here is that if we can get to that point, we can abstract the architecture
00:02:53.160 | construction process and really focus on other, more high-level problems.
00:02:57.080 | So this is sort of the motivation for this line of work.
00:03:01.400 | And the way that we're going to be doing this is, of course, by working on the most general-purpose
00:03:06.800 | architecture that we have so far, which is basically a transformer, and you'll all be
00:03:10.840 | very familiar with the basic building blocks of a transformer.
00:03:14.560 | But just at a very high level, we can think about what they do right, which is they use
00:03:18.520 | a general-purpose inductive bias.
00:03:21.160 | So they're non-local, which means they're not making domain-specific assumptions about
00:03:24.680 | which points should be compared to each other.
00:03:26.680 | Rather, they tend to be global in terms of the attentional focus that they have.
00:03:31.960 | They use position as a feature rather than a hard constraint of the architecture, and
00:03:36.360 | this is in contrast to MLP-based architectures or ConvNets in the way that they typically
00:03:42.840 | work, which use position as an architectural component to constrain how compute is happening.
00:03:50.380 | And then, of course, finally, there's extensive weight sharing in the way that they're designed.
00:03:56.880 | And because they focus on matmuls, they tend to be TPU and GPU-friendly.
00:03:59.920 | So these are all very nice things about the way transformers work.
00:04:03.800 | Of course, on the other hand, they have very poor compute memory scaling.
00:04:07.520 | And there are two components to this.
00:04:09.200 | So attention itself scales quadratically.
00:04:11.520 | So there's this big O of M squared L complexity at the heart of transformers.
00:04:18.360 | And I like writing it this way because it really emphasizes that this is a property
00:04:22.240 | of-- that basically, as you make your models bigger, either at the input size or as you
00:04:26.840 | make them deeper, this problem is just going to get worse.
00:04:31.080 | And because you have this scaling in depth as well, there's another practical thing that
00:04:35.920 | happens here.
00:04:37.160 | Because the amount of compute that we're doing is proportional to the input size, so there's
00:04:41.360 | no bottleneck in the way that standard transformers work, even the linear scaling becomes a problem.
00:04:48.600 | And so in practice, for very, very large transformers, this can often be the bottleneck that really
00:04:53.800 | matters.
00:04:54.800 | But they're both at play here.
00:04:55.800 | And so we really want to tamp down both of these.
00:04:58.160 | And so the perspective here is that to have really general-purpose architectures, we can't
00:05:03.080 | have ones that are just in principle general.
00:05:04.960 | We have to have ones that you can actually use on the scales and the kinds of data that
00:05:09.640 | we care about.
00:05:13.440 | And so just to-- this will all be old hat for all of you, but just the way that standard
00:05:18.080 | QKV attention works is basically like this.
00:05:21.360 | So it's all matrix multiplication.
00:05:23.320 | So we have some input.
00:05:24.520 | We compute the query keys and values by having a 1D convolution, a one-by-one convolution
00:05:29.240 | that we run over the input.
00:05:31.280 | We then compute the attention scores.
00:05:34.840 | This is a matrix multiply that has the following-- these sorts of shapes.
00:05:38.920 | We then use the output here to compute the weights, to compute the actual output of the
00:05:45.200 | attention module itself.
00:05:46.880 | And then finally, we run this through an additional MLP, which is applied convolutionally, to
00:05:51.360 | get the outputs.
00:05:53.160 | So this is the starting point of what we're working on here.
00:05:58.040 | And let me just briefly just reiterate why we would want the advantages that we have
00:06:03.040 | with these standard transformers.
00:06:05.180 | So non-locality is one of the two inductive bias principles that we have here.
00:06:09.800 | It's useful, I think, to contrast this to the effective locality that you get in ConvNets
00:06:14.400 | and what this actually means.
00:06:16.240 | So if we look at, basically, as a function of depth, which inputs can see which other
00:06:21.600 | functions, which means how easily it is to express a function of two input points, let's
00:06:26.760 | say we look at this yellow and purple point here at the input.
00:06:29.800 | Now, I've set them as far apart as possible.
00:06:32.800 | But we might ask, how deep would the effective computation have to be before you actually
00:06:38.000 | process these two?
00:06:39.000 | And if you look at a three-by-three convolution, you're having to look, basically, until the
00:06:45.200 | very end of the network, until you're processing these things together.
00:06:50.440 | And what this means is that the functions that you can express that actually look at
00:06:55.640 | both of these points end up being quite shallow, because they have to be built on top of this
00:06:59.400 | very, very deep stack that just gives you the locality.
00:07:02.900 | And so in point of fact, if you look at, for example, the way ResNets work, so you have
00:07:07.480 | an initial block, which has a seven-by-seven convolution, and then afterwards, it's three-by-three
00:07:11.120 | cons all the way up, you need 28 three-by-three cons with that standard processing stack before
00:07:16.700 | all of the 224 by 224 pixels in an image are looking at each other.
00:07:21.260 | And what this means is that in a ResNet-50, the points on the very edge of the pixels
00:07:25.840 | actually never see each other.
00:07:27.560 | And I found this a little bit counterintuitive, but it suggests that we really are constraining
00:07:33.520 | quite a lot the functions that are easy to express with these models.
00:07:36.640 | And so there are some functions of images you just can't capture with a ResNet-50.
00:07:42.680 | On the other hand, if you look at an architecture that has global attention over the full input,
00:07:48.180 | so a transformer, if you could scale it that way, or a perceiver, as we're going to be
00:07:52.080 | talking about, all of the pixels can interact.
00:07:55.480 | So the model can basically capture these things and express these functions much more easily
00:07:59.800 | than can be expressed in things that put locality first.
00:08:05.520 | We also-- the other interesting property of these sorts of architectures is that position
00:08:10.440 | is featurized.
00:08:12.120 | And this basically means that we're no longer sort of encoding the architectural location
00:08:16.680 | of something to figure out where it's located with respect to the other ones.
00:08:21.760 | And this allows the network to basically use any positional information that it wants but
00:08:27.260 | can also discard it as it prefers.
00:08:30.560 | And so this is the standard way it's done, of course, in the context of architectures
00:08:34.280 | that use Fourier or sinusoidal-like features, but there's a lot of flexibility here.
00:08:39.400 | OK, so now just thinking in terms of how ConvNets relate to transformers, sort of at the opposite
00:08:45.320 | end, it may look like that we have a sort of scalability versus generality trade-off.
00:08:52.000 | And so if we look at ConvNets, the way that they're applied-- so typically, we can think
00:08:56.880 | about using them on grid-structured data.
00:08:58.920 | There are, of course, generalizations of convolutions that work on data sets with more interesting
00:09:04.720 | topology, but typically, we can think of them as operating on grids in some sort of space,
00:09:10.920 | whereas transformers apply to generic sets.
00:09:13.080 | So transformers are more general from this point of view.
00:09:16.080 | On the other hand, they scale much, much worse.
00:09:18.680 | So ConvNets are linear, both in the input points, the filter size, and the number of
00:09:23.840 | layers of that architecture, whereas transformers have this quadratic scaling, and they're still
00:09:28.940 | linear in the depth.
00:09:31.320 | So from this point of view, what we're interested in doing in the perceiver line of work was
00:09:35.760 | to scale transformers, but to keep the generality property.
00:09:39.040 | So we want something that lives in between these two extremes.
00:09:43.120 | And the way that we do this is by looking at self-attention and sort of modifying it
00:09:48.480 | in a way that allows us to scale better.
00:09:51.880 | So to walk through what self-attention actually does in sort of standard transformers, we
00:09:56.760 | take our input array, which here is written as the indices, which is the number of tokens
00:10:02.000 | or the number of pixels, basically the number of input units, depending on what you're looking
00:10:05.840 | at, and the channels.
00:10:08.320 | We have a 1D convolution.
00:10:09.680 | So this is big O of M with respect to the Q, K, and V. We then compute the attention
00:10:14.880 | maps using the output of this operation.
00:10:17.380 | This gives us a matrix multiply, which is the source of the quadratic scaling.
00:10:21.920 | And then finally, we compute output features with another matrix multiply.
00:10:27.320 | This is already-- we're already rate-limited here, because for even standard resolution
00:10:32.880 | images, M is quite large.
00:10:34.080 | So it's around 50,000 for standard ImageNet images, which, again, are very small.
00:10:38.120 | So this is something that just isn't going to work if we want deep architectures.
00:10:43.160 | So what we do is we replace-- at the input to the architecture, we replace the self-attention
00:10:47.920 | with a cross-attention layer.
00:10:49.960 | And we do this using, basically, a learned query.
00:10:53.720 | And so we're replacing only the query from the input here with a learned component.
00:10:59.200 | And so these indices and channels, you can just think of these as basically working like
00:11:03.000 | a learned initial state for an RNN.
00:11:05.300 | There's a variety of names that this idea goes under in the literature.
00:11:09.480 | We refer to them as sort of as latents.
00:11:13.240 | But they're sometimes called inducing points or other things.
00:11:17.520 | So the basic idea is we're learning the input to the query and keeping the key value of
00:11:20.800 | it the same.
00:11:22.680 | The downside-- or the sort of upside of this is that when we compute the attention map
00:11:27.720 | after this, now we basically turn this from a square matrix to a rectangular matrix and
00:11:35.160 | reduces the complexity of the matrix multiply to big O of Mn.
00:11:38.120 | So now it's linear in the input size.
00:11:40.740 | And the second matrix multiply has the exact same property.
00:11:44.220 | So it becomes-- from quadratic, it becomes linear.
00:11:47.840 | And the quite cool thing about this is that, OK, so the cross-attention is linear in complexity.
00:11:52.980 | But the output is actually smaller.
00:11:54.720 | And so this, I think, is actually the more important point here is that this allows us
00:11:58.160 | to map something which is quite large into something that has size that's independent
00:12:02.140 | of the input.
00:12:03.140 | So we have full control over this as a hyperparameter.
00:12:06.000 | And this allows us to build deep networks on top of this latent.
00:12:10.320 | So because this is of a small size that we can control, we can afford to have quadratic
00:12:13.980 | complexity on top of this.
00:12:15.860 | And so we use this idea-- yep.
00:12:16.860 | Go ahead.
00:12:17.860 | Oh, sorry.
00:12:18.860 | I'm still a little bit confused as to how you guys are able to turn this square into
00:12:24.280 | a rectangle in the second step.
00:12:25.800 | Is it because you replaced the query with a learned something that is significantly smaller
00:12:31.020 | compared to the input size in the first step?
00:12:33.560 | Yeah, that's exactly right.
00:12:35.260 | So if you look at the-- so the underlying matrix multiply here, which is written as
00:12:40.640 | the QK transpose, so this will basically-- so the outer dimension here has shape n, which
00:12:47.240 | is determined by the query.
00:12:48.860 | And so by shrinking that query, we're just changing the output of the matrix multiply.
00:12:53.440 | Thank you.
00:12:54.440 | Yeah.
00:12:55.440 | So I guess--
00:12:56.440 | Sorry.
00:12:57.440 | Go ahead, please.
00:12:59.440 | Cool.
00:13:00.440 | So basically, you only do that for the query, right?
00:13:03.280 | So key and value remain like the original size matrices, correct?
00:13:08.740 | That's right.
00:13:10.740 | But so basically-- so I don't know what I'm not understanding, basically.
00:13:17.200 | So the problem for me is that for a query, now in my head, I'm looking for-- let's say
00:13:23.520 | I have the if token.
00:13:25.380 | Now there is no if query anymore.
00:13:29.000 | Doesn't that cause a problem when I'm trying to use it and to compute scores?
00:13:35.520 | Yeah.
00:13:36.820 | So what's happening here is you'll have a smaller subset of queries.
00:13:40.660 | So if you think about this not in terms of the matrix multiplies, but in terms of comparing
00:13:44.260 | each query to each key.
00:13:46.540 | So in normal self-attention, we have one query for each key, so every point compares to every
00:13:51.240 | other point, right?
00:13:52.240 | So here, what we've done is instead of comparing every point to every other point, we have
00:13:56.000 | a set of sort of cluster centers you might be able to think about them as.
00:13:59.620 | So it's a smaller number, and we compare each of those to each of the input points.
00:14:03.940 | But we don't know which tokens technically belong to which clusters, right?
00:14:11.140 | That's right.
00:14:12.140 | So it has to be learned.
00:14:13.140 | [INAUDIBLE]
00:14:14.140 | Yeah, exactly.
00:14:15.660 | So one way to think about this, about what's happening here, is that instead of-- so in
00:14:22.300 | a normal self-attention transformer, by comparing all to all, we're sort of saying, OK, I know
00:14:28.260 | what the feature is at this point, and I want it to attend to similar features.
00:14:32.940 | Here what we're saying is we're learning a bunch of supplementary points that should
00:14:38.380 | be sort of maximally similar to some subset of the inputs.
00:14:41.540 | So correct me if I'm wrong, but this is essentially doing some sort of hard attention, where you're
00:14:46.780 | saying instead of querying over all the points, let's select some points which we think are
00:14:52.060 | very similar, and only put self-attention over this hard point, like these points you
00:14:57.420 | have selected.
00:14:58.420 | Right?
00:14:59.420 | Yeah, so they're related.
00:15:01.660 | That would be one way to think about it.
00:15:03.980 | The slight modifier to that idea, though, is that they basically live in an abstract
00:15:08.420 | space.
00:15:09.420 | So they're not assigned sort of one-to-one to one of the input queries, or to one of
00:15:13.780 | the input points.
00:15:14.780 | They're sort of learned, so they can be somewhere in the middle.
00:15:17.860 | But I think that's a good way to think about it.
00:15:19.340 | That's a good intuition.
00:15:21.500 | But I guess one of the places where I'm a little confused here is you have here indices
00:15:26.900 | and indices for the two, like the purple and green matrices in the far left.
00:15:31.440 | But those indices are not necessarily corresponding to inputs.
00:15:34.460 | Like in the NLP space, those would not necessarily be tokens, right?
00:15:37.380 | These are just sort of--
00:15:38.580 | Exactly.
00:15:39.580 | --indices.
00:15:40.580 | But the--
00:15:41.580 | That's right.
00:15:42.580 | --index in this case is the result of some kind of mapping from the input tokens to an
00:15:46.180 | n-by-d matrix.
00:15:47.180 | Is that right?
00:15:48.180 | No, it's actually-- so it basically acts like-- it's a learned set of weights, is one way
00:15:52.420 | to think about it.
00:15:53.520 | So they function exactly the same way that learned position encodings do.
00:15:56.620 | So it's basically just a-- it's a learned embedding.
00:15:59.660 | But it's not conditioned on anything.
00:16:01.620 | It's just sort of-- it just is-- it's just a set of weights.
00:16:06.780 | Oh, OK.
00:16:07.780 | That makes more sense.
00:16:08.780 | Thank you.
00:16:09.780 | Mm-hmm.
00:16:11.780 | So if there are no more questions, I'm going to keep going.
00:16:16.980 | But of course, feel free to interrupt me.
00:16:20.220 | So the way that-- given this idea-- so we have this learned latent array, which, again,
00:16:26.220 | it functions sort of like an RNN initial state, or it's a set of weights.
00:16:30.260 | We basically randomly initialize that.
00:16:32.820 | And then we use this to attend onto the input byte array.
00:16:36.700 | And so the byte array here is the flattened set of pixels, for example, for ImageNet.
00:16:41.340 | And the output of this is going to live in the same space as-- so the same index space
00:16:45.980 | as the latent array does.
00:16:47.980 | And there's residual connections in the way that you would normally do in an attention
00:16:53.140 | layer as well.
00:16:56.040 | So once we're in the space, we can then build an architecture by taking-- by using a standard
00:17:01.500 | transformer but phrased in the latent space rather than in the input space.
00:17:06.300 | And this is going to allow us to basically end up-- because we've sort of distilled the
00:17:10.520 | input down to the smaller space, we can still flexibly allow all of these points to interact.
00:17:15.300 | So this should be still as nearly as expressive as the transformer-- as a normal transformer
00:17:21.220 | And then each of the modules here now is quadratic in the latent size rather than the input size.
00:17:25.420 | So this is something that we can control quite a lot.
00:17:29.580 | So in the original version of the perceiver, we found it was very helpful to have additional
00:17:36.020 | cross-attends.
00:17:37.020 | So this is certainly something that you can do.
00:17:39.740 | And the reason-- the intuition behind this is that if this bottleneck is quite severe,
00:17:46.360 | we can't maintain all of the information from the input.
00:17:48.860 | And so we want these queries, which are now sort of conditioned on the past, to be able
00:17:52.740 | to look back at the input point.
00:17:55.000 | And so this is something that we found to be quite helpful when tuning for the first
00:17:59.980 | paper.
00:18:00.980 | But the caveat, I will say, is that we're no longer recommending this as best practice
00:18:04.620 | because these cross-attentions end up being quite heavy.
00:18:07.420 | But this is something that you can explore, certainly, if you want sort of more conditional
00:18:10.500 | queries or if you want to be able to cross-attend to new inputs that are coming in.
00:18:16.020 | The other thing that we found quite helpful in the context of data sets that have a limited
00:18:20.900 | amount of data, which for these architectures includes ImageNet, is to allow weight sharing
00:18:26.020 | in depth.
00:18:27.220 | And so this basically just amounts to tying the weights for the different cross-attention
00:18:31.160 | and different self-attention layers as they're repeated.
00:18:33.940 | So this ends up looking like an RNN that's unrolled in depth.
00:18:38.860 | So this is just at a high level.
00:18:42.140 | This gives us an architecture that we can apply to images but doesn't make any assumptions
00:18:45.580 | about image structure.
00:18:46.820 | So it's one that you can use elsewhere.
00:18:49.300 | And we give information about the input spatial structure by having positional encodings.
00:18:58.180 | And here we use a 2D Fourier feature position encoding.
00:19:01.300 | And just to show you what that looks like here, to give you a sense.
00:19:04.660 | So each of the input points is assigned basically-- so you'll be in some position here.
00:19:11.340 | And we have sinusoidal and cosinusoidal features in 2D.
00:19:14.560 | So this is basically a Fourier decomposition of the position of the 2D input.
00:19:19.880 | And a couple of things that we found were that if we sampled the frequency, that's the
00:19:24.540 | maximum frequency that's used, up to the Nyquist frequency of the signal, we end up doing better
00:19:30.640 | than if you use a lower version of this.
00:19:32.320 | And this basically is because this will allow every other point to be aware of every distinct
00:19:37.960 | point in the image.
00:19:38.960 | Whereas if you sample at a lower frequency, you're going to end up with aliasing.
00:19:41.920 | And so not all points will be legible.
00:19:45.480 | We also found that sampling the spectrum relatively densely tends to help.
00:19:49.400 | And the contrast here, at the time we were developing, was with respect to NERF.
00:19:54.160 | So NERF, at least in earlier implementations, used quite a small number of frequency bands.
00:19:59.520 | We found that the more we added, the better we did.
00:20:01.840 | So in general, this is something to be attentive to.
00:20:06.000 | And then finally, as opposed to language, where you typically have addition of whatever
00:20:10.560 | your embedding is with the sinusoidal or position encoding that you use, here we found that
00:20:16.960 | concatenating them performed consistently better.
00:20:20.320 | And so this may be because the content embedding is not as sparse as it is in language.
00:20:25.600 | We're not totally sure.
00:20:26.600 | But this is something that I observed consistently.
00:20:30.400 | And before I move on to results, I just want to contrast this to some other approaches
00:20:34.400 | for using transformers in the image context.
00:20:38.720 | So the obvious precedent here is visual transformers.
00:20:42.920 | And I think this is a very-- this line of work is great, especially in the image context.
00:20:48.400 | But there are some caveats about it that make it less suitable for sort of more general
00:20:53.000 | purpose use.
00:20:54.840 | So one is that-- so vision transformers do use an input 2D convolution.
00:20:59.060 | So this is often phrased in terms of patches, input patches.
00:21:03.040 | It's a special case of a 2D transformer.
00:21:05.680 | So it does restrict the class of inputs you can use it for.
00:21:10.520 | And because we're basically building this patching or convolution into it, this means
00:21:16.140 | that this as an approach really isn't sufficient to get it to work on non-grid data.
00:21:20.560 | There are other ways you could adapt it.
00:21:21.720 | But this is something that you will have to special case for every domain you're looking
00:21:25.400 | And then finally, because we have this sort of input where we're telling the architecture
00:21:30.640 | what it should look at first in the initial grouping, this does amount to getting rid
00:21:34.400 | of the non-locality assumption.
00:21:36.560 | It's not super clear how much doing this just once will make a difference.
00:21:40.600 | But this is something to be aware of when you're thinking about this architecture.
00:21:45.000 | And then finally, cross-attention itself is used quite broadly in the vision literature.
00:21:50.100 | So just to highlight a couple of examples, Detter, which is an object detection method
00:21:56.520 | from Facebook, basically has a convolutional backbone that's then used to give an output
00:22:01.960 | feature map.
00:22:03.120 | This is then passed into a transformer encoder decoder.
00:22:06.000 | And of course, whenever you hear encoder decoder, you think cross-attention, because from the
00:22:09.520 | encoder to the decoder, there's a cross-attention step.
00:22:12.560 | And so they're using basically the cross-attention to go from some feature map representation
00:22:16.480 | to something that looks more like the object bounding boxes.
00:22:20.880 | There's also quite nice work on learning self-supervised or unsupervised object segmentation models.
00:22:30.280 | And in this work, they're doing something very similar where they have a convolutional
00:22:33.320 | backbone.
00:22:34.320 | They then use something like the latents that we introduce here to do-- they call them slots
00:22:42.040 | here.
00:22:43.040 | But basically, to assign some of the output pixels to different slots so that they sort
00:22:48.040 | of have independent complementary decoding of the slots in the segmentation model here.
00:22:53.480 | And there's a lot of other things.
00:22:58.640 | So first, I just-- so now I'm going to sort of walk you through the results of this model.
00:23:02.960 | Can I-- oh, go ahead.
00:23:03.960 | Yeah.
00:23:04.960 | I'll go after you.
00:23:05.960 | Go for it.
00:23:07.960 | Cool.
00:23:08.960 | Sorry for that.
00:23:09.960 | Can you go back a couple of slides where you had the-- how the inputs flow into, I
00:23:19.520 | think, one of-- yeah, that one.
00:23:23.040 | So I have two questions.
00:23:24.040 | So the latent transformer is basically like a self-attention.
00:23:27.080 | Is that correct?
00:23:28.840 | Yeah.
00:23:29.840 | So the latent transformer is a fully self-attentional transformer.
00:23:33.400 | Got it.
00:23:35.200 | And why-- see, for the key and value, they flow directly into the cross-attention.
00:23:41.520 | And there is the query also flowing into it.
00:23:44.440 | But the latent array is flowing into the cross-attention in parallel to the query.
00:23:49.920 | Can you explain that?
00:23:50.920 | Yeah.
00:23:51.920 | So this is just--
00:23:52.920 | [INTERPOSING VOICES]
00:23:53.920 | --here.
00:23:54.920 | I'm just trying to understand.
00:23:55.920 | Yeah.
00:23:56.920 | Yeah.
00:23:57.920 | So how do you pick the residual connection?
00:23:58.920 | So the cross-attention-- this is a cross-attention depicted as a cross-attention module.
00:24:04.200 | And so the cross-attention itself has the attention.
00:24:06.640 | It has a residual connection.
00:24:07.840 | And then there's an MLP.
00:24:09.000 | So that's what that's meant to indicate.
00:24:11.640 | But it's basically-- the QKV is standard.
00:24:14.800 | Got it.
00:24:16.400 | Thanks.
00:24:17.400 | Mm-hmm.
00:24:19.400 | I had a question that is slightly related to this.
00:24:22.320 | We can just stay off the slide, actually.
00:24:24.160 | So I think one thing that's interesting about this argument--
00:24:26.400 | [INAUDIBLE]
00:24:27.400 | I'm sorry.
00:24:28.400 | I lost you.
00:24:29.400 | You're cutting off.
00:24:30.400 | It's mostly consisting of attention layers, whether it's self-attention or in image transformers.
00:24:40.080 | Can you hear me?
00:24:41.080 | Is that coming through?
00:24:42.080 | No, it's cutting off.
00:24:43.080 | I think--
00:24:44.080 | Yeah.
00:24:45.080 | But I think some recent--
00:24:46.080 | Oh, OK.
00:24:47.080 | Should I type it?
00:24:48.080 | Is that-- should I type?
00:24:49.080 | Yeah.
00:24:50.080 | I think that's a good idea.
00:24:51.080 | Yeah.
00:24:53.080 | I'll type it.
00:24:54.080 | [INAUDIBLE]
00:24:55.080 | All right.
00:24:56.080 | It's kind of [INAUDIBLE]
00:24:58.600 | Feel free to go ahead.
00:24:59.600 | I'll type it slowly.
00:25:00.960 | And--
00:25:01.960 | Sounds good.
00:25:02.960 | Sounds good to me.
00:25:03.960 | Yeah.
00:25:04.960 | Actually, can I chime in?
00:25:05.960 | Drew, while you're on that previous slide--
00:25:07.220 | Yeah.
00:25:08.220 | --on the flow, so these residual connections, I actually didn't know the cross-attention
00:25:12.400 | used them.
00:25:13.400 | How reliant are these sequential cross-attention layers on the residual connections?
00:25:18.560 | Yeah.
00:25:20.040 | So here, in the initial-- so two things I will say is that in the initial cross-attention,
00:25:30.680 | it doesn't really make a difference.
00:25:32.080 | So this is something we've ablated.
00:25:34.560 | When we get to the Perceiver I/O version of this, we also did the same thing in the decoder
00:25:38.400 | cross-attention.
00:25:39.400 | And it can make some of a difference-- it can make a difference there, depending on
00:25:41.720 | what you're doing.
00:25:42.720 | I think it's actually essential when you're using repeated cross-attention of this way,
00:25:47.680 | so when you have this sort of iterative structure.
00:25:49.720 | And the reason for this is that the thing that's actually used to condition the query
00:25:54.580 | is basically all the-- that's your full representation of the state of the architecture so far.
00:26:00.000 | And so the skip connection is from the-- it's in, basically, the query channel.
00:26:04.600 | It's in the latent space.
00:26:06.660 | And so this is basically what allows you to end up with this sort of dense and stable
00:26:10.240 | architecture.
00:26:12.240 | Thank you.
00:26:13.240 | Mm-hmm.
00:26:15.240 | So to ImageNet-- OK, so in standard ImageNet processing, basically, we compare against
00:26:28.600 | a few-- so this is a little bit out of date at this point-- but against a few just sanity
00:26:34.280 | check baselines here.
00:26:36.200 | So comparing against ResNet-50 and then, at the time, the best vision transformer model
00:26:40.280 | that was purely on ImageNet.
00:26:42.300 | And we're definitely in the ballpark.
00:26:44.100 | This isn't-- these aren't anywhere near state-of-the-art results.
00:26:47.620 | But this is an architecture that, again, is not using any 2D convolutions.
00:26:51.580 | And so the fact that it was able to do this well was, we found, very, very surprising
00:26:54.760 | at the time.
00:26:55.760 | One of the quite cool things about this is that, because this architecture is not making
00:27:00.400 | any assumptions-- the architecture itself isn't making any assumptions about the spatial
00:27:05.760 | structure of the input images-- we can look at permuted ImageNet.
00:27:10.720 | And in the first version of this, what we do is, basically, we compute the features
00:27:14.960 | using the 2D position.
00:27:16.360 | So the 2D position is fixed to a position-- to the pixel.
00:27:19.580 | And then we just shuffle them all.
00:27:21.040 | And so this is-- basically, we'll give you a sense of how dependent the baselines are
00:27:25.520 | on the input image structure.
00:27:28.360 | And so if we look at the transformer and perceiver, by construction, they don't change.
00:27:33.080 | So this is not an empirical finding.
00:27:34.720 | This is a property of the models.
00:27:36.920 | But we find that ResNet-50 falls by about-- the performance falls by about half.
00:27:41.920 | And VIT, which, again, only has one layer where it's relying on the spatial structure,
00:27:45.880 | also has about a 15-point drop.
00:27:48.200 | And so this suggests that it's relying quite a lot on that very first one to give it some
00:27:53.000 | information about the structure.
00:27:55.560 | We can push this a little bit by, instead of relying on 2D Fourier features, learning
00:28:00.880 | completely learned positional encodings.
00:28:03.600 | And this basically-- this is an architecture now-- this is a model that has absolutely
00:28:07.240 | no information about the input structure.
00:28:10.120 | And so shuffling them and learning them again is absolutely equivalent.
00:28:13.620 | And we find that this architecture also can be pushed above 70%.
00:28:16.920 | And we've gotten slightly better numbers here.
00:28:19.020 | In general, this seems to work worse.
00:28:21.860 | So the 2D information is useful.
00:28:24.300 | But it's quite cool that you can get what would have been numbers comparable to state
00:28:29.120 | of the art about five or six years ago.
00:28:31.080 | So this is quite cool.
00:28:32.400 | Sorry, I'm a little thick here.
00:28:35.920 | You're saying the difference between the last two rows is that the second-to-last row has
00:28:42.360 | a two-dimensional position embedding, and the last one has a one-dimensional position
00:28:45.680 | embedding, essentially.
00:28:46.680 | Is that right?
00:28:47.680 | So it's learned.
00:28:48.900 | So it's basically-- it'll be-- it's, I believe, a 256-dimensional vector that's learned.
00:28:56.560 | But it doesn't-- it basically-- it means that the model itself has no information about
00:29:01.360 | the input spatial structure.
00:29:04.600 | So the 2D positional encodings that we're using end up having about 200-- it's 200-some
00:29:09.880 | features, depending on what you're looking at.
00:29:11.960 | But they're always-- they give you very detailed information about the 2D structure of the
00:29:15.440 | input, because they're based on a Fourier decomposition of the input space.
00:29:18.360 | I see.
00:29:20.360 | That makes sense.
00:29:21.360 | Thank you.
00:29:22.360 | Hi, Drew.
00:29:23.360 | Can I ask a question about frequency you use to generate those sensorial waves?
00:29:30.440 | Yeah.
00:29:31.440 | So like a couple of slides before.
00:29:36.200 | Yeah.
00:29:37.200 | Yeah.
00:29:38.200 | Yeah.
00:29:39.200 | Yeah.
00:29:40.200 | This slide.
00:29:41.200 | So basically, I do have taken some lectures in signal processing, and I know if I want
00:29:48.900 | to avoid aliasing, I need to sample with at least Nyquist frequency.
00:29:54.360 | So I'm curious to know why do you use frequency starting from 1 to the Nyquist frequency,
00:30:01.000 | instead of starting from Nyquist frequency to some very high frequency?
00:30:05.760 | Oh, I see.
00:30:07.260 | So basically-- so the maximum frequency that's used is always Nyquist.
00:30:13.920 | So anything about Nyquist is going to be aliased, so you're not actually going to be able to
00:30:16.960 | resolve it, because it's in pixel space, right?
00:30:20.780 | So we sample-- one is basically just giving you an oscillation that covers the entire
00:30:25.320 | image.
00:30:26.320 | Yeah.
00:30:27.320 | And so this is basically just a sample of the full range of non-aliased frequencies.
00:30:31.320 | Oh, OK.
00:30:32.720 | Cool.
00:30:33.720 | Thank you.
00:30:36.720 | Yeah.
00:30:38.720 | So after the image results, we wanted to try it on other domains.
00:30:48.680 | And in particular, we were interested in how this could be used to work on sort of multimodal
00:30:52.480 | domains, so ones combining various different types of input features.
00:30:57.480 | And one challenge or one sort of problem that you encounter in these sorts of spaces is
00:31:02.440 | that the data from different modalities end up having different features, and they always
00:31:07.100 | have different semantics.
00:31:08.560 | So if you take the positional encoding plus the RGB for video, you end up with some number
00:31:13.640 | of channels.
00:31:14.640 | And then if you have audio, that corresponds-- the data may be paired, but it tends to have
00:31:18.300 | fewer features, and it only has a 1D positional encoding.
00:31:21.940 | So the way that we handle this is basically by learning modality-specific position encodings.
00:31:27.860 | And so these are basically embeddings that are special and learned for each of the modalities.
00:31:33.100 | And what this does is basically tags-- ends up tagging the features that come from audio
00:31:37.300 | or video with some information that the network can learn that allows it to distinguish which
00:31:41.660 | one's which.
00:31:42.920 | But given these padded-- these sort of learned padded feature vectors, we then concatenate
00:31:48.500 | them all, and that's how we process multimodal data.
00:31:51.640 | So basically, the input to the architecture still looks like just one big array.
00:31:55.120 | It's just that when constructing this, we know that some of those features, some of
00:31:58.060 | the rows in that array, come from video and some come from audio.
00:32:01.220 | But the model itself isn't given information about that other than what it learns.
00:32:04.460 | Oh, we also have some questions.
00:32:08.100 | Yeah.
00:32:09.100 | You can go first.
00:32:10.100 | Can we turn?
00:32:12.100 | Yeah, sorry.
00:32:13.100 | I thought it was [INAUDIBLE] but yeah.
00:32:14.100 | Yeah, sure.
00:32:15.100 | If you can hear me, this is just a simple reason.
00:32:16.100 | I haven't studied a lot of transformer stuff formally, so I just didn't know what a positional
00:32:17.100 | embedding was.
00:32:18.100 | Oh, so what a positional embedding?
00:32:20.100 | So basically, a positional embedding is-- it's a feature that says this-- so the simplest
00:32:41.320 | way to think about it is in text.
00:32:43.000 | So text, the input is 1D, so things live in some 1D sequence.
00:32:46.680 | And for each point there, you featurize where it's located in that sequence.
00:32:50.540 | So the simplest thing to do would be if you have negative 1 to 1 is the full range.
00:32:55.000 | It just denotes actually where it's located in that sequence.
00:32:58.440 | But we typically will add-- we'll want to featurize this to have more dimensions than
00:33:03.980 | just a single one.
00:33:05.240 | And so the Fourier decomposition is one way to do this to give it privileged information
00:33:10.560 | about the high frequency structure.
00:33:12.960 | But we can also just use the position to index onto some embedding array, which is how we
00:33:18.960 | do it when we're learning things.
00:33:20.440 | So basically, it's just a set of weights that are added to the feature for that point that
00:33:24.160 | give the network information about where it's located in the Groucher sequence.
00:33:32.280 | You want to go next?
00:33:33.280 | Sorry, I had to find a mute button-- unmute button.
00:33:41.160 | OK, so I actually have two questions regarding the Fourier features.
00:33:48.640 | I think-- do you guys sample them uniformly, or are they-- do you learn these?
00:33:58.720 | Yeah, so basically, we sample them linearly.
00:34:03.080 | So basically, we take the full space, and we sample them linearly with whatever the
00:34:05.800 | budget is.
00:34:08.040 | There are-- so in various settings, we have actually tried learning these.
00:34:12.160 | So you could actually initialize an array with them and then learn them.
00:34:15.560 | And that does help sometimes, actually.
00:34:18.440 | And you could potentially learn-- you could try a more sophisticated strategy on this
00:34:23.040 | OK, cool.
00:34:24.040 | My follow-up question is that basically, I feel like the selling point of your research
00:34:29.680 | is that you don't make any structural assumptions.
00:34:33.040 | You can take any type of format.
00:34:35.200 | However, for the encoding, wouldn't the dimensionality-- so for example, if it's text, it's 1D, right?
00:34:43.640 | If it's an image, it will be 2D.
00:34:46.960 | And if it's a video, it will be 3D.
00:34:49.960 | You have more-- the positional encoding will have more points, right?
00:34:56.480 | Wouldn't that inherently give away the nature of the input?
00:35:03.360 | Yeah, so it does.
00:35:04.960 | So I completely agree with this.
00:35:06.480 | You're totally right.
00:35:08.600 | The version of this where we have learned position encodings is the most pure from that
00:35:12.000 | point of view.
00:35:13.000 | So it's one that gives it basically no information about the ground truth spatial structure.
00:35:18.260 | What it does give the model-- so when you do the learned position encoding, it will
00:35:22.360 | say that, for example, there is a correspondence between point k on image 1 and point k on
00:35:28.800 | image 2.
00:35:30.060 | So that's basically the least amount of information you can give it while still allowing
00:35:34.920 | it to figure out what the structural relationship between the input points is.
00:35:38.920 | So this is the direction that we've been trying to push in.
00:35:42.000 | In general, giving the architecture access to ground truth structural information, like
00:35:47.320 | this lives on this point in 2D, is helpful.
00:35:50.480 | So there's a couple things here.
00:35:53.400 | There's from a practical point of view, if you want good results, you need to exploit
00:35:57.680 | these things, or it's helpful to exploit these things.
00:36:00.480 | But we do want to move in the direction where we're relying on these things less.
00:36:04.400 | And so this is basically something we're actively looking into.
00:36:07.440 | OK, makes sense.
00:36:09.440 | Thank you.
00:36:10.440 | So I think has posted her question on the chat.
00:36:15.920 | I also see you have your hand raised.
00:36:18.280 | So if you want, you can give it a try.
00:36:21.560 | If not, I'll read out the question.
00:36:23.360 | OK, I'll try.
00:36:24.760 | Just let me know if it's choppy.
00:36:26.560 | Yeah, so is it good right now?
00:36:29.280 | So far, so good.
00:36:30.280 | Yeah.
00:36:31.280 | Oh, good.
00:36:32.280 | OK, cool.
00:36:33.280 | So if you look at the perceiver diagram you had, it's a bunch of attention layers, right?
00:36:36.320 | Like cross-attention and self-attention.
00:36:38.360 | And I think there's been this small trend in recent work in vision transformers to try
00:36:44.440 | to sort of replace the last few layers instead of having attention, like make them be convolutions
00:36:49.200 | to address this attention scaling problem, right, in a different manner.
00:36:53.600 | And so here, the perceiver architecture is trying to make self-attention less expensive.
00:36:58.600 | And there, they're just trying to replace it, and they kind of just avoid the problem.
00:37:03.640 | And so I'm curious.
00:37:04.640 | And so I've seen papers both ways, like some that try to do things like the ones you cited,
00:37:08.760 | and then some that are trying to do this as well.
00:37:11.120 | And in my mind, everyone always has the good results and stuff.
00:37:14.360 | So I'm curious if you think there's a reason to do one or the other, or if you think this
00:37:18.560 | alternative approach is also promising, or is there a reason research should go in one
00:37:23.840 | direction or the other?
00:37:26.440 | Yeah, so to my mind, the big trade-off is one between...
00:37:29.040 | So the vision literature, I think, has just exploded in terms of these sort of hybrids
00:37:34.040 | and people trying to find the exact right place on the Pareto curve for the trade-off
00:37:38.240 | of speed and performance.
00:37:41.680 | But they're basically looking primarily on vision-specific problems.
00:37:44.900 | So something that the computer vision community itself typically doesn't regularize itself
00:37:49.800 | away from things that don't work on things that aren't vision.
00:37:54.080 | So you end up with things that are very, very efficient and very performant on vision problems.
00:37:59.040 | So I think from that point of view, it's an incredibly important line of work, and that's
00:38:03.040 | probably the right way of doing things.
00:38:06.780 | What we're aiming for is the things that are as general as possible while still being performant.
00:38:14.080 | Got it.
00:38:15.080 | So this kind of thing is critical...
00:38:16.640 | Oh, sorry to cut you off.
00:38:17.640 | Go ahead.
00:38:18.640 | No, no.
00:38:19.640 | Please, go ahead.
00:38:20.640 | I was going to say, so this kind of thing is important...
00:38:23.380 | Just to summarize.
00:38:24.380 | So you feel like it's important to focus on attention because that's kind of critical
00:38:26.800 | for NLP.
00:38:27.800 | Like you can't just sort of put in a convolution at the end and sort of fix the problem.
00:38:31.180 | But in vision, maybe you can and it's fine.
00:38:33.160 | Is that a right way of understanding it?
00:38:36.280 | That's part of it.
00:38:37.280 | Vision and NLP aren't the only two domains.
00:38:39.280 | And so the thing that we're looking for are really basically...
00:38:42.600 | So the kinds of problems that we're interested in doing with this include things like event-based
00:38:48.000 | cameras, cell biology, sort of proteins, all of these sorts of things where we may or may
00:38:54.680 | not have the right convolutional inductive biases to even know how to build those sorts
00:38:58.700 | of things.
00:38:59.700 | Got it.
00:39:00.700 | They end up being whole research programs, like the mesh-based convolution work.
00:39:04.260 | Oh, cool.
00:39:05.260 | Thank you.
00:39:06.260 | I also had one more question about the architecture.
00:39:09.260 | So I saw that...
00:39:10.260 | I'm sorry if you said this and I just missed it, but you had cross-attention and then like
00:39:15.180 | that tritium transformer and then cross-attention.
00:39:18.020 | I'm curious what happens if you replace the self-attention in those layers with cross-attention.
00:39:22.500 | Does it affect your accuracy?
00:39:24.180 | Is that even feasible?
00:39:25.180 | Is that a valid question?
00:39:28.300 | Yeah.
00:39:29.300 | So the sort of thing that you could do is you could modify this to make it sort of hierarchical
00:39:32.620 | so that there are multiple stages of cross-attention.
00:39:35.100 | We haven't gotten this working yet, but it doesn't mean it's not a good idea.
00:39:40.820 | So there might be a right way to do this that we haven't figured out right, but it's something
00:39:44.380 | we have tried a little bit.
00:39:45.380 | Oh, cool.
00:39:46.380 | Okay.
00:39:47.380 | Thank you so much.
00:39:48.380 | I appreciate it.
00:39:49.380 | Yeah, no problem.
00:39:50.380 | Okay.
00:39:51.380 | Let me...
00:39:52.380 | We're running short on time, so maybe I'll skip ahead.
00:39:57.380 | Okay.
00:39:58.380 | So before we run out of too much time, I want to at least talk about the sort of the modifications
00:40:07.700 | to this architecture that we've made to make it work sort of even more generally.
00:40:12.020 | So one of the problems of the sort of the first architecture that we looked at here,
00:40:16.580 | the basic perceiver, is that it works basically for arbitrary inputs, but it's designed to
00:40:24.460 | work only on classification or regression tasks as an output.
00:40:28.400 | And so basically we wanted to see if we could use the same cross-attention strategy for
00:40:32.540 | decoding and it turns out you can.
00:40:34.420 | It's something that works pretty well, just kind of out of the box.
00:40:37.820 | So the idea is that we have, if we have our cross-attention input and self-attention sort
00:40:43.060 | of to do the processing, we can introduce a set of additional queries.
00:40:48.140 | And these are basically queries that give the semantics of each of the points that you're
00:40:51.780 | trying to decode.
00:40:53.860 | And we pass these as input to another cross-attention layer, which is configured in basically the
00:40:58.860 | opposite way that the encoder cross-attention is configured.
00:41:02.200 | So now the queries are going to be something that's potentially large and the keys and
00:41:05.660 | values are coming from this latent.
00:41:07.660 | And so what this allows us to do basically is to keep all of the nice advantages of the
00:41:11.740 | original perceiver.
00:41:13.020 | So we have an encoder that scales linearly, we have a processor stage, this sort of latent
00:41:16.980 | self-attention that scales independently of the input size, and we now have a decoder
00:41:21.940 | that keeps the decoupling, but gives us linear scaling with respect to output size.
00:41:27.020 | And so by doing this, we can now basically apply the same approach to basically dense
00:41:33.820 | output tasks.
00:41:35.420 | And so to give you a sense of how this works, just sort of intuitively, if we're doing auto
00:41:39.940 | encoding on this image of puppies, basically what we do is we encode process, and then
00:41:45.940 | to decode, we take a query that corresponds to each of the points, and then we pass it
00:41:51.100 | into this decoder.
00:41:52.340 | So we can query one of the points, we get one pixel, query another one, we get another
00:41:56.380 | one, and all the way up till we get all 10,000 points.
00:42:00.340 | And that's how we can do reconstruction with this.
00:42:02.980 | And the cool thing about this is that it opens up a bunch of new applications.
00:42:08.940 | And we can get different kinds of outputs just by changing how the queries work.
00:42:12.460 | So if we want to do something like multimodal auto encoding, where we have some of the outputs
00:42:16.340 | are videos, we use the same construction trick to get positions, to get queries that have
00:42:22.980 | the relevant semantics for each of the points that we're decoding.
00:42:25.940 | And we can do this even though basically the sizes of these different data, so the number
00:42:30.540 | of points they have is quite diverse.
00:42:32.820 | So in the multimodal auto encoding experiments that we have in this paper, we do this for
00:42:37.180 | video, audio, and labels at the same time, so that all of them are just passed into their
00:42:40.820 | uniform network, and then decoded one by one in this way.
00:42:44.900 | But we can also do mass language modeling now by conditioning on the position in a sequence.
00:42:51.340 | We can do multitask classification by having basically an index that gives which task you're
00:42:56.500 | querying from the network.
00:42:58.940 | And we can do things like optical flow by passing in input features as well as the positions.
00:43:03.940 | And so I'm just going to just quickly skip to a couple of the different-- I can share
00:43:10.300 | these slides with you all afterwards to look through them.
00:43:14.620 | Some of these things are quite cool.
00:43:16.620 | But just quickly, I want to talk about language and then optical flow.
00:43:21.980 | So for language, basically what we wanted to do with this was to see if we could use
00:43:27.180 | this to replace tokenization.
00:43:30.020 | And why might we care about getting rid of tokenization?
00:43:33.580 | So one, we use tokenization primarily because transformers scale poorly with sequence length.
00:43:38.980 | And tokenizing cuts sequence length by about a factor of four.
00:43:44.660 | But there are various problems that arise with this.
00:43:46.900 | And so why might we care about removing tokenizers?
00:43:52.380 | So for one, tokenizers perform less well on rare words.
00:43:58.220 | So if you compare the sort of the byte-based decomposition, the UTF-8 encoding of an input
00:44:04.620 | sequence like this, you can see that there's basically a uniform allocation of points in
00:44:10.220 | memory to each of the input characters.
00:44:12.260 | The exception are diacritics, which end up splitting into two.
00:44:16.140 | But if you look at the sentence piece tokenization, so it's learned that pepper is one token,
00:44:21.620 | but jalapeno gets split into five in this case.
00:44:25.520 | So this basically says the amount of capacity that you allocate depends on how rare the
00:44:31.220 | word is, which can lead to suboptimal encodings.
00:44:35.780 | They're also brittle to subtle perturbations.
00:44:37.740 | A famous example of this is that if you've ever played around with GPT-3, you'll notice
00:44:44.500 | that the output can be quite sensitive to if you add a space or emit a space at the
00:44:49.580 | And that basically is because the space can end up being factorized into different parts
00:44:52.380 | of the tokenization.
00:44:53.380 | There are other things that can happen there, too, but this is one cause of that.
00:44:58.780 | And finally, tokens don't transfer across languages.
00:45:01.620 | So if you wanted to have a model that without any tuning could be used on many different
00:45:05.220 | languages at the same time, tokenizers are a blocker for this.
00:45:08.620 | So if we can get rid of them, it'll simplify the pipeline, it'll also make things less
00:45:12.140 | brittle, and then hopefully lead to more general models.
00:45:16.620 | So the way that we do mass language modeling is the same as the way that I showed in that
00:45:20.260 | schematic auto-encoding experiment.
00:45:23.020 | So we mask some fraction of our inputs, about 15%, is sort of the standard magic number.
00:45:28.780 | We then decode at each of the positions that are masked, and we task the model with decoding
00:45:33.980 | whatever characters were masked at those locations.
00:45:38.420 | And then once we have this model, so this is what we do for pre-training, we can then
00:45:41.740 | fine tune it by replacing the decoder with a multitask decoder that takes in the tasks
00:45:47.180 | that we're using on the downstream evaluation setting, and training the model to reconstruct
00:45:53.060 | the logits on a per-task basis.
00:45:56.340 | Okay, so to look at how this model performs, we basically first compare it to BERT base.
00:46:04.420 | So this is just a solid benchmark that we understand very well.
00:46:07.860 | And first, by looking at sort of matched, two models that have matched flops, we can
00:46:13.420 | see that Perceiver IO and BERT base work on par.
00:46:18.420 | You see there's a different trade-off here.
00:46:19.980 | So to get the same number of flops, basically we make Perceiver IO deeper, and this ends
00:46:24.660 | up giving it more parameters, but on a per-flops basis, it ends up performing about the same.
00:46:31.900 | On the other hand, if we remove the tokenizer from BERT and keep the flops the same, we
00:46:36.620 | see that the number of parameters and the depth just drastically fall down, and this
00:46:41.540 | is because BERT scales quite poorly with sequence length, because it uses a normal transformer.
00:46:47.260 | But if we use a Perceiver without the tokenization, we can see that we only get a slight reduction
00:46:54.140 | in the number of parameters at the flops count, but the performance performs almost exactly
00:46:59.220 | the same.
00:47:00.500 | So this means that the Perceiver in this setting is performing basically the same with and
00:47:03.980 | without the tokenization.
00:47:04.980 | It's learning a different strategy, it's using different parameters, but it basically can
00:47:08.380 | be brought to the same performance.
00:47:10.980 | We can then scale this more by leaning into what happens in the tokenizer-free setting,
00:47:15.740 | and we see that we can get a moderate performance boost as well.
00:47:20.660 | I think it's also, in the language setting, it's also useful to look at what the attention
00:47:24.820 | maps that are learned, and what's being visualized here are basically, for each of the latents,
00:47:29.740 | for some subset of the latents, we're looking at where it's attending to in the input sequence,
00:47:35.220 | and some of these end up being local, so looking at specific points in the sentence.
00:47:40.780 | Some of them are periodic, so they look at recurring points over the sequence, and some
00:47:46.020 | of them also look like they pick out syntactic features, which is quite nice.
00:47:50.260 | So they pick out basically exclamation points, or capital letters, or other punctuation that's
00:47:55.460 | quite useful and decodable right at the beginning of the sequence.
00:48:02.420 | We can also basically use this exact same architecture on optical flow, and optical
00:48:08.100 | flow is basically an important classical problem in computer vision, where given a pair of
00:48:13.220 | frames in a video, we want to basically track all of the points, so figure out the motion
00:48:18.340 | from every point from one frame to the other.
00:48:21.540 | And so optical flow is usually visualized using these sort of colorized images that
00:48:25.900 | are shown on the bottom, and what this gives you basically is a per pixel indication of
00:48:30.460 | the velocity at every single point.
00:48:33.420 | And so you can see that, so the blade that the character here is holding is moving to
00:48:40.060 | the right, whereas this creature behind her is sort of moving downwards.
00:48:46.980 | So there are a couple of problems with optical flow that make it interesting to sort of approach.
00:48:52.900 | So one is it's a dense task, and it basically involves long range correspondences, but the
00:48:59.060 | standard training protocol, there's basically no large scale realistic training data, just
00:49:02.660 | because it's incredibly hard to sort of label all of the pixels in a real world scene and
00:49:06.660 | figure out where they go to.
00:49:08.540 | So the typical way to do this is to train on some synthetic data, and then evaluate
00:49:11.980 | on more realistic scenes, and optical flow is also interesting, because it's basically
00:49:19.620 | the locus of some of the most complicated visual architectures in the literature.
00:49:24.780 | So the previous state of the art result here is this method called raft, which won the
00:49:28.860 | best paper award at DCCV last year.
00:49:31.540 | And I'm just highlighting this to give you a sense of how much work people do into sort
00:49:34.980 | of hand engineering these architectures.
00:49:37.580 | So this is a very, very cleverly designed architecture, and basically it incorporates
00:49:41.580 | things like global correlation volumes that are explicitly computed at different offsets
00:49:46.340 | to basically allow the model to reason about how things at different scales are moving
00:49:50.580 | with respect to each other.
00:49:53.460 | It also has local neighborhood gather operations, as well as update blocks to keep track of
00:49:58.980 | what's happening within each specific correlation block.
00:50:02.180 | And then finally, there's a flow-specific upsampling operators that were developed.
00:50:07.560 | So in contrast to this, we're basically-- we wanted to see how well Perceiver.io would
00:50:13.100 | do here.
00:50:14.240 | And just to give you a sense of sort of what we were expecting coming into this, we thought,
00:50:17.820 | well, maybe-- so Perceiver.io was throwing a lot of the structure away, so we were hoping
00:50:21.540 | that we would get some good results, but it would probably overfit, and there's this sort
00:50:25.060 | of problem of the domain transfer that's happening here.
00:50:27.780 | But on the other hand, self-attention seems to be a reasonable way to match this sort
00:50:31.060 | of correspondence thing.
00:50:33.260 | What we actually found was that just by doing the very, very simple preprocessing here,
00:50:38.900 | so extracting basically a patch around each pixel, and then using the standard Perceiver.io
00:50:45.100 | architecture, we were able to get state-of-the-art results here.
00:50:47.820 | And so this is basically-- was validation of this general approach of trying to have
00:50:55.140 | general-purpose architectures that would transfer over.
00:50:58.460 | And so basically, with minimal tuning, we were able to get results that would beat both
00:51:04.380 | of the sort of compelling benchmarks on both of the Sintel evaluation methods, and to get
00:51:11.900 | comparable results on KITTI.
00:51:13.620 | So these are the standard ones.
00:51:15.540 | And we can also sort of visualize what happens when we apply this on real-world data.
00:51:20.620 | So there's no ground truth here, so we can't really compare it, but it's still useful to
00:51:24.780 | sort of see how it moves around.
00:51:27.060 | And we can see that qualitatively, it's able to capture a lot of the fine structure, and
00:51:31.700 | to sort of get the right motion for the things that are very clearly moving in a specific
00:51:37.380 | direction.
00:51:39.540 | We can also sort of-- it's also, I think, informative to look at what happens, how it
00:51:43.780 | manages to represent sort of small structure.
00:51:47.620 | Is this video playing?
00:51:48.620 | Yeah, we can see it.
00:51:51.540 | OK, cool.
00:51:53.300 | So the thing to look at here is the fine water droplets that are sort of flying through the
00:51:57.200 | air as that bird flies by.
00:51:59.340 | And because we're decoding at every single output point, the architecture is able to
00:52:04.140 | represent those.
00:52:05.420 | So it's able to capture very, very fine-scale segmentation that would be difficult to capture
00:52:09.500 | if you had, for example, a convolutional upsampler here.
00:52:12.860 | OK, so I'm just going to-- oh, the light has gone off in this room.
00:52:20.860 | I'm also curious if you also try other tasks like depth estimation, because Perseverio
00:52:26.620 | looks like it can do also quite well on that modalities.
00:52:30.500 | Yeah, so we haven't published anything, but some internal results suggest that it works
00:52:36.420 | just fine.
00:52:37.980 | There basically-- there don't seem to be-- one of the surprising things, the things that
00:52:41.860 | we were a little bit unsure about, was how much information was going to be contained
00:52:45.180 | in this latent.
00:52:46.700 | Because basically, you're abstracting quite a lot, and it doesn't have any 2D structure
00:52:50.380 | intrinsically.
00:52:51.900 | But it does seem like this-- it seems to be able to represent things quite well.
00:52:56.820 | And these sorts of decoding mechanisms do seem to be able to do that.
00:53:00.500 | Got it.
00:53:01.500 | Great.
00:53:02.500 | So I'm just going to-- just in the interest of time, I'm going to skip ahead to the conclusion.
00:53:08.420 | Drew, I had one question with respect to the metrics that you've shared for the optical
00:53:13.380 | flow, the number.
00:53:15.380 | So in the table, it was like Sintel, Final, Clean, and Kitty.
00:53:20.420 | Were these different data sets or different metrics?
00:53:23.900 | Same metric for different data sets, or these are three different metrics?
00:53:27.540 | Yeah, so these are three different data sets.
00:53:30.140 | So Sintel Clean and Sintel Final are basically two-- they're two ways of doing the final
00:53:35.180 | rendering for Sintel.
00:53:37.620 | In all cases, these methods are trained just on the autoflow data set.
00:53:41.780 | So they're trained on this sort of general purpose kind of wacky synthetic motion data
00:53:48.540 | And then we're evaluating on these different demands without fine-tuning.
00:53:52.060 | Yeah, the flow has quite-- the data sets are quite small, so it's generally even problematic
00:54:00.140 | to fine-tune.
00:54:01.140 | Thank you.
00:54:02.580 | Mm-hmm.
00:54:03.580 | OK, so just to summarize--
00:54:07.540 | What was the ground truth to find the endpoint error?
00:54:12.740 | Yeah, so the way this works is Sintel is a computer-- it's basically a relatively high-quality
00:54:20.500 | CGI movie that was basically open source.
00:54:24.180 | And so they actually have the ground truth.
00:54:25.740 | So if you know the ground truth 3D state, you can compute the pixel correspondence from
00:54:31.340 | frame to frame.
00:54:32.420 | So that's what's used on Sintel.
00:54:33.940 | And then KITTI, they basically have a LiDAR sensor that's used to figure out the depth
00:54:39.740 | of all points.
00:54:40.740 | And then they compute the correspondences.
00:54:43.260 | So the ground truth is actually the ground truth optical flow.
00:54:46.460 | But in general, it's hard to get dense optical flow.
00:54:49.580 | It's very expensive to collect it.
00:54:52.540 | Great.
00:54:53.540 | Thanks.
00:54:54.540 | Mm-hmm.
00:54:56.540 | So basically, just to summarize, so the perceivers are attention-based architectures that scale
00:55:03.180 | linearly and work as drop-in replacements for transformers on a variety of settings.
00:55:07.340 | They also seem to be able to achieve results that are comparable, at least in performance,
00:55:13.420 | with models that rely on 2D convolutions.
00:55:15.140 | But of course, there is a trade-off here.
00:55:17.260 | And so it's good to be very aware of this, of generality versus speed in specific domains.
00:55:22.620 | And so as was pointed out, in settings where you can use 2D convolutions, it's certainly
00:55:27.380 | good to have them in the loop.
00:55:30.620 | It's basically a unified architecture that allows joint modeling of different modalities
00:55:36.340 | of different sizes.
00:55:38.820 | And basically, overall, it seems to be a quite flexible architecture that's able to produce
00:55:43.980 | a state-of-the-art or near state-of-the-art results on a variety of different domains.
00:55:48.460 | And in the two papers, we look at a number of other domains that I didn't talk about,
00:55:53.380 | including 3D point cloud modeling, replacing the transformer that's used in the StarCraft
00:55:58.300 | and the StarCraft behavioral cloning agent, and a couple of others.
00:56:03.260 | So we have a lot of evidence that this general approach seems to work broadly.
00:56:07.920 | And there's a lot of things we still haven't tried.
00:56:09.860 | So we're very interested in pushing this and always open for suggestions and so forth.
00:56:16.420 | So we're relying on a large body of related work because we're drawing from a lot of different
00:56:21.260 | areas here.
00:56:22.260 | So here are some highlights.
00:56:24.140 | And then I just want to thank my co-authors on this work.
00:56:29.420 | And of course, I'm happy to talk more.
00:56:31.700 | Thanks, Drew.
00:56:33.540 | Yeah, thanks a lot.
00:56:36.340 | Thanks, Drew.
00:56:37.340 | So one question I had is, so what do you think is the future of perceiver models?
00:56:44.300 | Do you think this is going to be used more in the transformer community to replace Conn
00:56:52.220 | and add to none of this stuff?
00:56:54.580 | Yeah.
00:56:55.580 | So I think, broadly speaking, I think of perceivers now as sort of-- because we know how to adapt
00:57:02.860 | them pretty well to sort of domains where we don't have a great idea of the right way
00:57:06.820 | to structure an architecture, an inductive bias.
00:57:09.940 | So I think that's one of the really strong cases for it, so settings in which you don't
00:57:15.220 | really know what the right way to structure a problem is.
00:57:18.420 | I also think these kinds of approaches can be used in conjunction with confnets for sort
00:57:24.540 | of things that are as domain agnostic as needed.
00:57:28.940 | But I think multimodal and new domains is really the-- that's where these are obvious
00:57:35.500 | choices.
00:57:36.500 | Got it.
00:57:37.500 | Also, what do you think are the current bottlenecks with this?
00:57:39.420 | And if you don't mind, if you can disclose, what are you working on towards next with
00:57:43.780 | this stuff?
00:57:45.900 | So I can't talk about too many details about that.
00:57:50.080 | But a couple of domains-- so one, we don't really have a great handle on how to use them
00:57:56.420 | on sort of small-scale data, so data where you don't have the data to sort of recover
00:58:04.300 | the inductive bias.
00:58:05.540 | So this is, I think, a really important area.
00:58:08.100 | The other thing that we haven't sort of talked about here, but you could probably imagine
00:58:11.940 | that we'd be thinking about, would be how to train on multiple modalities or sort of
00:58:15.220 | multiple things at the same time.
00:58:17.780 | So right now, all of these architectures are sort of trained in isolation.
00:58:22.420 | But there are a lot of opportunities for sort of figuring out how to pose problems together
00:58:26.940 | and use a single architecture on all of them.
00:58:28.940 | Got it.
00:58:29.940 | Also, I'm not sure if you've tried, but can you also use this for tabular data stuff?
00:58:35.860 | Yeah.
00:58:36.860 | So effectively, the architecture treats any input data as tabular data.
00:58:41.340 | So I think that's exactly the right way to think about it.
00:58:44.260 | Cool.
00:58:45.260 | Sounds good.
00:58:46.260 | Yeah.
00:58:47.260 | Thanks for the talk.
00:58:48.260 | I will open to make general questions from the students.
00:58:50.900 | So let's jump to the recording.
00:58:51.900 | Okay.
00:58:52.900 | Great.
00:58:53.900 | Thank you.