back to indexStanford CS25: V1 I Self Attention and Non-parametric transformers (NPTs)
Chapters
0:0
3:25 Fast Auto Repressive Decoding
14:1 Motivation and a Brief Summary
15:6 Parametric Prediction
16:2 Non-Parametric Transformer Architecture
23:40 Three Key Components to Npts
24:44 Dataset Matrix
25:8 Masking Matrix
26:50 Linear Embedding
31:3 Masking Based Training Objective
32:24 Stochastic Target Masking
37:20 Results
41:28 Poker Hand Data
42:32 Deep Kernel Learning
54:45 Semi-Synthetic Dataset
55:51 Attention Maps
00:00:00.000 |
>> Thanks so much. Great to be here and happy Halloween, belated Halloween everyone. So 00:00:10.880 |
I think the talk is going to be split into two sections. So I'll start by spending like 00:00:16.000 |
10 minutes, 15 minutes chatting about transformers in general. But I'm assuming most of you are 00:00:22.120 |
familiar with them and we can move on to MPTs, which Yannick and Neil will be presenting. 00:00:28.960 |
So let's see. I'm going to like try to fly through the transformer overview and maybe 00:00:36.520 |
spend a little bit extra time on like the history of transformers and maybe just tell 00:00:42.200 |
the story a little bit. I think that might be more interesting. So just in terms of the 00:00:49.640 |
transformer architecture, the two kinds of things that it introduced for the first time 00:00:54.520 |
were multi-head attention and self-attention. And then it combined those with fast autoregressive 00:01:00.880 |
decoding. So before the transformer, pretty much everyone was using LSTMs and LSTMs with 00:01:07.560 |
attention. But I'll try to get into the difference of self-attention, multi-head attention. So 00:01:16.520 |
originally you would have two sequences and you would have a attention module, which would 00:01:23.160 |
attend from the source to the target. And so each token or each word in the source sequence 00:01:29.120 |
would get associated with, you know, a soft approximation of one element in the target 00:01:34.800 |
sequence. And so you'd end up with something like this. But with self-attention, we did 00:01:41.680 |
away with the two separate sequences. We make them both the same. And so you're relating 00:01:47.200 |
each element within the sequence to another element in the sequence. And so the idea here 00:01:56.320 |
is that you're learning a relationship of the words within a sentence to the other words. 00:02:01.080 |
So you can imagine something like an adjective, which is being applied to a noun. And so you 00:02:06.240 |
want to relate that adjective, like the blue ball, you want to relate blue as referring 00:02:11.120 |
to ball. So we're learning patterns within the sequence, intra-sequence patterns. So 00:02:19.480 |
sorry, I gave this talk in Kenya, so I am using Kiswahili here. But with multi-head 00:02:27.160 |
attention, the idea is you have like each word represented by an embedding, which is 00:02:32.200 |
in the depth dimension here. And then you have your sentence of words. You split that 00:02:36.720 |
up into a bunch of different groups. So here I've chopped it depth-wise into four groups. 00:02:43.080 |
You apply attention to each one of these groups independently. And then when you get the result 00:02:49.400 |
back, you can catenate them together and you're back to your model dimension representation. 00:02:56.760 |
So what this lets you do is if each attention head, like each attention head can now focus 00:03:02.600 |
on learning one pattern. So maybe attention head one is learning the relationship of adjectives 00:03:09.240 |
to nouns. And the second attention head can learn something different. So this lets us 00:03:15.600 |
learn like a hierarchy or a list of different relationships. Okay. So that was self-attention. 00:03:24.560 |
The other piece is fast autoregressive decoding. And do I really want to go into this? Okay, 00:03:33.740 |
I will. So the important thing about this is that if you're doing normal autoregressive 00:03:40.100 |
decoding, what you do is you generate your first token. And now conditioned on that first 00:03:44.000 |
token, you generate the second and conditioned on the first two, you generate the third and 00:03:47.760 |
so on and so forth. But that's super slow, right? Like it's a loop applying this thing 00:03:52.120 |
again and again. And so what we can do instead is we make an assumption in the code that 00:03:59.200 |
our model always generates the right thing. And then we generate a prediction, only one 00:04:07.360 |
token ahead. So you have your outputs, which are Y, you have your targets, which are Y 00:04:26.760 |
hat. And what you do is you feed in those gold targets so that you don't need to actually 00:04:34.480 |
do this loop. So instead of assuming, instead of having to generate the first token, feed 00:04:39.360 |
it back into your architecture, generate a second token, you feed in the entire target 00:04:44.360 |
sequence and you just pretend that you generated all the right tokens up to position K. And 00:04:50.520 |
then you predict the K plus first and you compute your loss on that. So in reality, 00:04:56.440 |
your model might've generated at the beginning of training junk, but you're getting a loss 00:05:01.880 |
as if your model had seen all the correct tokens and is now just predicting the next 00:05:07.240 |
one. This is a little bit subtle, but it's hugely impactful for training speed because 00:05:13.280 |
all of this can be done on parallel in parallel. And so it's actually what make transformers 00:05:18.280 |
so scalable. Okay. So in order to do this successfully, if you were just feeding in 00:05:25.780 |
all of the, all of the correct tokens, uh, naively, what would happen is your model would 00:05:32.080 |
just be able to, uh, look forward in time and cheat. So you've, you've put in all of 00:05:39.360 |
your true targets, the things that you're trying to get your model to predict. And so 00:05:44.200 |
if that's where you're computing your loss on, if you could just look forward in time 00:05:47.240 |
and say, okay, I'm just going to grab that. And we'd get zero error trivially, right? 00:05:51.360 |
Cause you you've given it all the right answers. So what we have to do inside the architecture 00:05:56.100 |
is we need to actually prevent, uh, the attention mechanism from being able to look at tokens 00:06:02.100 |
that it shouldn't have been able to see already. Um, so the way that this looks is you create 00:06:09.460 |
a mask on your attention. Um, and so, sorry, this is the example of like doing a trivial 00:06:16.780 |
attention. If you don't mask your attention properly, um, what it's going to do is it's 00:06:22.220 |
just going to look into the future, just grab the token that you're telling it to predict 00:06:27.460 |
and copy it over. And so it learned something trivial, something that doesn't actually generalize. 00:06:31.540 |
And so what we do is we actually prevent it from attending to those tokens. We prevent 00:06:36.100 |
it from attending into the future for each position in the source sequence. We block 00:06:42.460 |
out, uh, everything that it shouldn't be able to see everything into the future. And then 00:06:47.700 |
as we move down, we gradually unblock so it can start to see into the past. Um, so those 00:06:54.340 |
are kind of like the two, the three major components of transformers, um, the self-attention, 00:07:03.380 |
the multi-head attention, and then deploying this gold targets, decoding is fast, autoregressive 00:07:11.260 |
attention. In terms of the story, which might be a little bit more interesting. Um, so transformers, 00:07:19.660 |
I was an intern, uh, with Lukasz Kaiser, uh, at Google back in 2017. Um, and I was sitting 00:07:26.700 |
next to Gnome, uh, and Ashish was like a couple of seats down from us. Um, and what's really 00:07:33.900 |
incredible is that essentially this entire project came together in like three months 00:07:39.260 |
and it was done. So I showed up at Google, uh, no, I'm had been working on autoregressive 00:07:46.180 |
models. Um, same thing with like Ashish and Yakov and Nikki. And, um, they just been kind 00:07:52.420 |
of like exploring the space, figuring it out. Uh, and Lukasz and I, at the same time, we'd 00:07:57.620 |
been working on this framework called tensor to tensor, um, which was like explicitly made 00:08:03.340 |
for, uh, multimodal learning, autoregressive learning. Um, and Lukasz is kind of like a 00:08:10.900 |
master of keeping track of everything that's happening in the field and adopting it. And 00:08:17.740 |
so within tensor to tensor, there were like these, there was like these kind of emerging 00:08:24.820 |
little things that maybe one paper had been written about, uh, and people were interested 00:08:29.020 |
in like layer norm, um, but it hadn't actually taken off yet. Uh, the warmup in the learning 00:08:34.660 |
rate schedule, um, all of these little pieces were just default, like on by default. Um, 00:08:43.420 |
and so when gnome and Ashish and Nikki and Yakov came over and adopted tensor to tensor, 00:08:50.660 |
all of these things were just on by default. Um, and so a lot of people, when they look 00:08:56.460 |
at the, the transformer paper, it just seems like there's so many like arbitrary little 00:09:00.940 |
things thrown in. And when now, like in present day, these have become standard for like a 00:09:06.180 |
lot of different training algorithms, like the learning rate warmup, um, the way that 00:09:11.860 |
we did initialization, uh, all of these pieces have just become the norm, but back then they 00:09:16.980 |
had like kind of just been introduced. Um, and so we, uh, we spent a lot of time running 00:09:24.940 |
ablations trying to figure out like, which were the necessary pieces and what made it 00:09:29.420 |
And if any of you have actually tried training transformers and tried like pulling out the 00:09:34.580 |
learning rate warmup, um, or changing any of these little pieces, you'll see that it 00:09:39.860 |
really does break down optimization. Like it actually really does, um, hurt performance 00:09:46.340 |
for instance, like removing the layer norms, that type of thing. Um, so I always thought 00:09:52.180 |
it was kind of funny how all of these random additions that Lukasz had just like thrown 00:09:57.860 |
in, cause he was playing around with them turned out to be crucial. Uh, and they were 00:10:02.260 |
just on by default. Um, so anyway, it was like three months. I remember, um, it all 00:10:11.660 |
really started coming together towards the end, like just before the NeurIPS deadline. 00:10:16.740 |
Um, and I can still remember sitting in the micro kitchen and Ashish telling me like, 00:10:21.500 |
that's just like, I was a little intern, uh, telling me like, this is going to be such 00:10:25.620 |
a big deal. And I was like, yeah, sure. Okay. I have no idea what's happening. I just showed 00:10:30.620 |
up. Um, and he was like, no dude, like this, this actually matters. Like, you know, we, 00:10:36.580 |
we bumped up blue three points and I was like, sick, great anyway. Um, and then I can remember 00:10:45.620 |
on the night of the deadline for NeurIPS, it was like 2am, uh, Ashish, Ashish was the 00:10:54.340 |
only one left at the office and we were still like moving around figures and like adjusting 00:10:59.620 |
things. And then, um, I went to bed, but Ashish stayed up and I slept in like this tiny little 00:11:04.740 |
phone booth. Um, and then for the other paper that I was submitting, I forgot to press submit, 00:11:11.580 |
but luckily like some lady opened the door to the phone booth and hit me in the head 00:11:16.740 |
while I was sleeping in the morning. And just before the deadline, I got the paper in. And 00:11:22.100 |
so I owe it to that lady, uh, for, for submitting to NeurIPS that year. But yeah, anyway, the, 00:11:29.740 |
I think the crazy thing about transformers was that it all came together in like three 00:11:32.900 |
months. Like most of the ideas happened in that span. And it was just like this sprint 00:11:37.580 |
towards the NeurIPS deadline. Um, I think a lot of the other members on the team, Jakob, 00:11:43.220 |
Lukasz, Shalom, Ashish, they knew how important it was, but for me, I was like, I don't know. 00:11:48.500 |
I really did not appreciate, uh, the impact. Um, but in retrospect, it's been amazing how 00:11:55.540 |
the community has kind of like come together and adopted it. Um, and I think most of that 00:12:01.860 |
can be ascribed to the ease of optimization. It seems like very robust to hyper parameter 00:12:07.980 |
choices. So you don't need to like tune the hell out of it, spend a lot of time tweaking 00:12:12.220 |
little things. Um, and the other side is that it's like super tailored to the accelerators 00:12:17.060 |
that we run on. Um, so it's like very paralyzable, um, hyper-efficient. Um, and so it lends itself 00:12:26.220 |
to that kind of scaling law effort that's, that's really taken off in, in popularity. 00:12:34.620 |
That's such a cool story. Oh my God. We're both excited. So we just unmuted at the same 00:12:43.260 |
Yeah. So that's, that's my section. If there's any questions, happy to answer them. Otherwise, 00:12:50.700 |
uh, let's get into NPTs. NPTs are like, I think, um, there's such a nice next level 00:12:58.540 |
abstraction of the architecture. So you've probably seen the trend of, um, transformers 00:13:04.140 |
getting applied to, to new domains, um, first into vision and video and audio. Um, but this 00:13:13.100 |
is kind of like cutting back to an even more abstract level. Um, like I think tabular data. 00:13:19.580 |
Yeah. I don't know. I'll let Yannick and Neil take over from here, but, uh, I think NPTs 00:13:28.900 |
Thanks for the introduction. Um, thanks all for the invitation. We're very happy to be 00:13:33.740 |
here and Neil and I are now going to tell you about our, um, self-attention between 00:13:40.420 |
data points paper, where we introduce, introduce the non-parametric transformer architecture. 00:13:45.380 |
Um, we'll start with a little bit of motivation, move on to explaining the architecture in 00:13:50.260 |
detail, then show you the experiments. This is more or less a step through of the paper, 00:13:54.580 |
but maybe, you know, with a little bit of extra insight here and there, all right, as 00:14:02.220 |
promised motivation, I have a brief summary. So, um, we'll start by thinking about something 00:14:08.060 |
that we don't often think about. And that is that from CNNs to transformers, most of 00:14:13.340 |
supervised deep learning relies on parametric prediction. So what that means is that we 00:14:19.220 |
have some self-training data and we want to learn to predict the outcomes Y from inputs 00:14:25.020 |
X. And for this, we set up some model with tunable parameters, theta, then we optimize 00:14:31.020 |
these parameters to maximize predictive likelihoods on a training set, or, you know, equivalently, 00:14:36.660 |
we minimize some loss. And then after training, we have this optimized set of parameters theta 00:14:43.100 |
and then at test time, we just put these into the model and use these parameters to predict 00:14:48.340 |
on novel test data. And so crucially here, our prediction at test time only depends on 00:14:54.620 |
these parameters, right? It's parametric. Also, that means that given these parameters, 00:15:00.340 |
the prediction is entirely independent of the training data. And so why would we want 00:15:06.380 |
to do parametric prediction? Well, it's really convenient because all that we've learned 00:15:11.260 |
from the training data can be summarized in the parameters. And so prediction time, we 00:15:16.300 |
only need these final parameters and we do not need to store the training data, which 00:15:19.820 |
might be really, really large. On the other hand, we usually have models that already 00:15:25.380 |
predict for a bunch of data in parallel, right? Think of mini-batching in modern architectures. 00:15:30.920 |
And actually things like batch norm already make these data interact. And so our thinking 00:15:37.100 |
here was that if we've got all of this data in parallel anyways, there's no reason not 00:15:41.780 |
to make use of it. And so more, a bit grander, we kind of challenged parametric prediction 00:15:48.460 |
as the dominant paradigm in deep learning. And so we want to give models the additional 00:15:53.780 |
flexibility of using the training data directly when making predictions. And so a bit more 00:16:00.340 |
concretely, we introduced the non-parametric transformer architecture. And this is going 00:16:06.220 |
to be a general deep learning architecture, meaning we can apply it to a variety of scenarios. 00:16:12.420 |
NPTs will take the entire data set as input whenever possible. And NPTs then crucially 00:16:19.540 |
learn to predict from interactions between data points. And to achieve this, we use multi-head 00:16:26.140 |
self-attention that, as Aiden has introduced us to, has just really established itself 00:16:32.260 |
as a general purpose layer for reasoning. We also take another thing from the NLP community 00:16:39.500 |
and we use a stochastic masking mechanism. And we use that to tell NPTs where to predict 00:16:45.300 |
and also to regularize the learning task of it. And lastly, of course, we hope to convince 00:16:50.980 |
you that this ends up working really, really well, and that this kind of simple idea of 00:16:55.700 |
learning to predict from the other data points of the input, from the training points of 00:17:00.380 |
the input, ends up working rather well. And so very briefly summarizing what we've heard 00:17:08.980 |
already. A, we input into NPTs the entire data set. And then B, let's say for the purpose 00:17:15.780 |
of this slide here, we only care about predicting the orange question mark in that green row. 00:17:22.280 |
And then we can compare NPTs to parametric prediction, right? So a classical deep learning 00:17:27.780 |
model would predict this target value only from the features of that single green input. 00:17:33.460 |
To do that, it would use the parameters theta, those would depend on whatever training data 00:17:37.620 |
we've seen and so on. But at test time, we only look at that single row for which we 00:17:42.580 |
care about the prediction. In contrast, NPTs predict an explicit dependence on all samples 00:17:49.060 |
in the input. They can look beyond that single green datum of interest and look at all other 00:17:54.200 |
samples that are there and consider their values for prediction. So this presents an 00:17:58.440 |
entirely different way of thinking about how we learn predictive mechanisms. And somebody 00:18:04.540 |
on Twitter called this KNN 2.0, which we would have not written in the paper, but maybe is 00:18:09.760 |
kind of a nice way of thinking about how NPTs can learn to predict. So of course, nonparametric 00:18:17.240 |
models are a thing already. We didn't invent them at all. And I define them here as prediction 00:18:23.720 |
in explicit dependence on the training data, which is certainly what NPTs do. Classical 00:18:29.200 |
examples like Gaussian processes, k-nearest neighbor, kernel methods, those might be familiar 00:18:34.320 |
to you. And there exists also efforts to combine the benefits of nonparametrics and representation 00:18:40.280 |
learning in a similar fashion to how we did it in NPTs. However, these approaches are 00:18:47.220 |
usually limited in some sense in comparison to NPTs, right? They're often kind of motivated 00:18:51.960 |
from the statistics community a bit more. They often require more finicky approximative 00:18:55.920 |
inference schemes, are limited in the interactions they can learn, or things like that. And so 00:19:01.560 |
we really think NPTs present maybe the most versatile and most widely applicable of these 00:19:07.760 |
nonparametric prediction approaches. But that's something we explicitly wanted to have. We 00:19:12.240 |
wanted to have something that's really easy to use, plug and play, works in a ton of scenarios, 00:19:16.940 |
and works really well. And so with that, I hand over to Neil, who's going to tell you 00:19:23.380 |
about the nonparametric transformer architecture in all of its details. 00:19:27.800 |
Yeah, we also have one question. Go ahead. Yes, yes. 00:19:32.360 |
Hi, Jendrik. Hi. So could you please go back to the previous slide? 00:19:40.680 |
The very previous slide? Yeah, yes. This slide, yeah. So in terms of 00:19:45.440 |
the problem definition, I think it's quite similar to some meta-learning problem, which 00:19:51.160 |
basically there is a mapping from a data point on the data sets to some predictions. So could 00:19:58.480 |
you please suggest any differences between your problem setting and the meta-learning 00:20:05.040 |
problem settings? I can't really figure out any differences between these two problems. 00:20:11.760 |
Well, I think it really depends on the framing that you want to have. So I would say meta-learning 00:20:19.240 |
would be when I try to predict over multiple data sets. So when I try to learn some sort 00:20:26.160 |
of prediction model, or I can just plug in a different data set, and it will almost automatically 00:20:32.080 |
give me new predictions on this different data distribution. But that's not what we 00:20:36.080 |
do at all. We're training a single model for a fixed data set. And so this is why I wouldn't 00:20:41.720 |
really call that meta-learning, because we're trying to predict on the same tasks that all 00:20:47.160 |
the supervised deep learning, or any supervised machine learning method is trying to predict 00:20:53.400 |
well on. Okay, so you mean you use kind of same test 00:20:59.920 |
set to test your training model, right? So basically, in meta-learning, we're going to 00:21:16.360 |
test on different kind of meta-test sets. But in your case, you just want to use a test 00:21:22.800 |
set, which is similar to the distribution of your training set, right? 00:21:29.600 |
Yeah, absolutely. So we explore data set distribution shift a bit. I think it's a really interesting 00:21:34.520 |
scenario. I think meta-learning different data sets is also an interesting scenario, 00:21:40.000 |
right, when you have this model where you could just push in different data sets. But 00:21:44.160 |
for the scope of this paper, it's very much training set, test set, they come from the 00:21:49.040 |
same distribution, and we're just trying to do supervised learning in a standard setting. 00:21:55.720 |
I see, cool. Thank you. Thank you for the question. 00:21:59.720 |
Yeah, and I would chime in a couple of additional things, I guess. So at least from what I understand 00:22:06.480 |
from the problem definition of meta-learning, I think the aim is more perhaps being able 00:22:11.960 |
to perform well on a new data set with a relatively small number of additional gradient steps 00:22:16.280 |
on that data set. So I think there are some interesting ways that you could actually consider 00:22:20.880 |
applying NPTs in a meta-learning type setting. And so we'll get into this a little bit more, 00:22:26.080 |
but for example, there might be ways to essentially add in a new data set. So let's suppose we've 00:22:32.280 |
trained on a bunch of different data sets. We now add in a new data set. We can perhaps 00:22:36.480 |
do some sorts of kind of zero shot meta-learning, basically, where there's no need for additional 00:22:43.800 |
gradient steps because we're basically predicting kind of similar to how you might do prompting 00:22:48.520 |
nowadays in NLP literature. Anyways, yeah, I think we'll get into some more details. 00:22:54.800 |
Just to chime in on that, I don't think that every meta-learning algorithm-- I think the 00:23:04.640 |
one that you're describing right now are optimization-based, but there are also black box ones. You don't 00:23:09.800 |
need to further-- I think the main difference seems to be that there is one task versus 00:23:15.960 |
multiple tasks for meta-learning. Yeah, I think so too. I think the main framing question 00:23:23.640 |
is whether or not there's multiple data sets. Cool. OK, awesome. If there's no other questions, 00:23:33.360 |
I'll dive a bit more into the architecture. Awesome. So there's three key components to 00:23:42.920 |
NPTs. I'm going to first state them at a high level, and then we'll go through each of them 00:23:46.480 |
in more detail. So first of all, we take the entire data set, all data points as input. 00:23:53.440 |
So for example, at test time, the model is going to take as input both training and test 00:23:57.880 |
data, and we approximate this with mini-batches for large data. We apply self-attention between 00:24:05.560 |
data points. So for example, at test time, we model relationships amongst training points, 00:24:10.960 |
amongst test points, and between the two sets. And then finally, we have this masking-based 00:24:17.240 |
training objective. It's a BERT-like stochastic masking, and the key point is that we actually 00:24:22.680 |
use it on both features as well as on training targets. And we'll get into why that leads 00:24:27.660 |
to an interesting predictive mechanism later. 00:24:32.040 |
So to start with this idea of data sets as input, there's two things that compose the 00:24:37.140 |
input to NPT. It's a full data set in the form of a matrix X and a masking matrix M. 00:24:43.880 |
And so Yannick has described this data set matrix a little bit. We basically have data 00:24:48.200 |
points as rows. The columns are attributes, and each attribute shares some kind of semantic 00:24:53.160 |
meaning among all of its data points. So say, for example, you're just doing single-target 00:24:58.260 |
classification or regression. The last column would be the target, and the rest of the matrix 00:25:03.560 |
would be input features. So for example, the pixels of an image. 00:25:08.520 |
We also have a masking matrix. So let's say we're thinking about mass language modeling. 00:25:13.720 |
The mass tokens will just tell us where we're going to conceal words and where we're going 00:25:17.760 |
to back-propagate a loss. We do a similar type of thing here, where we use this binary 00:25:22.340 |
mass matrix to specify which entries are mass. And the goal is to predict mass values from 00:25:29.920 |
observed values. I see that there was a question about handling inputs with different lengths. 00:25:38.440 |
In the data sets we've considered, we'll get into it in the results section, but it's mostly 00:25:43.360 |
been sort of tabular and image data, where the lengths for each of the data points is 00:25:47.840 |
the same, but it would work just like padding. That would be a reasonable way to go about 00:25:51.960 |
that. There's also kind of an interesting-- yeah, go for it, Yannick. 00:25:56.500 |
Just to add to that, I'm not sure if length refers to columns or to rows, right? Rows, 00:26:02.720 |
we don't care about how many rows. Length, padding or something would be an option. 00:26:07.520 |
Yeah, my question was about column. Exactly, that makes sense. Awesome, thanks. 00:26:13.160 |
Yeah, I mean, that goes along with the whole meta-learning discussion. I think if we wanted 00:26:17.600 |
to adapt to data sets that have a different number of data points per data set, we can 00:26:22.360 |
take advantage of the fact that self-attention is kind of OK with that. Cool. 00:26:31.560 |
So continuing on, left to discuss here is basically how we do the embedding. So to put 00:26:39.440 |
this more explicitly, we have this data matrix that has n data points. It's called x, and 00:26:45.280 |
it also has d attributes. And we have this binary mass matrix m. We're going to stack 00:26:49.160 |
them, and then we're going to do a linear embedding. 00:26:52.800 |
So specifically, we're doing the same linear embedding independently for each data point. 00:26:57.720 |
We're learning a different embedding for each attribute. We have a positional encoding on 00:27:02.320 |
the index of the attributes because we don't really care about, say, being equivariant 00:27:06.000 |
over the columns. If it's tabular data, you, of course, want to treat all these kind of 00:27:09.280 |
heterogeneous columns differently. And then finally, we have an encoding on the type of 00:27:14.000 |
column, so whether or not it's continuous or categorical. 00:27:17.800 |
And that ends up giving us this input data set representation that is dimensions n by 00:27:22.760 |
d by e. The second key component of NPTs is attention between data points. So to do that, 00:27:33.360 |
we first take this representation we have and flatten to an n by d times e representation. 00:27:40.040 |
So basically, we're treating each of these d times e size rows as if it's a token representation. 00:27:46.120 |
We're actually going to just accomplish this operation using multi-head self-attention. 00:27:50.720 |
We've reviewed this a lot, but the nice thing is that we know from language modeling, if 00:27:54.600 |
we stack this multiple times, we can model these higher order dependencies. And here 00:27:58.040 |
they're between data points, and that's really the key draw of this architecture. 00:28:02.480 |
There's been other kind of instances of people using attention for similar sorts of things. 00:28:06.960 |
So for example, like attentive neural processes. A lot of times, they've sort of used just 00:28:12.200 |
a single layer as kind of a representational lookup. And we believe that this actually 00:28:17.320 |
ends up limiting expressivity, and that by stacking this many times, you can learn more 00:28:20.860 |
complex relationships between the data points. 00:28:23.680 |
Anil, we also have some questions. So you can go ahead first. 00:28:30.280 |
Oh, cool. Thanks. I have a question about how you guys do the embedding. Is there always 00:28:35.200 |
part of these convolutional filters or linear layers? What is the type of embedding that 00:28:42.520 |
Yeah, so I'm attempting to go back to the slide. I think it's not very happy with me 00:28:47.560 |
right now. But yeah, so for tabular data, we did just linear embeddings, actually. So 00:28:55.960 |
we could get into, I guess, details of featurization for categorical and continuous, but it's literally 00:29:00.120 |
like, say, for categorical, you do a one-hot encoding, and then you learn this embedding 00:29:04.440 |
that is specific to that attribute. And then for numerical, I believe we were just standard 00:29:09.480 |
normalizing. For the image data, we did end up using a ResNet-18 encoder for CIFAR-10. 00:29:17.480 |
However, I think that-- I mean, we'll discuss that a bit later in results, but that embedding 00:29:23.160 |
is a bit arbitrary. You can sort of do whatever. The key part of the architecture is the attention 00:29:28.120 |
between data points. So in terms of how you actually want to embed each attribute, it's 00:29:46.040 |
OK, awesome. Cool. So here we have attention between data points done. So we can also do 00:29:54.080 |
this attention between attributes. So we reshape back to this n by d by e representation, and 00:30:01.640 |
then we can just apply self-attention independently to each row, in other words, to a single data 00:30:06.760 |
point. And the intuition for why we would kind of do this nested-type idea where we 00:30:11.320 |
switch between attention between data points and attention between attributes is just we're 00:30:15.200 |
trying to learn better per-data-point representations for the between-data-point interactions. This 00:30:21.440 |
is literally just normal self-attention, as you'd see in language modeling or image classification. 00:30:31.080 |
And finally, we just rinse and repeat. So what are we actually getting out of this? 00:30:36.720 |
To summarize, we're learning higher-order relationships between data points. We're learning 00:30:40.920 |
transformations of individual data points. And then importantly, NPT is equivariant to 00:30:46.640 |
a permutation of the data points. This basically just reflects the intuition that the learned 00:30:52.080 |
relationships between the data points should not depend on the ordering in which you receive 00:31:01.200 |
The third key component of NPTs is a masking-based training objective. So recall that what we're 00:31:08.000 |
trying to do is we're trying to predict missing entries from observed entries. And those masked 00:31:13.000 |
values can be both features or targets. So again, the classic use, say, in masked language 00:31:18.320 |
modeling is to do self-supervised learning on a sequence of tokens, which you could think 00:31:22.320 |
of as just having features in our setting. Ours is a bit different in that we do stochastic 00:31:28.960 |
feature masking to mask feature values with a probability p sub feature. And then we also 00:31:33.320 |
do this masking of training targets with this probability p sub target. 00:31:39.540 |
So if we write out the training objective, we are just taking a weighted sum of the negative 00:31:45.280 |
log likelihood loss from targets as well as from features. And of course, at test time, 00:31:50.000 |
we're only going to mask and compute a loss over the targets of test points. 00:31:55.220 |
So to break this down a bit further and point out some of the cool parts of it here, the 00:32:00.760 |
thing that's highlighted right now on the far right is the term relating to the features. 00:32:05.280 |
It's the feature masking. Basically, we find that this has a nice regularizing effect. 00:32:11.760 |
More or less, the model can now predict anywhere and makes the task a bit harder and introduces 00:32:15.240 |
some more supervision. And we found in an ablation for the tabular data sets that it 00:32:19.080 |
helped for eight of 10 of those. And then there's this other term, which is kind of 00:32:23.880 |
interesting. It's the stochastic target masking. And the idea is that you're actually going 00:32:30.060 |
to have some training targets unmasked to the model at input at training time, which 00:32:36.920 |
means that the NPT can learn to predict the mask targets of certain training data points 00:32:42.400 |
using the targets of other training data points as well as all of the training features. 00:32:47.680 |
And so that means you don't actually need to memorize a mapping between training inputs 00:32:52.240 |
and outputs in your parameters. You can instead devote the representational capacity of the 00:32:56.240 |
model to learn functions that use other training features and targets as input. 00:33:01.600 |
So this is kind of getting into the idea of this sort of like learn KNN idea. Obviously, 00:33:06.380 |
we can learn more complex relational lookups and those sorts of things from this. But you 00:33:12.320 |
can imagine one such case being we have a bunch of test data points coming in. We're 00:33:18.160 |
going to look at their features and use that to assign them to clusters of training data 00:33:22.840 |
points. And then our prediction for those points is just going to be an interpolation 00:33:27.220 |
of the training targets in that respective cluster. That's like an example of something 00:33:31.400 |
that this mechanism lets NPTs learn. All right. So if there's any questions, we can take them 00:33:40.640 |
now. Otherwise, I'm happy to take them in the discussion or something. All right. So 00:33:54.000 |
I'm curious when you're using the entire data set, does that limit the type of data sets 00:34:03.120 |
Yeah. So in practice, we do random mini batching as an approximation. So the idea is just to 00:34:10.600 |
say, if you have a reasonably large mini batch, you're going to benefit a bit from still having 00:34:15.560 |
kind of this lookup ability. Because if you have a reasonable number of classes, probably 00:34:19.760 |
you're going to be able to learn some interesting mappings based on features and targets amongst 00:34:25.200 |
those classes. We found in practice that-- and we'll get into this a little bit. But 00:34:30.720 |
we do actually indeed learn to use relationships between data points on prediction for data 00:34:38.160 |
points where we're doing mini batching. And we also didn't necessarily find that you need 00:34:42.360 |
like a ludicrously large batch size for this to be a thing. But I do think it's just-- 00:34:48.160 |
this is, in general, an important point. And it's one that points us towards looking into, 00:34:52.240 |
say, sparse transformers literature for trying to expand to some larger data sets without 00:35:02.800 |
If I can add a number to that, we can, without mini batching, accommodate data sets of around 00:35:10.120 |
like 8,000 points or so. So that already accounts for a fair proportion, I would say, of the 00:35:17.720 |
tabular data sets out there. But we also do data sets with 11 million points where, obviously, 00:35:22.760 |
we then resort to mini batching. So it's very good to have an idea of the sizes that we're 00:35:29.440 |
I'm curious on that. I mean, it's pretty exciting. I feel like you don't normally hear about 00:35:35.000 |
transformers being applied to data sets of size 8,000. I'm curious-- and we can talk 00:35:42.360 |
about this sort of later once we've covered the other material-- if you found that sample 00:35:45.640 |
efficiency is one of the key gains here, or just experience working on small data sets 00:35:50.640 |
of transformers generally. And yeah. But I'm happy to punt the answer to that until after 00:35:57.760 |
Yeah, I think that'd be really nice to talk about a bit. And it was something that, in 00:36:02.560 |
general, I guess I'd say was surprising to us in terms of how robust NPTs were on small 00:36:07.600 |
data sets, and how we surprisingly didn't have to tune a terrible number of parameters. 00:36:17.560 |
Awesome. So to get into the experiments, we focused a lot on tabular data because it's 00:36:26.440 |
a very general setting. And it's also notoriously challenging for deep learning. So we know 00:36:32.160 |
tree-based boosting methods, stuff like XGBoost, is very dominant. And this is also a very 00:36:37.920 |
relevant domain to, I think, people in industry and that sort of thing. So we were excited 00:36:42.260 |
about the idea of trying to do better on this. 00:36:44.940 |
So we chose a broad selection of data sets varying across a few different dimensions. 00:36:51.120 |
As we mentioned, on the order of hundreds to tens of millions of instances, broad range 00:36:57.160 |
in the number of features, in the composition of features in terms of being categorical 00:37:00.880 |
or continuous, various types of tasks, binary and multi-class classification, as well as 00:37:06.000 |
regression. And like I said, the baselines were kind of the usual suspects for tabular 00:37:09.240 |
data-- XGBoost, CatBoost, LightGBM, TuneMLPs, and TabNet, which is a transformer architecture 00:37:19.360 |
So to get into the results, here I'm showing the average rank for the various subtasks. 00:37:25.720 |
We did well in terms of rank-wise performance against methods like CatBoost and XGBoost, 00:37:30.600 |
which are designed specifically for tabular data. And in fact, we find that NPT is the 00:37:35.260 |
top performer on four of the 10 of these data sets. 00:37:38.920 |
On image data, I mentioned that we used a CNN encoder. And with that, we were performing 00:37:43.560 |
well on CIFAR-10. And we also think that, in general, with, let's say, new work on image 00:37:49.440 |
transformers on small data, this can probably just be done with linear patching. And so 00:37:54.000 |
this kind of-- the manner in which you're embedding things is probably not the key. 00:37:59.000 |
Neil, if I can jump in with two questions. Can you go back two slides first? One is just 00:38:08.080 |
a small, minor point. Back one more, please. Thank you. Here are the features. 50 plus. 00:38:18.160 |
I'll have to double-check what the exact number is. I'm pretty sure it's probably around 50, 00:38:23.240 |
Ah, so the 50 is really an order of. It's not like 150 or 5,000. 00:38:29.320 |
Yes. Yeah. I mean, I'll double-check for you, or you can check with the metadata statistics 00:38:35.600 |
at the end of the paper. But no, it wasn't arbitrarily large. I would say, though, we 00:38:43.480 |
did these ablations on whether or not we actually need attention between attributes. We did 00:38:49.040 |
find that this ended up benefiting us. But you could perhaps do, say, just an MLP embedding 00:38:57.680 |
in that dimension and go to a relatively small number of hidden dimensions and fit an arbitrary 00:39:03.280 |
number of features. So I think that, yeah, if you kind of relax the necessity of attention 00:39:10.960 |
between attributes, you can probably scale out at least that dimension quite a lot. 00:39:15.240 |
OK. And then my second question, if you could go forward one slide. Thank you. Here, I'm 00:39:22.360 |
not sure I quite caught. What is it four of 10 data sets, two of 10 data sets and four 00:39:29.120 |
This is of the of all the tabular data sets that we had. So. Oh, I see. OK. Yeah, exactly. 00:39:41.280 |
The standard errors here, because I mean, there's like there's just 10 data sets, right? 00:39:47.160 |
Yeah, correct. 10 total tabular data sets. Yeah. But these are right. Yeah, these are 00:39:53.200 |
these are rank wise performance. Correct. OK, I'm just seeing. How the where the uncertainty 00:40:01.920 |
comes from in this case. Yeah, average averaged over four of 10 data sets, the rank. So for 00:40:09.040 |
each particular data set, we have a rank of all the different methods. Then we take the 00:40:11.960 |
average and the very answer of the of the rankings within each of the types of task 00:40:20.360 |
within binary classification, within multiclass, et cetera. 00:40:24.840 |
We also, if you're curious, you know, have have the full results in the paper. Yeah. 00:40:28.920 |
Thank you. We also have a couple of questions. Some of the questions on the. 00:40:36.760 |
Hey, yeah, thanks. I guess I just found it a little surprising that the worst performer 00:40:43.840 |
was can and given that it's also nonparametric, I guess, could you comment on that? And is 00:40:51.560 |
yeah. Is it that there's something like intrinsic to the NPT that makes it just exceptional 00:40:57.600 |
far beyond other nonparametric methods? Or yeah, why? Why is it that can performs the 00:41:05.560 |
Well, I suppose ultimately can and it is it's still a relatively naive predictive method 00:41:11.720 |
and that, you know, it might just be predicting based on kind of cluster means. So for example, 00:41:18.600 |
you know, I think this is probably universally true for all the data sets, but there's probably 00:41:21.720 |
some amount of kind of additional reasoning that needs to occur over the features, at 00:41:25.880 |
least to basic level. So for example, like one of the data sets is this poker hand data 00:41:29.480 |
set where it's like a mapping between all of the different hands you have in poker and 00:41:33.600 |
what like they're commonly known to people like full houses or whatever. 00:41:37.160 |
So this, this requires some amount of reasoning over the features to be able to group things 00:41:41.020 |
together. So just taking like the cluster means of the featurization of those different, 00:41:46.920 |
you know, hands is likely not going to give you a great predictive function. Whereas NPTs 00:41:52.600 |
can kind of do the classic thing where say you have an MLP type of thing over the features 00:41:57.640 |
or like a tree type of thing over the features, you can learn some sort of complex embedding, 00:42:02.440 |
but then you also can do some nonparametric sort of prediction based on say like clusters 00:42:09.040 |
I see. Yeah, that makes sense. I guess. What if, what if you used pre-trained embeddings 00:42:16.180 |
from a stack of encoders as, as your vector representation for the CNN, how do you think 00:42:23.020 |
that would perform compared to the rest of the crowd? 00:42:26.020 |
Yeah. Yeah. So this is like, I mean, this idea is kind of like deep kernel learning 00:42:30.740 |
or like, yeah, I believe it is deep kernel learning is basically you use an MLP independently. 00:42:38.980 |
So you learn an MLP on each input data point, and then you apply a GP over all the representations 00:42:44.580 |
of those. So you get this sort of like complex embedding and then the lookups. The key difference 00:42:49.260 |
between that type of idea and NPTs is that we also learn the relationships between the 00:42:54.060 |
data points themselves, because we use this parametric attention mechanism to learn the 00:42:58.460 |
relationship. So we're not just learning like an embedding independently. We're basically 00:43:03.920 |
backpropagating through the entire process, learning the ways in which we would try to 00:43:07.500 |
embed this, but also the, the ways that say the lookup would occur. And essentially the, 00:43:13.340 |
the relationships at a, that it could potentially be kind of higher order as well. 00:43:17.540 |
Okay, cool. Wait, one more, one more follow-up or. 00:43:25.300 |
Cool. Yeah. Thanks. So I guess then if, if the advantage of NPT has to do with sort of 00:43:31.940 |
the relationships between data points, then what if you, you know, took the, took the, 00:43:41.100 |
let's say, you know, encoder representations and then you pass that as input, say for the, 00:43:48.680 |
you know, 10 nearest neighbors, along with like some other like input representation 00:43:56.460 |
and sort of had this like weighted average like attention style where you, you weighted 00:44:02.180 |
the, the vectors of the nearest neighbors based on the attention weights between those 00:44:11.020 |
input data points. And then like the supplied input data point, and then like pass that 00:44:17.180 |
as, as you know, the, the vector to like the final prediction layer, like, do you think 00:44:24.220 |
that captures some amount of the relationship or, or is that off base? 00:44:30.900 |
So I think the nice part, like, and really what our idea is behind this whole thing is 00:44:35.500 |
just these sorts of instances where certain fixed kernels would perform particularly well 00:44:40.900 |
in tasks is like kind of an annoyance and like ultimately like tuning a lot of these 00:44:44.980 |
types of things are trying to derive the predictive methods that might make a lot of sense for 00:44:49.300 |
a given situation kind of stinks. And ideally you'd want to just back propagate on a data 00:44:54.060 |
set and kind of learn these relationships yourself. So I actually would be really interested 00:44:58.180 |
to see if we can come up with some synthetic experiments that have these sort of like very 00:45:02.700 |
particular can and like predictive mechanisms and just see if we can learn precisely those 00:45:07.240 |
and get, you know, zero error with NPS. And in fact, like we'll get into this a little 00:45:12.340 |
bit with some of the interventional experiments we do, we have like kind of precise lookup 00:45:17.300 |
functions that NPS end up being able to learn so we can learn interesting relational functions. 00:45:27.140 |
Cool. All right. I have one more question from. Yeah, sure. 00:45:34.260 |
I just wanted to clarify something about basically so at test time, you just take the exact same 00:45:40.940 |
data set and you just like add like your test examples, right? And then you like do the 00:45:46.420 |
same type of like masking. And is that how it works? 00:45:50.620 |
Yeah, correct. OK, got it. And I do have one more question. 00:45:55.300 |
That is just because I'm I think I misunderstood like how like the effects of your NPT objective. 00:46:02.700 |
Do you mind going back to that slide? Sure. Yeah. Can you repeat one more time? Like what 00:46:13.180 |
makes this so special? Yeah. So the the regularizer on the right 00:46:19.060 |
over the features, I would think of very similarly to self-supervised learning with just a standard 00:46:24.780 |
transformer like you're you're basically just introducing a lot more supervision and you're 00:46:29.380 |
even if, say, you're just doing a supervised objective, this is kind of like some amount 00:46:33.420 |
of reconstruction over the features. You learn a more interesting representation and like 00:46:37.620 |
what a regularizing effect, which we think is interesting, but perhaps not as interesting 00:46:41.940 |
as this stochastic target masking. This one is unique because in kind of standard parametric 00:46:48.220 |
deep learning, you're not going to have an instance in your training process where you're 00:46:53.060 |
taking targets as input. And so basically what happens is you have your training data 00:47:00.940 |
set as input, whatever, you're going to have some stochastic feature masking stuff happening 00:47:04.820 |
on the features amongst the training targets. You're randomly going to have some of those 00:47:10.380 |
unmasked and some of them will indeed be masked. You're going to be backpropagating a loss 00:47:14.380 |
on the ones that are masked, of course, because, you know, you don't want your model to have 00:47:18.060 |
those available at input if you're going to actually try to backpropagate a loss on it. 00:47:22.620 |
But you can use the other ones as input. And that means you can learn these kind of like 00:47:25.740 |
interpolative functions. So that was like this whole idea of like being able to kind 00:47:31.860 |
But doesn't that allow the model to cheat again? 00:47:36.620 |
Yeah. So this is like an interesting point and actually like subtle. So I think it's 00:47:41.940 |
really worthwhile to bring up. So first of all, we never actually backpropagate a loss 00:47:48.260 |
on something that was visible to the model at input. And so if, for example, the model 00:47:54.740 |
did actually end up basically overfitting on training labels, we would not observe the 00:47:59.260 |
model's ability to generalize to test data. We don't observe this. So obviously, it seems 00:48:04.900 |
like this kind of blocking of backpropagation on labels that are visible at input to the 00:48:13.380 |
It could also be possible that in BERT style stochastic masking, you also randomly will 00:48:18.740 |
flip some labels to be in a different category. So this is like kind of just like a random 00:48:24.940 |
fine print that was introduced in the BERT masking text. We also do that. So it's possible 00:48:29.780 |
that that somehow contributes to that. But it's probably pretty likely to just be the 00:48:35.740 |
fact that we're not backpropagating a loss on something that's visible. 00:48:47.900 |
Sorry. Can we go to the metrics, the performance, the results slide? 00:48:54.220 |
I feel like I missed something else. I'm sorry about this. So looking on the binary classification 00:48:59.820 |
AUROC, can you clarify what these numbers mean? Are they the AUROC? 00:49:08.860 |
So this is the, so on each of the data sets, so say for a particular binary classification 00:49:15.740 |
data set, we're going to get a ranking of the methods. We're going to repeat this. Yeah. 00:49:23.980 |
So these numbers here are the relative ranking across in this particular case, the four data 00:49:30.500 |
I see. So this, these values are not the AUROCs on average across the data sets. 00:49:38.540 |
I mean like averaging, averaging AUROC might make sense, but averaging things like accuracy 00:49:44.500 |
and RMSE seems like a bad idea, right? Because you might have some data sets where everything 00:49:49.500 |
has high accuracy or where RMSE needs something drastically different. 00:49:53.540 |
I see. So this, these numbers here only tell us the relative ranking between the different 00:49:57.660 |
methods, not how well they actually perform. I mean, it tells us how they perform relative 00:50:02.220 |
to one another, but not how well they perform. I see. 00:50:06.020 |
But that's all in the appendix. We all have, we have that information. 00:50:08.340 |
I see. Okay. I was, I was sitting here confused going like, why is AUROC, why is the best 00:50:12.360 |
one the smallest? And accuracy, what is an accuracy of 2.5? Anyways. Okay. That makes 00:50:21.060 |
Awesome. Great. So I'll try to speed through this just in the interest of time. But the 00:50:29.460 |
basically thing, the thing that you might be thinking after all of these results is 00:50:33.580 |
are we even learning any data point interactions on these real data sets? 00:50:37.820 |
And so basically we designed an experiment to figure this out. And the idea is that we're 00:50:42.140 |
going to disallow NPT from using other data points when predicting on one of them. If 00:50:47.820 |
we do that and we observe that NPT actually predicts or performs significantly worse, 00:50:53.580 |
it is indeed using these interactions between data points. A subtle challenge or it's kind 00:50:59.140 |
of like an added bonus we can get from this is that ideally we wouldn't actually break 00:51:03.860 |
batch statistics. So let's say like the mean of each particular attribute. If we can find 00:51:09.160 |
a way to do this experiment such that we don't break these things, we can kind of rule out 00:51:13.980 |
the possibility that we learned something that's a bit similar to batch norm. 00:51:18.260 |
And so the way that we do this is we basically look at the predictions for each one of the 00:51:22.380 |
data points in sequence. So let's say in this case, we're looking at the prediction of the 00:51:26.820 |
model for this particular green row. And it's going to be predicting in this last column 00:51:30.860 |
that has this question mark which is masked. What we're going to do is we're going to permute 00:51:34.900 |
each of the attributes independently amongst all other data points except for that one. 00:51:39.360 |
So the information for that row, if it was kind of just predicting like a classic parametric 00:51:43.620 |
deep model is still intact, but the information from all of the other rows is gone. So that's 00:51:48.980 |
why we call this sort of the corruption experiment. And so we find in general, when we perform 00:51:54.540 |
this experiment, performance kind of falls off a cliff for the vast majority of these 00:51:58.940 |
methods. And I'll note that the performances between the methods on a lot of these were 00:52:03.580 |
fairly close. And so this is actually indeed pretty significant. So for example, on protein, 00:52:09.260 |
we went from being the top performer amongst all the methods to the worst performer worse 00:52:12.420 |
than even like KNN or something like that. I'll also note that there's kind of this interesting 00:52:18.700 |
behavior where on these data sets like forest and kick and breast cancer, we actually observed 00:52:23.740 |
that there's basically no drop in performance. And we basically see this as kind of an interesting 00:52:28.660 |
feature and not necessarily a bug of the model, which is that if we're backpropagating on 00:52:34.020 |
a given data set, the model can sort of just find that it's actually not that worthwhile 00:52:38.660 |
to attempt to predict using some kind of relational predictive mechanism amongst data points and 00:52:43.940 |
can instead just learn to predict parametrically and basically ignore other data points when 00:52:48.980 |
it's predicting on any given one of them. And so this probably leads to some kind of 00:52:53.700 |
like interesting ideas where perhaps you could do like post hoc pruning or something like 00:52:57.020 |
that, taking away the tension between data points and doing fine tuning, let's say. Alright, 00:53:05.580 |
so now I'll hand over to Yannick to talk a bit about learning some interesting relationships. 00:53:10.380 |
Yeah, will you though? I see that we're at the end of what the time is, but like, I know 00:53:19.060 |
there's a buffer planned in or something. I can go through this experiment, we can have 00:53:24.140 |
a bit of discussion. What do you guys prefer? Yeah, I think normally what we do is we would 00:53:31.580 |
sort of stop the recording at this point and have an off the record discussion. And I guess 00:53:37.860 |
the question to ask is, does anyone have any questions at this point? But I think we've 00:53:48.040 |
basically been wanting questions as they come. So I personally feel fine just considering 00:53:55.700 |
this as sort of questions throughout. Yeah, I guess that sounds good. Yannick, you can 00:54:02.420 |
go forward with it, with your talk as planned and later we can see about the time thing. 00:54:10.780 |
I think this will only be like another four or five minutes talk. Yeah, yeah, that's good. 00:54:16.220 |
Then go for it, yeah, for sure. Alright, so Neil has now told us how well NPTs perform 00:54:24.300 |
in real data and that they do make use of information from other samples of the input. 00:54:29.780 |
But we're not going to take this a bit further and come up with some toy experiments that 00:54:34.300 |
test the extent to which NPTs can learn to look up information from other rows, like 00:54:39.780 |
the extent to which they can learn this nonparametric prediction mechanism. And so specifically 00:54:45.020 |
what we'll do is we'll create the following semi-synthetic data set. I want you to focus 00:54:50.120 |
on A now. So we'll take one of the tabular data sets that we've used previously, specifically 00:54:56.140 |
the protein data set, but it doesn't really matter. What matters is that it's a regression 00:54:59.720 |
data set. And so now what we do is we, the top half here is the original data set, but 00:55:06.460 |
the bottom half is a copy of the original data set where we have unveiled the true target 00:55:13.500 |
value. So now NPTs could learn to use attention between data points to achieve arbitrarily 00:55:19.580 |
good performance. They could learn to look up the target values in these matching duplicate 00:55:24.860 |
rows and then paste them back into that masked out target value. And then at test time, of 00:55:31.740 |
course, we put in a novel test data input where this mechanism is also possible just 00:55:37.780 |
to make sure that it hasn't learned to memorize anything, but has actually learned this correct 00:55:43.100 |
relational mechanism. And so what we see is that indeed, NPTs do successfully learn to 00:55:49.220 |
perform this lookup. So what I'm visualizing here is attention maps, and they very clearly 00:55:53.420 |
show that, let's say when predicting for this green row here, this first green row, what 00:55:58.380 |
NPTs look at is exactly only that other green row here. And so this is really nice. We can 00:56:07.260 |
further look at the Pearson correlation between what NPTs should predict and what they actually 00:56:14.820 |
do predict. And so this is 99.9%. This is much better than anything you could achieve 00:56:20.300 |
with parametric prediction. And so it seems that NPTs here can actually discover this 00:56:24.480 |
mechanism. And discover here, I feel like it's the right word because NPTs could have, 00:56:30.740 |
as we've seen, just also continue to predict in parametric fashion, right, from each row 00:56:36.060 |
independently. This is really kind of showing to us that there is this bias in the model 00:56:42.380 |
to learn to predict from other rows. And of course, that is also very attractive in this 00:56:48.020 |
setting because it allows you to achieve arbitrary load loss in this setting, or as lowest you 00:56:53.380 |
can optimize for it. And so we kind of take that to mean that our, you know, gradient 00:57:00.300 |
based discovery, non-parametric philosophy seems to make some sense. And so we can take 00:57:06.580 |
this a bit further by performing somewhat of an interventional experiment that investigates 00:57:11.820 |
the extent to which NPTs have actually learned a robust, you know, causal mechanism that's 00:57:18.100 |
underlying this semisynthetic data set. And so just appending, you know, this extra column 00:57:28.780 |
of test data, that's already kind of cool, but I think we can take it a bit further and 00:57:33.220 |
actually study if this generalizes beyond the data that we see in the training set or 00:57:37.900 |
beyond data coming from this specific distribution. And so what we now do is we intervene on individual 00:57:44.380 |
duplicate data points at test time by varying their target value. So now we only care about 00:57:50.620 |
the prediction in a specific row. We do this across all rows, but at each time we just 00:57:55.660 |
cover a single row. What we do is we change the target value here, that what we're hoping 00:58:00.780 |
to see is that NPT just adjusts the prediction as well, right? There's a very simple intervention 00:58:06.060 |
experiment for us to test if NPTs have actually learned this mechanism. And to some extent 00:58:10.820 |
it also tests robustness because now we're associating target values with features that 00:58:16.060 |
are not part of the training distribution here. And so what we see is that as we adjust 00:58:24.140 |
these values here, this is the kind of the duplicate value. And then we here see the 00:58:29.020 |
target value. As we adjust them, we can see the correlation stays really, really good. 00:58:33.380 |
It's not quite 99.9%, like on average, we're now at 99.6, but it's still very, very good. 00:58:40.580 |
And at this point you might be slightly annoyed with me because standard nonparametric models 00:58:47.580 |
can also solve this task. This is a task that I could solve by nearest neighbors. Sure, 00:58:52.900 |
maybe I would have to change the input format a bit because this is kind of like in a batch 00:58:57.020 |
setting and I could just use masks, but most generally a nearest neighbor can also, it 00:59:02.580 |
also looks up different input points based on their features. Nearest neighbor doesn't 00:59:08.180 |
learn to do this. I still think it's cool that we need to learn this because it does 00:59:12.100 |
require a decent amount of computational sequences that we have to learn, like match all the 00:59:17.900 |
features, look up target value, copy it back and so on. But it is in fact very easy for 00:59:24.100 |
us to complicate this task to a degree such that essentially no other model that we know 00:59:29.460 |
of can solve this very easily. And so a really simple thing to do is just to add a plus one 00:59:37.580 |
to all of the duplicate values. So now nearest neighbor would look up the right row, of course, 00:59:45.980 |
but it would always predict the wrong target with a plus one on it. And in fact, many of 00:59:49.980 |
the models that we're aware of, they're not modeling the joint distribution over features 00:59:56.340 |
and targets. What they're modeling is the conditional distribution of the targets given 01:00:01.580 |
the input features. And so they also cannot do this. And so for us it's really not a problem 01:00:06.860 |
at all. MPTs will just learn to subtract another one and no problems. And sure, this is also 01:00:13.300 |
still a very synthetic setting, but I do think, I mean, I challenge you to come up with some 01:00:19.900 |
thing that MPTs can't solve, but the other models can solve. I think this, in general, 01:00:25.260 |
this masking mechanism and the non-parametricity of the approach is really nice in general 01:00:31.220 |
and leads to lots of nice behavior in a variety of settings. And so with that, I think we 01:00:36.940 |
can go to the conclusions, which Neil is going to give you. 01:00:41.060 |
Yeah, I think, I mean, we're going to cut out the main part here. I'll just fast forward. 01:00:47.860 |
Just look at them. Yeah, yeah. I was going to say, I think you'll get the gist. MPTs 01:00:55.020 |
take the entire data set as input and they use self-attention to model complex relationships 01:00:58.820 |
between data points. They do well in experiments on tabular data as well as image data. We 01:01:06.300 |
present some of these interventional experiments to show that they can solve complex reasoning 01:01:10.100 |
tasks. There's some more experiments in the paper. I'd say that the interesting type of 01:01:14.820 |
future work is scaling type things. So we can, you know, not having this mini-batching 01:01:20.020 |
approximation and then also just trying to expand this to some more interesting application 01:01:24.060 |
demands. So we talked a little bit about meta-learning, but it could also be things like, you know, 01:01:27.340 |
few-shot generalization in general, domain adaptation, semi-supervised learning, et cetera. 01:01:33.380 |
So I think if there's some more questions, maybe we can do some more discussion. 01:01:38.340 |
Yeah. I think sounds good. Great. Thanks for the talk. I think everyone had a fun time. 01:01:45.340 |
I will just ask some general questions and then we can have like a discussion session 01:01:49.540 |
with everyone after that. So I think one thing that I noticed is like this, like you said, 01:01:55.580 |
this is similar to like KNNs and I thought like this seems similar to like graph neural 01:01:59.620 |
networks where you can think like each data point is like a node and then you can think 01:02:03.660 |
of everything as a fully connected graph and you're learning some sort of attention weight 01:02:06.820 |
in this graph. So this is like a node prediction task you are kind of doing on this sort of 01:02:11.700 |
like graph structure. So any comments on that? Like, is it similar to like graph neural networks 01:02:16.220 |
or is it like other differences? Yeah, this is a very good observation. Yeah, 01:02:21.860 |
I think there are a lot of similarities to work on graph neural networks. If we want 01:02:26.340 |
to talk about differences, the differences might be that we're kind of assuming a fully 01:02:30.860 |
connected graph, right? And so you could maybe also phrase that as we're discovering the 01:02:36.380 |
relational structure or as graph neural networks usually assume that it's given. But that's 01:02:40.940 |
also not always true. And so there are a lot of similarities. I don't know, Neil, if there 01:02:45.660 |
was something specific you would like to mention, go ahead. But it's a very good observation 01:02:49.980 |
and we also do feel that that's the case. And we've added an extra section on related 01:02:54.660 |
work to graph neural networks in the updated version of the paper that will be online soon. 01:03:00.620 |
Yeah, I agree with everything you've said. I think the closest work from the GNN literature 01:03:07.540 |
that we were looking at a little bit was this neural relational inference paper, which uses 01:03:11.260 |
message-passing neural networks to try to kind of like learn edges that may or may not 01:03:17.540 |
exist and help for like extrapolating, I think, positions of like particles in like a multi-particle 01:03:23.600 |
system or something, which is like kind of a similar idea to us. Like, you know, if you 01:03:27.700 |
don't have these edges as given, the attention mechanism could kind of approximate an interesting 01:03:31.900 |
relationship amongst some interacting things. I see. Got it. Yeah, that's really cool. Another 01:03:39.380 |
thing is like, so you mostly look on like tabular data, but can you also like have other 01:03:43.020 |
modalities, like if you want to do language or something, can you still use non-parametric 01:03:47.580 |
transformers? Yeah, so I think part of our motivation for doing tabular was because we 01:03:54.580 |
felt like tabular data is, in a sense, a generalization of, let's say, the language data, for example. 01:04:00.420 |
I mean, I guess there's these other notions that people have brought up, like padding, 01:04:06.300 |
but ultimately you can think of it as like a bunch of categorical attributes. So it is 01:04:11.700 |
definitely generalizable to things like sentences and we do, you know, images. So, yeah. I think 01:04:19.500 |
actually like, I always go back and forth on whether or not I think smaller or larger 01:04:25.660 |
data is more interesting for us. So I think small data is really interesting because we 01:04:29.740 |
can just fit the entire data set into it and all of this just works out of the box without 01:04:35.620 |
any extra thought. But large data is actually also really interesting because, sure, you 01:04:41.740 |
might have to introduce some approximative mechanism or some lookup mechanism because 01:04:45.740 |
you can't always have the entire data set in. But at the same time, you are very explicitly 01:04:51.740 |
kind of trading off the compute that you use to look up with the compute that you need 01:04:57.700 |
to store. Like how many parameters in GPT are used for storing data, right? There's 01:05:03.460 |
lots of memorization happening in these models and we know that. And so maybe we can use 01:05:08.580 |
the parameters more efficiently to learn lookup type behavior, right? That is more close to 01:05:13.780 |
this, you know, neural KNN or whatever. So I think these are very exciting questions. 01:05:18.340 |
Yeah, yeah. I'll also be looking forward to the future works because it seems like a very 01:05:23.380 |
good way to like do one-shot learning kind of situation. So, yeah, really very interesting 01:05:28.700 |
to see that. Okay, so I will stop the recording and we can have like any other questions.