Back to Index

Stanford CS25: V1 I Self Attention and Non-parametric transformers (NPTs)


Chapters

0:0
3:25 Fast Auto Repressive Decoding
14:1 Motivation and a Brief Summary
15:6 Parametric Prediction
16:2 Non-Parametric Transformer Architecture
23:40 Three Key Components to Npts
24:44 Dataset Matrix
25:8 Masking Matrix
26:50 Linear Embedding
31:3 Masking Based Training Objective
32:24 Stochastic Target Masking
37:20 Results
41:28 Poker Hand Data
42:32 Deep Kernel Learning
54:45 Semi-Synthetic Dataset
55:51 Attention Maps

Transcript

>> Thanks so much. Great to be here and happy Halloween, belated Halloween everyone. So I think the talk is going to be split into two sections. So I'll start by spending like 10 minutes, 15 minutes chatting about transformers in general. But I'm assuming most of you are familiar with them and we can move on to MPTs, which Yannick and Neil will be presenting.

So let's see. I'm going to like try to fly through the transformer overview and maybe spend a little bit extra time on like the history of transformers and maybe just tell the story a little bit. I think that might be more interesting. So just in terms of the transformer architecture, the two kinds of things that it introduced for the first time were multi-head attention and self-attention.

And then it combined those with fast autoregressive decoding. So before the transformer, pretty much everyone was using LSTMs and LSTMs with attention. But I'll try to get into the difference of self-attention, multi-head attention. So originally you would have two sequences and you would have a attention module, which would attend from the source to the target.

And so each token or each word in the source sequence would get associated with, you know, a soft approximation of one element in the target sequence. And so you'd end up with something like this. But with self-attention, we did away with the two separate sequences. We make them both the same.

And so you're relating each element within the sequence to another element in the sequence. And so the idea here is that you're learning a relationship of the words within a sentence to the other words. So you can imagine something like an adjective, which is being applied to a noun.

And so you want to relate that adjective, like the blue ball, you want to relate blue as referring to ball. So we're learning patterns within the sequence, intra-sequence patterns. So sorry, I gave this talk in Kenya, so I am using Kiswahili here. But with multi-head attention, the idea is you have like each word represented by an embedding, which is in the depth dimension here.

And then you have your sentence of words. You split that up into a bunch of different groups. So here I've chopped it depth-wise into four groups. You apply attention to each one of these groups independently. And then when you get the result back, you can catenate them together and you're back to your model dimension representation.

So what this lets you do is if each attention head, like each attention head can now focus on learning one pattern. So maybe attention head one is learning the relationship of adjectives to nouns. And the second attention head can learn something different. So this lets us learn like a hierarchy or a list of different relationships.

Okay. So that was self-attention. The other piece is fast autoregressive decoding. And do I really want to go into this? Okay, I will. So the important thing about this is that if you're doing normal autoregressive decoding, what you do is you generate your first token. And now conditioned on that first token, you generate the second and conditioned on the first two, you generate the third and so on and so forth.

But that's super slow, right? Like it's a loop applying this thing again and again. And so what we can do instead is we make an assumption in the code that our model always generates the right thing. And then we generate a prediction, only one token ahead. So you have your outputs, which are Y, you have your targets, which are Y hat.

And what you do is you feed in those gold targets so that you don't need to actually do this loop. So instead of assuming, instead of having to generate the first token, feed it back into your architecture, generate a second token, you feed in the entire target sequence and you just pretend that you generated all the right tokens up to position K.

And then you predict the K plus first and you compute your loss on that. So in reality, your model might've generated at the beginning of training junk, but you're getting a loss as if your model had seen all the correct tokens and is now just predicting the next one.

This is a little bit subtle, but it's hugely impactful for training speed because all of this can be done on parallel in parallel. And so it's actually what make transformers so scalable. Okay. So in order to do this successfully, if you were just feeding in all of the, all of the correct tokens, uh, naively, what would happen is your model would just be able to, uh, look forward in time and cheat.

So you've, you've put in all of your true targets, the things that you're trying to get your model to predict. And so if that's where you're computing your loss on, if you could just look forward in time and say, okay, I'm just going to grab that. And we'd get zero error trivially, right?

Cause you you've given it all the right answers. So what we have to do inside the architecture is we need to actually prevent, uh, the attention mechanism from being able to look at tokens that it shouldn't have been able to see already. Um, so the way that this looks is you create a mask on your attention.

Um, and so, sorry, this is the example of like doing a trivial attention. If you don't mask your attention properly, um, what it's going to do is it's just going to look into the future, just grab the token that you're telling it to predict and copy it over. And so it learned something trivial, something that doesn't actually generalize.

And so what we do is we actually prevent it from attending to those tokens. We prevent it from attending into the future for each position in the source sequence. We block out, uh, everything that it shouldn't be able to see everything into the future. And then as we move down, we gradually unblock so it can start to see into the past.

Um, so those are kind of like the two, the three major components of transformers, um, the self-attention, the multi-head attention, and then deploying this gold targets, decoding is fast, autoregressive attention. In terms of the story, which might be a little bit more interesting. Um, so transformers, I was an intern, uh, with Lukasz Kaiser, uh, at Google back in 2017.

Um, and I was sitting next to Gnome, uh, and Ashish was like a couple of seats down from us. Um, and what's really incredible is that essentially this entire project came together in like three months and it was done. So I showed up at Google, uh, no, I'm had been working on autoregressive models.

Um, same thing with like Ashish and Yakov and Nikki. And, um, they just been kind of like exploring the space, figuring it out. Uh, and Lukasz and I, at the same time, we'd been working on this framework called tensor to tensor, um, which was like explicitly made for, uh, multimodal learning, autoregressive learning.

