back to index

Stanford CS25: V1 I Self Attention and Non-parametric transformers (NPTs)


Chapters

0:0
3:25 Fast Auto Repressive Decoding
14:1 Motivation and a Brief Summary
15:6 Parametric Prediction
16:2 Non-Parametric Transformer Architecture
23:40 Three Key Components to Npts
24:44 Dataset Matrix
25:8 Masking Matrix
26:50 Linear Embedding
31:3 Masking Based Training Objective
32:24 Stochastic Target Masking
37:20 Results
41:28 Poker Hand Data
42:32 Deep Kernel Learning
54:45 Semi-Synthetic Dataset
55:51 Attention Maps

Whisper Transcript | Transcript Only Page

00:00:00.000 | >> Thanks so much. Great to be here and happy Halloween, belated Halloween everyone. So
00:00:10.880 | I think the talk is going to be split into two sections. So I'll start by spending like
00:00:16.000 | 10 minutes, 15 minutes chatting about transformers in general. But I'm assuming most of you are
00:00:22.120 | familiar with them and we can move on to MPTs, which Yannick and Neil will be presenting.
00:00:28.960 | So let's see. I'm going to like try to fly through the transformer overview and maybe
00:00:36.520 | spend a little bit extra time on like the history of transformers and maybe just tell
00:00:42.200 | the story a little bit. I think that might be more interesting. So just in terms of the
00:00:49.640 | transformer architecture, the two kinds of things that it introduced for the first time
00:00:54.520 | were multi-head attention and self-attention. And then it combined those with fast autoregressive
00:01:00.880 | decoding. So before the transformer, pretty much everyone was using LSTMs and LSTMs with
00:01:07.560 | attention. But I'll try to get into the difference of self-attention, multi-head attention. So
00:01:16.520 | originally you would have two sequences and you would have a attention module, which would
00:01:23.160 | attend from the source to the target. And so each token or each word in the source sequence
00:01:29.120 | would get associated with, you know, a soft approximation of one element in the target
00:01:34.800 | sequence. And so you'd end up with something like this. But with self-attention, we did
00:01:41.680 | away with the two separate sequences. We make them both the same. And so you're relating
00:01:47.200 | each element within the sequence to another element in the sequence. And so the idea here
00:01:56.320 | is that you're learning a relationship of the words within a sentence to the other words.
00:02:01.080 | So you can imagine something like an adjective, which is being applied to a noun. And so you
00:02:06.240 | want to relate that adjective, like the blue ball, you want to relate blue as referring
00:02:11.120 | to ball. So we're learning patterns within the sequence, intra-sequence patterns. So
00:02:19.480 | sorry, I gave this talk in Kenya, so I am using Kiswahili here. But with multi-head
00:02:27.160 | attention, the idea is you have like each word represented by an embedding, which is
00:02:32.200 | in the depth dimension here. And then you have your sentence of words. You split that
00:02:36.720 | up into a bunch of different groups. So here I've chopped it depth-wise into four groups.
00:02:43.080 | You apply attention to each one of these groups independently. And then when you get the result
00:02:49.400 | back, you can catenate them together and you're back to your model dimension representation.
00:02:56.760 | So what this lets you do is if each attention head, like each attention head can now focus
00:03:02.600 | on learning one pattern. So maybe attention head one is learning the relationship of adjectives
00:03:09.240 | to nouns. And the second attention head can learn something different. So this lets us
00:03:15.600 | learn like a hierarchy or a list of different relationships. Okay. So that was self-attention.
00:03:24.560 | The other piece is fast autoregressive decoding. And do I really want to go into this? Okay,
00:03:33.740 | I will. So the important thing about this is that if you're doing normal autoregressive
00:03:40.100 | decoding, what you do is you generate your first token. And now conditioned on that first
00:03:44.000 | token, you generate the second and conditioned on the first two, you generate the third and
00:03:47.760 | so on and so forth. But that's super slow, right? Like it's a loop applying this thing
00:03:52.120 | again and again. And so what we can do instead is we make an assumption in the code that
00:03:59.200 | our model always generates the right thing. And then we generate a prediction, only one
00:04:07.360 | token ahead. So you have your outputs, which are Y, you have your targets, which are Y
00:04:26.760 | hat. And what you do is you feed in those gold targets so that you don't need to actually
00:04:34.480 | do this loop. So instead of assuming, instead of having to generate the first token, feed
00:04:39.360 | it back into your architecture, generate a second token, you feed in the entire target
00:04:44.360 | sequence and you just pretend that you generated all the right tokens up to position K. And
00:04:50.520 | then you predict the K plus first and you compute your loss on that. So in reality,
00:04:56.440 | your model might've generated at the beginning of training junk, but you're getting a loss
00:05:01.880 | as if your model had seen all the correct tokens and is now just predicting the next
00:05:07.240 | one. This is a little bit subtle, but it's hugely impactful for training speed because
00:05:13.280 | all of this can be done on parallel in parallel. And so it's actually what make transformers
00:05:18.280 | so scalable. Okay. So in order to do this successfully, if you were just feeding in
00:05:25.780 | all of the, all of the correct tokens, uh, naively, what would happen is your model would
00:05:32.080 | just be able to, uh, look forward in time and cheat. So you've, you've put in all of
00:05:39.360 | your true targets, the things that you're trying to get your model to predict. And so
00:05:44.200 | if that's where you're computing your loss on, if you could just look forward in time
00:05:47.240 | and say, okay, I'm just going to grab that. And we'd get zero error trivially, right?
00:05:51.360 | Cause you you've given it all the right answers. So what we have to do inside the architecture
00:05:56.100 | is we need to actually prevent, uh, the attention mechanism from being able to look at tokens
00:06:02.100 | that it shouldn't have been able to see already. Um, so the way that this looks is you create
00:06:09.460 | a mask on your attention. Um, and so, sorry, this is the example of like doing a trivial
00:06:16.780 | attention. If you don't mask your attention properly, um, what it's going to do is it's
00:06:22.220 | just going to look into the future, just grab the token that you're telling it to predict
00:06:27.460 | and copy it over. And so it learned something trivial, something that doesn't actually generalize.
00:06:31.540 | And so what we do is we actually prevent it from attending to those tokens. We prevent
00:06:36.100 | it from attending into the future for each position in the source sequence. We block
00:06:42.460 | out, uh, everything that it shouldn't be able to see everything into the future. And then
00:06:47.700 | as we move down, we gradually unblock so it can start to see into the past. Um, so those
00:06:54.340 | are kind of like the two, the three major components of transformers, um, the self-attention,
00:07:03.380 | the multi-head attention, and then deploying this gold targets, decoding is fast, autoregressive
00:07:11.260 | attention. In terms of the story, which might be a little bit more interesting. Um, so transformers,
00:07:19.660 | I was an intern, uh, with Lukasz Kaiser, uh, at Google back in 2017. Um, and I was sitting
00:07:26.700 | next to Gnome, uh, and Ashish was like a couple of seats down from us. Um, and what's really
00:07:33.900 | incredible is that essentially this entire project came together in like three months
00:07:39.260 | and it was done. So I showed up at Google, uh, no, I'm had been working on autoregressive
00:07:46.180 | models. Um, same thing with like Ashish and Yakov and Nikki. And, um, they just been kind
00:07:52.420 | of like exploring the space, figuring it out. Uh, and Lukasz and I, at the same time, we'd
00:07:57.620 | been working on this framework called tensor to tensor, um, which was like explicitly made
00:08:03.340 | for, uh, multimodal learning, autoregressive learning. Um, and Lukasz is kind of like a
00:08:10.900 | master of keeping track of everything that's happening in the field and adopting it. And
00:08:17.740 | so within tensor to tensor, there were like these, there was like these kind of emerging
00:08:24.820 | little things that maybe one paper had been written about, uh, and people were interested
00:08:29.020 | in like layer norm, um, but it hadn't actually taken off yet. Uh, the warmup in the learning
00:08:34.660 | rate schedule, um, all of these little pieces were just default, like on by default. Um,
00:08:43.420 | and so when gnome and Ashish and Nikki and Yakov came over and adopted tensor to tensor,
00:08:50.660 | all of these things were just on by default. Um, and so a lot of people, when they look
00:08:56.460 | at the, the transformer paper, it just seems like there's so many like arbitrary little
00:09:00.940 | things thrown in. And when now, like in present day, these have become standard for like a
00:09:06.180 | lot of different training algorithms, like the learning rate warmup, um, the way that
00:09:11.860 | we did initialization, uh, all of these pieces have just become the norm, but back then they
00:09:16.980 | had like kind of just been introduced. Um, and so we, uh, we spent a lot of time running
00:09:24.940 | ablations trying to figure out like, which were the necessary pieces and what made it
00:09:28.420 | work.
00:09:29.420 | And if any of you have actually tried training transformers and tried like pulling out the
00:09:34.580 | learning rate warmup, um, or changing any of these little pieces, you'll see that it
00:09:39.860 | really does break down optimization. Like it actually really does, um, hurt performance
00:09:46.340 | for instance, like removing the layer norms, that type of thing. Um, so I always thought
00:09:52.180 | it was kind of funny how all of these random additions that Lukasz had just like thrown
00:09:57.860 | in, cause he was playing around with them turned out to be crucial. Uh, and they were
00:10:02.260 | just on by default. Um, so anyway, it was like three months. I remember, um, it all
00:10:11.660 | really started coming together towards the end, like just before the NeurIPS deadline.
00:10:16.740 | Um, and I can still remember sitting in the micro kitchen and Ashish telling me like,
00:10:21.500 | that's just like, I was a little intern, uh, telling me like, this is going to be such
00:10:25.620 | a big deal. And I was like, yeah, sure. Okay. I have no idea what's happening. I just showed
00:10:30.620 | up. Um, and he was like, no dude, like this, this actually matters. Like, you know, we,
00:10:36.580 | we bumped up blue three points and I was like, sick, great anyway. Um, and then I can remember
00:10:45.620 | on the night of the deadline for NeurIPS, it was like 2am, uh, Ashish, Ashish was the
00:10:54.340 | only one left at the office and we were still like moving around figures and like adjusting
00:10:59.620 | things. And then, um, I went to bed, but Ashish stayed up and I slept in like this tiny little
00:11:04.740 | phone booth. Um, and then for the other paper that I was submitting, I forgot to press submit,
00:11:11.580 | but luckily like some lady opened the door to the phone booth and hit me in the head
00:11:16.740 | while I was sleeping in the morning. And just before the deadline, I got the paper in. And
00:11:22.100 | so I owe it to that lady, uh, for, for submitting to NeurIPS that year. But yeah, anyway, the,
00:11:29.740 | I think the crazy thing about transformers was that it all came together in like three
00:11:32.900 | months. Like most of the ideas happened in that span. And it was just like this sprint
00:11:37.580 | towards the NeurIPS deadline. Um, I think a lot of the other members on the team, Jakob,
00:11:43.220 | Lukasz, Shalom, Ashish, they knew how important it was, but for me, I was like, I don't know.
00:11:48.500 | I really did not appreciate, uh, the impact. Um, but in retrospect, it's been amazing how
00:11:55.540 | the community has kind of like come together and adopted it. Um, and I think most of that
00:12:01.860 | can be ascribed to the ease of optimization. It seems like very robust to hyper parameter
00:12:07.980 | choices. So you don't need to like tune the hell out of it, spend a lot of time tweaking
00:12:12.220 | little things. Um, and the other side is that it's like super tailored to the accelerators
00:12:17.060 | that we run on. Um, so it's like very paralyzable, um, hyper-efficient. Um, and so it lends itself
00:12:26.220 | to that kind of scaling law effort that's, that's really taken off in, in popularity.
00:12:31.900 | Okay. Uh, unless there are any questions.
00:12:34.620 | That's such a cool story. Oh my God. We're both excited. So we just unmuted at the same
00:12:42.260 | time.
00:12:43.260 | Yeah. So that's, that's my section. If there's any questions, happy to answer them. Otherwise,
00:12:50.700 | uh, let's get into NPTs. NPTs are like, I think, um, there's such a nice next level
00:12:58.540 | abstraction of the architecture. So you've probably seen the trend of, um, transformers
00:13:04.140 | getting applied to, to new domains, um, first into vision and video and audio. Um, but this
00:13:13.100 | is kind of like cutting back to an even more abstract level. Um, like I think tabular data.
00:13:19.580 | Yeah. I don't know. I'll let Yannick and Neil take over from here, but, uh, I think NPTs
00:13:24.420 | is a pretty sick, pretty sick project.
00:13:28.900 | Thanks for the introduction. Um, thanks all for the invitation. We're very happy to be
00:13:33.740 | here and Neil and I are now going to tell you about our, um, self-attention between
00:13:40.420 | data points paper, where we introduce, introduce the non-parametric transformer architecture.
00:13:45.380 | Um, we'll start with a little bit of motivation, move on to explaining the architecture in
00:13:50.260 | detail, then show you the experiments. This is more or less a step through of the paper,
00:13:54.580 | but maybe, you know, with a little bit of extra insight here and there, all right, as
00:14:02.220 | promised motivation, I have a brief summary. So, um, we'll start by thinking about something
00:14:08.060 | that we don't often think about. And that is that from CNNs to transformers, most of
00:14:13.340 | supervised deep learning relies on parametric prediction. So what that means is that we
00:14:19.220 | have some self-training data and we want to learn to predict the outcomes Y from inputs
00:14:25.020 | X. And for this, we set up some model with tunable parameters, theta, then we optimize
00:14:31.020 | these parameters to maximize predictive likelihoods on a training set, or, you know, equivalently,
00:14:36.660 | we minimize some loss. And then after training, we have this optimized set of parameters theta
00:14:43.100 | and then at test time, we just put these into the model and use these parameters to predict
00:14:48.340 | on novel test data. And so crucially here, our prediction at test time only depends on
00:14:54.620 | these parameters, right? It's parametric. Also, that means that given these parameters,
00:15:00.340 | the prediction is entirely independent of the training data. And so why would we want
00:15:06.380 | to do parametric prediction? Well, it's really convenient because all that we've learned
00:15:11.260 | from the training data can be summarized in the parameters. And so prediction time, we
00:15:16.300 | only need these final parameters and we do not need to store the training data, which
00:15:19.820 | might be really, really large. On the other hand, we usually have models that already
00:15:25.380 | predict for a bunch of data in parallel, right? Think of mini-batching in modern architectures.
00:15:30.920 | And actually things like batch norm already make these data interact. And so our thinking
00:15:37.100 | here was that if we've got all of this data in parallel anyways, there's no reason not
00:15:41.780 | to make use of it. And so more, a bit grander, we kind of challenged parametric prediction
00:15:48.460 | as the dominant paradigm in deep learning. And so we want to give models the additional
00:15:53.780 | flexibility of using the training data directly when making predictions. And so a bit more
00:16:00.340 | concretely, we introduced the non-parametric transformer architecture. And this is going
00:16:06.220 | to be a general deep learning architecture, meaning we can apply it to a variety of scenarios.
00:16:12.420 | NPTs will take the entire data set as input whenever possible. And NPTs then crucially
00:16:19.540 | learn to predict from interactions between data points. And to achieve this, we use multi-head
00:16:26.140 | self-attention that, as Aiden has introduced us to, has just really established itself
00:16:32.260 | as a general purpose layer for reasoning. We also take another thing from the NLP community
00:16:39.500 | and we use a stochastic masking mechanism. And we use that to tell NPTs where to predict
00:16:45.300 | and also to regularize the learning task of it. And lastly, of course, we hope to convince
00:16:50.980 | you that this ends up working really, really well, and that this kind of simple idea of
00:16:55.700 | learning to predict from the other data points of the input, from the training points of
00:17:00.380 | the input, ends up working rather well. And so very briefly summarizing what we've heard
00:17:08.980 | already. A, we input into NPTs the entire data set. And then B, let's say for the purpose
00:17:15.780 | of this slide here, we only care about predicting the orange question mark in that green row.
00:17:22.280 | And then we can compare NPTs to parametric prediction, right? So a classical deep learning
00:17:27.780 | model would predict this target value only from the features of that single green input.
00:17:33.460 | To do that, it would use the parameters theta, those would depend on whatever training data
00:17:37.620 | we've seen and so on. But at test time, we only look at that single row for which we
00:17:42.580 | care about the prediction. In contrast, NPTs predict an explicit dependence on all samples
00:17:49.060 | in the input. They can look beyond that single green datum of interest and look at all other
00:17:54.200 | samples that are there and consider their values for prediction. So this presents an
00:17:58.440 | entirely different way of thinking about how we learn predictive mechanisms. And somebody
00:18:04.540 | on Twitter called this KNN 2.0, which we would have not written in the paper, but maybe is
00:18:09.760 | kind of a nice way of thinking about how NPTs can learn to predict. So of course, nonparametric
00:18:17.240 | models are a thing already. We didn't invent them at all. And I define them here as prediction
00:18:23.720 | in explicit dependence on the training data, which is certainly what NPTs do. Classical
00:18:29.200 | examples like Gaussian processes, k-nearest neighbor, kernel methods, those might be familiar
00:18:34.320 | to you. And there exists also efforts to combine the benefits of nonparametrics and representation
00:18:40.280 | learning in a similar fashion to how we did it in NPTs. However, these approaches are
00:18:47.220 | usually limited in some sense in comparison to NPTs, right? They're often kind of motivated
00:18:51.960 | from the statistics community a bit more. They often require more finicky approximative
00:18:55.920 | inference schemes, are limited in the interactions they can learn, or things like that. And so
00:19:01.560 | we really think NPTs present maybe the most versatile and most widely applicable of these
00:19:07.760 | nonparametric prediction approaches. But that's something we explicitly wanted to have. We
00:19:12.240 | wanted to have something that's really easy to use, plug and play, works in a ton of scenarios,
00:19:16.940 | and works really well. And so with that, I hand over to Neil, who's going to tell you
00:19:23.380 | about the nonparametric transformer architecture in all of its details.
00:19:27.800 | Yeah, we also have one question. Go ahead. Yes, yes.
00:19:32.360 | Hi, Jendrik. Hi. So could you please go back to the previous slide?
00:19:40.680 | The very previous slide? Yeah, yes. This slide, yeah. So in terms of
00:19:45.440 | the problem definition, I think it's quite similar to some meta-learning problem, which
00:19:51.160 | basically there is a mapping from a data point on the data sets to some predictions. So could
00:19:58.480 | you please suggest any differences between your problem setting and the meta-learning
00:20:05.040 | problem settings? I can't really figure out any differences between these two problems.
00:20:11.760 | Well, I think it really depends on the framing that you want to have. So I would say meta-learning
00:20:19.240 | would be when I try to predict over multiple data sets. So when I try to learn some sort
00:20:26.160 | of prediction model, or I can just plug in a different data set, and it will almost automatically
00:20:32.080 | give me new predictions on this different data distribution. But that's not what we
00:20:36.080 | do at all. We're training a single model for a fixed data set. And so this is why I wouldn't
00:20:41.720 | really call that meta-learning, because we're trying to predict on the same tasks that all
00:20:47.160 | the supervised deep learning, or any supervised machine learning method is trying to predict
00:20:53.400 | well on. Okay, so you mean you use kind of same test
00:20:59.920 | set to test your training model, right? So basically, in meta-learning, we're going to
00:21:16.360 | test on different kind of meta-test sets. But in your case, you just want to use a test
00:21:22.800 | set, which is similar to the distribution of your training set, right?
00:21:29.600 | Yeah, absolutely. So we explore data set distribution shift a bit. I think it's a really interesting
00:21:34.520 | scenario. I think meta-learning different data sets is also an interesting scenario,
00:21:40.000 | right, when you have this model where you could just push in different data sets. But
00:21:44.160 | for the scope of this paper, it's very much training set, test set, they come from the
00:21:49.040 | same distribution, and we're just trying to do supervised learning in a standard setting.
00:21:55.720 | I see, cool. Thank you. Thank you for the question.
00:21:59.720 | Yeah, and I would chime in a couple of additional things, I guess. So at least from what I understand
00:22:06.480 | from the problem definition of meta-learning, I think the aim is more perhaps being able
00:22:11.960 | to perform well on a new data set with a relatively small number of additional gradient steps
00:22:16.280 | on that data set. So I think there are some interesting ways that you could actually consider
00:22:20.880 | applying NPTs in a meta-learning type setting. And so we'll get into this a little bit more,
00:22:26.080 | but for example, there might be ways to essentially add in a new data set. So let's suppose we've
00:22:32.280 | trained on a bunch of different data sets. We now add in a new data set. We can perhaps
00:22:36.480 | do some sorts of kind of zero shot meta-learning, basically, where there's no need for additional
00:22:43.800 | gradient steps because we're basically predicting kind of similar to how you might do prompting
00:22:48.520 | nowadays in NLP literature. Anyways, yeah, I think we'll get into some more details.
00:22:54.800 | Just to chime in on that, I don't think that every meta-learning algorithm-- I think the
00:23:04.640 | one that you're describing right now are optimization-based, but there are also black box ones. You don't
00:23:09.800 | need to further-- I think the main difference seems to be that there is one task versus
00:23:15.960 | multiple tasks for meta-learning. Yeah, I think so too. I think the main framing question
00:23:23.640 | is whether or not there's multiple data sets. Cool. OK, awesome. If there's no other questions,
00:23:33.360 | I'll dive a bit more into the architecture. Awesome. So there's three key components to
00:23:42.920 | NPTs. I'm going to first state them at a high level, and then we'll go through each of them
00:23:46.480 | in more detail. So first of all, we take the entire data set, all data points as input.
00:23:53.440 | So for example, at test time, the model is going to take as input both training and test
00:23:57.880 | data, and we approximate this with mini-batches for large data. We apply self-attention between
00:24:05.560 | data points. So for example, at test time, we model relationships amongst training points,
00:24:10.960 | amongst test points, and between the two sets. And then finally, we have this masking-based
00:24:17.240 | training objective. It's a BERT-like stochastic masking, and the key point is that we actually
00:24:22.680 | use it on both features as well as on training targets. And we'll get into why that leads
00:24:27.660 | to an interesting predictive mechanism later.
00:24:32.040 | So to start with this idea of data sets as input, there's two things that compose the
00:24:37.140 | input to NPT. It's a full data set in the form of a matrix X and a masking matrix M.
00:24:43.880 | And so Yannick has described this data set matrix a little bit. We basically have data
00:24:48.200 | points as rows. The columns are attributes, and each attribute shares some kind of semantic
00:24:53.160 | meaning among all of its data points. So say, for example, you're just doing single-target
00:24:58.260 | classification or regression. The last column would be the target, and the rest of the matrix
00:25:03.560 | would be input features. So for example, the pixels of an image.
00:25:08.520 | We also have a masking matrix. So let's say we're thinking about mass language modeling.
00:25:13.720 | The mass tokens will just tell us where we're going to conceal words and where we're going
00:25:17.760 | to back-propagate a loss. We do a similar type of thing here, where we use this binary
00:25:22.340 | mass matrix to specify which entries are mass. And the goal is to predict mass values from
00:25:29.920 | observed values. I see that there was a question about handling inputs with different lengths.
00:25:38.440 | In the data sets we've considered, we'll get into it in the results section, but it's mostly
00:25:43.360 | been sort of tabular and image data, where the lengths for each of the data points is
00:25:47.840 | the same, but it would work just like padding. That would be a reasonable way to go about
00:25:51.960 | that. There's also kind of an interesting-- yeah, go for it, Yannick.
00:25:56.500 | Just to add to that, I'm not sure if length refers to columns or to rows, right? Rows,
00:26:02.720 | we don't care about how many rows. Length, padding or something would be an option.
00:26:07.520 | Yeah, my question was about column. Exactly, that makes sense. Awesome, thanks.
00:26:13.160 | Yeah, I mean, that goes along with the whole meta-learning discussion. I think if we wanted
00:26:17.600 | to adapt to data sets that have a different number of data points per data set, we can
00:26:22.360 | take advantage of the fact that self-attention is kind of OK with that. Cool.
00:26:31.560 | So continuing on, left to discuss here is basically how we do the embedding. So to put
00:26:39.440 | this more explicitly, we have this data matrix that has n data points. It's called x, and
00:26:45.280 | it also has d attributes. And we have this binary mass matrix m. We're going to stack
00:26:49.160 | them, and then we're going to do a linear embedding.
00:26:52.800 | So specifically, we're doing the same linear embedding independently for each data point.
00:26:57.720 | We're learning a different embedding for each attribute. We have a positional encoding on
00:27:02.320 | the index of the attributes because we don't really care about, say, being equivariant
00:27:06.000 | over the columns. If it's tabular data, you, of course, want to treat all these kind of
00:27:09.280 | heterogeneous columns differently. And then finally, we have an encoding on the type of
00:27:14.000 | column, so whether or not it's continuous or categorical.
00:27:17.800 | And that ends up giving us this input data set representation that is dimensions n by
00:27:22.760 | d by e. The second key component of NPTs is attention between data points. So to do that,
00:27:33.360 | we first take this representation we have and flatten to an n by d times e representation.
00:27:40.040 | So basically, we're treating each of these d times e size rows as if it's a token representation.
00:27:46.120 | We're actually going to just accomplish this operation using multi-head self-attention.
00:27:50.720 | We've reviewed this a lot, but the nice thing is that we know from language modeling, if
00:27:54.600 | we stack this multiple times, we can model these higher order dependencies. And here
00:27:58.040 | they're between data points, and that's really the key draw of this architecture.
00:28:02.480 | There's been other kind of instances of people using attention for similar sorts of things.
00:28:06.960 | So for example, like attentive neural processes. A lot of times, they've sort of used just
00:28:12.200 | a single layer as kind of a representational lookup. And we believe that this actually
00:28:17.320 | ends up limiting expressivity, and that by stacking this many times, you can learn more
00:28:20.860 | complex relationships between the data points.
00:28:23.680 | Anil, we also have some questions. So you can go ahead first.
00:28:30.280 | Oh, cool. Thanks. I have a question about how you guys do the embedding. Is there always
00:28:35.200 | part of these convolutional filters or linear layers? What is the type of embedding that
00:28:41.520 | you guys use?
00:28:42.520 | Yeah, so I'm attempting to go back to the slide. I think it's not very happy with me
00:28:47.560 | right now. But yeah, so for tabular data, we did just linear embeddings, actually. So
00:28:55.960 | we could get into, I guess, details of featurization for categorical and continuous, but it's literally
00:29:00.120 | like, say, for categorical, you do a one-hot encoding, and then you learn this embedding
00:29:04.440 | that is specific to that attribute. And then for numerical, I believe we were just standard
00:29:09.480 | normalizing. For the image data, we did end up using a ResNet-18 encoder for CIFAR-10.
00:29:17.480 | However, I think that-- I mean, we'll discuss that a bit later in results, but that embedding
00:29:23.160 | is a bit arbitrary. You can sort of do whatever. The key part of the architecture is the attention
00:29:28.120 | between data points. So in terms of how you actually want to embed each attribute, it's
00:29:32.520 | kind of up to you.
00:29:35.840 | Thanks.
00:29:36.840 | I think-- you had a question?
00:29:42.360 | Same question, I was victorious.
00:29:46.040 | OK, awesome. Cool. So here we have attention between data points done. So we can also do
00:29:54.080 | this attention between attributes. So we reshape back to this n by d by e representation, and
00:30:01.640 | then we can just apply self-attention independently to each row, in other words, to a single data
00:30:06.760 | point. And the intuition for why we would kind of do this nested-type idea where we
00:30:11.320 | switch between attention between data points and attention between attributes is just we're
00:30:15.200 | trying to learn better per-data-point representations for the between-data-point interactions. This
00:30:21.440 | is literally just normal self-attention, as you'd see in language modeling or image classification.
00:30:26.600 | The attributes are the tokens here.
00:30:31.080 | And finally, we just rinse and repeat. So what are we actually getting out of this?
00:30:36.720 | To summarize, we're learning higher-order relationships between data points. We're learning
00:30:40.920 | transformations of individual data points. And then importantly, NPT is equivariant to
00:30:46.640 | a permutation of the data points. This basically just reflects the intuition that the learned
00:30:52.080 | relationships between the data points should not depend on the ordering in which you receive
00:30:55.800 | them or in which you observe your data set.
00:31:01.200 | The third key component of NPTs is a masking-based training objective. So recall that what we're
00:31:08.000 | trying to do is we're trying to predict missing entries from observed entries. And those masked
00:31:13.000 | values can be both features or targets. So again, the classic use, say, in masked language
00:31:18.320 | modeling is to do self-supervised learning on a sequence of tokens, which you could think
00:31:22.320 | of as just having features in our setting. Ours is a bit different in that we do stochastic
00:31:28.960 | feature masking to mask feature values with a probability p sub feature. And then we also
00:31:33.320 | do this masking of training targets with this probability p sub target.
00:31:39.540 | So if we write out the training objective, we are just taking a weighted sum of the negative
00:31:45.280 | log likelihood loss from targets as well as from features. And of course, at test time,
00:31:50.000 | we're only going to mask and compute a loss over the targets of test points.
00:31:55.220 | So to break this down a bit further and point out some of the cool parts of it here, the
00:32:00.760 | thing that's highlighted right now on the far right is the term relating to the features.
00:32:05.280 | It's the feature masking. Basically, we find that this has a nice regularizing effect.
00:32:11.760 | More or less, the model can now predict anywhere and makes the task a bit harder and introduces
00:32:15.240 | some more supervision. And we found in an ablation for the tabular data sets that it
00:32:19.080 | helped for eight of 10 of those. And then there's this other term, which is kind of
00:32:23.880 | interesting. It's the stochastic target masking. And the idea is that you're actually going
00:32:30.060 | to have some training targets unmasked to the model at input at training time, which
00:32:36.920 | means that the NPT can learn to predict the mask targets of certain training data points
00:32:42.400 | using the targets of other training data points as well as all of the training features.
00:32:47.680 | And so that means you don't actually need to memorize a mapping between training inputs
00:32:52.240 | and outputs in your parameters. You can instead devote the representational capacity of the
00:32:56.240 | model to learn functions that use other training features and targets as input.
00:33:01.600 | So this is kind of getting into the idea of this sort of like learn KNN idea. Obviously,
00:33:06.380 | we can learn more complex relational lookups and those sorts of things from this. But you
00:33:12.320 | can imagine one such case being we have a bunch of test data points coming in. We're
00:33:18.160 | going to look at their features and use that to assign them to clusters of training data
00:33:22.840 | points. And then our prediction for those points is just going to be an interpolation
00:33:27.220 | of the training targets in that respective cluster. That's like an example of something
00:33:31.400 | that this mechanism lets NPTs learn. All right. So if there's any questions, we can take them
00:33:40.640 | now. Otherwise, I'm happy to take them in the discussion or something. All right. So
00:33:51.160 | let's discuss. Yeah, go for it.
00:33:54.000 | I'm curious when you're using the entire data set, does that limit the type of data sets
00:34:00.080 | you can use because of the size?
00:34:03.120 | Yeah. So in practice, we do random mini batching as an approximation. So the idea is just to
00:34:10.600 | say, if you have a reasonably large mini batch, you're going to benefit a bit from still having
00:34:15.560 | kind of this lookup ability. Because if you have a reasonable number of classes, probably
00:34:19.760 | you're going to be able to learn some interesting mappings based on features and targets amongst
00:34:25.200 | those classes. We found in practice that-- and we'll get into this a little bit. But
00:34:30.720 | we do actually indeed learn to use relationships between data points on prediction for data
00:34:38.160 | points where we're doing mini batching. And we also didn't necessarily find that you need
00:34:42.360 | like a ludicrously large batch size for this to be a thing. But I do think it's just--
00:34:48.160 | this is, in general, an important point. And it's one that points us towards looking into,
00:34:52.240 | say, sparse transformers literature for trying to expand to some larger data sets without
00:34:58.120 | having the mini batching assumption.
00:34:59.920 | Great. Thank you.
00:35:02.800 | If I can add a number to that, we can, without mini batching, accommodate data sets of around
00:35:10.120 | like 8,000 points or so. So that already accounts for a fair proportion, I would say, of the
00:35:17.720 | tabular data sets out there. But we also do data sets with 11 million points where, obviously,
00:35:22.760 | we then resort to mini batching. So it's very good to have an idea of the sizes that we're
00:35:27.600 | talking about.
00:35:29.440 | I'm curious on that. I mean, it's pretty exciting. I feel like you don't normally hear about
00:35:35.000 | transformers being applied to data sets of size 8,000. I'm curious-- and we can talk
00:35:42.360 | about this sort of later once we've covered the other material-- if you found that sample
00:35:45.640 | efficiency is one of the key gains here, or just experience working on small data sets
00:35:50.640 | of transformers generally. And yeah. But I'm happy to punt the answer to that until after
00:35:55.480 | as part of the discussion.
00:35:57.760 | Yeah, I think that'd be really nice to talk about a bit. And it was something that, in
00:36:02.560 | general, I guess I'd say was surprising to us in terms of how robust NPTs were on small
00:36:07.600 | data sets, and how we surprisingly didn't have to tune a terrible number of parameters.
00:36:11.560 | But we can get into details in a bit.
00:36:17.560 | Awesome. So to get into the experiments, we focused a lot on tabular data because it's
00:36:26.440 | a very general setting. And it's also notoriously challenging for deep learning. So we know
00:36:32.160 | tree-based boosting methods, stuff like XGBoost, is very dominant. And this is also a very
00:36:37.920 | relevant domain to, I think, people in industry and that sort of thing. So we were excited
00:36:42.260 | about the idea of trying to do better on this.
00:36:44.940 | So we chose a broad selection of data sets varying across a few different dimensions.
00:36:51.120 | As we mentioned, on the order of hundreds to tens of millions of instances, broad range
00:36:57.160 | in the number of features, in the composition of features in terms of being categorical
00:37:00.880 | or continuous, various types of tasks, binary and multi-class classification, as well as
00:37:06.000 | regression. And like I said, the baselines were kind of the usual suspects for tabular
00:37:09.240 | data-- XGBoost, CatBoost, LightGBM, TuneMLPs, and TabNet, which is a transformer architecture
00:37:16.200 | for tabular data.
00:37:19.360 | So to get into the results, here I'm showing the average rank for the various subtasks.
00:37:25.720 | We did well in terms of rank-wise performance against methods like CatBoost and XGBoost,
00:37:30.600 | which are designed specifically for tabular data. And in fact, we find that NPT is the
00:37:35.260 | top performer on four of the 10 of these data sets.
00:37:38.920 | On image data, I mentioned that we used a CNN encoder. And with that, we were performing
00:37:43.560 | well on CIFAR-10. And we also think that, in general, with, let's say, new work on image
00:37:49.440 | transformers on small data, this can probably just be done with linear patching. And so
00:37:54.000 | this kind of-- the manner in which you're embedding things is probably not the key.
00:37:59.000 | Neil, if I can jump in with two questions. Can you go back two slides first? One is just
00:38:08.080 | a small, minor point. Back one more, please. Thank you. Here are the features. 50 plus.
00:38:14.120 | What does plus mean here?
00:38:18.160 | I'll have to double-check what the exact number is. I'm pretty sure it's probably around 50,
00:38:22.240 | I would guess.
00:38:23.240 | Ah, so the 50 is really an order of. It's not like 150 or 5,000.
00:38:29.320 | Yes. Yeah. I mean, I'll double-check for you, or you can check with the metadata statistics
00:38:35.600 | at the end of the paper. But no, it wasn't arbitrarily large. I would say, though, we
00:38:43.480 | did these ablations on whether or not we actually need attention between attributes. We did
00:38:49.040 | find that this ended up benefiting us. But you could perhaps do, say, just an MLP embedding
00:38:57.680 | in that dimension and go to a relatively small number of hidden dimensions and fit an arbitrary
00:39:03.280 | number of features. So I think that, yeah, if you kind of relax the necessity of attention
00:39:10.960 | between attributes, you can probably scale out at least that dimension quite a lot.
00:39:15.240 | OK. And then my second question, if you could go forward one slide. Thank you. Here, I'm
00:39:22.360 | not sure I quite caught. What is it four of 10 data sets, two of 10 data sets and four
00:39:26.600 | of 10 mean?
00:39:29.120 | This is of the of all the tabular data sets that we had. So. Oh, I see. OK. Yeah, exactly.
00:39:37.680 | Awesome. Any other questions?
00:39:41.280 | The standard errors here, because I mean, there's like there's just 10 data sets, right?
00:39:47.160 | Yeah, correct. 10 total tabular data sets. Yeah. But these are right. Yeah, these are
00:39:53.200 | these are rank wise performance. Correct. OK, I'm just seeing. How the where the uncertainty
00:40:01.920 | comes from in this case. Yeah, average averaged over four of 10 data sets, the rank. So for
00:40:09.040 | each particular data set, we have a rank of all the different methods. Then we take the
00:40:11.960 | average and the very answer of the of the rankings within each of the types of task
00:40:20.360 | within binary classification, within multiclass, et cetera.
00:40:24.840 | We also, if you're curious, you know, have have the full results in the paper. Yeah.
00:40:28.920 | Thank you. We also have a couple of questions. Some of the questions on the.
00:40:36.760 | Hey, yeah, thanks. I guess I just found it a little surprising that the worst performer
00:40:43.840 | was can and given that it's also nonparametric, I guess, could you comment on that? And is
00:40:51.560 | yeah. Is it that there's something like intrinsic to the NPT that makes it just exceptional
00:40:57.600 | far beyond other nonparametric methods? Or yeah, why? Why is it that can performs the
00:41:03.720 | worst here?
00:41:05.560 | Well, I suppose ultimately can and it is it's still a relatively naive predictive method
00:41:11.720 | and that, you know, it might just be predicting based on kind of cluster means. So for example,
00:41:18.600 | you know, I think this is probably universally true for all the data sets, but there's probably
00:41:21.720 | some amount of kind of additional reasoning that needs to occur over the features, at
00:41:25.880 | least to basic level. So for example, like one of the data sets is this poker hand data
00:41:29.480 | set where it's like a mapping between all of the different hands you have in poker and
00:41:33.600 | what like they're commonly known to people like full houses or whatever.
00:41:37.160 | So this, this requires some amount of reasoning over the features to be able to group things
00:41:41.020 | together. So just taking like the cluster means of the featurization of those different,
00:41:46.920 | you know, hands is likely not going to give you a great predictive function. Whereas NPTs
00:41:52.600 | can kind of do the classic thing where say you have an MLP type of thing over the features
00:41:57.640 | or like a tree type of thing over the features, you can learn some sort of complex embedding,
00:42:02.440 | but then you also can do some nonparametric sort of prediction based on say like clusters
00:42:08.040 | of embeddings.
00:42:09.040 | I see. Yeah, that makes sense. I guess. What if, what if you used pre-trained embeddings
00:42:16.180 | from a stack of encoders as, as your vector representation for the CNN, how do you think
00:42:23.020 | that would perform compared to the rest of the crowd?
00:42:26.020 | Yeah. Yeah. So this is like, I mean, this idea is kind of like deep kernel learning
00:42:30.740 | or like, yeah, I believe it is deep kernel learning is basically you use an MLP independently.
00:42:38.980 | So you learn an MLP on each input data point, and then you apply a GP over all the representations
00:42:44.580 | of those. So you get this sort of like complex embedding and then the lookups. The key difference
00:42:49.260 | between that type of idea and NPTs is that we also learn the relationships between the
00:42:54.060 | data points themselves, because we use this parametric attention mechanism to learn the
00:42:58.460 | relationship. So we're not just learning like an embedding independently. We're basically
00:43:03.920 | backpropagating through the entire process, learning the ways in which we would try to
00:43:07.500 | embed this, but also the, the ways that say the lookup would occur. And essentially the,
00:43:13.340 | the relationships at a, that it could potentially be kind of higher order as well.
00:43:17.540 | Okay, cool. Wait, one more, one more follow-up or.
00:43:22.740 | Oh yeah, go for it.
00:43:25.300 | Cool. Yeah. Thanks. So I guess then if, if the advantage of NPT has to do with sort of
00:43:31.940 | the relationships between data points, then what if you, you know, took the, took the,
00:43:41.100 | let's say, you know, encoder representations and then you pass that as input, say for the,
00:43:48.680 | you know, 10 nearest neighbors, along with like some other like input representation
00:43:56.460 | and sort of had this like weighted average like attention style where you, you weighted
00:44:02.180 | the, the vectors of the nearest neighbors based on the attention weights between those
00:44:11.020 | input data points. And then like the supplied input data point, and then like pass that
00:44:17.180 | as, as you know, the, the vector to like the final prediction layer, like, do you think
00:44:24.220 | that captures some amount of the relationship or, or is that off base?
00:44:30.900 | So I think the nice part, like, and really what our idea is behind this whole thing is
00:44:35.500 | just these sorts of instances where certain fixed kernels would perform particularly well
00:44:40.900 | in tasks is like kind of an annoyance and like ultimately like tuning a lot of these
00:44:44.980 | types of things are trying to derive the predictive methods that might make a lot of sense for
00:44:49.300 | a given situation kind of stinks. And ideally you'd want to just back propagate on a data
00:44:54.060 | set and kind of learn these relationships yourself. So I actually would be really interested
00:44:58.180 | to see if we can come up with some synthetic experiments that have these sort of like very
00:45:02.700 | particular can and like predictive mechanisms and just see if we can learn precisely those
00:45:07.240 | and get, you know, zero error with NPS. And in fact, like we'll get into this a little
00:45:12.340 | bit with some of the interventional experiments we do, we have like kind of precise lookup
00:45:17.300 | functions that NPS end up being able to learn so we can learn interesting relational functions.
00:45:22.380 | Cool. Yeah. Thanks a lot. Appreciate it.
00:45:27.140 | Cool. All right. I have one more question from. Yeah, sure.
00:45:34.260 | I just wanted to clarify something about basically so at test time, you just take the exact same
00:45:40.940 | data set and you just like add like your test examples, right? And then you like do the
00:45:46.420 | same type of like masking. And is that how it works?
00:45:50.620 | Yeah, correct. OK, got it. And I do have one more question.
00:45:55.300 | That is just because I'm I think I misunderstood like how like the effects of your NPT objective.
00:46:02.700 | Do you mind going back to that slide? Sure. Yeah. Can you repeat one more time? Like what
00:46:13.180 | makes this so special? Yeah. So the the regularizer on the right
00:46:19.060 | over the features, I would think of very similarly to self-supervised learning with just a standard
00:46:24.780 | transformer like you're you're basically just introducing a lot more supervision and you're
00:46:29.380 | even if, say, you're just doing a supervised objective, this is kind of like some amount
00:46:33.420 | of reconstruction over the features. You learn a more interesting representation and like
00:46:37.620 | what a regularizing effect, which we think is interesting, but perhaps not as interesting
00:46:41.940 | as this stochastic target masking. This one is unique because in kind of standard parametric
00:46:48.220 | deep learning, you're not going to have an instance in your training process where you're
00:46:53.060 | taking targets as input. And so basically what happens is you have your training data
00:47:00.940 | set as input, whatever, you're going to have some stochastic feature masking stuff happening
00:47:04.820 | on the features amongst the training targets. You're randomly going to have some of those
00:47:10.380 | unmasked and some of them will indeed be masked. You're going to be backpropagating a loss
00:47:14.380 | on the ones that are masked, of course, because, you know, you don't want your model to have
00:47:18.060 | those available at input if you're going to actually try to backpropagate a loss on it.
00:47:22.620 | But you can use the other ones as input. And that means you can learn these kind of like
00:47:25.740 | interpolative functions. So that was like this whole idea of like being able to kind
00:47:29.660 | of learn KNN.
00:47:31.860 | But doesn't that allow the model to cheat again?
00:47:36.620 | Yeah. So this is like an interesting point and actually like subtle. So I think it's
00:47:41.940 | really worthwhile to bring up. So first of all, we never actually backpropagate a loss
00:47:48.260 | on something that was visible to the model at input. And so if, for example, the model
00:47:54.740 | did actually end up basically overfitting on training labels, we would not observe the
00:47:59.260 | model's ability to generalize to test data. We don't observe this. So obviously, it seems
00:48:04.900 | like this kind of blocking of backpropagation on labels that are visible at input to the
00:48:11.500 | NPT is helping.
00:48:13.380 | It could also be possible that in BERT style stochastic masking, you also randomly will
00:48:18.740 | flip some labels to be in a different category. So this is like kind of just like a random
00:48:24.940 | fine print that was introduced in the BERT masking text. We also do that. So it's possible
00:48:29.780 | that that somehow contributes to that. But it's probably pretty likely to just be the
00:48:35.740 | fact that we're not backpropagating a loss on something that's visible.
00:48:39.380 | Great. Thanks. Makes sense.
00:48:45.060 | I have two more questions if I can jump in.
00:48:46.900 | Sure.
00:48:47.900 | Sorry. Can we go to the metrics, the performance, the results slide?
00:48:53.220 | Sure.
00:48:54.220 | I feel like I missed something else. I'm sorry about this. So looking on the binary classification
00:48:59.820 | AUROC, can you clarify what these numbers mean? Are they the AUROC?
00:49:08.860 | So this is the, so on each of the data sets, so say for a particular binary classification
00:49:15.740 | data set, we're going to get a ranking of the methods. We're going to repeat this. Yeah.
00:49:22.980 | Go for it.
00:49:23.980 | So these numbers here are the relative ranking across in this particular case, the four data
00:49:28.500 | sets.
00:49:29.500 | Correct. Yeah.
00:49:30.500 | I see. So this, these values are not the AUROCs on average across the data sets.
00:49:36.380 | No. Yeah. They're not.
00:49:38.540 | I mean like averaging, averaging AUROC might make sense, but averaging things like accuracy
00:49:44.500 | and RMSE seems like a bad idea, right? Because you might have some data sets where everything
00:49:49.500 | has high accuracy or where RMSE needs something drastically different.
00:49:53.540 | I see. So this, these numbers here only tell us the relative ranking between the different
00:49:57.660 | methods, not how well they actually perform. I mean, it tells us how they perform relative
00:50:02.220 | to one another, but not how well they perform. I see.
00:50:05.020 | Okay.
00:50:06.020 | But that's all in the appendix. We all have, we have that information.
00:50:08.340 | I see. Okay. I was, I was sitting here confused going like, why is AUROC, why is the best
00:50:12.360 | one the smallest? And accuracy, what is an accuracy of 2.5? Anyways. Okay. That makes
00:50:16.820 | much more sense. Thank you both.
00:50:21.060 | Awesome. Great. So I'll try to speed through this just in the interest of time. But the
00:50:29.460 | basically thing, the thing that you might be thinking after all of these results is
00:50:33.580 | are we even learning any data point interactions on these real data sets?
00:50:37.820 | And so basically we designed an experiment to figure this out. And the idea is that we're
00:50:42.140 | going to disallow NPT from using other data points when predicting on one of them. If
00:50:47.820 | we do that and we observe that NPT actually predicts or performs significantly worse,
00:50:53.580 | it is indeed using these interactions between data points. A subtle challenge or it's kind
00:50:59.140 | of like an added bonus we can get from this is that ideally we wouldn't actually break
00:51:03.860 | batch statistics. So let's say like the mean of each particular attribute. If we can find
00:51:09.160 | a way to do this experiment such that we don't break these things, we can kind of rule out
00:51:13.980 | the possibility that we learned something that's a bit similar to batch norm.
00:51:18.260 | And so the way that we do this is we basically look at the predictions for each one of the
00:51:22.380 | data points in sequence. So let's say in this case, we're looking at the prediction of the
00:51:26.820 | model for this particular green row. And it's going to be predicting in this last column
00:51:30.860 | that has this question mark which is masked. What we're going to do is we're going to permute
00:51:34.900 | each of the attributes independently amongst all other data points except for that one.
00:51:39.360 | So the information for that row, if it was kind of just predicting like a classic parametric
00:51:43.620 | deep model is still intact, but the information from all of the other rows is gone. So that's
00:51:48.980 | why we call this sort of the corruption experiment. And so we find in general, when we perform
00:51:54.540 | this experiment, performance kind of falls off a cliff for the vast majority of these
00:51:58.940 | methods. And I'll note that the performances between the methods on a lot of these were
00:52:03.580 | fairly close. And so this is actually indeed pretty significant. So for example, on protein,
00:52:09.260 | we went from being the top performer amongst all the methods to the worst performer worse
00:52:12.420 | than even like KNN or something like that. I'll also note that there's kind of this interesting
00:52:18.700 | behavior where on these data sets like forest and kick and breast cancer, we actually observed
00:52:23.740 | that there's basically no drop in performance. And we basically see this as kind of an interesting
00:52:28.660 | feature and not necessarily a bug of the model, which is that if we're backpropagating on
00:52:34.020 | a given data set, the model can sort of just find that it's actually not that worthwhile
00:52:38.660 | to attempt to predict using some kind of relational predictive mechanism amongst data points and
00:52:43.940 | can instead just learn to predict parametrically and basically ignore other data points when
00:52:48.980 | it's predicting on any given one of them. And so this probably leads to some kind of
00:52:53.700 | like interesting ideas where perhaps you could do like post hoc pruning or something like
00:52:57.020 | that, taking away the tension between data points and doing fine tuning, let's say. Alright,
00:53:05.580 | so now I'll hand over to Yannick to talk a bit about learning some interesting relationships.
00:53:10.380 | Yeah, will you though? I see that we're at the end of what the time is, but like, I know
00:53:19.060 | there's a buffer planned in or something. I can go through this experiment, we can have
00:53:24.140 | a bit of discussion. What do you guys prefer? Yeah, I think normally what we do is we would
00:53:31.580 | sort of stop the recording at this point and have an off the record discussion. And I guess
00:53:37.860 | the question to ask is, does anyone have any questions at this point? But I think we've
00:53:48.040 | basically been wanting questions as they come. So I personally feel fine just considering
00:53:55.700 | this as sort of questions throughout. Yeah, I guess that sounds good. Yannick, you can
00:54:02.420 | go forward with it, with your talk as planned and later we can see about the time thing.
00:54:10.780 | I think this will only be like another four or five minutes talk. Yeah, yeah, that's good.
00:54:16.220 | Then go for it, yeah, for sure. Alright, so Neil has now told us how well NPTs perform
00:54:24.300 | in real data and that they do make use of information from other samples of the input.
00:54:29.780 | But we're not going to take this a bit further and come up with some toy experiments that
00:54:34.300 | test the extent to which NPTs can learn to look up information from other rows, like
00:54:39.780 | the extent to which they can learn this nonparametric prediction mechanism. And so specifically
00:54:45.020 | what we'll do is we'll create the following semi-synthetic data set. I want you to focus
00:54:50.120 | on A now. So we'll take one of the tabular data sets that we've used previously, specifically
00:54:56.140 | the protein data set, but it doesn't really matter. What matters is that it's a regression
00:54:59.720 | data set. And so now what we do is we, the top half here is the original data set, but
00:55:06.460 | the bottom half is a copy of the original data set where we have unveiled the true target
00:55:13.500 | value. So now NPTs could learn to use attention between data points to achieve arbitrarily
00:55:19.580 | good performance. They could learn to look up the target values in these matching duplicate
00:55:24.860 | rows and then paste them back into that masked out target value. And then at test time, of
00:55:31.740 | course, we put in a novel test data input where this mechanism is also possible just
00:55:37.780 | to make sure that it hasn't learned to memorize anything, but has actually learned this correct
00:55:43.100 | relational mechanism. And so what we see is that indeed, NPTs do successfully learn to
00:55:49.220 | perform this lookup. So what I'm visualizing here is attention maps, and they very clearly
00:55:53.420 | show that, let's say when predicting for this green row here, this first green row, what
00:55:58.380 | NPTs look at is exactly only that other green row here. And so this is really nice. We can
00:56:07.260 | further look at the Pearson correlation between what NPTs should predict and what they actually
00:56:14.820 | do predict. And so this is 99.9%. This is much better than anything you could achieve
00:56:20.300 | with parametric prediction. And so it seems that NPTs here can actually discover this
00:56:24.480 | mechanism. And discover here, I feel like it's the right word because NPTs could have,
00:56:30.740 | as we've seen, just also continue to predict in parametric fashion, right, from each row
00:56:36.060 | independently. This is really kind of showing to us that there is this bias in the model
00:56:42.380 | to learn to predict from other rows. And of course, that is also very attractive in this
00:56:48.020 | setting because it allows you to achieve arbitrary load loss in this setting, or as lowest you
00:56:53.380 | can optimize for it. And so we kind of take that to mean that our, you know, gradient
00:57:00.300 | based discovery, non-parametric philosophy seems to make some sense. And so we can take
00:57:06.580 | this a bit further by performing somewhat of an interventional experiment that investigates
00:57:11.820 | the extent to which NPTs have actually learned a robust, you know, causal mechanism that's
00:57:18.100 | underlying this semisynthetic data set. And so just appending, you know, this extra column
00:57:28.780 | of test data, that's already kind of cool, but I think we can take it a bit further and
00:57:33.220 | actually study if this generalizes beyond the data that we see in the training set or
00:57:37.900 | beyond data coming from this specific distribution. And so what we now do is we intervene on individual
00:57:44.380 | duplicate data points at test time by varying their target value. So now we only care about
00:57:50.620 | the prediction in a specific row. We do this across all rows, but at each time we just
00:57:55.660 | cover a single row. What we do is we change the target value here, that what we're hoping
00:58:00.780 | to see is that NPT just adjusts the prediction as well, right? There's a very simple intervention
00:58:06.060 | experiment for us to test if NPTs have actually learned this mechanism. And to some extent
00:58:10.820 | it also tests robustness because now we're associating target values with features that
00:58:16.060 | are not part of the training distribution here. And so what we see is that as we adjust
00:58:24.140 | these values here, this is the kind of the duplicate value. And then we here see the
00:58:29.020 | target value. As we adjust them, we can see the correlation stays really, really good.
00:58:33.380 | It's not quite 99.9%, like on average, we're now at 99.6, but it's still very, very good.
00:58:40.580 | And at this point you might be slightly annoyed with me because standard nonparametric models
00:58:47.580 | can also solve this task. This is a task that I could solve by nearest neighbors. Sure,
00:58:52.900 | maybe I would have to change the input format a bit because this is kind of like in a batch
00:58:57.020 | setting and I could just use masks, but most generally a nearest neighbor can also, it
00:59:02.580 | also looks up different input points based on their features. Nearest neighbor doesn't
00:59:08.180 | learn to do this. I still think it's cool that we need to learn this because it does
00:59:12.100 | require a decent amount of computational sequences that we have to learn, like match all the
00:59:17.900 | features, look up target value, copy it back and so on. But it is in fact very easy for
00:59:24.100 | us to complicate this task to a degree such that essentially no other model that we know
00:59:29.460 | of can solve this very easily. And so a really simple thing to do is just to add a plus one
00:59:37.580 | to all of the duplicate values. So now nearest neighbor would look up the right row, of course,
00:59:45.980 | but it would always predict the wrong target with a plus one on it. And in fact, many of
00:59:49.980 | the models that we're aware of, they're not modeling the joint distribution over features
00:59:56.340 | and targets. What they're modeling is the conditional distribution of the targets given
01:00:01.580 | the input features. And so they also cannot do this. And so for us it's really not a problem
01:00:06.860 | at all. MPTs will just learn to subtract another one and no problems. And sure, this is also
01:00:13.300 | still a very synthetic setting, but I do think, I mean, I challenge you to come up with some
01:00:19.900 | thing that MPTs can't solve, but the other models can solve. I think this, in general,
01:00:25.260 | this masking mechanism and the non-parametricity of the approach is really nice in general
01:00:31.220 | and leads to lots of nice behavior in a variety of settings. And so with that, I think we
01:00:36.940 | can go to the conclusions, which Neil is going to give you.
01:00:41.060 | Yeah, I think, I mean, we're going to cut out the main part here. I'll just fast forward.
01:00:47.860 | Just look at them. Yeah, yeah. I was going to say, I think you'll get the gist. MPTs
01:00:55.020 | take the entire data set as input and they use self-attention to model complex relationships
01:00:58.820 | between data points. They do well in experiments on tabular data as well as image data. We
01:01:06.300 | present some of these interventional experiments to show that they can solve complex reasoning
01:01:10.100 | tasks. There's some more experiments in the paper. I'd say that the interesting type of
01:01:14.820 | future work is scaling type things. So we can, you know, not having this mini-batching
01:01:20.020 | approximation and then also just trying to expand this to some more interesting application
01:01:24.060 | demands. So we talked a little bit about meta-learning, but it could also be things like, you know,
01:01:27.340 | few-shot generalization in general, domain adaptation, semi-supervised learning, et cetera.
01:01:33.380 | So I think if there's some more questions, maybe we can do some more discussion.
01:01:38.340 | Yeah. I think sounds good. Great. Thanks for the talk. I think everyone had a fun time.
01:01:45.340 | I will just ask some general questions and then we can have like a discussion session
01:01:49.540 | with everyone after that. So I think one thing that I noticed is like this, like you said,
01:01:55.580 | this is similar to like KNNs and I thought like this seems similar to like graph neural
01:01:59.620 | networks where you can think like each data point is like a node and then you can think
01:02:03.660 | of everything as a fully connected graph and you're learning some sort of attention weight
01:02:06.820 | in this graph. So this is like a node prediction task you are kind of doing on this sort of
01:02:11.700 | like graph structure. So any comments on that? Like, is it similar to like graph neural networks
01:02:16.220 | or is it like other differences? Yeah, this is a very good observation. Yeah,
01:02:21.860 | I think there are a lot of similarities to work on graph neural networks. If we want
01:02:26.340 | to talk about differences, the differences might be that we're kind of assuming a fully
01:02:30.860 | connected graph, right? And so you could maybe also phrase that as we're discovering the
01:02:36.380 | relational structure or as graph neural networks usually assume that it's given. But that's
01:02:40.940 | also not always true. And so there are a lot of similarities. I don't know, Neil, if there
01:02:45.660 | was something specific you would like to mention, go ahead. But it's a very good observation
01:02:49.980 | and we also do feel that that's the case. And we've added an extra section on related
01:02:54.660 | work to graph neural networks in the updated version of the paper that will be online soon.
01:03:00.620 | Yeah, I agree with everything you've said. I think the closest work from the GNN literature
01:03:07.540 | that we were looking at a little bit was this neural relational inference paper, which uses
01:03:11.260 | message-passing neural networks to try to kind of like learn edges that may or may not
01:03:17.540 | exist and help for like extrapolating, I think, positions of like particles in like a multi-particle
01:03:23.600 | system or something, which is like kind of a similar idea to us. Like, you know, if you
01:03:27.700 | don't have these edges as given, the attention mechanism could kind of approximate an interesting
01:03:31.900 | relationship amongst some interacting things. I see. Got it. Yeah, that's really cool. Another
01:03:39.380 | thing is like, so you mostly look on like tabular data, but can you also like have other
01:03:43.020 | modalities, like if you want to do language or something, can you still use non-parametric
01:03:47.580 | transformers? Yeah, so I think part of our motivation for doing tabular was because we
01:03:54.580 | felt like tabular data is, in a sense, a generalization of, let's say, the language data, for example.
01:04:00.420 | I mean, I guess there's these other notions that people have brought up, like padding,
01:04:06.300 | but ultimately you can think of it as like a bunch of categorical attributes. So it is
01:04:11.700 | definitely generalizable to things like sentences and we do, you know, images. So, yeah. I think
01:04:19.500 | actually like, I always go back and forth on whether or not I think smaller or larger
01:04:25.660 | data is more interesting for us. So I think small data is really interesting because we
01:04:29.740 | can just fit the entire data set into it and all of this just works out of the box without
01:04:35.620 | any extra thought. But large data is actually also really interesting because, sure, you
01:04:41.740 | might have to introduce some approximative mechanism or some lookup mechanism because
01:04:45.740 | you can't always have the entire data set in. But at the same time, you are very explicitly
01:04:51.740 | kind of trading off the compute that you use to look up with the compute that you need
01:04:57.700 | to store. Like how many parameters in GPT are used for storing data, right? There's
01:05:03.460 | lots of memorization happening in these models and we know that. And so maybe we can use
01:05:08.580 | the parameters more efficiently to learn lookup type behavior, right? That is more close to
01:05:13.780 | this, you know, neural KNN or whatever. So I think these are very exciting questions.
01:05:18.340 | Yeah, yeah. I'll also be looking forward to the future works because it seems like a very
01:05:23.380 | good way to like do one-shot learning kind of situation. So, yeah, really very interesting
01:05:28.700 | to see that. Okay, so I will stop the recording and we can have like any other questions.
01:05:36.900 | Okay, thank you.
01:05:37.900 | Okay, thank you.
01:05:38.900 | Okay, thank you.
01:05:39.900 | Thank you.
01:05:39.900 | [BLANK_AUDIO]