So today, I'm going to be talking about some recent work that we've been doing at DeepMind, developing this line of architectures that we're calling perceivers, and I'll be motivating this in terms of a goal that we have, which is to develop general purpose architectures. And so just right off the bat, I want to motivate why we care about general purpose architectures.
And so both of the reasons are fairly pragmatic. But basically, the idea is, if we're thinking about all of the data that we could possibly imagine collecting in the world, a lot of it basically involves what we think of as sort of traditional sense modalities, and these things range from touch and proprioception to echolocation to the kind of perception you need to ingest text, however you want to format that, to more exotic things like event-based cameras, whisker, touching with whistler senses, things like smell and depth, and all the way up to the kinds of sense modalities that we really think about when we're thinking about scientific perception.
And so basically, if we think about the full set of data and what it would take to actually model each of these different modalities, it basically looks effectively intractable to try to engineer inductive biases that will work for every single one of these. So we don't want to engineer them one by one.
This is an approach that's worked, and in some ways, it's maybe a reasonable description of how we think about developing new architectures for different problems, but it's just not going to scale. We can't afford, as a community, to hand-design inductive biases that will work for each and every one of these.
And so rather than doing that, we want to sort of build architectures that, at least at first pass, will allow us to handle everything. There's another practical argument for why we should work on general-purpose architectures, and that's because it will allow us to build simpler, more unified systems. So if you look at how, in particular, complex multimodal data streams are typically approached in the sensory, computer vision, or pattern recognition literatures, effectively, the typical way this is done is by using inductive biases that we know hold for the individual modalities and then engineer ways of combining those different subsystems.
So this can mean building specific heads, specific input modules for each of these things, and then trying out the various different ways of combining them. So this can work, but it gives us systems that, in principle, really will only work on one or a small number of domains, and it gives us systems that are very hard to maintain, tend to be fragile, tend to depend on specific processing assumptions about the input modalities.
So rather than do that, we sort of want to move in the direction of having unified blackbox architectures that kind of just work. And the idea here is that if we can get to that point, we can abstract the architecture construction process and really focus on other, more high-level problems.
So this is sort of the motivation for this line of work. And the way that we're going to be doing this is, of course, by working on the most general-purpose architecture that we have so far, which is basically a transformer, and you'll all be very familiar with the basic building blocks of a transformer.
But just at a very high level, we can think about what they do right, which is they use a general-purpose inductive bias. So they're non-local, which means they're not making domain-specific assumptions about which points should be compared to each other. Rather, they tend to be global in terms of the attentional focus that they have.
They use position as a feature rather than a hard constraint of the architecture, and this is in contrast to MLP-based architectures or ConvNets in the way that they typically work, which use position as an architectural component to constrain how compute is happening. And then, of course, finally, there's extensive weight sharing in the way that they're designed.
And because they focus on matmuls, they tend to be TPU and GPU-friendly. So these are all very nice things about the way transformers work. Of course, on the other hand, they have very poor compute memory scaling. And there are two components to this. So attention itself scales quadratically. So there's this big O of M squared L complexity at the heart of transformers.
And I like writing it this way because it really emphasizes that this is a property of-- that basically, as you make your models bigger, either at the input size or as you make them deeper, this problem is just going to get worse. And because you have this scaling in depth as well, there's another practical thing that happens here.
Because the amount of compute that we're doing is proportional to the input size, so there's no bottleneck in the way that standard transformers work, even the linear scaling becomes a problem. And so in practice, for very, very large transformers, this can often be the bottleneck that really matters. But they're both at play here.
And so we really want to tamp down both of these. And so the perspective here is that to have really general-purpose architectures, we can't have ones that are just in principle general. We have to have ones that you can actually use on the scales and the kinds of data that we care about.
And so just to-- this will all be old hat for all of you, but just the way that standard QKV attention works is basically like this. So it's all matrix multiplication. So we have some input. We compute the query keys and values by having a 1D convolution, a one-by-one convolution that we run over the input.
We then compute the attention scores. This is a matrix multiply that has the following-- these sorts of shapes. We then use the output here to compute the weights, to compute the actual output of the attention module itself. And then finally, we run this through an additional MLP, which is applied convolutionally, to get the outputs.
So this is the starting point of what we're working on here. And let me just briefly just reiterate why we would want the advantages that we have with these standard transformers. So non-locality is one of the two inductive bias principles that we have here. It's useful, I think, to contrast this to the effective locality that you get in ConvNets and what this actually means.
So if we look at, basically, as a function of depth, which inputs can see which other functions, which means how easily it is to express a function of two input points, let's say we look at this yellow and purple point here at the input. Now, I've set them as far apart as possible.
But we might ask, how deep would the effective computation have to be before you actually process these two? And if you look at a three-by-three convolution, you're having to look, basically, until the very end of the network, until you're processing these things together. And what this means is that the functions that you can express that actually look at both of these points end up being quite shallow, because they have to be built on top of this very, very deep stack that just gives you the locality.
And so in point of fact, if you look at, for example, the way ResNets work, so you have an initial block, which has a seven-by-seven convolution, and then afterwards, it's three-by-three cons all the way up, you need 28 three-by-three cons with that standard processing stack before all of the 224 by 224 pixels in an image are looking at each other.
And what this means is that in a ResNet-50, the points on the very edge of the pixels actually never see each other. And I found this a little bit counterintuitive, but it suggests that we really are constraining quite a lot the functions that are easy to express with these models.
And so there are some functions of images you just can't capture with a ResNet-50. On the other hand, if you look at an architecture that has global attention over the full input, so a transformer, if you could scale it that way, or a perceiver, as we're going to be talking about, all of the pixels can interact.
So the model can basically capture these things and express these functions much more easily than can be expressed in things that put locality first. We also-- the other interesting property of these sorts of architectures is that position is featurized. And this basically means that we're no longer sort of encoding the architectural location of something to figure out where it's located with respect to the other ones.
And this allows the network to basically use any positional information that it wants but can also discard it as it prefers. And so this is the standard way it's done, of course, in the context of architectures that use Fourier or sinusoidal-like features, but there's a lot of flexibility here.
OK, so now just thinking in terms of how ConvNets relate to transformers, sort of at the opposite end, it may look like that we have a sort of scalability versus generality trade-off. And so if we look at ConvNets, the way that they're applied-- so typically, we can think about using them on grid-structured data.
There are, of course, generalizations of convolutions that work on data sets with more interesting topology, but typically, we can think of them as operating on grids in some sort of space, whereas transformers apply to generic sets. So transformers are more general from this point of view. On the other hand, they scale much, much worse.
So ConvNets are linear, both in the input points, the filter size, and the number of layers of that architecture, whereas transformers have this quadratic scaling, and they're still linear in the depth. So from this point of view, what we're interested in doing in the perceiver line of work was to scale transformers, but to keep the generality property.
So we want something that lives in between these two extremes. And the way that we do this is by looking at self-attention and sort of modifying it in a way that allows us to scale better. So to walk through what self-attention actually does in sort of standard transformers, we take our input array, which here is written as the indices, which is the number of tokens or the number of pixels, basically the number of input units, depending on what you're looking at, and the channels.
We have a 1D convolution. So this is big O of M with respect to the Q, K, and V. We then compute the attention maps using the output of this operation. This gives us a matrix multiply, which is the source of the quadratic scaling. And then finally, we compute output features with another matrix multiply.
This is already-- we're already rate-limited here, because for even standard resolution images, M is quite large. So it's around 50,000 for standard ImageNet images, which, again, are very small. So this is something that just isn't going to work if we want deep architectures. So what we do is we replace-- at the input to the architecture, we replace the self-attention with a cross-attention layer.
And we do this using, basically, a learned query. And so we're replacing only the query from the input here with a learned component. And so these indices and channels, you can just think of these as basically working like a learned initial state for an RNN. There's a variety of names that this idea goes under in the literature.
We refer to them as sort of as latents. But they're sometimes called inducing points or other things. So the basic idea is we're learning the input to the query and keeping the key value of it the same. The downside-- or the sort of upside of this is that when we compute the attention map after this, now we basically turn this from a square matrix to a rectangular matrix and reduces the complexity of the matrix multiply to big O of Mn.
So now it's linear in the input size. And the second matrix multiply has the exact same property. So it becomes-- from quadratic, it becomes linear. And the quite cool thing about this is that, OK, so the cross-attention is linear in complexity. But the output is actually smaller. And so this, I think, is actually the more important point here is that this allows us to map something which is quite large into something that has size that's independent of the input.
So we have full control over this as a hyperparameter. And this allows us to build deep networks on top of this latent. So because this is of a small size that we can control, we can afford to have quadratic complexity on top of this. And so we use this idea-- yep.
Go ahead. Oh, sorry. I'm still a little bit confused as to how you guys are able to turn this square into a rectangle in the second step. Is it because you replaced the query with a learned something that is significantly smaller compared to the input size in the first step?
Yeah, that's exactly right. So if you look at the-- so the underlying matrix multiply here, which is written as the QK transpose, so this will basically-- so the outer dimension here has shape n, which is determined by the query. And so by shrinking that query, we're just changing the output of the matrix multiply.
OK. Thank you. Yeah. So I guess-- Sorry. Go ahead, please. OK. Cool. So basically, you only do that for the query, right? So key and value remain like the original size matrices, correct? That's right. OK. But so basically-- so I don't know what I'm not understanding, basically. So the problem for me is that for a query, now in my head, I'm looking for-- let's say I have the if token.
Now there is no if query anymore. Doesn't that cause a problem when I'm trying to use it and to compute scores? Yeah. So what's happening here is you'll have a smaller subset of queries. So if you think about this not in terms of the matrix multiplies, but in terms of comparing each query to each key.
So in normal self-attention, we have one query for each key, so every point compares to every other point, right? So here, what we've done is instead of comparing every point to every other point, we have a set of sort of cluster centers you might be able to think about them as.
So it's a smaller number, and we compare each of those to each of the input points. But we don't know which tokens technically belong to which clusters, right? That's right. So it has to be learned. Yeah, exactly. So one way to think about this, about what's happening here, is that instead of-- so in a normal self-attention transformer, by comparing all to all, we're sort of saying, OK, I know what the feature is at this point, and I want it to attend to similar features.
Here what we're saying is we're learning a bunch of supplementary points that should be sort of maximally similar to some subset of the inputs. So correct me if I'm wrong, but this is essentially doing some sort of hard attention, where you're saying instead of querying over all the points, let's select some points which we think are very similar, and only put self-attention over this hard point, like these points you have selected.
Right? Yeah, so they're related. That would be one way to think about it. The slight modifier to that idea, though, is that they basically live in an abstract space. So they're not assigned sort of one-to-one to one of the input queries, or to one of the input points. They're sort of learned, so they can be somewhere in the middle.
But I think that's a good way to think about it. That's a good intuition. But I guess one of the places where I'm a little confused here is you have here indices and indices for the two, like the purple and green matrices in the far left. But those indices are not necessarily corresponding to inputs.
Like in the NLP space, those would not necessarily be tokens, right? These are just sort of-- Exactly. --indices. But the-- That's right. --index in this case is the result of some kind of mapping from the input tokens to an n-by-d matrix. Is that right? No, it's actually-- so it basically acts like-- it's a learned set of weights, is one way to think about it.
So they function exactly the same way that learned position encodings do. So it's basically just a-- it's a learned embedding. But it's not conditioned on anything. It's just sort of-- it just is-- it's just a set of weights. Oh, OK. That makes more sense. Thank you. Mm-hmm. OK. So if there are no more questions, I'm going to keep going.
But of course, feel free to interrupt me. So the way that-- given this idea-- so we have this learned latent array, which, again, it functions sort of like an RNN initial state, or it's a set of weights. We basically randomly initialize that. And then we use this to attend onto the input byte array.
And so the byte array here is the flattened set of pixels, for example, for ImageNet. And the output of this is going to live in the same space as-- so the same index space as the latent array does. And there's residual connections in the way that you would normally do in an attention layer as well.
So once we're in the space, we can then build an architecture by taking-- by using a standard transformer but phrased in the latent space rather than in the input space. And this is going to allow us to basically end up-- because we've sort of distilled the input down to the smaller space, we can still flexibly allow all of these points to interact.
So this should be still as nearly as expressive as the transformer-- as a normal transformer is. And then each of the modules here now is quadratic in the latent size rather than the input size. So this is something that we can control quite a lot. So in the original version of the perceiver, we found it was very helpful to have additional cross-attends.
So this is certainly something that you can do. And the reason-- the intuition behind this is that if this bottleneck is quite severe, we can't maintain all of the information from the input. And so we want these queries, which are now sort of conditioned on the past, to be able to look back at the input point.
And so this is something that we found to be quite helpful when tuning for the first paper. But the caveat, I will say, is that we're no longer recommending this as best practice because these cross-attentions end up being quite heavy. But this is something that you can explore, certainly, if you want sort of more conditional queries or if you want to be able to cross-attend to new inputs that are coming in.
The other thing that we found quite helpful in the context of data sets that have a limited amount of data, which for these architectures includes ImageNet, is to allow weight sharing in depth. And so this basically just amounts to tying the weights for the different cross-attention and different self-attention layers as they're repeated.
So this ends up looking like an RNN that's unrolled in depth. So this is just at a high level. This gives us an architecture that we can apply to images but doesn't make any assumptions about image structure. So it's one that you can use elsewhere. And we give information about the input spatial structure by having positional encodings.
And here we use a 2D Fourier feature position encoding. And just to show you what that looks like here, to give you a sense. So each of the input points is assigned basically-- so you'll be in some position here. And we have sinusoidal and cosinusoidal features in 2D. So this is basically a Fourier decomposition of the position of the 2D input.
And a couple of things that we found were that if we sampled the frequency, that's the maximum frequency that's used, up to the Nyquist frequency of the signal, we end up doing better than if you use a lower version of this. And this basically is because this will allow every other point to be aware of every distinct point in the image.
Whereas if you sample at a lower frequency, you're going to end up with aliasing. And so not all points will be legible. We also found that sampling the spectrum relatively densely tends to help. And the contrast here, at the time we were developing, was with respect to NERF. So NERF, at least in earlier implementations, used quite a small number of frequency bands.
We found that the more we added, the better we did. So in general, this is something to be attentive to. And then finally, as opposed to language, where you typically have addition of whatever your embedding is with the sinusoidal or position encoding that you use, here we found that concatenating them performed consistently better.
And so this may be because the content embedding is not as sparse as it is in language. We're not totally sure. But this is something that I observed consistently. And before I move on to results, I just want to contrast this to some other approaches for using transformers in the image context.
So the obvious precedent here is visual transformers. And I think this is a very-- this line of work is great, especially in the image context. But there are some caveats about it that make it less suitable for sort of more general purpose use. So one is that-- so vision transformers do use an input 2D convolution.
So this is often phrased in terms of patches, input patches. It's a special case of a 2D transformer. So it does restrict the class of inputs you can use it for. And because we're basically building this patching or convolution into it, this means that this as an approach really isn't sufficient to get it to work on non-grid data.
There are other ways you could adapt it. But this is something that you will have to special case for every domain you're looking at. And then finally, because we have this sort of input where we're telling the architecture what it should look at first in the initial grouping, this does amount to getting rid of the non-locality assumption.
It's not super clear how much doing this just once will make a difference. But this is something to be aware of when you're thinking about this architecture. And then finally, cross-attention itself is used quite broadly in the vision literature. So just to highlight a couple of examples, Detter, which is an object detection method from Facebook, basically has a convolutional backbone that's then used to give an output feature map.
This is then passed into a transformer encoder decoder. And of course, whenever you hear encoder decoder, you think cross-attention, because from the encoder to the decoder, there's a cross-attention step. And so they're using basically the cross-attention to go from some feature map representation to something that looks more like the object bounding boxes.
There's also quite nice work on learning self-supervised or unsupervised object segmentation models. And in this work, they're doing something very similar where they have a convolutional backbone. They then use something like the latents that we introduce here to do-- they call them slots here. But basically, to assign some of the output pixels to different slots so that they sort of have independent complementary decoding of the slots in the segmentation model here.
And there's a lot of other things. OK. So first, I just-- so now I'm going to sort of walk you through the results of this model. Hi. Can I-- oh, go ahead. Yeah. I'll go after you. Go for it. OK. Cool. Sorry for that. Can you go back a couple of slides where you had the-- how the inputs flow into, I think, one of-- yeah, that one.
OK. So I have two questions. So the latent transformer is basically like a self-attention. Is that correct? Yeah. So the latent transformer is a fully self-attentional transformer. Got it. And why-- see, for the key and value, they flow directly into the cross-attention. And there is the query also flowing into it.
But the latent array is flowing into the cross-attention in parallel to the query. Can you explain that? Yeah. So this is just-- --here. I'm just trying to understand. Yeah. Yeah. So how do you pick the residual connection? So the cross-attention-- this is a cross-attention depicted as a cross-attention module.
And so the cross-attention itself has the attention. It has a residual connection. And then there's an MLP. So that's what that's meant to indicate. OK. But it's basically-- the QKV is standard. Got it. Thanks. Mm-hmm. Hi. I had a question that is slightly related to this. We can just stay off the slide, actually.
So I think one thing that's interesting about this argument-- I'm sorry. I lost you. You're cutting off. It's mostly consisting of attention layers, whether it's self-attention or in image transformers. Can you hear me? Is that coming through? No, it's cutting off. I think-- Yeah. But I think some recent-- Oh, OK.
Should I type it? Is that-- should I type? Yeah. I think that's a good idea. Yeah. OK. I'll type it. All right. It's kind of Feel free to go ahead. I'll type it slowly. And-- Sounds good. Sounds good to me. Yeah. Actually, can I chime in? Drew, while you're on that previous slide-- Yeah.
--on the flow, so these residual connections, I actually didn't know the cross-attention used them. How reliant are these sequential cross-attention layers on the residual connections? Yeah. So here, in the initial-- so two things I will say is that in the initial cross-attention, it doesn't really make a difference. So this is something we've ablated.
When we get to the Perceiver I/O version of this, we also did the same thing in the decoder cross-attention. And it can make some of a difference-- it can make a difference there, depending on what you're doing. I think it's actually essential when you're using repeated cross-attention of this way, so when you have this sort of iterative structure.
And the reason for this is that the thing that's actually used to condition the query is basically all the-- that's your full representation of the state of the architecture so far. And so the skip connection is from the-- it's in, basically, the query channel. It's in the latent space.
And so this is basically what allows you to end up with this sort of dense and stable architecture. OK. Thank you. Mm-hmm. OK. So to ImageNet-- OK, so in standard ImageNet processing, basically, we compare against a few-- so this is a little bit out of date at this point-- but against a few just sanity check baselines here.
So comparing against ResNet-50 and then, at the time, the best vision transformer model that was purely on ImageNet. And we're definitely in the ballpark. This isn't-- these aren't anywhere near state-of-the-art results. But this is an architecture that, again, is not using any 2D convolutions. And so the fact that it was able to do this well was, we found, very, very surprising at the time.
One of the quite cool things about this is that, because this architecture is not making any assumptions-- the architecture itself isn't making any assumptions about the spatial structure of the input images-- we can look at permuted ImageNet. And in the first version of this, what we do is, basically, we compute the features using the 2D position.
So the 2D position is fixed to a position-- to the pixel. And then we just shuffle them all. And so this is-- basically, we'll give you a sense of how dependent the baselines are on the input image structure. And so if we look at the transformer and perceiver, by construction, they don't change.
So this is not an empirical finding. This is a property of the models. But we find that ResNet-50 falls by about-- the performance falls by about half. And VIT, which, again, only has one layer where it's relying on the spatial structure, also has about a 15-point drop. And so this suggests that it's relying quite a lot on that very first one to give it some information about the structure.
We can push this a little bit by, instead of relying on 2D Fourier features, learning completely learned positional encodings. And this basically-- this is an architecture now-- this is a model that has absolutely no information about the input structure. And so shuffling them and learning them again is absolutely equivalent.
And we find that this architecture also can be pushed above 70%. And we've gotten slightly better numbers here. In general, this seems to work worse. So the 2D information is useful. But it's quite cool that you can get what would have been numbers comparable to state of the art about five or six years ago.
So this is quite cool. Sorry, I'm a little thick here. You're saying the difference between the last two rows is that the second-to-last row has a two-dimensional position embedding, and the last one has a one-dimensional position embedding, essentially. Is that right? So it's learned. So it's basically-- it'll be-- it's, I believe, a 256-dimensional vector that's learned.
But it doesn't-- it basically-- it means that the model itself has no information about the input spatial structure. So the 2D positional encodings that we're using end up having about 200-- it's 200-some features, depending on what you're looking at. But they're always-- they give you very detailed information about the 2D structure of the input, because they're based on a Fourier decomposition of the input space.
I see. OK. That makes sense. Thank you. Hi, Drew. Can I ask a question about frequency you use to generate those sensorial waves? Yeah. So like a couple of slides before. Yeah. Yeah. Yeah. Yeah. This slide. So basically, I do have taken some lectures in signal processing, and I know if I want to avoid aliasing, I need to sample with at least Nyquist frequency.
So I'm curious to know why do you use frequency starting from 1 to the Nyquist frequency, instead of starting from Nyquist frequency to some very high frequency? Oh, I see. So basically-- so the maximum frequency that's used is always Nyquist. So anything about Nyquist is going to be aliased, so you're not actually going to be able to resolve it, because it's in pixel space, right?
So we sample-- one is basically just giving you an oscillation that covers the entire image. Yeah. And so this is basically just a sample of the full range of non-aliased frequencies. Oh, OK. Cool. Thank you. OK. OK. Yeah. OK. So after the image results, we wanted to try it on other domains.
And in particular, we were interested in how this could be used to work on sort of multimodal domains, so ones combining various different types of input features. And one challenge or one sort of problem that you encounter in these sorts of spaces is that the data from different modalities end up having different features, and they always have different semantics.
So if you take the positional encoding plus the RGB for video, you end up with some number of channels. And then if you have audio, that corresponds-- the data may be paired, but it tends to have fewer features, and it only has a 1D positional encoding. So the way that we handle this is basically by learning modality-specific position encodings.
And so these are basically embeddings that are special and learned for each of the modalities. And what this does is basically tags-- ends up tagging the features that come from audio or video with some information that the network can learn that allows it to distinguish which one's which. But given these padded-- these sort of learned padded feature vectors, we then concatenate them all, and that's how we process multimodal data.
So basically, the input to the architecture still looks like just one big array. It's just that when constructing this, we know that some of those features, some of the rows in that array, come from video and some come from audio. But the model itself isn't given information about that other than what it learns.
Oh, we also have some questions. Yeah. You can go first. Can we turn? Oh. Yeah, sorry. I thought it was but yeah. Yeah, sure. If you can hear me, this is just a simple reason. I haven't studied a lot of transformer stuff formally, so I just didn't know what a positional embedding was.
Oh, so what a positional embedding? Yes. So basically, a positional embedding is-- it's a feature that says this-- so the simplest way to think about it is in text. So text, the input is 1D, so things live in some 1D sequence. And for each point there, you featurize where it's located in that sequence.
So the simplest thing to do would be if you have negative 1 to 1 is the full range. It just denotes actually where it's located in that sequence. But we typically will add-- we'll want to featurize this to have more dimensions than just a single one. And so the Fourier decomposition is one way to do this to give it privileged information about the high frequency structure.
But we can also just use the position to index onto some embedding array, which is how we do it when we're learning things. So basically, it's just a set of weights that are added to the feature for that point that give the network information about where it's located in the Groucher sequence.
You want to go next? Sorry, I had to find a mute button-- unmute button. OK, so I actually have two questions regarding the Fourier features. I think-- do you guys sample them uniformly, or are they-- do you learn these? Yeah, so basically, we sample them linearly. So basically, we take the full space, and we sample them linearly with whatever the budget is.
There are-- so in various settings, we have actually tried learning these. So you could actually initialize an array with them and then learn them. And that does help sometimes, actually. And you could potentially learn-- you could try a more sophisticated strategy on this too. OK, cool. My follow-up question is that basically, I feel like the selling point of your research is that you don't make any structural assumptions.
You can take any type of format. However, for the encoding, wouldn't the dimensionality-- so for example, if it's text, it's 1D, right? If it's an image, it will be 2D. And if it's a video, it will be 3D. You have more-- the positional encoding will have more points, right?
Wouldn't that inherently give away the nature of the input? Yeah, so it does. So I completely agree with this. You're totally right. The version of this where we have learned position encodings is the most pure from that point of view. So it's one that gives it basically no information about the ground truth spatial structure.
What it does give the model-- so when you do the learned position encoding, it will say that, for example, there is a correspondence between point k on image 1 and point k on image 2. So that's basically the least amount of information you can give it while still allowing it to figure out what the structural relationship between the input points is.
So this is the direction that we've been trying to push in. In general, giving the architecture access to ground truth structural information, like this lives on this point in 2D, is helpful. So there's a couple things here. There's from a practical point of view, if you want good results, you need to exploit these things, or it's helpful to exploit these things.
But we do want to move in the direction where we're relying on these things less. And so this is basically something we're actively looking into. OK, makes sense. Thank you. So I think has posted her question on the chat. I also see you have your hand raised. So if you want, you can give it a try.
If not, I'll read out the question. OK, I'll try. Just let me know if it's choppy. Yeah, so is it good right now? So far, so good. Yeah. Oh, good. OK, cool. So if you look at the perceiver diagram you had, it's a bunch of attention layers, right? Like cross-attention and self-attention.
And I think there's been this small trend in recent work in vision transformers to try to sort of replace the last few layers instead of having attention, like make them be convolutions to address this attention scaling problem, right, in a different manner. And so here, the perceiver architecture is trying to make self-attention less expensive.
And there, they're just trying to replace it, and they kind of just avoid the problem. And so I'm curious. And so I've seen papers both ways, like some that try to do things like the ones you cited, and then some that are trying to do this as well. And in my mind, everyone always has the good results and stuff.
So I'm curious if you think there's a reason to do one or the other, or if you think this alternative approach is also promising, or is there a reason research should go in one direction or the other? Yeah, so to my mind, the big trade-off is one between... So the vision literature, I think, has just exploded in terms of these sort of hybrids and people trying to find the exact right place on the Pareto curve for the trade-off of speed and performance.
But they're basically looking primarily on vision-specific problems. So something that the computer vision community itself typically doesn't regularize itself away from things that don't work on things that aren't vision. So you end up with things that are very, very efficient and very performant on vision problems. So I think from that point of view, it's an incredibly important line of work, and that's probably the right way of doing things.
What we're aiming for is the things that are as general as possible while still being performant. Got it. So this kind of thing is critical... Oh, sorry to cut you off. Go ahead. No, no. Please, go ahead. I was going to say, so this kind of thing is important...
Just to summarize. So you feel like it's important to focus on attention because that's kind of critical for NLP. Like you can't just sort of put in a convolution at the end and sort of fix the problem. But in vision, maybe you can and it's fine. Is that a right way of understanding it?
That's part of it. Vision and NLP aren't the only two domains. And so the thing that we're looking for are really basically... So the kinds of problems that we're interested in doing with this include things like event-based cameras, cell biology, sort of proteins, all of these sorts of things where we may or may not have the right convolutional inductive biases to even know how to build those sorts of things.
Got it. They end up being whole research programs, like the mesh-based convolution work. Oh, cool. Thank you. I also had one more question about the architecture. So I saw that... I'm sorry if you said this and I just missed it, but you had cross-attention and then like that tritium transformer and then cross-attention.
I'm curious what happens if you replace the self-attention in those layers with cross-attention. Does it affect your accuracy? Is that even feasible? Is that a valid question? Yeah. So the sort of thing that you could do is you could modify this to make it sort of hierarchical so that there are multiple stages of cross-attention.
We haven't gotten this working yet, but it doesn't mean it's not a good idea. So there might be a right way to do this that we haven't figured out right, but it's something we have tried a little bit. Oh, cool. Okay. Thank you so much. I appreciate it. Yeah, no problem.
Okay. Let me... We're running short on time, so maybe I'll skip ahead. Okay. So before we run out of too much time, I want to at least talk about the sort of the modifications to this architecture that we've made to make it work sort of even more generally. So one of the problems of the sort of the first architecture that we looked at here, the basic perceiver, is that it works basically for arbitrary inputs, but it's designed to work only on classification or regression tasks as an output.
And so basically we wanted to see if we could use the same cross-attention strategy for decoding and it turns out you can. It's something that works pretty well, just kind of out of the box. So the idea is that we have, if we have our cross-attention input and self-attention sort of to do the processing, we can introduce a set of additional queries.
And these are basically queries that give the semantics of each of the points that you're trying to decode. And we pass these as input to another cross-attention layer, which is configured in basically the opposite way that the encoder cross-attention is configured. So now the queries are going to be something that's potentially large and the keys and values are coming from this latent.
And so what this allows us to do basically is to keep all of the nice advantages of the original perceiver. So we have an encoder that scales linearly, we have a processor stage, this sort of latent self-attention that scales independently of the input size, and we now have a decoder that keeps the decoupling, but gives us linear scaling with respect to output size.
And so by doing this, we can now basically apply the same approach to basically dense output tasks. And so to give you a sense of how this works, just sort of intuitively, if we're doing auto encoding on this image of puppies, basically what we do is we encode process, and then to decode, we take a query that corresponds to each of the points, and then we pass it into this decoder.
So we can query one of the points, we get one pixel, query another one, we get another one, and all the way up till we get all 10,000 points. And that's how we can do reconstruction with this. And the cool thing about this is that it opens up a bunch of new applications.
And we can get different kinds of outputs just by changing how the queries work. So if we want to do something like multimodal auto encoding, where we have some of the outputs are videos, we use the same construction trick to get positions, to get queries that have the relevant semantics for each of the points that we're decoding.
And we can do this even though basically the sizes of these different data, so the number of points they have is quite diverse. So in the multimodal auto encoding experiments that we have in this paper, we do this for video, audio, and labels at the same time, so that all of them are just passed into their uniform network, and then decoded one by one in this way.
But we can also do mass language modeling now by conditioning on the position in a sequence. We can do multitask classification by having basically an index that gives which task you're querying from the network. And we can do things like optical flow by passing in input features as well as the positions.
And so I'm just going to just quickly skip to a couple of the different-- I can share these slides with you all afterwards to look through them. Some of these things are quite cool. But just quickly, I want to talk about language and then optical flow. So for language, basically what we wanted to do with this was to see if we could use this to replace tokenization.
And why might we care about getting rid of tokenization? So one, we use tokenization primarily because transformers scale poorly with sequence length. And tokenizing cuts sequence length by about a factor of four. But there are various problems that arise with this. And so why might we care about removing tokenizers?
So for one, tokenizers perform less well on rare words. So if you compare the sort of the byte-based decomposition, the UTF-8 encoding of an input sequence like this, you can see that there's basically a uniform allocation of points in memory to each of the input characters. The exception are diacritics, which end up splitting into two.
But if you look at the sentence piece tokenization, so it's learned that pepper is one token, but jalapeno gets split into five in this case. So this basically says the amount of capacity that you allocate depends on how rare the word is, which can lead to suboptimal encodings. They're also brittle to subtle perturbations.
A famous example of this is that if you've ever played around with GPT-3, you'll notice that the output can be quite sensitive to if you add a space or emit a space at the end. And that basically is because the space can end up being factorized into different parts of the tokenization.
There are other things that can happen there, too, but this is one cause of that. And finally, tokens don't transfer across languages. So if you wanted to have a model that without any tuning could be used on many different languages at the same time, tokenizers are a blocker for this.
So if we can get rid of them, it'll simplify the pipeline, it'll also make things less brittle, and then hopefully lead to more general models. So the way that we do mass language modeling is the same as the way that I showed in that schematic auto-encoding experiment. So we mask some fraction of our inputs, about 15%, is sort of the standard magic number.
We then decode at each of the positions that are masked, and we task the model with decoding whatever characters were masked at those locations. And then once we have this model, so this is what we do for pre-training, we can then fine tune it by replacing the decoder with a multitask decoder that takes in the tasks that we're using on the downstream evaluation setting, and training the model to reconstruct the logits on a per-task basis.
Okay, so to look at how this model performs, we basically first compare it to BERT base. So this is just a solid benchmark that we understand very well. And first, by looking at sort of matched, two models that have matched flops, we can see that Perceiver IO and BERT base work on par.
You see there's a different trade-off here. So to get the same number of flops, basically we make Perceiver IO deeper, and this ends up giving it more parameters, but on a per-flops basis, it ends up performing about the same. On the other hand, if we remove the tokenizer from BERT and keep the flops the same, we see that the number of parameters and the depth just drastically fall down, and this is because BERT scales quite poorly with sequence length, because it uses a normal transformer.
But if we use a Perceiver without the tokenization, we can see that we only get a slight reduction in the number of parameters at the flops count, but the performance performs almost exactly the same. So this means that the Perceiver in this setting is performing basically the same with and without the tokenization.
It's learning a different strategy, it's using different parameters, but it basically can be brought to the same performance. We can then scale this more by leaning into what happens in the tokenizer-free setting, and we see that we can get a moderate performance boost as well. I think it's also, in the language setting, it's also useful to look at what the attention maps that are learned, and what's being visualized here are basically, for each of the latents, for some subset of the latents, we're looking at where it's attending to in the input sequence, and some of these end up being local, so looking at specific points in the sentence.
Some of them are periodic, so they look at recurring points over the sequence, and some of them also look like they pick out syntactic features, which is quite nice. So they pick out basically exclamation points, or capital letters, or other punctuation that's quite useful and decodable right at the beginning of the sequence.
We can also basically use this exact same architecture on optical flow, and optical flow is basically an important classical problem in computer vision, where given a pair of frames in a video, we want to basically track all of the points, so figure out the motion from every point from one frame to the other.
And so optical flow is usually visualized using these sort of colorized images that are shown on the bottom, and what this gives you basically is a per pixel indication of the velocity at every single point. And so you can see that, so the blade that the character here is holding is moving to the right, whereas this creature behind her is sort of moving downwards.
So there are a couple of problems with optical flow that make it interesting to sort of approach. So one is it's a dense task, and it basically involves long range correspondences, but the standard training protocol, there's basically no large scale realistic training data, just because it's incredibly hard to sort of label all of the pixels in a real world scene and figure out where they go to.
So the typical way to do this is to train on some synthetic data, and then evaluate on more realistic scenes, and optical flow is also interesting, because it's basically the locus of some of the most complicated visual architectures in the literature. So the previous state of the art result here is this method called raft, which won the best paper award at DCCV last year.
And I'm just highlighting this to give you a sense of how much work people do into sort of hand engineering these architectures. So this is a very, very cleverly designed architecture, and basically it incorporates things like global correlation volumes that are explicitly computed at different offsets to basically allow the model to reason about how things at different scales are moving with respect to each other.
It also has local neighborhood gather operations, as well as update blocks to keep track of what's happening within each specific correlation block. And then finally, there's a flow-specific upsampling operators that were developed. So in contrast to this, we're basically-- we wanted to see how well Perceiver.io would do here.
And just to give you a sense of sort of what we were expecting coming into this, we thought, well, maybe-- so Perceiver.io was throwing a lot of the structure away, so we were hoping that we would get some good results, but it would probably overfit, and there's this sort of problem of the domain transfer that's happening here.
But on the other hand, self-attention seems to be a reasonable way to match this sort of correspondence thing. What we actually found was that just by doing the very, very simple preprocessing here, so extracting basically a patch around each pixel, and then using the standard Perceiver.io architecture, we were able to get state-of-the-art results here.
And so this is basically-- was validation of this general approach of trying to have general-purpose architectures that would transfer over. And so basically, with minimal tuning, we were able to get results that would beat both of the sort of compelling benchmarks on both of the Sintel evaluation methods, and to get comparable results on KITTI.
So these are the standard ones. And we can also sort of visualize what happens when we apply this on real-world data. So there's no ground truth here, so we can't really compare it, but it's still useful to sort of see how it moves around. And we can see that qualitatively, it's able to capture a lot of the fine structure, and to sort of get the right motion for the things that are very clearly moving in a specific direction.
We can also sort of-- it's also, I think, informative to look at what happens, how it manages to represent sort of small structure. Is this video playing? Yeah, we can see it. OK, cool. So the thing to look at here is the fine water droplets that are sort of flying through the air as that bird flies by.
And because we're decoding at every single output point, the architecture is able to represent those. So it's able to capture very, very fine-scale segmentation that would be difficult to capture if you had, for example, a convolutional upsampler here. OK, so I'm just going to-- oh, the light has gone off in this room.
I'm also curious if you also try other tasks like depth estimation, because Perseverio looks like it can do also quite well on that modalities. Yeah, so we haven't published anything, but some internal results suggest that it works just fine. There basically-- there don't seem to be-- one of the surprising things, the things that we were a little bit unsure about, was how much information was going to be contained in this latent.
Because basically, you're abstracting quite a lot, and it doesn't have any 2D structure intrinsically. But it does seem like this-- it seems to be able to represent things quite well. And these sorts of decoding mechanisms do seem to be able to do that. Got it. Great. So I'm just going to-- just in the interest of time, I'm going to skip ahead to the conclusion.
Drew, I had one question with respect to the metrics that you've shared for the optical flow, the number. So in the table, it was like Sintel, Final, Clean, and Kitty. Were these different data sets or different metrics? Same metric for different data sets, or these are three different metrics?
Yeah, so these are three different data sets. So Sintel Clean and Sintel Final are basically two-- they're two ways of doing the final rendering for Sintel. In all cases, these methods are trained just on the autoflow data set. So they're trained on this sort of general purpose kind of wacky synthetic motion data set.
And then we're evaluating on these different demands without fine-tuning. OK. Yeah, the flow has quite-- the data sets are quite small, so it's generally even problematic to fine-tune. Thank you. Mm-hmm. OK, so just to summarize-- What was the ground truth to find the endpoint error? Yeah, so the way this works is Sintel is a computer-- it's basically a relatively high-quality CGI movie that was basically open source.
And so they actually have the ground truth. So if you know the ground truth 3D state, you can compute the pixel correspondence from frame to frame. So that's what's used on Sintel. And then KITTI, they basically have a LiDAR sensor that's used to figure out the depth of all points.
And then they compute the correspondences. So the ground truth is actually the ground truth optical flow. But in general, it's hard to get dense optical flow. It's very expensive to collect it. Great. Thanks. Mm-hmm. OK. So basically, just to summarize, so the perceivers are attention-based architectures that scale linearly and work as drop-in replacements for transformers on a variety of settings.
They also seem to be able to achieve results that are comparable, at least in performance, with models that rely on 2D convolutions. But of course, there is a trade-off here. And so it's good to be very aware of this, of generality versus speed in specific domains. And so as was pointed out, in settings where you can use 2D convolutions, it's certainly good to have them in the loop.
It's basically a unified architecture that allows joint modeling of different modalities of different sizes. And basically, overall, it seems to be a quite flexible architecture that's able to produce a state-of-the-art or near state-of-the-art results on a variety of different domains. And in the two papers, we look at a number of other domains that I didn't talk about, including 3D point cloud modeling, replacing the transformer that's used in the StarCraft and the StarCraft behavioral cloning agent, and a couple of others.
So we have a lot of evidence that this general approach seems to work broadly. And there's a lot of things we still haven't tried. So we're very interested in pushing this and always open for suggestions and so forth. So we're relying on a large body of related work because we're drawing from a lot of different areas here.
So here are some highlights. And then I just want to thank my co-authors on this work. And of course, I'm happy to talk more. Thanks, Drew. Yeah, thanks a lot. Thanks, Drew. So one question I had is, so what do you think is the future of perceiver models? Do you think this is going to be used more in the transformer community to replace Conn and add to none of this stuff?
Yeah. So I think, broadly speaking, I think of perceivers now as sort of-- because we know how to adapt them pretty well to sort of domains where we don't have a great idea of the right way to structure an architecture, an inductive bias. So I think that's one of the really strong cases for it, so settings in which you don't really know what the right way to structure a problem is.
I also think these kinds of approaches can be used in conjunction with confnets for sort of things that are as domain agnostic as needed. But I think multimodal and new domains is really the-- that's where these are obvious choices. Got it. Also, what do you think are the current bottlenecks with this?
And if you don't mind, if you can disclose, what are you working on towards next with this stuff? So I can't talk about too many details about that. But a couple of domains-- so one, we don't really have a great handle on how to use them on sort of small-scale data, so data where you don't have the data to sort of recover the inductive bias.
So this is, I think, a really important area. The other thing that we haven't sort of talked about here, but you could probably imagine that we'd be thinking about, would be how to train on multiple modalities or sort of multiple things at the same time. So right now, all of these architectures are sort of trained in isolation.
But there are a lot of opportunities for sort of figuring out how to pose problems together and use a single architecture on all of them. Got it. Also, I'm not sure if you've tried, but can you also use this for tabular data stuff? Yeah. So effectively, the architecture treats any input data as tabular data.
So I think that's exactly the right way to think about it. Cool. Sounds good. Yeah. Thanks for the talk. I will open to make general questions from the students. So let's jump to the recording. Okay. Great. Thank you. you you