Um, and Lukasz is kind of like a master of keeping track of everything that's happening in the field and adopting it. And so within tensor to tensor, there were like these, there was like these kind of emerging little things that maybe one paper had been written about, uh, and people were interested in like layer norm, um, but it hadn't actually taken off yet.

Uh, the warmup in the learning rate schedule, um, all of these little pieces were just default, like on by default. Um, and so when gnome and Ashish and Nikki and Yakov came over and adopted tensor to tensor, all of these things were just on by default. Um, and so a lot of people, when they look at the, the transformer paper, it just seems like there's so many like arbitrary little things thrown in.

And when now, like in present day, these have become standard for like a lot of different training algorithms, like the learning rate warmup, um, the way that we did initialization, uh, all of these pieces have just become the norm, but back then they had like kind of just been introduced.

Um, and so we, uh, we spent a lot of time running ablations trying to figure out like, which were the necessary pieces and what made it work. And if any of you have actually tried training transformers and tried like pulling out the learning rate warmup, um, or changing any of these little pieces, you'll see that it really does break down optimization.

Like it actually really does, um, hurt performance for instance, like removing the layer norms, that type of thing. Um, so I always thought it was kind of funny how all of these random additions that Lukasz had just like thrown in, cause he was playing around with them turned out to be crucial.

Uh, and they were just on by default. Um, so anyway, it was like three months. I remember, um, it all really started coming together towards the end, like just before the NeurIPS deadline. Um, and I can still remember sitting in the micro kitchen and Ashish telling me like, that's just like, I was a little intern, uh, telling me like, this is going to be such a big deal.

And I was like, yeah, sure. Okay. I have no idea what's happening. I just showed up. Um, and he was like, no dude, like this, this actually matters. Like, you know, we, we bumped up blue three points and I was like, sick, great anyway. Um, and then I can remember on the night of the deadline for NeurIPS, it was like 2am, uh, Ashish, Ashish was the only one left at the office and we were still like moving around figures and like adjusting things.

And then, um, I went to bed, but Ashish stayed up and I slept in like this tiny little phone booth. Um, and then for the other paper that I was submitting, I forgot to press submit, but luckily like some lady opened the door to the phone booth and hit me in the head while I was sleeping in the morning.

And just before the deadline, I got the paper in. And so I owe it to that lady, uh, for, for submitting to NeurIPS that year. But yeah, anyway, the, I think the crazy thing about transformers was that it all came together in like three months. Like most of the ideas happened in that span.

And it was just like this sprint towards the NeurIPS deadline. Um, I think a lot of the other members on the team, Jakob, Lukasz, Shalom, Ashish, they knew how important it was, but for me, I was like, I don't know. I really did not appreciate, uh, the impact. Um, but in retrospect, it's been amazing how the community has kind of like come together and adopted it.

Um, and I think most of that can be ascribed to the ease of optimization. It seems like very robust to hyper parameter choices. So you don't need to like tune the hell out of it, spend a lot of time tweaking little things. Um, and the other side is that it's like super tailored to the accelerators that we run on.

Um, so it's like very paralyzable, um, hyper-efficient. Um, and so it lends itself to that kind of scaling law effort that's, that's really taken off in, in popularity. Okay. Uh, unless there are any questions. That's such a cool story. Oh my God. We're both excited. So we just unmuted at the same time.

Yeah. So that's, that's my section. If there's any questions, happy to answer them. Otherwise, uh, let's get into NPTs. NPTs are like, I think, um, there's such a nice next level abstraction of the architecture. So you've probably seen the trend of, um, transformers getting applied to, to new domains, um, first into vision and video and audio.

Um, but this is kind of like cutting back to an even more abstract level. Um, like I think tabular data. Yeah. I don't know. I'll let Yannick and Neil take over from here, but, uh, I think NPTs is a pretty sick, pretty sick project. Thanks for the introduction. Um, thanks all for the invitation.

We're very happy to be here and Neil and I are now going to tell you about our, um, self-attention between data points paper, where we introduce, introduce the non-parametric transformer architecture. Um, we'll start with a little bit of motivation, move on to explaining the architecture in detail, then show you the experiments.

This is more or less a step through of the paper, but maybe, you know, with a little bit of extra insight here and there, all right, as promised motivation, I have a brief summary. So, um, we'll start by thinking about something that we don't often think about. And that is that from CNNs to transformers, most of supervised deep learning relies on parametric prediction.

So what that means is that we have some self-training data and we want to learn to predict the outcomes Y from inputs X. And for this, we set up some model with tunable parameters, theta, then we optimize these parameters to maximize predictive likelihoods on a training set, or, you know, equivalently, we minimize some loss.

And then after training, we have this optimized set of parameters theta and then at test time, we just put these into the model and use these parameters to predict on novel test data. And so crucially here, our prediction at test time only depends on these parameters, right? It's parametric.

Also, that means that given these parameters, the prediction is entirely independent of the training data. And so why would we want to do parametric prediction? Well, it's really convenient because all that we've learned from the training data can be summarized in the parameters. And so prediction time, we only need these final parameters and we do not need to store the training data, which might be really, really large.

On the other hand, we usually have models that already predict for a bunch of data in parallel, right? Think of mini-batching in modern architectures. And actually things like batch norm already make these data interact. And so our thinking here was that if we've got all of this data in parallel anyways, there's no reason not to make use of it.

And so more, a bit grander, we kind of challenged parametric prediction as the dominant paradigm in deep learning. And so we want to give models the additional flexibility of using the training data directly when making predictions. And so a bit more concretely, we introduced the non-parametric transformer architecture. And this is going to be a general deep learning architecture, meaning we can apply it to a variety of scenarios.

