back to indexStanford CS25: V2 I Introduction to Transformers w/ Andrej Karpathy
Chapters
0:0 Introduction
0:47 Introducing the Course
3:19 Basics of Transformers
3:35 The Attention Timeline
5:1 Prehistoric Era
6:10 Where we were in 2021
7:30 The Future
10:15 Transformers - Andrej Karpathy
10:39 Historical context
60:30 Thank you - Go forth and transform
00:00:00.000 |
Hi everyone. Welcome to CS25 Transformers United V2. This was a course that was held 00:00:11.000 |
at Stanford in the winter of 2023. This course is not about robots that can transform into 00:00:15.720 |
cars, as this picture might suggest. Rather, it's about deep learning models that have 00:00:19.980 |
taken the world by the storm and have revolutionized the field of AI and others. Starting from 00:00:24.460 |
natural language processing, transformers have been applied all over, from computer 00:00:27.880 |
vision, reinforcement learning, biology, robotics, etc. We have an exciting set of videos lined 00:00:33.560 |
up for you, with some truly fascinating speakers giving talks, presenting how they're applying 00:00:38.720 |
transformers to the research in different fields and areas. We hope you'll enjoy and 00:00:46.280 |
learn from these videos. So without any further ado, let's get started. 00:00:52.200 |
This is a purely introductory lecture, and we'll go into the building blocks of transformers. 00:00:59.000 |
So first, let's start with introducing the instructors. 00:01:03.600 |
So for me, I'm currently on a temporary deferral from the PhD program, and I'm leading AI at 00:01:08.000 |
a robotic startup, Collaborative Robotics, working on some general-purpose robots, somewhat 00:01:13.000 |
like a robot. And yeah, I'm very passionate about robotics and building efficient learning 00:01:18.920 |
systems. My research interests are in reinforcement learning, computer vision, linear modeling, 00:01:23.800 |
and I have a bunch of publications in robot technology and other areas. My undergrad was 00:01:29.080 |
at Cornell, it's a municipal Cornell, so nice to meet all. 00:01:33.960 |
So I'm Stephen, I'm a first-year CSP speaker. Previously did my master's at CMU, and undergrad 00:01:39.800 |
at Waterloo. I'm mainly into NLP research, anything involving language and text. But 00:01:44.760 |
more recently, I've been getting more into computer vision as well as multilingual. And 00:01:49.080 |
just some stuff I do for fun, a lot of music stuff, mainly piano. Some self-promo, but 00:01:54.080 |
I post a lot on my Insta, YouTube, and TikTok, so if you guys want to check it out. 00:01:58.840 |
My friends and I are also starting a Stanford Piano Club, so if anybody's interested, feel 00:02:04.000 |
free to email me for details. Other than that, martial arts, bodybuilding, and huge fan of 00:02:14.880 |
Okay, cool. Yeah, so my name's Ryland. Instead of talking about myself, I just want to very 00:02:22.160 |
briefly say that I'm super excited to take this class. I took it the last time it was 00:02:27.920 |
offered, I had a bunch of fun. I thought we brought in a really great group of speakers 00:02:32.640 |
last time, I'm super excited for this offering. And yeah, I'm thankful that you're all here, 00:02:36.960 |
and I'm looking forward to a really fun quarter together. Thank you. 00:02:39.360 |
Yeah, so fun fact, Ryland was the most outspoken student last year, and so if someone wants 00:02:44.720 |
to become an instructor next year, you know what to do. 00:02:52.240 |
Okay, cool. So what we hope you will learn in this class is, first of all, how do task 00:03:00.320 |
forms work? How they're being applied? And nowadays, like, we are pretty much everywhere 00:03:06.960 |
in AI machine learning. And what are some new interesting directions of research in 00:03:15.680 |
Cool. So this class is just an introductory, so we'll be just talking about the basics 00:03:21.280 |
of transformers, introducing them, talking about the self-attention mechanism on which 00:03:25.520 |
they're founded, and we'll do a deep dive more on, like, models like BERT, GPT, stuff 00:03:31.520 |
like that. So, great, happy to get started. Okay, so let me start with presenting the 00:03:38.160 |
attention timeline. Attention all started with this one paper, Attention is All You 00:03:43.760 |
Need, by Baswani et al. in 2017. That was the beginning of transformers. Before that, 00:03:49.760 |
we had the prehistoric era, where we had models like RNNs, LSTMs, and their simple attention 00:03:56.800 |
mechanisms that didn't evolve or scale at all. Starting 2017, we saw this explosion 00:04:02.400 |
of transformers into NLP, where people started using it for everything. I even heard this 00:04:07.840 |
quote from Google, it's like, "Our performance increased every time we fired our linguists." 00:04:11.520 |
For the first 90, 80, after 2018 to 2020, we saw this explosion of transformers into 00:04:19.040 |
other fields, like vision, a bunch of other stuff, and biology, alpha foli. And last year, 00:04:27.120 |
2021 was the start of the generative era, where we got a lot of generative modeling. 00:04:30.720 |
Started with models like CODEX, GPT, DALI, stable diffusion, so a lot of things happening 00:04:37.520 |
in generative modeling. And we started scaling up in AI. And now it's the present. So this 00:04:46.160 |
is 2022 and the start of 2023. And now we have models like Chai-3PP, Whisper, a bunch of others. 00:04:54.560 |
And we are scaling onwards without slowing down. So that's great. So that's the future. 00:04:59.440 |
So going more into this, so once there were RNNs, so we had sequence-to-sequence models, LSTMs, 00:05:09.200 |
GLUs. What worked here was that they were good at encoding history. But what did not work was 00:05:15.440 |
they didn't encode long sequences. And they were very bad at encoding context. 00:05:19.440 |
So consider this example. Consider trying to predict the last word in the text, "I grew up in 00:05:27.920 |
France, dot, dot, dot. I speak fluent, dash." Here, you need to understand the context for it 00:05:33.600 |
to predict French. And attention mechanism is very good at that. Whereas if you're just using LSTMs, 00:05:39.600 |
it doesn't work that well. Another thing transformers are good at is more based on 00:05:47.040 |
content is-- sorry. Also, context prediction is like finding attention maps. If I have something 00:05:54.480 |
like a word like "it," what noun does it correlate to? And we can give a probability attention on 00:06:02.400 |
what are the possible activations. And this works better than existing mechanisms. 00:06:09.120 |
OK. So where we were in 2021, we were on the verge of takeoff. We were starting to realize 00:06:16.880 |
the potential of transformers in different fields. We solved a lot of long-sequence 00:06:21.680 |
problems like protein folding, alpha fold, offline RL. We started to see few shots, 00:06:29.280 |
zero-shot generalization. We saw multimodal tasks and applications like generating images 00:06:34.160 |
from language. So that was DALI. Yeah. And it feels like Asian, but it was only like two years 00:06:40.160 |
ago. And this is also a talk on transformers that you can watch on YouTube. Cool. And this is where 00:06:51.040 |
we were going from 2021 to 2022, which is we have gone from the verge of taking off to actually 00:06:57.200 |
taking off. And now we are seeing unique applications in audio generation, art, music, 00:07:01.920 |
storytelling. We are starting to see reasoning capabilities like common sense, logical reasoning, 00:07:08.160 |
mathematical reasoning. We are also able to now get human enlightenment and interaction. 00:07:14.000 |
They're able to use reinforcement learning and human feedback. That's how trajectories train to 00:07:18.160 |
perform really good. We have a lot of mechanisms for controlling toxicity, bias, and ethics now. 00:07:24.400 |
And also a lot of developments in other areas like digital models. Cool. 00:07:30.640 |
So the future is a spaceship, and we are all excited about it. 00:07:36.400 |
And there's a lot more applications that we can enable. And it'd be great if you can see 00:07:45.600 |
transformers also work there. One big example is video understanding and generation. That is 00:07:50.080 |
something that everyone is interested in. And I'm hoping we'll see a lot of models in this 00:07:54.000 |
area this year. Also finance, business. I'll be very excited to see GBT author novel, 00:08:01.920 |
but we need to solve very long sequence modeling. And most transformers models are still limited to 00:08:07.920 |
like 4,000 tokens or something like that. So we need to make them generalize much more better 00:08:14.240 |
on long sequences. We also want to have generalized agents that can do a lot of multitask 00:08:22.320 |
analytic input predictions like Gato. And so I think we will see more of that too. And finally, 00:08:30.640 |
we also want domain-specific models. So you might want like a GBT model that's good at like maybe 00:08:38.480 |
like your health. So that could be like a doctor GBT model. You might have like a lawyer GBT model 00:08:43.200 |
that's like gain on only on law data. So currently we have like GBT models that have gain on everything, 00:08:47.440 |
but we might start to see more niche models that are like good at one task. And we could have like 00:08:52.480 |
a mixture of experts. It's like, you can think like, this is like how you normally consult an 00:08:56.640 |
expert. You'll have like expert AI models and you can go to a different AI model for your different 00:09:00.080 |
needs. There are still a lot of missing ingredients to make this all successful. The first of all is 00:09:09.840 |
external memory. We are already starting to see this with the models like GenGBT, where 00:09:15.840 |
the interactions are short-lived. There's no long-term memory, and they don't have ability 00:09:20.320 |
to remember or store conversations for long term. And this is something we want to fix. 00:09:26.800 |
Second is reducing the computation complexity. So attention mechanism is quadratic over 00:09:34.320 |
the sequence length, which is slow. And we want to reduce it or make it faster. 00:09:39.520 |
Another thing we want to do is we want to enhance the controllability of this model. 00:09:45.760 |
It's like a lot of these models can be stochastic, and we want to be able to control what sort of 00:09:50.560 |
outputs we get from them. And you might have experienced with GenGBT, if you just refresh, 00:09:55.200 |
you get like different output each time, but you might want to have a mechanism that controls 00:09:59.040 |
what sort of things you get. And finally, we want to align our state of art language models with 00:10:04.800 |
how the human brain works. And we are seeing the search, but we still need more research on seeing 00:10:10.320 |
how it can be manipulated. Thank you. Great. Hi. Yes, I'm excited to be here. I live very nearby, 00:10:19.360 |
so I got the invites to come to class. And I was like, OK, I'll just walk over. 00:10:22.560 |
But then I spent like 10 hours on the slides, so it wasn't as simple. 00:10:26.880 |
So yeah, I want to talk about transformers. I'm going to skip the first two over there. 00:10:32.560 |
We're not going to talk about those. We'll talk about that one, 00:10:35.040 |
just to simplify the lecture since we don't have time. OK. So I wanted to provide a little bit of 00:10:41.440 |
context of why does this transformers class even exist. So a little bit of historical context. 00:10:45.840 |
I feel like Bilbo over there. I joined, like telling you guys about this. I don't know if 00:10:50.880 |
you guys saw the drinks. And basically, I joined AI in roughly 2012 in full force, 00:10:56.720 |
so maybe a decade ago. And back then, you wouldn't even say that you joined AI, by the way. That was 00:11:00.800 |
like a dirty word. Now it's OK to talk about. But back then, it was not even deep learning. It was 00:11:05.760 |
machine learning. That was a term you would use if you were serious. But now, AI is OK to use, 00:11:10.880 |
I think. So basically, do you even realize how lucky you are potentially entering this area 00:11:15.280 |
in roughly 2003? So back then, in 2011 or so, when I was working specifically on computer vision, 00:11:21.200 |
your pipelines looked like this. So you wanted to classify some images. You would go to a paper, 00:11:29.120 |
and I think this is representative. You would have three pages in the paper describing all kinds of 00:11:33.680 |
zoo of kitchen sink of different kinds of features and descriptors. And you would go to a poster 00:11:38.560 |
session and in computer vision conference, and everyone would have their favorite feature 00:11:41.840 |
descriptors that they're proposing. It was totally ridiculous. And you would take notes on which one 00:11:45.200 |
you should incorporate into your pipeline, because you would extract all of them, and then you would 00:11:48.640 |
put an SVM on top. So that's what you would do. So there's two pages. Make sure you get your sparse 00:11:53.040 |
SIP histograms, your SSIMs, your color histograms, textiles, tiny images. And don't forget the 00:11:57.920 |
geometry-specific histograms. All of them had basically complicated code by themselves. So 00:12:02.640 |
you're collecting code from everywhere and running it, and it was a total nightmare. 00:12:05.680 |
So on top of that, it also didn't work. So this would be, I think, represents the prediction from 00:12:14.480 |
that time. You would just get predictions like this once in a while, and you'd be like, you just 00:12:18.480 |
shrug your shoulders like that just happens once in a while. Today, you would be looking for a bug. 00:12:23.680 |
And worse than that, every single chunk of AI had their own completely separate 00:12:32.480 |
vocabulary that they work with. So if you go to NLP papers, those papers would be completely 00:12:37.520 |
different. So you're reading the NLP paper, and you're like, what is this part of speech tagging, 00:12:42.320 |
morphological analysis, syntactic parsing, coreference resolution? What is NP, BT, 00:12:47.520 |
KJ, and your compute? So the vocabulary and everything was completely different, 00:12:51.280 |
and you couldn't read papers, I would say, across different areas. 00:12:53.920 |
So now that changed a little bit starting in 2012 when Oskar Krzyzewski and the colleagues 00:13:00.960 |
basically demonstrated that if you scale a large neural network on a large data set, 00:13:06.640 |
you can get very strong performance. And so up till then, there was a lot of focus on algorithms, 00:13:10.960 |
but this showed that actually neural nets scale very well. So you need to now worry about compute 00:13:14.880 |
and data, and if you scale it up, it works pretty well. And then that recipe actually did copy-paste 00:13:19.360 |
across many areas of AI. So we started to see neural networks pop up everywhere since 2012. 00:13:25.040 |
So we saw them in computer vision, and NLP in speech, and translation in RL, and so on. So 00:13:30.480 |
everyone started to use the same kind of modeling tool kit, modeling framework. And now when you go 00:13:34.720 |
to NLP and you start reading papers there, in machine translation, for example, this is a 00:13:39.120 |
sequence-to-sequence paper, which we'll come back to in a bit. You start to read those papers, and 00:13:43.360 |
you're like, OK, I can recognize these words, like there's a neural network, there's a parameter, 00:13:46.960 |
there's an optimizer, and it starts to read things that you know of. So that decreased 00:13:52.240 |
tremendously the barrier to entry across the different areas. And then I think the big deal 00:13:57.840 |
is that when the transformer came out in 2017, it's not even that just the toolkits and the 00:14:02.480 |
neural networks were similar, it's that literally the architectures converged to one architecture 00:14:06.960 |
that you copy-paste across everything seemingly. So this was kind of an unassuming machine 00:14:12.560 |
translation paper at the time proposing the transformer architecture, but what we found 00:14:15.680 |
since then is that you can just basically copy-paste this architecture and use it everywhere, 00:14:21.520 |
and what's changing is the details of the data and the chunking of the data and how you feed it in. 00:14:26.560 |
And that's a caricature, but it's kind of like a correct first-order statement. 00:14:29.920 |
And so now papers are even more similar looking because everyone's just using 00:14:33.760 |
transformer. And so this convergence was remarkable to watch and unfolded over the 00:14:39.440 |
last decade, and it's pretty crazy to me. What I find kind of interesting is I think this is 00:14:44.560 |
some kind of a hint that we're maybe converging to something that maybe the brain is doing, 00:14:47.760 |
because the brain is very homogeneous and uniform across the entire sheet of your cortex. 00:14:52.880 |
And okay, maybe some of the details are changing, but those feel like hyperparameters of a 00:14:56.640 |
transformer, but your auditory cortex and your visual cortex and everything else looks very 00:15:00.320 |
similar. And so maybe we're converging to some kind of a uniform, powerful learning algorithm 00:15:04.880 |
here, something like that, I think is kind of interesting and exciting. 00:15:08.000 |
Okay, so I want to talk about where the transformer came from briefly, historically. 00:15:12.880 |
So I want to start in 2003. I like this paper quite a bit. It was the first sort of popular 00:15:20.000 |
application of neural networks to the problem of language modeling. So predicting, in this case, 00:15:24.160 |
the next word in a sequence, which allows you to build generative models over text. 00:15:27.680 |
And in this case, they were using multi-layer perceptron, so a very simple neural net. 00:15:31.040 |
The neural nets took three words and predicted the probability distribution for the fourth word 00:15:34.720 |
in a sequence. So this was well and good at this point. Now, over time, people started to apply 00:15:41.280 |
this to machine translation. So that brings us to sequence-to-sequence paper from 2014 that was 00:15:46.960 |
pretty influential. And the big problem here was, okay, we don't just want to take three words and 00:15:51.360 |
predict the fourth. We want to predict how to go from an English sentence to a French sentence. 00:15:56.720 |
And the key problem was, okay, you can have arbitrary number of words in English and 00:16:00.400 |
arbitrary number of words in French, so how do you get an architecture that can process this 00:16:05.120 |
variably-sized input? And so here, they used a LSTM. And there's basically two chunks of this, 00:16:11.120 |
which are covered by the Slack, by this. But basically, you have an encoder LSTM on the left, 00:16:18.960 |
and it just consumes one word at a time and builds up a context of what it has read. And then that 00:16:25.280 |
acts as a conditioning vector to the decoder RNN or LSTM that basically goes chunk, chunk, 00:16:30.320 |
chunk for the next word in the sequence, translating the English to French or something 00:16:34.640 |
like that. Now, the big problem with this that people identified, I think, very quickly and 00:16:39.040 |
tried to resolve is that there's what's called this encoded bottleneck. So this entire English 00:16:45.280 |
sentence that we are trying to condition on is packed into a single vector that goes from the 00:16:49.600 |
encoder to the decoder. And so this is just too much information to potentially maintain in a 00:16:53.440 |
single vector, and that didn't seem correct. And so people were looking around for ways to alleviate 00:16:57.680 |
the attention of sort of the encoded bottleneck, as it was called at the time. And so that brings 00:17:02.800 |
us to this paper, Neural Machine Translation by Jointly Learning to Align and Translate. 00:17:06.960 |
And here, just going from the abstract, in this paper, we conjectured that use of a fixed-length 00:17:13.200 |
vector is a bottleneck in improving the performance of the basic encoded-decoder 00:17:16.720 |
architecture, and proposed to extend this by allowing the model to automatically soft-search 00:17:21.680 |
for parts of the source sentence that are relevant to predicting a target word, 00:17:26.880 |
without having to form these parts or hard segments exclusively. So this was a way to look 00:17:32.880 |
back to the words that are coming from the encoder, and it was achieved using this soft-search. So as 00:17:38.720 |
you are decoding the words here, while you are decoding them, you are allowed to look back at 00:17:45.360 |
the words at the encoder via this soft attention mechanism proposed in this paper. And so this 00:17:50.960 |
paper, I think, is the first time that I saw, basically, attention. So your context vector 00:17:57.760 |
that comes from the encoder is a weighted sum of the hidden states of the words in the encoding, 00:18:04.480 |
and then the weights of this sum come from a softmax that is based on these compatibilities 00:18:11.200 |
between the current state, as you're decoding, and the hidden states generated by the encoder. 00:18:15.760 |
And so this is the first time that really you start to look at it, and this is the current 00:18:20.800 |
modern equations of the attention. And I think this was the first paper that I saw it in. It's 00:18:25.520 |
the first time that there's a word "attention" used, as far as I know, to call this mechanism. 00:18:31.360 |
So I actually tried to dig into the details of the history of the attention. 00:18:35.600 |
So the first author here, Dimitri, I had an email correspondence with him, 00:18:40.000 |
and I basically sent him an email. I'm like, "Dimitri, this is really interesting. Transformers 00:18:43.360 |
have taken over. Where did you come up with the soft attention mechanism that ends up being the 00:18:46.880 |
heart of the transformer?" And to my surprise, he wrote me back this massive email, which was 00:18:52.160 |
really fascinating. So this is an excerpt from that email. So basically, he talks about how he 00:18:58.560 |
was looking for a way to avoid this bottleneck between the encoder and decoder. He had some 00:19:02.720 |
ideas about cursors that traversed the sequences that didn't quite work out. And then here - so 00:19:08.000 |
one day, I had this thought that it would be nice to enable the decoder RNN to learn how to search 00:19:11.840 |
where to put the cursor in the source sequence. This was sort of inspired by translation exercises 00:19:16.400 |
that learning English in my middle school involved. You gaze shifts back and forth between 00:19:22.720 |
source and target sequence as you translate. So literally, I thought this was kind of interesting 00:19:26.960 |
that he's not a native English speaker. And here, that gave him an edge in this machine translation 00:19:31.200 |
that led to attention and then led to transformer. So that's really fascinating. I expressed a soft 00:19:37.680 |
search as softmax and then weighted averaging of the binary states. And basically, to my great 00:19:42.560 |
excitement, this worked from the very first try. So really, I think, interesting piece of history. 00:19:48.320 |
And as it later turned out that the name of RNN search was kind of lame. So the better name 00:19:53.520 |
attention came from Yoshua on one of the final passes as they went over the paper. So maybe 00:20:00.000 |
attention is all I need would have been called like RNN searches. But we have Yoshua Bengio to 00:20:04.960 |
thank for a little bit of better name, I would say. So apparently, that's the history of this. 00:20:11.360 |
OK, so that brings us to 2017, which is attention is all you need. So this attention component, 00:20:16.240 |
which in Dimitri's paper was just like one small segment. And there's all this bi-directional RNN, 00:20:20.960 |
RNN and decoder. And this attention-only paper is saying, OK, you can actually delete everything. 00:20:26.880 |
What's making this work very well is just attention by itself. And so delete everything, 00:20:31.280 |
keep attention. And then what's remarkable about this paper, actually, is usually you see papers 00:20:35.520 |
that are very incremental. They add one thing, and they show that it's better. But I feel like 00:20:40.560 |
attention is all you need with a mix of multiple things at the same time. They were combined in a 00:20:44.960 |
very unique way, and then also achieved a very good local minimum in the architecture space. 00:20:50.720 |
And so to me, this is really a landmark paper that is quite remarkable, and I think had quite a lot 00:20:56.480 |
of work behind the scenes. So delete all the RNN, just keep attention. Because attention operates 00:21:03.040 |
over sets, and I'm going to go into this in a second, you now need to positionally encode your 00:21:06.640 |
inputs, because attention doesn't have the notion of space by itself. They - oops, I have to be 00:21:15.280 |
very careful - they adopted this residual network structure from ResNets. They interspersed 00:21:22.160 |
attention with multi-layer perceptrons. They used layer norms, which came from a different paper. 00:21:27.680 |
They introduced the concept of multiple heads of attention that were applied in parallel. 00:21:30.800 |
And they gave us, I think, like a fairly good set of hyperparameters that to this day are used. 00:21:35.440 |
So the expansion factor in the multi-layer perceptron goes up by 4x, and we'll go into 00:21:40.800 |
like a bit more detail, and this 4x has stuck around. And I believe there's a number of papers 00:21:45.280 |
that try to play with all kinds of little details of the transformer, and nothing sticks, because 00:21:50.000 |
this is actually quite good. The only thing to my knowledge that stuck, that didn't stick, was this 00:21:55.840 |
reshuffling of the layer norms to go into the pre-norm version, where here you see the layer 00:21:59.920 |
norms are after the multi-headed attention repeat forward, but they just put them before instead. 00:22:04.400 |
So just reshuffling of layer norms, but otherwise the GPTs and everything else that you're seeing 00:22:08.160 |
today is basically the 2017 architecture from five years ago. And even though everyone is 00:22:13.120 |
working on it, it's proven remarkably resilient, which I think is real interesting. There are 00:22:18.000 |
innovations that I think have been adopted also in positional encodings. It's more common to use 00:22:22.640 |
different rotary and relative positional encodings and so on. So I think there have been changes, 00:22:27.600 |
but for the most part it's proven very resilient. So really quite an interesting paper. Now I wanted 00:22:33.360 |
to go into the attention mechanism, and I think, I sort of like, the way I interpret it is not 00:22:40.160 |
similar to the ways that I've seen it presented before. So let me try a different way of like 00:22:47.280 |
how I see it. Basically to me, attention is kind of like the communication phase of the transformer, 00:22:51.440 |
and the transformer interleaves two phases. The communication phase, which is the multi-headed 00:22:56.560 |
attention, and the computation stage, which is this multilayer perceptron, or P12. So in the 00:23:01.760 |
communication phase, it's really just a data-dependent message passing on directed graphs. 00:23:06.400 |
And you can think of it as, okay, forget everything with machine translation and everything. Let's 00:23:11.280 |
just, we have directed graphs at each node. You are storing a vector. And then let me talk now 00:23:18.080 |
about the communication phase of how these vectors talk to each other in this directed graph. And 00:23:21.760 |
then the compute phase later is just a multilayer perceptron, which now, which then basically acts 00:23:27.200 |
on every node individually. But how do these nodes talk to each other in this directed graph? 00:23:32.240 |
So I wrote like some simple Python, like I wrote this in Python basically to create one round of 00:23:40.080 |
communication of using attention as the direct, as the message passing scheme. So here, a node 00:23:48.560 |
has this private data vector, as you can think of it as private information to this node. And then 00:23:54.720 |
it can also emit a key, a query, and a value. And simply that's done by linear transformation 00:23:59.440 |
from this node. So the key is, what are the things that I am, sorry, the query is, what are the 00:24:08.960 |
things that I'm looking for? The key is, what are the things that I have? And the value is, what are 00:24:13.040 |
the things that I will communicate? And so then when you have your graph that's made up of nodes 00:24:17.600 |
and some random edges, when you actually have these nodes communicating, what's happening is 00:24:21.360 |
you loop over all the nodes individually in some random order, and you are at some node, 00:24:27.040 |
and you get the query vector q, which is, I'm a node in some graph, and this is what I'm looking 00:24:33.200 |
for. And so that's just achieved via this linear transformation here. And then we look at all the 00:24:37.920 |
inputs that point to this node, and then they broadcast, what are the things that I have, 00:24:42.320 |
which is their keys. So they broadcast the keys, I have the query, then those interact by dot product 00:24:49.600 |
to get scores. So basically, simply by doing dot product, you get some kind of an unnormalized 00:24:55.520 |
weighting of the interestingness of all of the information in the nodes that point to me and to 00:25:00.960 |
the things I'm looking for. And then when you normalize that with a submax, so it just sums to 00:25:04.800 |
one, you basically just end up using those scores, which now sum to one and are a probability 00:25:09.600 |
distribution, and you do a weighted sum of the values to get your update. So I have a query, 00:25:17.280 |
they have keys, dot product to get interestingness, or like affinity, submax to normalize it, 00:25:23.840 |
and then weighted sum of those values flow to me and update me. And this is happening for each 00:25:28.720 |
node individually, and then we update at the end. And so this kind of a message passing scheme is 00:25:32.800 |
kind of like at the heart of the transformer, and happens in a more vectorized, batched way 00:25:40.240 |
that is more confusing, and is also interspersed with layer norms and things like that to make the 00:25:45.760 |
training behave better. But that's roughly what's happening in the attention mechanism, I think, 00:25:50.560 |
on a high level. So yeah, so in the communication phase of the transformer, then this message 00:25:59.680 |
passing scheme happens in every head in parallel, and then in every layer in series, and with 00:26:06.800 |
different weights each time. And that's it as far as the multi-headed attention goes. 00:26:13.120 |
And so if you look at these encoder-decoder models, you can sort of think of it then, 00:26:17.040 |
in terms of the connectivity of these nodes in the graph, you can kind of think of it as like, 00:26:20.560 |
okay, all these tokens that are in the encoder that we want to condition on, they are fully 00:26:24.880 |
connected to each other. So when they communicate, they communicate fully when you calculate their 00:26:29.520 |
features. But in the decoder, because we are trying to have a language model, we don't want 00:26:34.240 |
to have communication from future tokens, because they give away the answer at this step. So the 00:26:38.720 |
tokens in the decoder are fully connected from all the encoder states, and then they are also 00:26:43.840 |
fully connected from everything that is before them. And so you end up with this, like, triangular 00:26:48.320 |
structure in the directed graph. But that's the message passing scheme that this basically 00:26:53.280 |
implements. And then you have to be also a little bit careful, because in the cross-attention here 00:26:58.320 |
with the decoder, you consume the features from the top of the encoder. So think of it as, in the 00:27:03.600 |
encoder, all the nodes are looking at each other, all the tokens are looking at each other, many, 00:27:07.280 |
many times. And they really figure out what's in there. And then the decoder, when it's looking 00:27:11.520 |
only at the top nodes. So that's roughly the message passing scheme. I was going to go into 00:27:17.360 |
more of an implementation of the transformer. I don't know if there's any questions about this. 00:27:21.680 |
Can you explain a little bit about self-attention and multi-headed attention? 00:27:26.480 |
Yeah, so self-attention and multi-headed attention. So the multi-headed attention is 00:27:36.880 |
just this attention scheme, but it's just applied multiple times in parallel. Multiple heads just 00:27:41.600 |
means independent applications of the same attention. So this message passing scheme 00:27:47.360 |
basically just happens in parallel multiple times with different weights for the query key and value. 00:27:53.200 |
So you can almost look at it like, in parallel, I'm looking for, I'm seeking different kinds of 00:27:56.880 |
information from different nodes, and I'm collecting it all in the same node. It's all 00:28:01.360 |
done in parallel. So heads is really just like copy-paste in parallel. And layers are copy-paste, 00:28:09.680 |
but in series. Maybe that makes sense. And self-attention, when it's self-attention, 00:28:18.800 |
what it's referring to is that the node here produces each node here. So as I described it 00:28:24.000 |
here, this is really self-attention. Because every one of these nodes produces a key query and a 00:28:28.240 |
value from this individual node. When you have cross-attention, you have one cross-attention 00:28:34.160 |
here coming from the encoder. That just means that the queries are still produced from this node, 00:28:40.720 |
but the keys and the values are produced as a function of nodes that are coming from the 00:28:46.480 |
encoder. So I have my queries because I'm trying to decode the fifth word in the sequence, 00:28:53.760 |
and I'm looking for certain things because I'm the fifth word. And then the keys and the values, 00:28:58.240 |
in terms of the source of information that could answer my queries, can come from the previous 00:29:02.720 |
nodes in the current decoding sequence, or from the top of the encoder. So all the nodes that 00:29:07.280 |
have already seen all of the encoding tokens many, many times can now broadcast what they 00:29:12.320 |
contain in terms of information. So I guess to summarize, the self-attention is kind of like, 00:29:19.040 |
sorry, cross-attention and self-attention only differ in where the keys and the values come 00:29:23.680 |
from. Either the keys and values are produced from this node, or they are produced from some 00:29:29.120 |
external source, like an encoder and the nodes over there. But algorithmically, it's the same 00:29:56.640 |
So think of - so each one of these nodes is a token. 00:30:00.560 |
I guess, like, they don't have a very good picture of it in the transformer, but like 00:30:09.840 |
this node here could represent the third word in the output, in the decoder. 00:30:16.800 |
And in the beginning, it is just the embedding of the word. 00:30:21.440 |
And then, okay, I have to think through this analogy a little bit more. I came up with it 00:30:31.840 |
this morning. Actually, I came up with it yesterday. 00:30:50.240 |
These nodes are basically the factors. I'll go to an implementation - I'll go to the 00:30:54.720 |
implementation, and then maybe I'll make the connections to the graph. So let me try to 00:30:59.680 |
first go to - let me now go to, with this intuition in mind at least, to nanoGPT, which is a concrete 00:31:04.480 |
implementation of a transformer that is very minimal. So I worked on this over the last few 00:31:08.240 |
days, and here it is reproducing GPT-2 on open web text. So it's a pretty serious implementation 00:31:13.760 |
that reproduces GPT-2, I would say, and provided enough compute. This was one node of eight 00:31:19.040 |
GPUs for 38 hours or something like that, and it's very readable at 300 lives, so everyone 00:31:25.040 |
can take a look at it. And yeah, let me basically briefly step through it. So let's try to have a 00:31:32.400 |
decoder-only transformer. So what that means is that it's a language model. It tries to model 00:31:36.880 |
the next word in a sequence or the next character in a sequence. So the data that we train on is 00:31:43.120 |
always some kind of text. So here's some fake Shakespeare. Sorry, this is real Shakespeare. 00:31:47.040 |
We're going to produce fake Shakespeare. So this is called the tiny Shakespeare data set, 00:31:50.480 |
which is one of my favorite toy data sets. You take all of Shakespeare, concatenate it, 00:31:54.160 |
and it's one megabyte file, and then you can train language models on it and get infinite 00:31:57.600 |
Shakespeare if you like, which I think is kind of cool. So we have a text. The first thing we 00:32:01.600 |
need to do is we need to convert it to a sequence of integers, because transformers natively process, 00:32:06.960 |
you know, you can't plug text into transformer. You need to somehow encode it. So the way that 00:32:12.400 |
encoding is done is we convert, for example, in the simplest case, every character gets an integer, 00:32:16.560 |
and then instead of "hi" there, we would have this sequence of integers. So then you can encode 00:32:22.720 |
every single character as an integer and get a massive sequence of integers. You just concatenate 00:32:28.400 |
it all into one large, long, one-dimensional sequence, and then you can train on it. 00:32:32.560 |
Now, here we only have a single document. In some cases, if you have multiple independent 00:32:36.800 |
documents, what people like to do is create special tokens, and they intersperse those 00:32:40.160 |
documents with those special end-of-text tokens that they splice in between to create boundaries. 00:32:45.040 |
But those boundaries actually don't have any modeling impact. It's just that the transformer 00:32:52.080 |
is supposed to learn via backpropagation that the end-of-document sequence means that you 00:32:57.040 |
should wipe the memory. Okay, so then we produce batches. So these batches of data just mean that 00:33:04.640 |
we go back to the one-dimensional sequence, and we take out chunks of this sequence. So say if the 00:33:09.840 |
block size is 8, then the block size indicates the maximum length of context that your transformer 00:33:17.600 |
will process. So if our block size is 8, that means that we are going to have up to 8 characters 00:33:22.960 |
of context to predict the 9th character in the sequence. And the batch size indicates how many 00:33:28.160 |
sequences in parallel we're going to process. And we want this to be as large as possible, 00:33:31.760 |
so we're fully taking advantage of the GPU and the parallels on the boards. 00:33:34.640 |
So in this example, we're doing 4 by 8 batches. So every row here is independent example, sort of. 00:33:41.440 |
And then every row here is a small chunk of the sequence that we're going to train on. 00:33:48.640 |
And then we have both the inputs and the targets at every single point here. So to fully spell out 00:33:53.680 |
what's contained in a single 4 by 8 batch to the transformer, I sort of compact it here. So when 00:33:59.920 |
the input is 47 by itself, the target is 58. And when the input is the sequence 47, 58, the target 00:34:07.760 |
is 1. And when it's 47, 58, 1, the target is 51, and so on. So actually the single batch of examples 00:34:14.960 |
that's 4 by 8 actually has a ton of individual examples that we are expecting the transformer 00:34:19.200 |
to learn on in parallel. And so you'll see that the batches are learned on completely independently, 00:34:25.040 |
but the time dimension sort of here along horizontally is also trained on in parallel. 00:34:30.880 |
So sort of your real batch size is more like b times t. It's just that the context grows linearly 00:34:37.280 |
for the predictions that you make along the t direction in the model. So this is all the 00:34:44.000 |
examples that the model will learn from this single batch. So now this is the GPT class. 00:34:51.760 |
And because this is a decoder-only model, so we're not going to have an encoder because 00:34:57.760 |
there's no, like, English we're translating from. We're not trying to condition on some 00:35:01.280 |
other external information. We're just trying to produce a sequence of words that follow each other 00:35:06.080 |
or are likely to. So this is all PyTorch. And I'm going slightly faster because I'm assuming people 00:35:11.360 |
have taken 231n or something along those lines. But here in the forward pass, we take these indices 00:35:18.080 |
and then we both encode the identity of the indices just via an embedding lookup table. 00:35:26.720 |
So every single integer has a - we index into a lookup table of vectors in this nn.embedding 00:35:33.680 |
and pull out the word vector for that token. And then because the message - because transformed 00:35:40.800 |
by itself doesn't actually - it processes sets natively, so we need to also positionally encode 00:35:45.280 |
these vectors so that we basically have both the information about the token identity and 00:35:49.680 |
its place in the sequence from one to block size. Now those - the information about what and where 00:35:56.960 |
is combined additively. So the token embeddings and the positional embeddings are just added 00:36:00.720 |
exactly as here. So this x here, then there's optional dropout. This x here basically just 00:36:07.920 |
contains the set of words and their positions, and that feeds into the blocks of transformer. 00:36:16.960 |
And we're going to look into what's blocked here. But for here, for now, this is just a series of 00:36:20.560 |
blocks in the transformer. And then in the end, there's a layer norm, and then you're decoding 00:36:25.760 |
the logits for the next word or next integer in the sequence using a linear projection of 00:36:32.080 |
the output of this transformer. So lm_head here, short for language model head, is just a linear 00:36:37.600 |
function. So basically, positionally encode all the words, feed them into a sequence of blocks, 00:36:45.200 |
and then apply a linear layer to get the probability distribution for the next 00:36:48.640 |
character. And then if we have the targets, which we produced in the data loader, and you'll notice 00:36:54.960 |
that the targets are just the inputs offset by one in time, then those targets feed into a cross 00:37:00.560 |
entropy loss. So this is just a negative one likelihood typical classification loss. 00:37:04.000 |
So now let's drill into what's here in the blocks. So these blocks that are applied sequentially, 00:37:10.320 |
there's again, as I mentioned, this communicate phase and the compute phase. 00:37:14.880 |
So in the communicate phase, all the nodes get to talk to each other, and so these nodes are 00:37:19.760 |
basically - if our block size is eight, then we are going to have eight nodes in this graph. 00:37:26.560 |
There's eight nodes in this graph, the first node is pointed to only by itself, 00:37:30.240 |
the second node is pointed to by the first node and itself, the third node is pointed to by the 00:37:35.040 |
first two nodes and itself, etc. So there's eight nodes here. So you apply - there's a residual 00:37:41.200 |
pathway in x, you take it out, you apply a layer norm, and then the self-attention so that these 00:37:45.840 |
communicate, these eight nodes communicate, but you have to keep in mind that the batch is four. 00:37:50.240 |
So because batch is four, this is also applied - so we have eight nodes communicating, but there's 00:37:56.080 |
a batch of four of them all individually communicating among those eight nodes. There's 00:37:59.920 |
no crisscross across the batch dimension, of course. There's no batch normalization anywhere, 00:38:02.960 |
luckily. And then once they've changed information, they are processed using the 00:38:08.480 |
multilayer perceptron, and that's the compute phase. And then also here, we are missing 00:38:14.160 |
the cross-attention, because this is a decoder-only model. So all we have is this step 00:38:21.280 |
here, the multi-headed attention, and that's this line, the communicate phase, and then we have the 00:38:25.040 |
feedforward, which is the MLP, and that's the compute phase. I'll take questions a bit later. 00:38:30.640 |
Then the MLP here is fairly straightforward. The MLP is just individual processing on each node, 00:38:37.360 |
just transforming the feature representation sort of at that node. So applying a two-layer neural 00:38:45.040 |
net with a GELU non-linearity, which is - just think of it as a RELU or something like that. 00:38:49.360 |
It's just a non-linearity. And then MLP is straightforward. I don't think there's anything 00:38:54.480 |
too crazy there. And then this is the causal self-attention part, the communication phase. 00:38:58.800 |
So this is kind of like the meat of things and the most complicated part. It's only complicated 00:39:04.560 |
because of the batching and the implementation detail of how you mask the connectivity in the 00:39:10.800 |
graph so that you can't obtain any information from the future when you're predicting your token. 00:39:16.240 |
Otherwise, it gives away the information. So if I'm the fifth token, and if I'm the fifth position, 00:39:23.120 |
then I'm getting the fourth token coming into the input, and I'm attending to the third, 00:39:27.760 |
second, and first, and I'm trying to figure out what is the next token, well then in this batch, 00:39:33.440 |
in the next element over in the time dimension, the answer is at the input. So I can't get any 00:39:38.960 |
information from there. So that's why this is all tricky. But basically in the forward pass, 00:39:42.800 |
we are calculating the queries, keys, and values based on x. So these are the keys, 00:39:51.360 |
queries, and values. Here, when I'm computing the attention, I have the queries matrix multiplying 00:39:57.520 |
the keys. So this is the dot product in parallel for all the queries and all the keys, and all the 00:40:02.400 |
heads. So I failed to mention that there's also the aspect of the heads, which is also done all 00:40:08.080 |
in parallel here. So we have the batch dimension, the time dimension, and the head dimension, 00:40:11.920 |
and you end up with five-dimensional tensors, and it's all really confusing. So I invite you 00:40:15.200 |
to step through it later and convince yourself that this is actually doing the right thing. 00:40:19.040 |
But basically, you have the batch dimension, the head dimension, and the time dimension, 00:40:23.360 |
and then you have features at them. And so this is evaluating for all the batch elements, 00:40:28.320 |
for all the head elements, and all the time elements, the simple Python that I gave you 00:40:32.800 |
earlier, which is query dot product p. Then here, we do a masked fill. And what this is doing is 00:40:39.280 |
it's basically clamping the attention between the nodes that are not supposed to communicate 00:40:45.600 |
to be negative infinity. And we're doing negative infinity because we're about to softmax, 00:40:50.080 |
and so negative infinity will make basically the attention of those elements be zero. 00:40:53.680 |
And so here, we are going to basically end up with the weights, the sort of affinities between 00:41:01.520 |
these nodes, optional dropout, and then here, attention matrix multiply v is basically the 00:41:07.680 |
gathering of the information according to the affinities we've calculated. And this is just a 00:41:12.560 |
weighted sum of the values at all those nodes. So this matrix multipliers is doing that weighted 00:41:17.840 |
sum. And then transpose contiguous view, because it's all complicated and bashed in five-dimensional 00:41:23.440 |
tensors, but it's really not doing anything, optional dropout, and then a linear projection 00:41:28.560 |
back to the residual pathway. So this is implementing the communication phase here. 00:41:32.960 |
Then you can train this transformer, and then you can generate infinite Shakespeare, 00:41:41.120 |
and you will simply do this by - because our block size is eight, we start with a sum token, 00:41:46.480 |
say like, I use in this case, you can use something like a muon as the start token, 00:41:52.400 |
and then you communicate only to yourself because there's a single node, and you get the probability 00:41:57.440 |
distribution for the first word in the sequence, and then you decode it, or the first character in 00:42:03.680 |
the sequence, you decode the character, and then you bring back the character, and you re-encode 00:42:07.600 |
it as an integer, and now you have the second thing. And so you get, okay, we're at the first 00:42:13.360 |
position, and this is whatever integer it is, add the positional encodings, goes into the sequence, 00:42:18.800 |
goes into transformer, and again, this token now communicates with the first token and its identity. 00:42:26.640 |
And so you just keep plugging it back, and once you run out of the block size, which is eight, 00:42:30.880 |
you start to crop, because you can never have block size more than eight in the way you've 00:42:34.560 |
trained this transformer. So we have more and more context until eight, and then if you want 00:42:38.240 |
to generate beyond eight, you have to start cropping, because the transformer only works for 00:42:41.920 |
eight elements in time dimension. And so all of these transformers in the naive setting have a 00:42:48.000 |
finite block size, or context length, and in typical models, this will be 1024 tokens, 00:42:53.760 |
or 2048 tokens, something like that, but these tokens are usually like DPE tokens, 00:42:58.480 |
or sentence piece tokens, or workpiece tokens, there's many different encodings. 00:43:02.480 |
So it's not like that long, and so that's why I think I did mention, we really want 00:43:05.680 |
to expand the context size, and it gets gnarly, because the attention is quadratic in many cases. 00:43:09.920 |
Now, if you want to implement an encoder instead of a decoder attention, then all you have to do 00:43:19.680 |
is this mask node, and you just delete that line. So if you don't mask the attention, 00:43:25.360 |
then all the nodes communicate to each other, and everything is allowed, and information flows 00:43:29.840 |
between all the nodes. So if you want to have the encoder here, just delete all the encoder blocks, 00:43:36.960 |
we'll use attention, where this line is deleted, that's it. So you're allowing, 00:43:41.440 |
whatever this encoder might store, say 10 tokens, like 10 nodes, and they are all allowed to 00:43:47.040 |
communicate to each other, going up the transformer. And then if you want to implement cross attention, 00:43:53.360 |
so you have a full encoder decoder transformer, not just a decoder only transformer, or GPT, 00:43:59.280 |
then we need to also add cross attention in the middle. So here, there's a self attention piece, 00:44:05.520 |
where all the, there's a self attention piece, a cross attention piece, and this MLP. And in the 00:44:10.160 |
cross attention, we need to take the features from the top of the encoder, we need to add one more 00:44:15.760 |
line here, and this would be the cross attention, instead of, I should have implemented it, instead 00:44:21.680 |
of just pointing, I think. But there'll be a cross attention line here, so we'll have three lines, 00:44:26.480 |
because we need to add another block. And the queries will come from x, but the keys and the 00:44:31.520 |
values will come from the top of the encoder. And there will be basically information flowing from 00:44:37.040 |
the encoder strictly to all the nodes inside x. And then that's it. So it's very simple sort of 00:44:43.760 |
modifications on the decoder attention. So you'll hear people talk that you kind of have a decoder 00:44:49.920 |
only model, like GPT, you can have an encoder only model, like BERT, or you can have an encoder 00:44:54.800 |
decoder model, like say T5, doing things like machine translation. So, and in BERT, you can't 00:45:01.440 |
train it using sort of this language modeling setup that's autoregressive, and you're just 00:45:06.080 |
trying to predict the next element in the sequence, you're training it with slightly different 00:45:08.880 |
objectives, you're putting in like the full sentence, and the full sentence is allowed to 00:45:13.040 |
communicate fully, and then you're trying to classify sentiment or something like that. 00:45:17.360 |
So you're not trying to model like the next token in the sequence. So these are trained 00:45:21.840 |
slightly different with mask, with using masking and other denoising techniques. 00:45:30.000 |
Okay, so that's kind of like the transformer. I'm going to continue. So yeah, maybe more questions. 00:45:38.320 |
These are excellent questions. So when we're employing information, 00:45:43.760 |
for instance, like the graph that we all did, and when we were like, something like that, 00:45:48.480 |
you know, this transformer still performs, like it's a dynamic graph, that the connections 00:45:56.880 |
change in every instance, and you also have some feature information. So just like, we are 00:46:02.720 |
enforcing these constraints on it by just masking, but it is aware of the work that it tends to do. 00:46:09.840 |
So I'm not sure if I fully followed. So there's different ways to look at this analogy, but one 00:46:16.880 |
analogy is you can interpret this graph as really fixed. It's just that every time we do the 00:46:21.120 |
communicate, we are using different weights. You can look at it that way. So if we have block size 00:46:25.440 |
of eight in my example, we would have eight nodes. Here we have two, four, six, okay, so we'd have 00:46:30.080 |
eight nodes. They would be connected in, you lay them out, and you only connect from left to right. 00:46:35.200 |
But for a different problem, that might not be the case, but you have a graph where the connections 00:46:40.160 |
might change. Why would the connection, usually the connections don't change as a function of 00:46:46.160 |
the data or something like that. That means like the molecules look like an actual graph, 00:46:50.880 |
and look like that. I don't think I've seen a single example where the connectivity changes 00:47:02.960 |
dynamically in function of data. Usually the connectivity is fixed. If you have an encoder 00:47:06.560 |
and you're training a BERT, you have how many tokens you want, and they are fully connected. 00:47:10.720 |
And if you have a decoder only model, you have this triangular thing. And if you have encoder 00:47:15.600 |
decoder, then you have awkwardly sort of like two pools of nodes. Yeah, go ahead. 00:47:48.880 |
Yeah, it's really hard to say. So that's why I think this paper is so interesting is like, 00:48:16.880 |
yeah, usually you'd see like a path, and maybe they had path internally. They just didn't publish 00:48:20.320 |
it. All you can see is sort of things that didn't look like a transformer. I mean, you have ResNets, 00:48:24.800 |
which have lots of this. But a ResNet would be kind of like this, but there's no self-attention 00:48:30.480 |
component. But the MLP is there kind of in a ResNet. So a ResNet looks very much like this, 00:48:37.600 |
except there's no - you can use layer norms in ResNets, I believe, as well. Typically, 00:48:41.600 |
sometimes they can be batch norms. So it is kind of like a ResNet. It is kind of like they took a 00:48:46.000 |
ResNet and they put in a self-attentionary block in addition to the pre-existing MLP block, which 00:48:52.160 |
is kind of like convolutions. And MLP would, strictly speaking, be convolution, one-by-one 00:48:56.240 |
convolution. But I think the idea is similar in that MLP is just kind of like typical weights, 00:49:11.120 |
But I will say, yeah, it's kind of interesting because a lot of work is not there, and then 00:49:16.800 |
they give you this transformer, and then it turns out five years later, it's not changed, even though 00:49:19.840 |
everyone's trying to change it. So it's kind of interesting to me that it's kind of like a package, 00:49:23.280 |
in like a package, which I think is really interesting historically. And I also talked to 00:49:28.160 |
paper authors, and they were unaware of the impact that the transformer would have at the time. 00:49:33.840 |
So when you read this paper, actually, it's kind of unfortunate because this is like the paper that 00:49:39.280 |
changed everything. But when people read it, it's like question marks, because it reads like a 00:49:43.200 |
pretty random machine translation paper. Like, oh, we're doing machine translation. Oh, here's a cool 00:49:48.000 |
architecture. OK, great, good results. It doesn't sort of know what's going to happen. And so when 00:49:55.440 |
people read it today, I think they're kind of confused, potentially. I will have some tweets 00:50:01.360 |
at the end, but I think I would have renamed it with the benefit of hindsight of like, well, 00:50:05.440 |
I'll get to it. Yeah, I think that's a good question as well. Currently, I mean, I certainly 00:50:24.560 |
don't love the autoregressive modeling approach. I think it's kind of weird to sample a token and 00:50:29.120 |
then commit to it. So maybe there's some ways-- some hybrids with diffusion, as an example, 00:50:38.000 |
which I think would be really cool. Or we'll find some other ways to edit the sequences later, 00:50:43.760 |
but still in the autoregressive framework. But I think diffusion is kind of like an up-and-coming 00:50:48.960 |
modeling approach that I personally find much more appealing. When I sample text, I don't go 00:50:53.360 |
chunk, chunk, chunk, and commit. I do a draft one, and then I do a better draft two. And that feels 00:50:58.800 |
like a diffusion process. So that would be my hope. 00:51:02.320 |
OK, also a question. So yeah, I use like the Gartner logic where it takes a weight which 00:51:10.560 |
is like a graph. Will you say like the self-attention is sort of like computing like 00:51:16.720 |
an edge weight using the dot product on the node similarity, and then once we have the edge weight, 00:51:21.680 |
we just multiply it by the values, and then we just propagate it? 00:51:26.960 |
And do you think there's like analogy between graph neural networks and self-attention? 00:51:32.320 |
I find the graph neural networks kind of like a confusing term, because 00:51:35.200 |
I mean, yeah, previously there was this notion of-- I kind of feel like maybe today everything 00:51:41.920 |
is a graph neural network, because the transformer is a graph neural network processor. The native 00:51:46.080 |
representation that the transformer operates over is sets that are connected by edges in a directed 00:51:50.880 |
way. And so that's the native representation. And then, yeah. OK, I should go on, because I still 00:51:57.440 |
Sorry, sorry, sorry. There's a question I want to say about this. [INAUDIBLE] 00:52:07.920 |
Oh, yeah. Yeah, the root D, I think, basically like if you're initializing with random weights 00:52:14.320 |
separate from a Gaussian, as your dimension size grows, so does your values, the variance grows, 00:52:19.600 |
and then your softmax will just become the one-half vector. So it's just a way to control 00:52:24.720 |
the variance and bring it to always be in a good range for softmax and nice diffuse distribution. 00:52:29.200 |
OK, so it's almost like an initialization thing. OK, so transformers have been applied to all the 00:52:41.760 |
other fields. And the way this was done is, in my opinion, kind of ridiculous ways, honestly, 00:52:47.680 |
because I was a computer vision person, and you have comm nets, and they kind of make sense. 00:52:51.680 |
So what we're doing now with bits, as an example, is you take an image, and you chop it up into 00:52:55.520 |
little squares. And then those squares literally feed into a transformer, and that's it, 00:52:59.440 |
which is kind of ridiculous. And so, I mean, yeah. And so the transformer doesn't even, 00:53:07.040 |
in the simplest case, like really know where these patches might come from. They are usually 00:53:10.800 |
positionally encoded, but it has to sort of like rediscover a lot of the structure, I think, 00:53:16.960 |
of them in some ways. And it's kind of weird to approach it that way. But it's just like 00:53:24.000 |
the simplest baseline of the chomping up big images into small squares and feeding them in 00:53:28.560 |
as like the individual nodes actually works fairly well. And then this is in the transformer encoder. 00:53:32.640 |
So all the patches are talking to each other throughout the entire transformer. 00:53:35.920 |
And the number of nodes here would be sort of like nine. 00:53:39.440 |
Also, in speech recognition, you just take your MEL spectrogram, and you chop it up into little 00:53:46.080 |
slices and feed them into a transformer. So there was paper like this, but also Whisper. Whisper is 00:53:50.320 |
a copy-based transformer. If you saw Whisper from OpenAI, you just chop up a MEL spectrogram and 00:53:55.600 |
feed it into a transformer, and then pretend you're dealing with text, and it works very well. 00:53:59.760 |
Decision transformer in RL, you take your states, actions, and reward that you experience in 00:54:04.880 |
environment, and you just pretend it's a language, and you start to model the sequences of that. 00:54:09.520 |
And then you can use that for planning later. That works pretty well. Even things like alpha 00:54:13.920 |
folds. So we're frequently talking about molecules and how you can plug them in. So at the heart of 00:54:18.240 |
alpha fold computationally is also a transformer. One thing I wanted to also say about transformers 00:54:23.840 |
is I find that they're super flexible, and I really enjoy that. I'll give you an example from 00:54:28.960 |
Tesla. You have a ComNet that takes an image and makes predictions about the image. And then the 00:54:34.880 |
big question is, how do you feed in extra information? And it's not always trivial. Say 00:54:38.640 |
I have additional information that I want to inform, that I want the outputs to be informed 00:54:43.040 |
by. Maybe I have other sensors, like radar. Maybe I have some map information, or a vehicle type, 00:54:47.520 |
or some audio. And the question is, how do you feed information into a ComNet? Where do you feed 00:54:52.000 |
it in? Do you concatenate it? Do you add it? At what stage? And so with a transformer, it's much 00:54:58.080 |
easier, because you just take whatever you want, you chop it up into pieces, and you feed it in 00:55:01.760 |
with a set of what you had before. And you let the self-attention figure out how everything should 00:55:05.200 |
communicate. And that actually, frankly, works. So just chop up everything and throw it into the 00:55:10.000 |
mix is kind of the way. And it frees neural nets from this burden of Euclidean space, 00:55:16.960 |
where previously you had to arrange your computation to conform to the Euclidean 00:55:22.080 |
space of three dimensions of how you're laying out the compute. The compute actually kind of 00:55:26.560 |
happens in almost 3D space, if you think about it. But in attention, everything is just sets. 00:55:31.920 |
So it's a very flexible framework, and you can just throw in stuff into your conditioning set, 00:55:35.680 |
and everything just self-attended over. So it's quite beautiful from that perspective. 00:55:39.760 |
OK. So now, what exactly makes transformers so effective? I think a good example of this 00:55:44.560 |
comes from the GPT-3 paper, which I encourage people to read. Language models are two-shot 00:55:49.440 |
learners. I would have probably renamed this a little bit. I would have said something like, 00:55:54.000 |
transformers are capable of in-context learning, or like meta-learning. That's kind of what makes 00:55:58.960 |
them really special. So basically, the setting that they're working with is, OK, I have some 00:56:03.040 |
context, and I'm trying to, let's say, passage. This is just one example of many. I have a passage, 00:56:07.120 |
and I'm asking questions about it. And then I'm giving, as part of the context, in the prompt, 00:56:12.960 |
I'm giving the questions and the answers. So I'm giving one example of question-answer, 00:56:16.240 |
another example of question-answer, another example of question-answer, and so on. 00:56:19.120 |
And this becomes, oh yeah, people are going to have to leave soon now. 00:56:30.160 |
OK, so what's really interesting is basically like, with more examples given in the context, 00:56:35.440 |
the accuracy improves. And so what that hints at is that the transformer is able to somehow 00:56:40.000 |
learn in the activations without doing any gradient descent in a typical fine-tuning fashion. 00:56:45.120 |
So if you fine-tune, you have to give an example and the answer, and you do fine-tuning 00:56:50.160 |
using gradient descent. But it looks like the transformer, internally in its weights, 00:56:53.760 |
is doing something that looks like potential gradient descent, some kind of a meta-learning 00:56:56.720 |
in the weights of the transformer as it is reading the prompt. And so in this paper, 00:57:00.640 |
they go into, OK, distinguishing this outer loop with stochastic gradient descent and this inner 00:57:04.880 |
loop of the in-context learning. So the inner loop is, as the transformer, sort of like reading the 00:57:09.040 |
sequence almost, and the outer loop is the training by gradient descent. So basically, 00:57:14.560 |
there's some training happening in the activations of the transformer as it is consuming a sequence 00:57:18.640 |
that maybe very much looks like gradient descent. And so there's some recent papers that kind of 00:57:22.320 |
hint at this and study it. And so as an example, in this paper here, they propose something called 00:57:27.440 |
the raw operator. And they argue that the raw operator is implemented by a transformer, 00:57:32.880 |
and then they show that you can implement things like ridge regression on top of a raw operator. 00:57:36.800 |
And so this is kind of giving - their paper is hinting that maybe there is some thing that looks 00:57:41.680 |
like gradient-based learning inside the activations of the transformer. And I think this is not 00:57:46.720 |
impossible to think through, because what is gradient-based learning? Forward pass, 00:57:50.320 |
backward pass, and then update. Well, that looks like a resonant, right, because you're just 00:57:54.800 |
changing - you're adding to the weights. So you start with initial random set of weights, forward 00:58:00.160 |
pass, backward pass, and update your weights, and then forward pass, backward pass, update weights. 00:58:03.840 |
Looks like a resonant. Transformer is a resonant. So much more hand-wavy, but basically some 00:58:11.280 |
papers trying to hint at why that could be potentially possible. And then I have a bunch 00:58:15.760 |
of tweets. I just got them pasted here in the end. This was kind of meant for general consumption, 00:58:20.560 |
so they're a bit more high-level and hype-y a little bit. But I'm talking about why this 00:58:24.960 |
architecture is so interesting and why it potentially became so popular. And I think 00:58:28.720 |
it simultaneously optimizes three properties that I think are very desirable. Number one, 00:58:32.320 |
the transformer is very expressive in the forward pass. It's able to implement very 00:58:37.600 |
interesting functions, potentially functions that can even do meta-learning. Number two, 00:58:42.480 |
it is very optimizable, thanks to things like residual connections, layer knowns, and so on. 00:58:46.240 |
And number three, it's extremely efficient. This is not always appreciated, but the transformer, 00:58:50.080 |
if you look at the computational graph, is a shallow wide network, which is perfect to take 00:58:54.720 |
advantage of the parallelism of GPUs. So I think the transformer was designed very deliberately 00:58:58.560 |
to run efficiently on GPUs. There's previous work like neural GPU that I really enjoy as well, 00:59:04.960 |
which is really just like how do we design neural nets that are efficient on GPUs, and thinking 00:59:09.040 |
backwards from the constraints of the hardware, which I think is a very interesting way to think 00:59:12.000 |
about it. Oh yeah, so here I'm saying I probably would have called the transformer a general 00:59:24.000 |
purpose efficient optimizable computer instead of attention is all you need. That's what I would 00:59:28.960 |
have maybe in hindsight called that paper. It's proposing a model that is very general purpose, 00:59:36.960 |
so forward pass is expressive. It's very efficient in terms of GPU usage, and it's 00:59:42.000 |
easily optimizable by gradient descent, and trains very nicely. Then I have some other hype tweets 00:59:48.160 |
here. Anyway, so you can read them later, but I think this one is maybe interesting. 00:59:54.960 |
So if previous neural nets are special purpose computers designed for a specific task, 01:00:00.160 |
GPT is a general purpose computer reconfigurable at runtime to run natural language programs. 01:00:05.920 |
So the programs are given as prompts, and then GPT runs the program by completing the document. 01:00:11.200 |
So I really like these analogies personally to computer. It's just like a powerful computer, 01:00:30.000 |
I don't know. Okay, you can read this later, but for now I'll just leave this up. 01:00:36.720 |
So sorry, I just found this tweet. So it turns out that if you scale up the training set 01:00:49.520 |
and use a powerful enough neural net like a transformer, the network becomes a kind of 01:00:52.880 |
general purpose computer over text. So I think that's a kind of like nice way to look at it, 01:00:56.640 |
and instead of performing a single text sequence, you can design the sequence in the prompt, 01:01:00.400 |
and because the transformer is both powerful but also is trained on a large enough, 01:01:04.000 |
very hard data set, it kind of becomes a general purpose text computer, 01:01:07.440 |
and so I think that's kind of interesting way to look at it. Yeah? 01:01:12.320 |
Um, you have three points to the vote. Yeah. Um, so I guess, like, for me, I learned about 01:01:56.720 |
kind of, like, the idea that, like, you think there's really no harm from gradient descent, 01:02:01.280 |
and I guess my question is, how much do you think it's, like, 01:02:05.040 |
it's pretty, really, like, most of it, you know, like, do they really think that it's 01:02:10.800 |
mostly more efficient, or do you think it's very, sort of, like, something that you have that, like, 01:02:17.120 |
you need the equivalent value of specific [inaudible] or do you [inaudible] 01:02:24.240 |
Yeah. So I think there's a bit of that, yeah. So I would say RNNs, like, in principle, yes, 01:02:29.680 |
they can implement arbitrary programs. I think it's kind of, like, a useless statement to some 01:02:33.280 |
extent, because they are not - they're probably - I'm not sure that they're probably expressive, 01:02:38.080 |
because in a sense of, like, power, in that they can implement these arbitrary functions, 01:02:42.160 |
but they're not optimizable, and they're certainly not efficient, because they are serial 01:02:46.960 |
computing devices. So I think - so if you look at it as a compute graph, RNNs are very long, 01:02:53.680 |
thin compute graph. Like, if you stretched out the neurons, and you look, like, take all the 01:03:01.200 |
individual neurons in our connectivity, and stretch them out, and try to visualize them, 01:03:04.400 |
RNNs would be, like, a very long graph, and it's bad, and it's bad also for optimizability, because 01:03:09.440 |
I don't exactly know why, but just the rough intuition is when you're backpropagating, 01:03:13.680 |
you don't want to make too many steps. And so transformers are a shallow, wide graph, 01:03:18.240 |
and so from supervision to inputs is a very small number of hops, and it's along residual pathways, 01:03:25.760 |
which make gradients flow very easily, and there's all these layer norms to control the 01:03:28.960 |
scales of all of those activations. And so there's not too many hops, and you're going 01:03:35.760 |
from supervision to input very quickly, and this flows through the graph. And it can all be done 01:03:41.920 |
in parallel, so you don't need to do this encoder-decoder RNNs, you have to go from first word, 01:03:46.080 |
then second word, then third word, but here in transformer, every single word was processed 01:03:50.400 |
completely as sort of in parallel, which is kind of - so I think all these are really important, 01:03:56.080 |
because all these are really important, and I think number three is less talked about, 01:04:00.080 |
but extremely important, because in deep learning, scale matters, and so the size of the network that 01:04:04.800 |
you can train gives you - is extremely important, and so if it's efficient on the current hardware, 01:04:16.080 |
No, so yeah, so you take your image, and you apparently chop them up into patches, 01:04:33.040 |
so there's the first thousand tokens or whatever, and now I have a special - so radar could be 01:04:38.880 |
also - but I don't actually know the native representation of radar, so - but you could - you 01:04:44.720 |
just need to chop it up and enter it, and then you have to encode it somehow. Like, the transformer 01:04:48.160 |
needs to know that they're coming from radar, so you create a special - you have some kind of a 01:04:52.800 |
special token that you - like, these radar tokens are slightly different in representation, and it's 01:04:58.720 |
learnable by gradient descent, and like, vehicle information would also come in with a special 01:05:04.320 |
embedding token that can be learned. So have you learned those, like, orally? 01:05:13.440 |
Yeah, it's all just a set, but you can positionally encode these sets if you want, 01:05:23.120 |
so - but positional encoding means you can hardwire, for example, the coordinates, like using 01:05:28.320 |
sinusoids and cosines, you can hardwire that, but it's better if you don't hardwire the position, 01:05:33.120 |
you just - it's just a vector that is always hanging out at this location, 01:05:36.240 |
whatever content is there just adds on it, and this vector is trainable by background, 01:05:43.760 |
I'm not sure if I understand the question. So I mean, the positional encoder is like, 01:06:12.800 |
they're actually like, not - they have - okay, so they have very little inductive 01:06:16.560 |
bias or something like that, they're just vectors hanging out in location always, 01:06:19.520 |
and you're trying to help the network in some way, and I think the intuition is good, but 01:06:27.360 |
like, if you have enough data, usually trying to mess with it is like a bad thing. 01:06:31.920 |
Like, trying to enter knowledge when you have enough knowledge in the data set itself 01:06:36.800 |
is not usually productive, so it really depends on what scale you are. If you have infinity data, 01:06:41.280 |
then you actually want to encode less and less, that turns out to work better, 01:06:44.160 |
and if you have very little data, then actually you do want to encode some biases, 01:06:47.520 |
and maybe if you have a much smaller data set, then maybe convolutions are a good idea, 01:06:50.640 |
because you actually have this bias coming from more filters. And so - but I think - so the 01:06:57.920 |
transformer is extremely general, but there are ways to mess with the encodings to put in more 01:07:01.600 |
structure, like you could, for example, encode sinuses and cosines and fix it, or you could 01:07:05.840 |
actually go to the attention mechanism and say, okay, if my image is chopped up into patches, 01:07:11.040 |
this patch can only communicate to this neighborhood, and you can - you just do that 01:07:14.080 |
in the attention matrix, just mask out whatever you don't want to communicate. And so people 01:07:18.720 |
really play with this, because the full attention is inefficient, so they will intersperse, 01:07:23.840 |
for example, layers that only communicate in little patches, and then layers that communicate 01:07:27.920 |
globally, and they will sort of do all kinds of tricks like that. So you can slowly bring in more 01:07:32.960 |
inductive bias, you would do it - but the inductive biases are sort of like, they're factored out from 01:07:37.920 |
the core transformer, and they are factored out in the connectivity of the nodes, and they are 01:07:42.960 |
factored out in the positional encodings, and you can mess with this for computation. 01:08:02.480 |
So there's probably about 200 papers on this now, if not more. They're kind of hard to track up, 01:08:07.920 |
honestly, like my Safari browser, which is - oh, it's on my computer, like 200 open tabs. But 01:08:13.840 |
yes, I'm not even sure if I want to pick my favorite, honestly. 01:08:22.000 |
Yeah, I think it was a very interesting talk from you this year, and you can think of a 01:08:28.160 |
transformer as like a CPU. I think the first test was to take five instructions out of like 01:08:32.400 |
4,000 programs, and then now, at the beginning of the CPU, what you have is like you store variables, 01:08:37.440 |
you have memory, so it's like, if I want to do a debugger program of the CPU, I just do it 01:08:41.120 |
multiple times. So maybe you can use a transformer like that. The other one that I actually like 01:08:46.640 |
even more is potentially keep the context length fixed, but allow the network to somehow use a 01:08:51.600 |
scratchpad. And so the way this works is you will teach the transformer somehow, via examples in 01:08:57.040 |
the prompt, that hey, you actually have a scratchpad. Hey, basically, you can't remember too 01:09:01.440 |
much. Your context line is finite. But you can use a scratchpad, and you do that by emitting a start 01:09:06.080 |
scratchpad, and then writing whatever you want to remember, and then end scratchpad. And then 01:09:10.480 |
you continue with whatever you want. And then later, when it's decoding, you actually have 01:09:15.200 |
special logic that when you detect start scratchpad, you will sort of like save whatever 01:09:19.520 |
it puts in there in like external thing, and allow it to attend over it. So basically, you can teach 01:09:23.920 |
the transformer just dynamically, because it's so meta-learned. You can teach it dynamically to use 01:09:28.800 |
other gizmos and gadgets, and allow it to expand its memory that way, if that makes sense. It's 01:09:32.800 |
just like human learning to use a notepad, right? You don't have to keep it in your brain. So keeping 01:09:37.440 |
things in your brain is kind of like the context length of the transformer. But maybe we can just 01:09:40.720 |
give it a notebook. And then it can query the notebook, and read from it, and write to it. 01:10:08.960 |
I don't know if I detected that. I kind of feel like-- did you feel like it was more than just 01:10:19.680 |
I didn't try extensively, but I did see a forgetting event. And I kind of felt like 01:10:24.160 |
the block size was just moved. Maybe I'm wrong. I don't actually know about the internals of 01:10:51.600 |
I mean, so right now, I'm working on things like nano-GPT. Where's nano-GPT? 01:10:59.040 |
I mean, I'm going basically slightly from computer vision and kind of computer vision-based 01:11:03.360 |
products to a little bit in the language domain. Where's chat-GPT? OK, nano-GPT. So originally, 01:11:08.160 |
I had min-GPT, which I rewrote to nano-GPT. And I'm working on this. I'm trying to reproduce 01:11:12.320 |
GPTs. And I mean, I think something like chat-GPT, I think, incrementally improved in a product 01:11:17.840 |
fashion would be extremely interesting. And I think a lot of people feel it. And that's why 01:11:23.440 |
it went so wide. So I think there's something like a Google plus, plus, plus to build that 01:11:29.040 |
I think is really interesting. So we did our speed around the clock.