NPTs will take the entire data set as input whenever possible. And NPTs then crucially learn to predict from interactions between data points. And to achieve this, we use multi-head self-attention that, as Aiden has introduced us to, has just really established itself as a general purpose layer for reasoning. We also take another thing from the NLP community and we use a stochastic masking mechanism.

And we use that to tell NPTs where to predict and also to regularize the learning task of it. And lastly, of course, we hope to convince you that this ends up working really, really well, and that this kind of simple idea of learning to predict from the other data points of the input, from the training points of the input, ends up working rather well.

And so very briefly summarizing what we've heard already. A, we input into NPTs the entire data set. And then B, let's say for the purpose of this slide here, we only care about predicting the orange question mark in that green row. And then we can compare NPTs to parametric prediction, right?

So a classical deep learning model would predict this target value only from the features of that single green input. To do that, it would use the parameters theta, those would depend on whatever training data we've seen and so on. But at test time, we only look at that single row for which we care about the prediction.

In contrast, NPTs predict an explicit dependence on all samples in the input. They can look beyond that single green datum of interest and look at all other samples that are there and consider their values for prediction. So this presents an entirely different way of thinking about how we learn predictive mechanisms.

And somebody on Twitter called this KNN 2.0, which we would have not written in the paper, but maybe is kind of a nice way of thinking about how NPTs can learn to predict. So of course, nonparametric models are a thing already. We didn't invent them at all. And I define them here as prediction in explicit dependence on the training data, which is certainly what NPTs do.

Classical examples like Gaussian processes, k-nearest neighbor, kernel methods, those might be familiar to you. And there exists also efforts to combine the benefits of nonparametrics and representation learning in a similar fashion to how we did it in NPTs. However, these approaches are usually limited in some sense in comparison to NPTs, right?

They're often kind of motivated from the statistics community a bit more. They often require more finicky approximative inference schemes, are limited in the interactions they can learn, or things like that. And so we really think NPTs present maybe the most versatile and most widely applicable of these nonparametric prediction approaches.

But that's something we explicitly wanted to have. We wanted to have something that's really easy to use, plug and play, works in a ton of scenarios, and works really well. And so with that, I hand over to Neil, who's going to tell you about the nonparametric transformer architecture in all of its details.

Yeah, we also have one question. Go ahead. Yes, yes. Hi, Jendrik. Hi. So could you please go back to the previous slide? The very previous slide? Yeah, yes. This slide, yeah. So in terms of the problem definition, I think it's quite similar to some meta-learning problem, which basically there is a mapping from a data point on the data sets to some predictions.

So could you please suggest any differences between your problem setting and the meta-learning problem settings? I can't really figure out any differences between these two problems. Well, I think it really depends on the framing that you want to have. So I would say meta-learning would be when I try to predict over multiple data sets.

So when I try to learn some sort of prediction model, or I can just plug in a different data set, and it will almost automatically give me new predictions on this different data distribution. But that's not what we do at all. We're training a single model for a fixed data set.

And so this is why I wouldn't really call that meta-learning, because we're trying to predict on the same tasks that all the supervised deep learning, or any supervised machine learning method is trying to predict well on. Okay, so you mean you use kind of same test set to test your training model, right?

So basically, in meta-learning, we're going to test on different kind of meta-test sets. But in your case, you just want to use a test set, which is similar to the distribution of your training set, right? Yeah, absolutely. So we explore data set distribution shift a bit. I think it's a really interesting scenario.

I think meta-learning different data sets is also an interesting scenario, right, when you have this model where you could just push in different data sets. But for the scope of this paper, it's very much training set, test set, they come from the same distribution, and we're just trying to do supervised learning in a standard setting.

I see, cool. Thank you. Thank you for the question. Yeah, and I would chime in a couple of additional things, I guess. So at least from what I understand from the problem definition of meta-learning, I think the aim is more perhaps being able to perform well on a new data set with a relatively small number of additional gradient steps on that data set.

So I think there are some interesting ways that you could actually consider applying NPTs in a meta-learning type setting. And so we'll get into this a little bit more, but for example, there might be ways to essentially add in a new data set. So let's suppose we've trained on a bunch of different data sets.

We now add in a new data set. We can perhaps do some sorts of kind of zero shot meta-learning, basically, where there's no need for additional gradient steps because we're basically predicting kind of similar to how you might do prompting nowadays in NLP literature. Anyways, yeah, I think we'll get into some more details.

Just to chime in on that, I don't think that every meta-learning algorithm-- I think the one that you're describing right now are optimization-based, but there are also black box ones. You don't need to further-- I think the main difference seems to be that there is one task versus multiple tasks for meta-learning.

Yeah, I think so too. I think the main framing question is whether or not there's multiple data sets. Cool. OK, awesome. If there's no other questions, I'll dive a bit more into the architecture. Awesome. So there's three key components to NPTs. I'm going to first state them at a high level, and then we'll go through each of them in more detail.

So first of all, we take the entire data set, all data points as input. So for example, at test time, the model is going to take as input both training and test data, and we approximate this with mini-batches for large data. We apply self-attention between data points. So for example, at test time, we model relationships amongst training points, amongst test points, and between the two sets.

And then finally, we have this masking-based training objective. It's a BERT-like stochastic masking, and the key point is that we actually use it on both features as well as on training targets. And we'll get into why that leads to an interesting predictive mechanism later. So to start with this idea of data sets as input, there's two things that compose the input to NPT.

It's a full data set in the form of a matrix X and a masking matrix M. And so Yannick has described this data set matrix a little bit. We basically have data points as rows. The columns are attributes, and each attribute shares some kind of semantic meaning among all of its data points.

So say, for example, you're just doing single-target classification or regression. The last column would be the target, and the rest of the matrix would be input features. So for example, the pixels of an image. We also have a masking matrix. So let's say we're thinking about mass language modeling.

The mass tokens will just tell us where we're going to conceal words and where we're going to back-propagate a loss. We do a similar type of thing here, where we use this binary mass matrix to specify which entries are mass. And the goal is to predict mass values from observed values.

I see that there was a question about handling inputs with different lengths. In the data sets we've considered, we'll get into it in the results section, but it's mostly been sort of tabular and image data, where the lengths for each of the data points is the same, but it would work just like padding.

That would be a reasonable way to go about that. There's also kind of an interesting-- yeah, go for it, Yannick. Just to add to that, I'm not sure if length refers to columns or to rows, right? Rows, we don't care about how many rows. Length, padding or something would be an option.

Yeah, my question was about column. Exactly, that makes sense. Awesome, thanks. Yeah, I mean, that goes along with the whole meta-learning discussion. I think if we wanted to adapt to data sets that have a different number of data points per data set, we can take advantage of the fact that self-attention is kind of OK with that.

Cool. So continuing on, left to discuss here is basically how we do the embedding. So to put this more explicitly, we have this data matrix that has n data points. It's called x, and it also has d attributes. And we have this binary mass matrix m. We're going to stack them, and then we're going to do a linear embedding.

So specifically, we're doing the same linear embedding independently for each data point. We're learning a different embedding for each attribute. We have a positional encoding on the index of the attributes because we don't really care about, say, being equivariant over the columns. If it's tabular data, you, of course, want to treat all these kind of heterogeneous columns differently.

And then finally, we have an encoding on the type of column, so whether or not it's continuous or categorical. And that ends up giving us this input data set representation that is dimensions n by d by e. The second key component of NPTs is attention between data points. So to do that, we first take this representation we have and flatten to an n by d times e representation.

So basically, we're treating each of these d times e size rows as if it's a token representation. We're actually going to just accomplish this operation using multi-head self-attention. We've reviewed this a lot, but the nice thing is that we know from language modeling, if we stack this multiple times, we can model these higher order dependencies.

And here they're between data points, and that's really the key draw of this architecture. There's been other kind of instances of people using attention for similar sorts of things. So for example, like attentive neural processes. A lot of times, they've sort of used just a single layer as kind of a representational lookup.

And we believe that this actually ends up limiting expressivity, and that by stacking this many times, you can learn more complex relationships between the data points. Anil, we also have some questions. So you can go ahead first. Oh, cool. Thanks. I have a question about how you guys do the embedding.

Is there always part of these convolutional filters or linear layers? What is the type of embedding that you guys use? Yeah, so I'm attempting to go back to the slide. I think it's not very happy with me right now. But yeah, so for tabular data, we did just linear embeddings, actually.

So we could get into, I guess, details of featurization for categorical and continuous, but it's literally like, say, for categorical, you do a one-hot encoding, and then you learn this embedding that is specific to that attribute. And then for numerical, I believe we were just standard normalizing. For the image data, we did end up using a ResNet-18 encoder for CIFAR-10.

However, I think that-- I mean, we'll discuss that a bit later in results, but that embedding is a bit arbitrary. You can sort of do whatever. The key part of the architecture is the attention between data points. So in terms of how you actually want to embed each attribute, it's kind of up to you.

Thanks. I think-- you had a question? Same question, I was victorious. OK, awesome. Cool. So here we have attention between data points done. So we can also do this attention between attributes. So we reshape back to this n by d by e representation, and then we can just apply self-attention independently to each row, in other words, to a single data point.

And the intuition for why we would kind of do this nested-type idea where we switch between attention between data points and attention between attributes is just we're trying to learn better per-data-point representations for the between-data-point interactions. This is literally just normal self-attention, as you'd see in language modeling or image classification.

The attributes are the tokens here. And finally, we just rinse and repeat. So what are we actually getting out of this? To summarize, we're learning higher-order relationships between data points. We're learning transformations of individual data points. And then importantly, NPT is equivariant to a permutation of the data points.

This basically just reflects the intuition that the learned relationships between the data points should not depend on the ordering in which you receive them or in which you observe your data set. The third key component of NPTs is a masking-based training objective. So recall that what we're trying to do is we're trying to predict missing entries from observed entries.

And those masked values can be both features or targets. So again, the classic use, say, in masked language modeling is to do self-supervised learning on a sequence of tokens, which you could think of as just having features in our setting. Ours is a bit different in that we do stochastic feature masking to mask feature values with a probability p sub feature.

And then we also do this masking of training targets with this probability p sub target. So if we write out the training objective, we are just taking a weighted sum of the negative log likelihood loss from targets as well as from features. And of course, at test time, we're only going to mask and compute a loss over the targets of test points.

So to break this down a bit further and point out some of the cool parts of it here, the thing that's highlighted right now on the far right is the term relating to the features. It's the feature masking. Basically, we find that this has a nice regularizing effect. More or less, the model can now predict anywhere and makes the task a bit harder and introduces some more supervision.

And we found in an ablation for the tabular data sets that it helped for eight of 10 of those. And then there's this other term, which is kind of interesting. It's the stochastic target masking. And the idea is that you're actually going to have some training targets unmasked to the model at input at training time, which means that the NPT can learn to predict the mask targets of certain training data points using the targets of other training data points as well as all of the training features.

And so that means you don't actually need to memorize a mapping between training inputs and outputs in your parameters. You can instead devote the representational capacity of the model to learn functions that use other training features and targets as input. So this is kind of getting into the idea of this sort of like learn KNN idea.

Obviously, we can learn more complex relational lookups and those sorts of things from this. But you can imagine one such case being we have a bunch of test data points coming in. We're going to look at their features and use that to assign them to clusters of training data points.

And then our prediction for those points is just going to be an interpolation of the training targets in that respective cluster. That's like an example of something that this mechanism lets NPTs learn. All right. So if there's any questions, we can take them now. Otherwise, I'm happy to take them in the discussion or something.

All right. So let's discuss. Yeah, go for it. I'm curious when you're using the entire data set, does that limit the type of data sets you can use because of the size? Yeah. So in practice, we do random mini batching as an approximation. So the idea is just to say, if you have a reasonably large mini batch, you're going to benefit a bit from still having kind of this lookup ability.

Because if you have a reasonable number of classes, probably you're going to be able to learn some interesting mappings based on features and targets amongst those classes. We found in practice that-- and we'll get into this a little bit. But we do actually indeed learn to use relationships between data points on prediction for data points where we're doing mini batching.

And we also didn't necessarily find that you need like a ludicrously large batch size for this to be a thing. But I do think it's just-- this is, in general, an important point. And it's one that points us towards looking into, say, sparse transformers literature for trying to expand to some larger data sets without having the mini batching assumption.

Great. Thank you. If I can add a number to that, we can, without mini batching, accommodate data sets of around like 8,000 points or so. So that already accounts for a fair proportion, I would say, of the tabular data sets out there. But we also do data sets with 11 million points where, obviously, we then resort to mini batching.

So it's very good to have an idea of the sizes that we're talking about. I'm curious on that. I mean, it's pretty exciting. I feel like you don't normally hear about transformers being applied to data sets of size 8,000. I'm curious-- and we can talk about this sort of later once we've covered the other material-- if you found that sample efficiency is one of the key gains here, or just experience working on small data sets of transformers generally.

And yeah. But I'm happy to punt the answer to that until after as part of the discussion. Yeah, I think that'd be really nice to talk about a bit. And it was something that, in general, I guess I'd say was surprising to us in terms of how robust NPTs were on small data sets, and how we surprisingly didn't have to tune a terrible number of parameters.

But we can get into details in a bit. Awesome. So to get into the experiments, we focused a lot on tabular data because it's a very general setting. And it's also notoriously challenging for deep learning. So we know tree-based boosting methods, stuff like XGBoost, is very dominant. And this is also a very relevant domain to, I think, people in industry and that sort of thing.

So we were excited about the idea of trying to do better on this. So we chose a broad selection of data sets varying across a few different dimensions. As we mentioned, on the order of hundreds to tens of millions of instances, broad range in the number of features, in the composition of features in terms of being categorical or continuous, various types of tasks, binary and multi-class classification, as well as regression.

And like I said, the baselines were kind of the usual suspects for tabular data-- XGBoost, CatBoost, LightGBM, TuneMLPs, and TabNet, which is a transformer architecture for tabular data. So to get into the results, here I'm showing the average rank for the various subtasks. We did well in terms of rank-wise performance against methods like CatBoost and XGBoost, which are designed specifically for tabular data.

And in fact, we find that NPT is the top performer on four of the 10 of these data sets. On image data, I mentioned that we used a CNN encoder. And with that, we were performing well on CIFAR-10. And we also think that, in general, with, let's say, new work on image transformers on small data, this can probably just be done with linear patching.

And so this kind of-- the manner in which you're embedding things is probably not the key. Neil, if I can jump in with two questions. Can you go back two slides first? One is just a small, minor point. Back one more, please. Thank you. Here are the features. 50 plus.

What does plus mean here? I'll have to double-check what the exact number is. I'm pretty sure it's probably around 50, I would guess. Ah, so the 50 is really an order of. It's not like 150 or 5,000. Yes. Yeah. I mean, I'll double-check for you, or you can check with the metadata statistics at the end of the paper.

But no, it wasn't arbitrarily large. I would say, though, we did these ablations on whether or not we actually need attention between attributes. We did find that this ended up benefiting us. But you could perhaps do, say, just an MLP embedding in that dimension and go to a relatively small number of hidden dimensions and fit an arbitrary number of features.

So I think that, yeah, if you kind of relax the necessity of attention between attributes, you can probably scale out at least that dimension quite a lot. OK. And then my second question, if you could go forward one slide. Thank you. Here, I'm not sure I quite caught. What is it four of 10 data sets, two of 10 data sets and four of 10 mean?

This is of the of all the tabular data sets that we had. So. Oh, I see. OK. Yeah, exactly. Awesome. Any other questions? The standard errors here, because I mean, there's like there's just 10 data sets, right? Yeah, correct. 10 total tabular data sets. Yeah. But these are right.

Yeah, these are these are rank wise performance. Correct. OK, I'm just seeing. How the where the uncertainty comes from in this case. Yeah, average averaged over four of 10 data sets, the rank. So for each particular data set, we have a rank of all the different methods. Then we take the average and the very answer of the of the rankings within each of the types of task within binary classification, within multiclass, et cetera.

We also, if you're curious, you know, have have the full results in the paper. Yeah. Thank you. We also have a couple of questions. Some of the questions on the. Hey, yeah, thanks. I guess I just found it a little surprising that the worst performer was can and given that it's also nonparametric, I guess, could you comment on that?

And is yeah. Is it that there's something like intrinsic to the NPT that makes it just exceptional far beyond other nonparametric methods? Or yeah, why? Why is it that can performs the worst here? Well, I suppose ultimately can and it is it's still a relatively naive predictive method and that, you know, it might just be predicting based on kind of cluster means.

So for example, you know, I think this is probably universally true for all the data sets, but there's probably some amount of kind of additional reasoning that needs to occur over the features, at least to basic level. So for example, like one of the data sets is this poker hand data set where it's like a mapping between all of the different hands you have in poker and what like they're commonly known to people like full houses or whatever.

So this, this requires some amount of reasoning over the features to be able to group things together. So just taking like the cluster means of the featurization of those different, you know, hands is likely not going to give you a great predictive function. Whereas NPTs can kind of do the classic thing where say you have an MLP type of thing over the features or like a tree type of thing over the features, you can learn some sort of complex embedding, but then you also can do some nonparametric sort of prediction based on say like clusters of embeddings.

I see. Yeah, that makes sense. I guess. What if, what if you used pre-trained embeddings from a stack of encoders as, as your vector representation for the CNN, how do you think that would perform compared to the rest of the crowd? Yeah. Yeah. So this is like, I mean, this idea is kind of like deep kernel learning or like, yeah, I believe it is deep kernel learning is basically you use an MLP independently.

So you learn an MLP on each input data point, and then you apply a GP over all the representations of those. So you get this sort of like complex embedding and then the lookups. The key difference between that type of idea and NPTs is that we also learn the relationships between the data points themselves, because we use this parametric attention mechanism to learn the relationship.

So we're not just learning like an embedding independently. We're basically backpropagating through the entire process, learning the ways in which we would try to embed this, but also the, the ways that say the lookup would occur. And essentially the, the relationships at a, that it could potentially be kind of higher order as well.

Okay, cool. Wait, one more, one more follow-up or. Oh yeah, go for it. Cool. Yeah. Thanks. So I guess then if, if the advantage of NPT has to do with sort of the relationships between data points, then what if you, you know, took the, took the, let's say, you know, encoder representations and then you pass that as input, say for the, you know, 10 nearest neighbors, along with like some other like input representation and sort of had this like weighted average like attention style where you, you weighted the, the vectors of the nearest neighbors based on the attention weights between those input data points.

And then like the supplied input data point, and then like pass that as, as you know, the, the vector to like the final prediction layer, like, do you think that captures some amount of the relationship or, or is that off base? So I think the nice part, like, and really what our idea is behind this whole thing is just these sorts of instances where certain fixed kernels would perform particularly well in tasks is like kind of an annoyance and like ultimately like tuning a lot of these types of things are trying to derive the predictive methods that might make a lot of sense for a given situation kind of stinks.

And ideally you'd want to just back propagate on a data set and kind of learn these relationships yourself. So I actually would be really interested to see if we can come up with some synthetic experiments that have these sort of like very particular can and like predictive mechanisms and just see if we can learn precisely those and get, you know, zero error with NPS.

And in fact, like we'll get into this a little bit with some of the interventional experiments we do, we have like kind of precise lookup functions that NPS end up being able to learn so we can learn interesting relational functions. Cool. Yeah. Thanks a lot. Appreciate it. Cool. All right.

I have one more question from. Yeah, sure. I just wanted to clarify something about basically so at test time, you just take the exact same data set and you just like add like your test examples, right? And then you like do the same type of like masking. And is that how it works?

Yeah, correct. OK, got it. And I do have one more question. That is just because I'm I think I misunderstood like how like the effects of your NPT objective. Do you mind going back to that slide? Sure. Yeah. Can you repeat one more time? Like what makes this so special?

Yeah. So the the regularizer on the right over the features, I would think of very similarly to self-supervised learning with just a standard transformer like you're you're basically just introducing a lot more supervision and you're even if, say, you're just doing a supervised objective, this is kind of like some amount of reconstruction over the features.

You learn a more interesting representation and like what a regularizing effect, which we think is interesting, but perhaps not as interesting as this stochastic target masking. This one is unique because in kind of standard parametric deep learning, you're not going to have an instance in your training process where you're taking targets as input.

And so basically what happens is you have your training data set as input, whatever, you're going to have some stochastic feature masking stuff happening on the features amongst the training targets. You're randomly going to have some of those unmasked and some of them will indeed be masked. You're going to be backpropagating a loss on the ones that are masked, of course, because, you know, you don't want your model to have those available at input if you're going to actually try to backpropagate a loss on it.

But you can use the other ones as input. And that means you can learn these kind of like interpolative functions. So that was like this whole idea of like being able to kind of learn KNN. But doesn't that allow the model to cheat again? Yeah. So this is like an interesting point and actually like subtle.

So I think it's really worthwhile to bring up. So first of all, we never actually backpropagate a loss on something that was visible to the model at input. And so if, for example, the model did actually end up basically overfitting on training labels, we would not observe the model's ability to generalize to test data.

We don't observe this. So obviously, it seems like this kind of blocking of backpropagation on labels that are visible at input to the NPT is helping. It could also be possible that in BERT style stochastic masking, you also randomly will flip some labels to be in a different category.

So this is like kind of just like a random fine print that was introduced in the BERT masking text. We also do that. So it's possible that that somehow contributes to that. But it's probably pretty likely to just be the fact that we're not backpropagating a loss on something that's visible.

Great. Thanks. Makes sense. I have two more questions if I can jump in. Sure. Sorry. Can we go to the metrics, the performance, the results slide? Sure. I feel like I missed something else. I'm sorry about this. So looking on the binary classification AUROC, can you clarify what these numbers mean?

Are they the AUROC? So this is the, so on each of the data sets, so say for a particular binary classification data set, we're going to get a ranking of the methods. We're going to repeat this. Yeah. Go for it. So these numbers here are the relative ranking across in this particular case, the four data sets.

Correct. Yeah. I see. So this, these values are not the AUROCs on average across the data sets. No. Yeah. They're not. I mean like averaging, averaging AUROC might make sense, but averaging things like accuracy and RMSE seems like a bad idea, right? Because you might have some data sets where everything has high accuracy or where RMSE needs something drastically different.

I see. So this, these numbers here only tell us the relative ranking between the different methods, not how well they actually perform. I mean, it tells us how they perform relative to one another, but not how well they perform. I see. Okay. But that's all in the appendix. We all have, we have that information.

I see. Okay. I was, I was sitting here confused going like, why is AUROC, why is the best one the smallest? And accuracy, what is an accuracy of 2.5? Anyways. Okay. That makes much more sense. Thank you both. Awesome. Great. So I'll try to speed through this just in the interest of time.

But the basically thing, the thing that you might be thinking after all of these results is are we even learning any data point interactions on these real data sets? And so basically we designed an experiment to figure this out. And the idea is that we're going to disallow NPT from using other data points when predicting on one of them.

If we do that and we observe that NPT actually predicts or performs significantly worse, it is indeed using these interactions between data points. A subtle challenge or it's kind of like an added bonus we can get from this is that ideally we wouldn't actually break batch statistics. So let's say like the mean of each particular attribute.

If we can find a way to do this experiment such that we don't break these things, we can kind of rule out the possibility that we learned something that's a bit similar to batch norm. And so the way that we do this is we basically look at the predictions for each one of the data points in sequence.

So let's say in this case, we're looking at the prediction of the model for this particular green row. And it's going to be predicting in this last column that has this question mark which is masked. What we're going to do is we're going to permute each of the attributes independently amongst all other data points except for that one.

So the information for that row, if it was kind of just predicting like a classic parametric deep model is still intact, but the information from all of the other rows is gone. So that's why we call this sort of the corruption experiment. And so we find in general, when we perform this experiment, performance kind of falls off a cliff for the vast majority of these methods.

And I'll note that the performances between the methods on a lot of these were fairly close. And so this is actually indeed pretty significant. So for example, on protein, we went from being the top performer amongst all the methods to the worst performer worse than even like KNN or something like that.

I'll also note that there's kind of this interesting behavior where on these data sets like forest and kick and breast cancer, we actually observed that there's basically no drop in performance. And we basically see this as kind of an interesting feature and not necessarily a bug of the model, which is that if we're backpropagating on a given data set, the model can sort of just find that it's actually not that worthwhile to attempt to predict using some kind of relational predictive mechanism amongst data points and can instead just learn to predict parametrically and basically ignore other data points when it's predicting on any given one of them.

And so this probably leads to some kind of like interesting ideas where perhaps you could do like post hoc pruning or something like that, taking away the tension between data points and doing fine tuning, let's say. Alright, so now I'll hand over to Yannick to talk a bit about learning some interesting relationships.

Yeah, will you though? I see that we're at the end of what the time is, but like, I know there's a buffer planned in or something. I can go through this experiment, we can have a bit of discussion. What do you guys prefer? Yeah, I think normally what we do is we would sort of stop the recording at this point and have an off the record discussion.

And I guess the question to ask is, does anyone have any questions at this point? But I think we've basically been wanting questions as they come. So I personally feel fine just considering this as sort of questions throughout. Yeah, I guess that sounds good. Yannick, you can go forward with it, with your talk as planned and later we can see about the time thing.

I think this will only be like another four or five minutes talk. Yeah, yeah, that's good. Then go for it, yeah, for sure. Alright, so Neil has now told us how well NPTs perform in real data and that they do make use of information from other samples of the input.

But we're not going to take this a bit further and come up with some toy experiments that test the extent to which NPTs can learn to look up information from other rows, like the extent to which they can learn this nonparametric prediction mechanism. And so specifically what we'll do is we'll create the following semi-synthetic data set.

I want you to focus on A now. So we'll take one of the tabular data sets that we've used previously, specifically the protein data set, but it doesn't really matter. What matters is that it's a regression data set. And so now what we do is we, the top half here is the original data set, but the bottom half is a copy of the original data set where we have unveiled the true target value.

So now NPTs could learn to use attention between data points to achieve arbitrarily good performance. They could learn to look up the target values in these matching duplicate rows and then paste them back into that masked out target value. And then at test time, of course, we put in a novel test data input where this mechanism is also possible just to make sure that it hasn't learned to memorize anything, but has actually learned this correct relational mechanism.

And so what we see is that indeed, NPTs do successfully learn to perform this lookup. So what I'm visualizing here is attention maps, and they very clearly show that, let's say when predicting for this green row here, this first green row, what NPTs look at is exactly only that other green row here.

And so this is really nice. We can further look at the Pearson correlation between what NPTs should predict and what they actually do predict. And so this is 99.9%. This is much better than anything you could achieve with parametric prediction. And so it seems that NPTs here can actually discover this mechanism.

And discover here, I feel like it's the right word because NPTs could have, as we've seen, just also continue to predict in parametric fashion, right, from each row independently. This is really kind of showing to us that there is this bias in the model to learn to predict from other rows.

And of course, that is also very attractive in this setting because it allows you to achieve arbitrary load loss in this setting, or as lowest you can optimize for it. And so we kind of take that to mean that our, you know, gradient based discovery, non-parametric philosophy seems to make some sense.

And so we can take this a bit further by performing somewhat of an interventional experiment that investigates the extent to which NPTs have actually learned a robust, you know, causal mechanism that's underlying this semisynthetic data set. And so just appending, you know, this extra column of test data, that's already kind of cool, but I think we can take it a bit further and actually study if this generalizes beyond the data that we see in the training set or beyond data coming from this specific distribution.

And so what we now do is we intervene on individual duplicate data points at test time by varying their target value. So now we only care about the prediction in a specific row. We do this across all rows, but at each time we just cover a single row. What we do is we change the target value here, that what we're hoping to see is that NPT just adjusts the prediction as well, right?

There's a very simple intervention experiment for us to test if NPTs have actually learned this mechanism. And to some extent it also tests robustness because now we're associating target values with features that are not part of the training distribution here. And so what we see is that as we adjust these values here, this is the kind of the duplicate value.

And then we here see the target value. As we adjust them, we can see the correlation stays really, really good. It's not quite 99.9%, like on average, we're now at 99.6, but it's still very, very good. And at this point you might be slightly annoyed with me because standard nonparametric models can also solve this task.

This is a task that I could solve by nearest neighbors. Sure, maybe I would have to change the input format a bit because this is kind of like in a batch setting and I could just use masks, but most generally a nearest neighbor can also, it also looks up different input points based on their features.

Nearest neighbor doesn't learn to do this. I still think it's cool that we need to learn this because it does require a decent amount of computational sequences that we have to learn, like match all the features, look up target value, copy it back and so on. But it is in fact very easy for us to complicate this task to a degree such that essentially no other model that we know of can solve this very easily.

And so a really simple thing to do is just to add a plus one to all of the duplicate values. So now nearest neighbor would look up the right row, of course, but it would always predict the wrong target with a plus one on it. And in fact, many of the models that we're aware of, they're not modeling the joint distribution over features and targets.

What they're modeling is the conditional distribution of the targets given the input features. And so they also cannot do this. And so for us it's really not a problem at all. MPTs will just learn to subtract another one and no problems. And sure, this is also still a very synthetic setting, but I do think, I mean, I challenge you to come up with some thing that MPTs can't solve, but the other models can solve.

I think this, in general, this masking mechanism and the non-parametricity of the approach is really nice in general and leads to lots of nice behavior in a variety of settings. And so with that, I think we can go to the conclusions, which Neil is going to give you. Yeah, I think, I mean, we're going to cut out the main part here.

I'll just fast forward. Just look at them. Yeah, yeah. I was going to say, I think you'll get the gist. MPTs take the entire data set as input and they use self-attention to model complex relationships between data points. They do well in experiments on tabular data as well as image data.

We present some of these interventional experiments to show that they can solve complex reasoning tasks. There's some more experiments in the paper. I'd say that the interesting type of future work is scaling type things. So we can, you know, not having this mini-batching approximation and then also just trying to expand this to some more interesting application demands.

So we talked a little bit about meta-learning, but it could also be things like, you know, few-shot generalization in general, domain adaptation, semi-supervised learning, et cetera. So I think if there's some more questions, maybe we can do some more discussion. Yeah. I think sounds good. Great. Thanks for the talk.

I think everyone had a fun time. I will just ask some general questions and then we can have like a discussion session with everyone after that. So I think one thing that I noticed is like this, like you said, this is similar to like KNNs and I thought like this seems similar to like graph neural networks where you can think like each data point is like a node and then you can think of everything as a fully connected graph and you're learning some sort of attention weight in this graph.

So this is like a node prediction task you are kind of doing on this sort of like graph structure. So any comments on that? Like, is it similar to like graph neural networks or is it like other differences? Yeah, this is a very good observation. Yeah, I think there are a lot of similarities to work on graph neural networks.

If we want to talk about differences, the differences might be that we're kind of assuming a fully connected graph, right? And so you could maybe also phrase that as we're discovering the relational structure or as graph neural networks usually assume that it's given. But that's also not always true.

And so there are a lot of similarities. I don't know, Neil, if there was something specific you would like to mention, go ahead. But it's a very good observation and we also do feel that that's the case. And we've added an extra section on related work to graph neural networks in the updated version of the paper that will be online soon.

Yeah, I agree with everything you've said. I think the closest work from the GNN literature that we were looking at a little bit was this neural relational inference paper, which uses message-passing neural networks to try to kind of like learn edges that may or may not exist and help for like extrapolating, I think, positions of like particles in like a multi-particle system or something, which is like kind of a similar idea to us.

Like, you know, if you don't have these edges as given, the attention mechanism could kind of approximate an interesting relationship amongst some interacting things. I see. Got it. Yeah, that's really cool. Another thing is like, so you mostly look on like tabular data, but can you also like have other modalities, like if you want to do language or something, can you still use non-parametric transformers?

Yeah, so I think part of our motivation for doing tabular was because we felt like tabular data is, in a sense, a generalization of, let's say, the language data, for example. I mean, I guess there's these other notions that people have brought up, like padding, but ultimately you can think of it as like a bunch of categorical attributes.

So it is definitely generalizable to things like sentences and we do, you know, images. So, yeah. I think actually like, I always go back and forth on whether or not I think smaller or larger data is more interesting for us. So I think small data is really interesting because we can just fit the entire data set into it and all of this just works out of the box without any extra thought.

But large data is actually also really interesting because, sure, you might have to introduce some approximative mechanism or some lookup mechanism because you can't always have the entire data set in. But at the same time, you are very explicitly kind of trading off the compute that you use to look up with the compute that you need to store.

Like how many parameters in GPT are used for storing data, right? There's lots of memorization happening in these models and we know that. And so maybe we can use the parameters more efficiently to learn lookup type behavior, right? That is more close to this, you know, neural KNN or whatever.

So I think these are very exciting questions. Yeah, yeah. I'll also be looking forward to the future works because it seems like a very good way to like do one-shot learning kind of situation. So, yeah, really very interesting to see that. Okay, so I will stop the recording and we can have like any other questions.

Okay, thank you. Okay, thank you. Okay, thank you. Thank you.