back to index

Stanford CS25: V2 I Introduction to Transformers w/ Andrej Karpathy


Chapters

0:0 Introduction
0:47 Introducing the Course
3:19 Basics of Transformers
3:35 The Attention Timeline
5:1 Prehistoric Era
6:10 Where we were in 2021
7:30 The Future
10:15 Transformers - Andrej Karpathy
10:39 Historical context
60:30 Thank you - Go forth and transform

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi everyone. Welcome to CS25 Transformers United V2. This was a course that was held
00:00:11.000 | at Stanford in the winter of 2023. This course is not about robots that can transform into
00:00:15.720 | cars, as this picture might suggest. Rather, it's about deep learning models that have
00:00:19.980 | taken the world by the storm and have revolutionized the field of AI and others. Starting from
00:00:24.460 | natural language processing, transformers have been applied all over, from computer
00:00:27.880 | vision, reinforcement learning, biology, robotics, etc. We have an exciting set of videos lined
00:00:33.560 | up for you, with some truly fascinating speakers giving talks, presenting how they're applying
00:00:38.720 | transformers to the research in different fields and areas. We hope you'll enjoy and
00:00:46.280 | learn from these videos. So without any further ado, let's get started.
00:00:52.200 | This is a purely introductory lecture, and we'll go into the building blocks of transformers.
00:00:59.000 | So first, let's start with introducing the instructors.
00:01:03.600 | So for me, I'm currently on a temporary deferral from the PhD program, and I'm leading AI at
00:01:08.000 | a robotic startup, Collaborative Robotics, working on some general-purpose robots, somewhat
00:01:13.000 | like a robot. And yeah, I'm very passionate about robotics and building efficient learning
00:01:18.920 | systems. My research interests are in reinforcement learning, computer vision, linear modeling,
00:01:23.800 | and I have a bunch of publications in robot technology and other areas. My undergrad was
00:01:29.080 | at Cornell, it's a municipal Cornell, so nice to meet all.
00:01:33.960 | So I'm Stephen, I'm a first-year CSP speaker. Previously did my master's at CMU, and undergrad
00:01:39.800 | at Waterloo. I'm mainly into NLP research, anything involving language and text. But
00:01:44.760 | more recently, I've been getting more into computer vision as well as multilingual. And
00:01:49.080 | just some stuff I do for fun, a lot of music stuff, mainly piano. Some self-promo, but
00:01:54.080 | I post a lot on my Insta, YouTube, and TikTok, so if you guys want to check it out.
00:01:58.840 | My friends and I are also starting a Stanford Piano Club, so if anybody's interested, feel
00:02:04.000 | free to email me for details. Other than that, martial arts, bodybuilding, and huge fan of
00:02:11.600 | K-dramas, anime, and occasional gamer.
00:02:14.880 | Okay, cool. Yeah, so my name's Ryland. Instead of talking about myself, I just want to very
00:02:22.160 | briefly say that I'm super excited to take this class. I took it the last time it was
00:02:27.920 | offered, I had a bunch of fun. I thought we brought in a really great group of speakers
00:02:32.640 | last time, I'm super excited for this offering. And yeah, I'm thankful that you're all here,
00:02:36.960 | and I'm looking forward to a really fun quarter together. Thank you.
00:02:39.360 | Yeah, so fun fact, Ryland was the most outspoken student last year, and so if someone wants
00:02:44.720 | to become an instructor next year, you know what to do.
00:02:52.240 | Okay, cool. So what we hope you will learn in this class is, first of all, how do task
00:03:00.320 | forms work? How they're being applied? And nowadays, like, we are pretty much everywhere
00:03:06.960 | in AI machine learning. And what are some new interesting directions of research in
00:03:13.680 | this topics?
00:03:15.680 | Cool. So this class is just an introductory, so we'll be just talking about the basics
00:03:21.280 | of transformers, introducing them, talking about the self-attention mechanism on which
00:03:25.520 | they're founded, and we'll do a deep dive more on, like, models like BERT, GPT, stuff
00:03:31.520 | like that. So, great, happy to get started. Okay, so let me start with presenting the
00:03:38.160 | attention timeline. Attention all started with this one paper, Attention is All You
00:03:43.760 | Need, by Baswani et al. in 2017. That was the beginning of transformers. Before that,
00:03:49.760 | we had the prehistoric era, where we had models like RNNs, LSTMs, and their simple attention
00:03:56.800 | mechanisms that didn't evolve or scale at all. Starting 2017, we saw this explosion
00:04:02.400 | of transformers into NLP, where people started using it for everything. I even heard this
00:04:07.840 | quote from Google, it's like, "Our performance increased every time we fired our linguists."
00:04:11.520 | For the first 90, 80, after 2018 to 2020, we saw this explosion of transformers into
00:04:19.040 | other fields, like vision, a bunch of other stuff, and biology, alpha foli. And last year,
00:04:27.120 | 2021 was the start of the generative era, where we got a lot of generative modeling.
00:04:30.720 | Started with models like CODEX, GPT, DALI, stable diffusion, so a lot of things happening
00:04:37.520 | in generative modeling. And we started scaling up in AI. And now it's the present. So this
00:04:46.160 | is 2022 and the start of 2023. And now we have models like Chai-3PP, Whisper, a bunch of others.
00:04:54.560 | And we are scaling onwards without slowing down. So that's great. So that's the future.
00:04:59.440 | So going more into this, so once there were RNNs, so we had sequence-to-sequence models, LSTMs,
00:05:09.200 | GLUs. What worked here was that they were good at encoding history. But what did not work was
00:05:15.440 | they didn't encode long sequences. And they were very bad at encoding context.
00:05:19.440 | So consider this example. Consider trying to predict the last word in the text, "I grew up in
00:05:27.920 | France, dot, dot, dot. I speak fluent, dash." Here, you need to understand the context for it
00:05:33.600 | to predict French. And attention mechanism is very good at that. Whereas if you're just using LSTMs,
00:05:39.600 | it doesn't work that well. Another thing transformers are good at is more based on
00:05:47.040 | content is-- sorry. Also, context prediction is like finding attention maps. If I have something
00:05:54.480 | like a word like "it," what noun does it correlate to? And we can give a probability attention on
00:06:02.400 | what are the possible activations. And this works better than existing mechanisms.
00:06:09.120 | OK. So where we were in 2021, we were on the verge of takeoff. We were starting to realize
00:06:16.880 | the potential of transformers in different fields. We solved a lot of long-sequence
00:06:21.680 | problems like protein folding, alpha fold, offline RL. We started to see few shots,
00:06:29.280 | zero-shot generalization. We saw multimodal tasks and applications like generating images
00:06:34.160 | from language. So that was DALI. Yeah. And it feels like Asian, but it was only like two years
00:06:40.160 | ago. And this is also a talk on transformers that you can watch on YouTube. Cool. And this is where
00:06:51.040 | we were going from 2021 to 2022, which is we have gone from the verge of taking off to actually
00:06:57.200 | taking off. And now we are seeing unique applications in audio generation, art, music,
00:07:01.920 | storytelling. We are starting to see reasoning capabilities like common sense, logical reasoning,
00:07:08.160 | mathematical reasoning. We are also able to now get human enlightenment and interaction.
00:07:14.000 | They're able to use reinforcement learning and human feedback. That's how trajectories train to
00:07:18.160 | perform really good. We have a lot of mechanisms for controlling toxicity, bias, and ethics now.
00:07:24.400 | And also a lot of developments in other areas like digital models. Cool.
00:07:30.640 | So the future is a spaceship, and we are all excited about it.
00:07:36.400 | And there's a lot more applications that we can enable. And it'd be great if you can see
00:07:45.600 | transformers also work there. One big example is video understanding and generation. That is
00:07:50.080 | something that everyone is interested in. And I'm hoping we'll see a lot of models in this
00:07:54.000 | area this year. Also finance, business. I'll be very excited to see GBT author novel,
00:08:01.920 | but we need to solve very long sequence modeling. And most transformers models are still limited to
00:08:07.920 | like 4,000 tokens or something like that. So we need to make them generalize much more better
00:08:14.240 | on long sequences. We also want to have generalized agents that can do a lot of multitask
00:08:22.320 | analytic input predictions like Gato. And so I think we will see more of that too. And finally,
00:08:30.640 | we also want domain-specific models. So you might want like a GBT model that's good at like maybe
00:08:38.480 | like your health. So that could be like a doctor GBT model. You might have like a lawyer GBT model
00:08:43.200 | that's like gain on only on law data. So currently we have like GBT models that have gain on everything,
00:08:47.440 | but we might start to see more niche models that are like good at one task. And we could have like
00:08:52.480 | a mixture of experts. It's like, you can think like, this is like how you normally consult an
00:08:56.640 | expert. You'll have like expert AI models and you can go to a different AI model for your different
00:09:00.080 | needs. There are still a lot of missing ingredients to make this all successful. The first of all is
00:09:09.840 | external memory. We are already starting to see this with the models like GenGBT, where
00:09:15.840 | the interactions are short-lived. There's no long-term memory, and they don't have ability
00:09:20.320 | to remember or store conversations for long term. And this is something we want to fix.
00:09:26.800 | Second is reducing the computation complexity. So attention mechanism is quadratic over
00:09:34.320 | the sequence length, which is slow. And we want to reduce it or make it faster.
00:09:39.520 | Another thing we want to do is we want to enhance the controllability of this model.
00:09:45.760 | It's like a lot of these models can be stochastic, and we want to be able to control what sort of
00:09:50.560 | outputs we get from them. And you might have experienced with GenGBT, if you just refresh,
00:09:55.200 | you get like different output each time, but you might want to have a mechanism that controls
00:09:59.040 | what sort of things you get. And finally, we want to align our state of art language models with
00:10:04.800 | how the human brain works. And we are seeing the search, but we still need more research on seeing
00:10:10.320 | how it can be manipulated. Thank you. Great. Hi. Yes, I'm excited to be here. I live very nearby,
00:10:19.360 | so I got the invites to come to class. And I was like, OK, I'll just walk over.
00:10:22.560 | But then I spent like 10 hours on the slides, so it wasn't as simple.
00:10:26.880 | So yeah, I want to talk about transformers. I'm going to skip the first two over there.
00:10:32.560 | We're not going to talk about those. We'll talk about that one,
00:10:35.040 | just to simplify the lecture since we don't have time. OK. So I wanted to provide a little bit of
00:10:41.440 | context of why does this transformers class even exist. So a little bit of historical context.
00:10:45.840 | I feel like Bilbo over there. I joined, like telling you guys about this. I don't know if
00:10:50.880 | you guys saw the drinks. And basically, I joined AI in roughly 2012 in full force,
00:10:56.720 | so maybe a decade ago. And back then, you wouldn't even say that you joined AI, by the way. That was
00:11:00.800 | like a dirty word. Now it's OK to talk about. But back then, it was not even deep learning. It was
00:11:05.760 | machine learning. That was a term you would use if you were serious. But now, AI is OK to use,
00:11:10.880 | I think. So basically, do you even realize how lucky you are potentially entering this area
00:11:15.280 | in roughly 2003? So back then, in 2011 or so, when I was working specifically on computer vision,
00:11:21.200 | your pipelines looked like this. So you wanted to classify some images. You would go to a paper,
00:11:29.120 | and I think this is representative. You would have three pages in the paper describing all kinds of
00:11:33.680 | zoo of kitchen sink of different kinds of features and descriptors. And you would go to a poster
00:11:38.560 | session and in computer vision conference, and everyone would have their favorite feature
00:11:41.840 | descriptors that they're proposing. It was totally ridiculous. And you would take notes on which one
00:11:45.200 | you should incorporate into your pipeline, because you would extract all of them, and then you would
00:11:48.640 | put an SVM on top. So that's what you would do. So there's two pages. Make sure you get your sparse
00:11:53.040 | SIP histograms, your SSIMs, your color histograms, textiles, tiny images. And don't forget the
00:11:57.920 | geometry-specific histograms. All of them had basically complicated code by themselves. So
00:12:02.640 | you're collecting code from everywhere and running it, and it was a total nightmare.
00:12:05.680 | So on top of that, it also didn't work. So this would be, I think, represents the prediction from
00:12:14.480 | that time. You would just get predictions like this once in a while, and you'd be like, you just
00:12:18.480 | shrug your shoulders like that just happens once in a while. Today, you would be looking for a bug.
00:12:23.680 | And worse than that, every single chunk of AI had their own completely separate
00:12:32.480 | vocabulary that they work with. So if you go to NLP papers, those papers would be completely
00:12:37.520 | different. So you're reading the NLP paper, and you're like, what is this part of speech tagging,
00:12:42.320 | morphological analysis, syntactic parsing, coreference resolution? What is NP, BT,
00:12:47.520 | KJ, and your compute? So the vocabulary and everything was completely different,
00:12:51.280 | and you couldn't read papers, I would say, across different areas.
00:12:53.920 | So now that changed a little bit starting in 2012 when Oskar Krzyzewski and the colleagues
00:13:00.960 | basically demonstrated that if you scale a large neural network on a large data set,
00:13:06.640 | you can get very strong performance. And so up till then, there was a lot of focus on algorithms,
00:13:10.960 | but this showed that actually neural nets scale very well. So you need to now worry about compute
00:13:14.880 | and data, and if you scale it up, it works pretty well. And then that recipe actually did copy-paste
00:13:19.360 | across many areas of AI. So we started to see neural networks pop up everywhere since 2012.
00:13:25.040 | So we saw them in computer vision, and NLP in speech, and translation in RL, and so on. So
00:13:30.480 | everyone started to use the same kind of modeling tool kit, modeling framework. And now when you go
00:13:34.720 | to NLP and you start reading papers there, in machine translation, for example, this is a
00:13:39.120 | sequence-to-sequence paper, which we'll come back to in a bit. You start to read those papers, and
00:13:43.360 | you're like, OK, I can recognize these words, like there's a neural network, there's a parameter,
00:13:46.960 | there's an optimizer, and it starts to read things that you know of. So that decreased
00:13:52.240 | tremendously the barrier to entry across the different areas. And then I think the big deal
00:13:57.840 | is that when the transformer came out in 2017, it's not even that just the toolkits and the
00:14:02.480 | neural networks were similar, it's that literally the architectures converged to one architecture
00:14:06.960 | that you copy-paste across everything seemingly. So this was kind of an unassuming machine
00:14:12.560 | translation paper at the time proposing the transformer architecture, but what we found
00:14:15.680 | since then is that you can just basically copy-paste this architecture and use it everywhere,
00:14:21.520 | and what's changing is the details of the data and the chunking of the data and how you feed it in.
00:14:26.560 | And that's a caricature, but it's kind of like a correct first-order statement.
00:14:29.920 | And so now papers are even more similar looking because everyone's just using
00:14:33.760 | transformer. And so this convergence was remarkable to watch and unfolded over the
00:14:39.440 | last decade, and it's pretty crazy to me. What I find kind of interesting is I think this is
00:14:44.560 | some kind of a hint that we're maybe converging to something that maybe the brain is doing,
00:14:47.760 | because the brain is very homogeneous and uniform across the entire sheet of your cortex.
00:14:52.880 | And okay, maybe some of the details are changing, but those feel like hyperparameters of a
00:14:56.640 | transformer, but your auditory cortex and your visual cortex and everything else looks very
00:15:00.320 | similar. And so maybe we're converging to some kind of a uniform, powerful learning algorithm
00:15:04.880 | here, something like that, I think is kind of interesting and exciting.
00:15:08.000 | Okay, so I want to talk about where the transformer came from briefly, historically.
00:15:12.880 | So I want to start in 2003. I like this paper quite a bit. It was the first sort of popular
00:15:20.000 | application of neural networks to the problem of language modeling. So predicting, in this case,
00:15:24.160 | the next word in a sequence, which allows you to build generative models over text.
00:15:27.680 | And in this case, they were using multi-layer perceptron, so a very simple neural net.
00:15:31.040 | The neural nets took three words and predicted the probability distribution for the fourth word
00:15:34.720 | in a sequence. So this was well and good at this point. Now, over time, people started to apply
00:15:41.280 | this to machine translation. So that brings us to sequence-to-sequence paper from 2014 that was
00:15:46.960 | pretty influential. And the big problem here was, okay, we don't just want to take three words and
00:15:51.360 | predict the fourth. We want to predict how to go from an English sentence to a French sentence.
00:15:56.720 | And the key problem was, okay, you can have arbitrary number of words in English and
00:16:00.400 | arbitrary number of words in French, so how do you get an architecture that can process this
00:16:05.120 | variably-sized input? And so here, they used a LSTM. And there's basically two chunks of this,
00:16:11.120 | which are covered by the Slack, by this. But basically, you have an encoder LSTM on the left,
00:16:18.960 | and it just consumes one word at a time and builds up a context of what it has read. And then that
00:16:25.280 | acts as a conditioning vector to the decoder RNN or LSTM that basically goes chunk, chunk,
00:16:30.320 | chunk for the next word in the sequence, translating the English to French or something
00:16:34.640 | like that. Now, the big problem with this that people identified, I think, very quickly and
00:16:39.040 | tried to resolve is that there's what's called this encoded bottleneck. So this entire English
00:16:45.280 | sentence that we are trying to condition on is packed into a single vector that goes from the
00:16:49.600 | encoder to the decoder. And so this is just too much information to potentially maintain in a
00:16:53.440 | single vector, and that didn't seem correct. And so people were looking around for ways to alleviate
00:16:57.680 | the attention of sort of the encoded bottleneck, as it was called at the time. And so that brings
00:17:02.800 | us to this paper, Neural Machine Translation by Jointly Learning to Align and Translate.
00:17:06.960 | And here, just going from the abstract, in this paper, we conjectured that use of a fixed-length
00:17:13.200 | vector is a bottleneck in improving the performance of the basic encoded-decoder
00:17:16.720 | architecture, and proposed to extend this by allowing the model to automatically soft-search
00:17:21.680 | for parts of the source sentence that are relevant to predicting a target word,
00:17:26.880 | without having to form these parts or hard segments exclusively. So this was a way to look
00:17:32.880 | back to the words that are coming from the encoder, and it was achieved using this soft-search. So as
00:17:38.720 | you are decoding the words here, while you are decoding them, you are allowed to look back at
00:17:45.360 | the words at the encoder via this soft attention mechanism proposed in this paper. And so this
00:17:50.960 | paper, I think, is the first time that I saw, basically, attention. So your context vector
00:17:57.760 | that comes from the encoder is a weighted sum of the hidden states of the words in the encoding,
00:18:04.480 | and then the weights of this sum come from a softmax that is based on these compatibilities
00:18:11.200 | between the current state, as you're decoding, and the hidden states generated by the encoder.
00:18:15.760 | And so this is the first time that really you start to look at it, and this is the current
00:18:20.800 | modern equations of the attention. And I think this was the first paper that I saw it in. It's
00:18:25.520 | the first time that there's a word "attention" used, as far as I know, to call this mechanism.
00:18:31.360 | So I actually tried to dig into the details of the history of the attention.
00:18:35.600 | So the first author here, Dimitri, I had an email correspondence with him,
00:18:40.000 | and I basically sent him an email. I'm like, "Dimitri, this is really interesting. Transformers
00:18:43.360 | have taken over. Where did you come up with the soft attention mechanism that ends up being the
00:18:46.880 | heart of the transformer?" And to my surprise, he wrote me back this massive email, which was
00:18:52.160 | really fascinating. So this is an excerpt from that email. So basically, he talks about how he
00:18:58.560 | was looking for a way to avoid this bottleneck between the encoder and decoder. He had some
00:19:02.720 | ideas about cursors that traversed the sequences that didn't quite work out. And then here - so
00:19:08.000 | one day, I had this thought that it would be nice to enable the decoder RNN to learn how to search
00:19:11.840 | where to put the cursor in the source sequence. This was sort of inspired by translation exercises
00:19:16.400 | that learning English in my middle school involved. You gaze shifts back and forth between
00:19:22.720 | source and target sequence as you translate. So literally, I thought this was kind of interesting
00:19:26.960 | that he's not a native English speaker. And here, that gave him an edge in this machine translation
00:19:31.200 | that led to attention and then led to transformer. So that's really fascinating. I expressed a soft
00:19:37.680 | search as softmax and then weighted averaging of the binary states. And basically, to my great
00:19:42.560 | excitement, this worked from the very first try. So really, I think, interesting piece of history.
00:19:48.320 | And as it later turned out that the name of RNN search was kind of lame. So the better name
00:19:53.520 | attention came from Yoshua on one of the final passes as they went over the paper. So maybe
00:20:00.000 | attention is all I need would have been called like RNN searches. But we have Yoshua Bengio to
00:20:04.960 | thank for a little bit of better name, I would say. So apparently, that's the history of this.
00:20:11.360 | OK, so that brings us to 2017, which is attention is all you need. So this attention component,
00:20:16.240 | which in Dimitri's paper was just like one small segment. And there's all this bi-directional RNN,
00:20:20.960 | RNN and decoder. And this attention-only paper is saying, OK, you can actually delete everything.
00:20:26.880 | What's making this work very well is just attention by itself. And so delete everything,
00:20:31.280 | keep attention. And then what's remarkable about this paper, actually, is usually you see papers
00:20:35.520 | that are very incremental. They add one thing, and they show that it's better. But I feel like
00:20:40.560 | attention is all you need with a mix of multiple things at the same time. They were combined in a
00:20:44.960 | very unique way, and then also achieved a very good local minimum in the architecture space.
00:20:50.720 | And so to me, this is really a landmark paper that is quite remarkable, and I think had quite a lot
00:20:56.480 | of work behind the scenes. So delete all the RNN, just keep attention. Because attention operates
00:21:03.040 | over sets, and I'm going to go into this in a second, you now need to positionally encode your
00:21:06.640 | inputs, because attention doesn't have the notion of space by itself. They - oops, I have to be
00:21:15.280 | very careful - they adopted this residual network structure from ResNets. They interspersed
00:21:22.160 | attention with multi-layer perceptrons. They used layer norms, which came from a different paper.
00:21:27.680 | They introduced the concept of multiple heads of attention that were applied in parallel.
00:21:30.800 | And they gave us, I think, like a fairly good set of hyperparameters that to this day are used.
00:21:35.440 | So the expansion factor in the multi-layer perceptron goes up by 4x, and we'll go into
00:21:40.800 | like a bit more detail, and this 4x has stuck around. And I believe there's a number of papers
00:21:45.280 | that try to play with all kinds of little details of the transformer, and nothing sticks, because
00:21:50.000 | this is actually quite good. The only thing to my knowledge that stuck, that didn't stick, was this
00:21:55.840 | reshuffling of the layer norms to go into the pre-norm version, where here you see the layer
00:21:59.920 | norms are after the multi-headed attention repeat forward, but they just put them before instead.
00:22:04.400 | So just reshuffling of layer norms, but otherwise the GPTs and everything else that you're seeing
00:22:08.160 | today is basically the 2017 architecture from five years ago. And even though everyone is
00:22:13.120 | working on it, it's proven remarkably resilient, which I think is real interesting. There are
00:22:18.000 | innovations that I think have been adopted also in positional encodings. It's more common to use
00:22:22.640 | different rotary and relative positional encodings and so on. So I think there have been changes,
00:22:27.600 | but for the most part it's proven very resilient. So really quite an interesting paper. Now I wanted
00:22:33.360 | to go into the attention mechanism, and I think, I sort of like, the way I interpret it is not
00:22:40.160 | similar to the ways that I've seen it presented before. So let me try a different way of like
00:22:47.280 | how I see it. Basically to me, attention is kind of like the communication phase of the transformer,
00:22:51.440 | and the transformer interleaves two phases. The communication phase, which is the multi-headed
00:22:56.560 | attention, and the computation stage, which is this multilayer perceptron, or P12. So in the
00:23:01.760 | communication phase, it's really just a data-dependent message passing on directed graphs.
00:23:06.400 | And you can think of it as, okay, forget everything with machine translation and everything. Let's
00:23:11.280 | just, we have directed graphs at each node. You are storing a vector. And then let me talk now
00:23:18.080 | about the communication phase of how these vectors talk to each other in this directed graph. And
00:23:21.760 | then the compute phase later is just a multilayer perceptron, which now, which then basically acts
00:23:27.200 | on every node individually. But how do these nodes talk to each other in this directed graph?
00:23:32.240 | So I wrote like some simple Python, like I wrote this in Python basically to create one round of
00:23:40.080 | communication of using attention as the direct, as the message passing scheme. So here, a node
00:23:48.560 | has this private data vector, as you can think of it as private information to this node. And then
00:23:54.720 | it can also emit a key, a query, and a value. And simply that's done by linear transformation
00:23:59.440 | from this node. So the key is, what are the things that I am, sorry, the query is, what are the
00:24:08.960 | things that I'm looking for? The key is, what are the things that I have? And the value is, what are
00:24:13.040 | the things that I will communicate? And so then when you have your graph that's made up of nodes
00:24:17.600 | and some random edges, when you actually have these nodes communicating, what's happening is
00:24:21.360 | you loop over all the nodes individually in some random order, and you are at some node,
00:24:27.040 | and you get the query vector q, which is, I'm a node in some graph, and this is what I'm looking
00:24:33.200 | for. And so that's just achieved via this linear transformation here. And then we look at all the
00:24:37.920 | inputs that point to this node, and then they broadcast, what are the things that I have,
00:24:42.320 | which is their keys. So they broadcast the keys, I have the query, then those interact by dot product
00:24:49.600 | to get scores. So basically, simply by doing dot product, you get some kind of an unnormalized
00:24:55.520 | weighting of the interestingness of all of the information in the nodes that point to me and to
00:25:00.960 | the things I'm looking for. And then when you normalize that with a submax, so it just sums to
00:25:04.800 | one, you basically just end up using those scores, which now sum to one and are a probability
00:25:09.600 | distribution, and you do a weighted sum of the values to get your update. So I have a query,
00:25:17.280 | they have keys, dot product to get interestingness, or like affinity, submax to normalize it,
00:25:23.840 | and then weighted sum of those values flow to me and update me. And this is happening for each
00:25:28.720 | node individually, and then we update at the end. And so this kind of a message passing scheme is
00:25:32.800 | kind of like at the heart of the transformer, and happens in a more vectorized, batched way
00:25:40.240 | that is more confusing, and is also interspersed with layer norms and things like that to make the
00:25:45.760 | training behave better. But that's roughly what's happening in the attention mechanism, I think,
00:25:50.560 | on a high level. So yeah, so in the communication phase of the transformer, then this message
00:25:59.680 | passing scheme happens in every head in parallel, and then in every layer in series, and with
00:26:06.800 | different weights each time. And that's it as far as the multi-headed attention goes.
00:26:13.120 | And so if you look at these encoder-decoder models, you can sort of think of it then,
00:26:17.040 | in terms of the connectivity of these nodes in the graph, you can kind of think of it as like,
00:26:20.560 | okay, all these tokens that are in the encoder that we want to condition on, they are fully
00:26:24.880 | connected to each other. So when they communicate, they communicate fully when you calculate their
00:26:29.520 | features. But in the decoder, because we are trying to have a language model, we don't want
00:26:34.240 | to have communication from future tokens, because they give away the answer at this step. So the
00:26:38.720 | tokens in the decoder are fully connected from all the encoder states, and then they are also
00:26:43.840 | fully connected from everything that is before them. And so you end up with this, like, triangular
00:26:48.320 | structure in the directed graph. But that's the message passing scheme that this basically
00:26:53.280 | implements. And then you have to be also a little bit careful, because in the cross-attention here
00:26:58.320 | with the decoder, you consume the features from the top of the encoder. So think of it as, in the
00:27:03.600 | encoder, all the nodes are looking at each other, all the tokens are looking at each other, many,
00:27:07.280 | many times. And they really figure out what's in there. And then the decoder, when it's looking
00:27:11.520 | only at the top nodes. So that's roughly the message passing scheme. I was going to go into
00:27:17.360 | more of an implementation of the transformer. I don't know if there's any questions about this.
00:27:21.680 | Can you explain a little bit about self-attention and multi-headed attention?
00:27:26.480 | Yeah, so self-attention and multi-headed attention. So the multi-headed attention is
00:27:36.880 | just this attention scheme, but it's just applied multiple times in parallel. Multiple heads just
00:27:41.600 | means independent applications of the same attention. So this message passing scheme
00:27:47.360 | basically just happens in parallel multiple times with different weights for the query key and value.
00:27:53.200 | So you can almost look at it like, in parallel, I'm looking for, I'm seeking different kinds of
00:27:56.880 | information from different nodes, and I'm collecting it all in the same node. It's all
00:28:01.360 | done in parallel. So heads is really just like copy-paste in parallel. And layers are copy-paste,
00:28:09.680 | but in series. Maybe that makes sense. And self-attention, when it's self-attention,
00:28:18.800 | what it's referring to is that the node here produces each node here. So as I described it
00:28:24.000 | here, this is really self-attention. Because every one of these nodes produces a key query and a
00:28:28.240 | value from this individual node. When you have cross-attention, you have one cross-attention
00:28:34.160 | here coming from the encoder. That just means that the queries are still produced from this node,
00:28:40.720 | but the keys and the values are produced as a function of nodes that are coming from the
00:28:46.480 | encoder. So I have my queries because I'm trying to decode the fifth word in the sequence,
00:28:53.760 | and I'm looking for certain things because I'm the fifth word. And then the keys and the values,
00:28:58.240 | in terms of the source of information that could answer my queries, can come from the previous
00:29:02.720 | nodes in the current decoding sequence, or from the top of the encoder. So all the nodes that
00:29:07.280 | have already seen all of the encoding tokens many, many times can now broadcast what they
00:29:12.320 | contain in terms of information. So I guess to summarize, the self-attention is kind of like,
00:29:19.040 | sorry, cross-attention and self-attention only differ in where the keys and the values come
00:29:23.680 | from. Either the keys and values are produced from this node, or they are produced from some
00:29:29.120 | external source, like an encoder and the nodes over there. But algorithmically, it's the same
00:29:35.520 | Michael operations. Okay.
00:29:40.400 | [INAUDIBLE]
00:29:48.640 | So, yeah, so [INAUDIBLE]
00:29:56.640 | So think of - so each one of these nodes is a token.
00:30:00.560 | I guess, like, they don't have a very good picture of it in the transformer, but like
00:30:09.840 | this node here could represent the third word in the output, in the decoder.
00:30:16.800 | And in the beginning, it is just the embedding of the word.
00:30:21.440 | And then, okay, I have to think through this analogy a little bit more. I came up with it
00:30:31.840 | this morning. Actually, I came up with it yesterday.
00:30:36.800 | [INAUDIBLE]
00:30:50.240 | These nodes are basically the factors. I'll go to an implementation - I'll go to the
00:30:54.720 | implementation, and then maybe I'll make the connections to the graph. So let me try to
00:30:59.680 | first go to - let me now go to, with this intuition in mind at least, to nanoGPT, which is a concrete
00:31:04.480 | implementation of a transformer that is very minimal. So I worked on this over the last few
00:31:08.240 | days, and here it is reproducing GPT-2 on open web text. So it's a pretty serious implementation
00:31:13.760 | that reproduces GPT-2, I would say, and provided enough compute. This was one node of eight
00:31:19.040 | GPUs for 38 hours or something like that, and it's very readable at 300 lives, so everyone
00:31:25.040 | can take a look at it. And yeah, let me basically briefly step through it. So let's try to have a
00:31:32.400 | decoder-only transformer. So what that means is that it's a language model. It tries to model
00:31:36.880 | the next word in a sequence or the next character in a sequence. So the data that we train on is
00:31:43.120 | always some kind of text. So here's some fake Shakespeare. Sorry, this is real Shakespeare.
00:31:47.040 | We're going to produce fake Shakespeare. So this is called the tiny Shakespeare data set,
00:31:50.480 | which is one of my favorite toy data sets. You take all of Shakespeare, concatenate it,
00:31:54.160 | and it's one megabyte file, and then you can train language models on it and get infinite
00:31:57.600 | Shakespeare if you like, which I think is kind of cool. So we have a text. The first thing we
00:32:01.600 | need to do is we need to convert it to a sequence of integers, because transformers natively process,
00:32:06.960 | you know, you can't plug text into transformer. You need to somehow encode it. So the way that
00:32:12.400 | encoding is done is we convert, for example, in the simplest case, every character gets an integer,
00:32:16.560 | and then instead of "hi" there, we would have this sequence of integers. So then you can encode
00:32:22.720 | every single character as an integer and get a massive sequence of integers. You just concatenate
00:32:28.400 | it all into one large, long, one-dimensional sequence, and then you can train on it.
00:32:32.560 | Now, here we only have a single document. In some cases, if you have multiple independent
00:32:36.800 | documents, what people like to do is create special tokens, and they intersperse those
00:32:40.160 | documents with those special end-of-text tokens that they splice in between to create boundaries.
00:32:45.040 | But those boundaries actually don't have any modeling impact. It's just that the transformer
00:32:52.080 | is supposed to learn via backpropagation that the end-of-document sequence means that you
00:32:57.040 | should wipe the memory. Okay, so then we produce batches. So these batches of data just mean that
00:33:04.640 | we go back to the one-dimensional sequence, and we take out chunks of this sequence. So say if the
00:33:09.840 | block size is 8, then the block size indicates the maximum length of context that your transformer
00:33:17.600 | will process. So if our block size is 8, that means that we are going to have up to 8 characters
00:33:22.960 | of context to predict the 9th character in the sequence. And the batch size indicates how many
00:33:28.160 | sequences in parallel we're going to process. And we want this to be as large as possible,
00:33:31.760 | so we're fully taking advantage of the GPU and the parallels on the boards.
00:33:34.640 | So in this example, we're doing 4 by 8 batches. So every row here is independent example, sort of.
00:33:41.440 | And then every row here is a small chunk of the sequence that we're going to train on.
00:33:48.640 | And then we have both the inputs and the targets at every single point here. So to fully spell out
00:33:53.680 | what's contained in a single 4 by 8 batch to the transformer, I sort of compact it here. So when
00:33:59.920 | the input is 47 by itself, the target is 58. And when the input is the sequence 47, 58, the target
00:34:07.760 | is 1. And when it's 47, 58, 1, the target is 51, and so on. So actually the single batch of examples
00:34:14.960 | that's 4 by 8 actually has a ton of individual examples that we are expecting the transformer
00:34:19.200 | to learn on in parallel. And so you'll see that the batches are learned on completely independently,
00:34:25.040 | but the time dimension sort of here along horizontally is also trained on in parallel.
00:34:30.880 | So sort of your real batch size is more like b times t. It's just that the context grows linearly
00:34:37.280 | for the predictions that you make along the t direction in the model. So this is all the
00:34:44.000 | examples that the model will learn from this single batch. So now this is the GPT class.
00:34:51.760 | And because this is a decoder-only model, so we're not going to have an encoder because
00:34:57.760 | there's no, like, English we're translating from. We're not trying to condition on some
00:35:01.280 | other external information. We're just trying to produce a sequence of words that follow each other
00:35:06.080 | or are likely to. So this is all PyTorch. And I'm going slightly faster because I'm assuming people
00:35:11.360 | have taken 231n or something along those lines. But here in the forward pass, we take these indices
00:35:18.080 | and then we both encode the identity of the indices just via an embedding lookup table.
00:35:26.720 | So every single integer has a - we index into a lookup table of vectors in this nn.embedding
00:35:33.680 | and pull out the word vector for that token. And then because the message - because transformed
00:35:40.800 | by itself doesn't actually - it processes sets natively, so we need to also positionally encode
00:35:45.280 | these vectors so that we basically have both the information about the token identity and
00:35:49.680 | its place in the sequence from one to block size. Now those - the information about what and where
00:35:56.960 | is combined additively. So the token embeddings and the positional embeddings are just added
00:36:00.720 | exactly as here. So this x here, then there's optional dropout. This x here basically just
00:36:07.920 | contains the set of words and their positions, and that feeds into the blocks of transformer.
00:36:16.960 | And we're going to look into what's blocked here. But for here, for now, this is just a series of
00:36:20.560 | blocks in the transformer. And then in the end, there's a layer norm, and then you're decoding
00:36:25.760 | the logits for the next word or next integer in the sequence using a linear projection of
00:36:32.080 | the output of this transformer. So lm_head here, short for language model head, is just a linear
00:36:37.600 | function. So basically, positionally encode all the words, feed them into a sequence of blocks,
00:36:45.200 | and then apply a linear layer to get the probability distribution for the next
00:36:48.640 | character. And then if we have the targets, which we produced in the data loader, and you'll notice
00:36:54.960 | that the targets are just the inputs offset by one in time, then those targets feed into a cross
00:37:00.560 | entropy loss. So this is just a negative one likelihood typical classification loss.
00:37:04.000 | So now let's drill into what's here in the blocks. So these blocks that are applied sequentially,
00:37:10.320 | there's again, as I mentioned, this communicate phase and the compute phase.
00:37:14.880 | So in the communicate phase, all the nodes get to talk to each other, and so these nodes are
00:37:19.760 | basically - if our block size is eight, then we are going to have eight nodes in this graph.
00:37:26.560 | There's eight nodes in this graph, the first node is pointed to only by itself,
00:37:30.240 | the second node is pointed to by the first node and itself, the third node is pointed to by the
00:37:35.040 | first two nodes and itself, etc. So there's eight nodes here. So you apply - there's a residual
00:37:41.200 | pathway in x, you take it out, you apply a layer norm, and then the self-attention so that these
00:37:45.840 | communicate, these eight nodes communicate, but you have to keep in mind that the batch is four.
00:37:50.240 | So because batch is four, this is also applied - so we have eight nodes communicating, but there's
00:37:56.080 | a batch of four of them all individually communicating among those eight nodes. There's
00:37:59.920 | no crisscross across the batch dimension, of course. There's no batch normalization anywhere,
00:38:02.960 | luckily. And then once they've changed information, they are processed using the
00:38:08.480 | multilayer perceptron, and that's the compute phase. And then also here, we are missing
00:38:14.160 | the cross-attention, because this is a decoder-only model. So all we have is this step
00:38:21.280 | here, the multi-headed attention, and that's this line, the communicate phase, and then we have the
00:38:25.040 | feedforward, which is the MLP, and that's the compute phase. I'll take questions a bit later.
00:38:30.640 | Then the MLP here is fairly straightforward. The MLP is just individual processing on each node,
00:38:37.360 | just transforming the feature representation sort of at that node. So applying a two-layer neural
00:38:45.040 | net with a GELU non-linearity, which is - just think of it as a RELU or something like that.
00:38:49.360 | It's just a non-linearity. And then MLP is straightforward. I don't think there's anything
00:38:54.480 | too crazy there. And then this is the causal self-attention part, the communication phase.
00:38:58.800 | So this is kind of like the meat of things and the most complicated part. It's only complicated
00:39:04.560 | because of the batching and the implementation detail of how you mask the connectivity in the
00:39:10.800 | graph so that you can't obtain any information from the future when you're predicting your token.
00:39:16.240 | Otherwise, it gives away the information. So if I'm the fifth token, and if I'm the fifth position,
00:39:23.120 | then I'm getting the fourth token coming into the input, and I'm attending to the third,
00:39:27.760 | second, and first, and I'm trying to figure out what is the next token, well then in this batch,
00:39:33.440 | in the next element over in the time dimension, the answer is at the input. So I can't get any
00:39:38.960 | information from there. So that's why this is all tricky. But basically in the forward pass,
00:39:42.800 | we are calculating the queries, keys, and values based on x. So these are the keys,
00:39:51.360 | queries, and values. Here, when I'm computing the attention, I have the queries matrix multiplying
00:39:57.520 | the keys. So this is the dot product in parallel for all the queries and all the keys, and all the
00:40:02.400 | heads. So I failed to mention that there's also the aspect of the heads, which is also done all
00:40:08.080 | in parallel here. So we have the batch dimension, the time dimension, and the head dimension,
00:40:11.920 | and you end up with five-dimensional tensors, and it's all really confusing. So I invite you
00:40:15.200 | to step through it later and convince yourself that this is actually doing the right thing.
00:40:19.040 | But basically, you have the batch dimension, the head dimension, and the time dimension,
00:40:23.360 | and then you have features at them. And so this is evaluating for all the batch elements,
00:40:28.320 | for all the head elements, and all the time elements, the simple Python that I gave you
00:40:32.800 | earlier, which is query dot product p. Then here, we do a masked fill. And what this is doing is
00:40:39.280 | it's basically clamping the attention between the nodes that are not supposed to communicate
00:40:45.600 | to be negative infinity. And we're doing negative infinity because we're about to softmax,
00:40:50.080 | and so negative infinity will make basically the attention of those elements be zero.
00:40:53.680 | And so here, we are going to basically end up with the weights, the sort of affinities between
00:41:01.520 | these nodes, optional dropout, and then here, attention matrix multiply v is basically the
00:41:07.680 | gathering of the information according to the affinities we've calculated. And this is just a
00:41:12.560 | weighted sum of the values at all those nodes. So this matrix multipliers is doing that weighted
00:41:17.840 | sum. And then transpose contiguous view, because it's all complicated and bashed in five-dimensional
00:41:23.440 | tensors, but it's really not doing anything, optional dropout, and then a linear projection
00:41:28.560 | back to the residual pathway. So this is implementing the communication phase here.
00:41:32.960 | Then you can train this transformer, and then you can generate infinite Shakespeare,
00:41:41.120 | and you will simply do this by - because our block size is eight, we start with a sum token,
00:41:46.480 | say like, I use in this case, you can use something like a muon as the start token,
00:41:52.400 | and then you communicate only to yourself because there's a single node, and you get the probability
00:41:57.440 | distribution for the first word in the sequence, and then you decode it, or the first character in
00:42:03.680 | the sequence, you decode the character, and then you bring back the character, and you re-encode
00:42:07.600 | it as an integer, and now you have the second thing. And so you get, okay, we're at the first
00:42:13.360 | position, and this is whatever integer it is, add the positional encodings, goes into the sequence,
00:42:18.800 | goes into transformer, and again, this token now communicates with the first token and its identity.
00:42:26.640 | And so you just keep plugging it back, and once you run out of the block size, which is eight,
00:42:30.880 | you start to crop, because you can never have block size more than eight in the way you've
00:42:34.560 | trained this transformer. So we have more and more context until eight, and then if you want
00:42:38.240 | to generate beyond eight, you have to start cropping, because the transformer only works for
00:42:41.920 | eight elements in time dimension. And so all of these transformers in the naive setting have a
00:42:48.000 | finite block size, or context length, and in typical models, this will be 1024 tokens,
00:42:53.760 | or 2048 tokens, something like that, but these tokens are usually like DPE tokens,
00:42:58.480 | or sentence piece tokens, or workpiece tokens, there's many different encodings.
00:43:02.480 | So it's not like that long, and so that's why I think I did mention, we really want
00:43:05.680 | to expand the context size, and it gets gnarly, because the attention is quadratic in many cases.
00:43:09.920 | Now, if you want to implement an encoder instead of a decoder attention, then all you have to do
00:43:19.680 | is this mask node, and you just delete that line. So if you don't mask the attention,
00:43:25.360 | then all the nodes communicate to each other, and everything is allowed, and information flows
00:43:29.840 | between all the nodes. So if you want to have the encoder here, just delete all the encoder blocks,
00:43:36.960 | we'll use attention, where this line is deleted, that's it. So you're allowing,
00:43:41.440 | whatever this encoder might store, say 10 tokens, like 10 nodes, and they are all allowed to
00:43:47.040 | communicate to each other, going up the transformer. And then if you want to implement cross attention,
00:43:53.360 | so you have a full encoder decoder transformer, not just a decoder only transformer, or GPT,
00:43:59.280 | then we need to also add cross attention in the middle. So here, there's a self attention piece,
00:44:05.520 | where all the, there's a self attention piece, a cross attention piece, and this MLP. And in the
00:44:10.160 | cross attention, we need to take the features from the top of the encoder, we need to add one more
00:44:15.760 | line here, and this would be the cross attention, instead of, I should have implemented it, instead
00:44:21.680 | of just pointing, I think. But there'll be a cross attention line here, so we'll have three lines,
00:44:26.480 | because we need to add another block. And the queries will come from x, but the keys and the
00:44:31.520 | values will come from the top of the encoder. And there will be basically information flowing from
00:44:37.040 | the encoder strictly to all the nodes inside x. And then that's it. So it's very simple sort of
00:44:43.760 | modifications on the decoder attention. So you'll hear people talk that you kind of have a decoder
00:44:49.920 | only model, like GPT, you can have an encoder only model, like BERT, or you can have an encoder
00:44:54.800 | decoder model, like say T5, doing things like machine translation. So, and in BERT, you can't
00:45:01.440 | train it using sort of this language modeling setup that's autoregressive, and you're just
00:45:06.080 | trying to predict the next element in the sequence, you're training it with slightly different
00:45:08.880 | objectives, you're putting in like the full sentence, and the full sentence is allowed to
00:45:13.040 | communicate fully, and then you're trying to classify sentiment or something like that.
00:45:17.360 | So you're not trying to model like the next token in the sequence. So these are trained
00:45:21.840 | slightly different with mask, with using masking and other denoising techniques.
00:45:30.000 | Okay, so that's kind of like the transformer. I'm going to continue. So yeah, maybe more questions.
00:45:38.320 | These are excellent questions. So when we're employing information,
00:45:43.760 | for instance, like the graph that we all did, and when we were like, something like that,
00:45:48.480 | you know, this transformer still performs, like it's a dynamic graph, that the connections
00:45:56.880 | change in every instance, and you also have some feature information. So just like, we are
00:46:02.720 | enforcing these constraints on it by just masking, but it is aware of the work that it tends to do.
00:46:09.840 | So I'm not sure if I fully followed. So there's different ways to look at this analogy, but one
00:46:16.880 | analogy is you can interpret this graph as really fixed. It's just that every time we do the
00:46:21.120 | communicate, we are using different weights. You can look at it that way. So if we have block size
00:46:25.440 | of eight in my example, we would have eight nodes. Here we have two, four, six, okay, so we'd have
00:46:30.080 | eight nodes. They would be connected in, you lay them out, and you only connect from left to right.
00:46:35.200 | But for a different problem, that might not be the case, but you have a graph where the connections
00:46:40.160 | might change. Why would the connection, usually the connections don't change as a function of
00:46:46.160 | the data or something like that. That means like the molecules look like an actual graph,
00:46:50.880 | and look like that. I don't think I've seen a single example where the connectivity changes
00:47:02.960 | dynamically in function of data. Usually the connectivity is fixed. If you have an encoder
00:47:06.560 | and you're training a BERT, you have how many tokens you want, and they are fully connected.
00:47:10.720 | And if you have a decoder only model, you have this triangular thing. And if you have encoder
00:47:15.600 | decoder, then you have awkwardly sort of like two pools of nodes. Yeah, go ahead.
00:47:24.800 | [inaudible]
00:47:48.880 | Yeah, it's really hard to say. So that's why I think this paper is so interesting is like,
00:48:16.880 | yeah, usually you'd see like a path, and maybe they had path internally. They just didn't publish
00:48:20.320 | it. All you can see is sort of things that didn't look like a transformer. I mean, you have ResNets,
00:48:24.800 | which have lots of this. But a ResNet would be kind of like this, but there's no self-attention
00:48:30.480 | component. But the MLP is there kind of in a ResNet. So a ResNet looks very much like this,
00:48:37.600 | except there's no - you can use layer norms in ResNets, I believe, as well. Typically,
00:48:41.600 | sometimes they can be batch norms. So it is kind of like a ResNet. It is kind of like they took a
00:48:46.000 | ResNet and they put in a self-attentionary block in addition to the pre-existing MLP block, which
00:48:52.160 | is kind of like convolutions. And MLP would, strictly speaking, be convolution, one-by-one
00:48:56.240 | convolution. But I think the idea is similar in that MLP is just kind of like typical weights,
00:49:02.960 | non-linearity weights or operation.
00:49:11.120 | But I will say, yeah, it's kind of interesting because a lot of work is not there, and then
00:49:16.800 | they give you this transformer, and then it turns out five years later, it's not changed, even though
00:49:19.840 | everyone's trying to change it. So it's kind of interesting to me that it's kind of like a package,
00:49:23.280 | in like a package, which I think is really interesting historically. And I also talked to
00:49:28.160 | paper authors, and they were unaware of the impact that the transformer would have at the time.
00:49:33.840 | So when you read this paper, actually, it's kind of unfortunate because this is like the paper that
00:49:39.280 | changed everything. But when people read it, it's like question marks, because it reads like a
00:49:43.200 | pretty random machine translation paper. Like, oh, we're doing machine translation. Oh, here's a cool
00:49:48.000 | architecture. OK, great, good results. It doesn't sort of know what's going to happen. And so when
00:49:55.440 | people read it today, I think they're kind of confused, potentially. I will have some tweets
00:50:01.360 | at the end, but I think I would have renamed it with the benefit of hindsight of like, well,
00:50:05.440 | I'll get to it. Yeah, I think that's a good question as well. Currently, I mean, I certainly
00:50:24.560 | don't love the autoregressive modeling approach. I think it's kind of weird to sample a token and
00:50:29.120 | then commit to it. So maybe there's some ways-- some hybrids with diffusion, as an example,
00:50:38.000 | which I think would be really cool. Or we'll find some other ways to edit the sequences later,
00:50:43.760 | but still in the autoregressive framework. But I think diffusion is kind of like an up-and-coming
00:50:48.960 | modeling approach that I personally find much more appealing. When I sample text, I don't go
00:50:53.360 | chunk, chunk, chunk, and commit. I do a draft one, and then I do a better draft two. And that feels
00:50:58.800 | like a diffusion process. So that would be my hope.
00:51:02.320 | OK, also a question. So yeah, I use like the Gartner logic where it takes a weight which
00:51:10.560 | is like a graph. Will you say like the self-attention is sort of like computing like
00:51:16.720 | an edge weight using the dot product on the node similarity, and then once we have the edge weight,
00:51:21.680 | we just multiply it by the values, and then we just propagate it?
00:51:25.200 | Yes, that's right.
00:51:26.960 | And do you think there's like analogy between graph neural networks and self-attention?
00:51:32.320 | I find the graph neural networks kind of like a confusing term, because
00:51:35.200 | I mean, yeah, previously there was this notion of-- I kind of feel like maybe today everything
00:51:41.920 | is a graph neural network, because the transformer is a graph neural network processor. The native
00:51:46.080 | representation that the transformer operates over is sets that are connected by edges in a directed
00:51:50.880 | way. And so that's the native representation. And then, yeah. OK, I should go on, because I still
00:51:56.640 | have like 30 slides.
00:51:57.440 | Sorry, sorry, sorry. There's a question I want to say about this. [INAUDIBLE]
00:52:07.920 | Oh, yeah. Yeah, the root D, I think, basically like if you're initializing with random weights
00:52:14.320 | separate from a Gaussian, as your dimension size grows, so does your values, the variance grows,
00:52:19.600 | and then your softmax will just become the one-half vector. So it's just a way to control
00:52:24.720 | the variance and bring it to always be in a good range for softmax and nice diffuse distribution.
00:52:29.200 | OK, so it's almost like an initialization thing. OK, so transformers have been applied to all the
00:52:41.760 | other fields. And the way this was done is, in my opinion, kind of ridiculous ways, honestly,
00:52:47.680 | because I was a computer vision person, and you have comm nets, and they kind of make sense.
00:52:51.680 | So what we're doing now with bits, as an example, is you take an image, and you chop it up into
00:52:55.520 | little squares. And then those squares literally feed into a transformer, and that's it,
00:52:59.440 | which is kind of ridiculous. And so, I mean, yeah. And so the transformer doesn't even,
00:53:07.040 | in the simplest case, like really know where these patches might come from. They are usually
00:53:10.800 | positionally encoded, but it has to sort of like rediscover a lot of the structure, I think,
00:53:16.960 | of them in some ways. And it's kind of weird to approach it that way. But it's just like
00:53:24.000 | the simplest baseline of the chomping up big images into small squares and feeding them in
00:53:28.560 | as like the individual nodes actually works fairly well. And then this is in the transformer encoder.
00:53:32.640 | So all the patches are talking to each other throughout the entire transformer.
00:53:35.920 | And the number of nodes here would be sort of like nine.
00:53:39.440 | Also, in speech recognition, you just take your MEL spectrogram, and you chop it up into little
00:53:46.080 | slices and feed them into a transformer. So there was paper like this, but also Whisper. Whisper is
00:53:50.320 | a copy-based transformer. If you saw Whisper from OpenAI, you just chop up a MEL spectrogram and
00:53:55.600 | feed it into a transformer, and then pretend you're dealing with text, and it works very well.
00:53:59.760 | Decision transformer in RL, you take your states, actions, and reward that you experience in
00:54:04.880 | environment, and you just pretend it's a language, and you start to model the sequences of that.
00:54:09.520 | And then you can use that for planning later. That works pretty well. Even things like alpha
00:54:13.920 | folds. So we're frequently talking about molecules and how you can plug them in. So at the heart of
00:54:18.240 | alpha fold computationally is also a transformer. One thing I wanted to also say about transformers
00:54:23.840 | is I find that they're super flexible, and I really enjoy that. I'll give you an example from
00:54:28.960 | Tesla. You have a ComNet that takes an image and makes predictions about the image. And then the
00:54:34.880 | big question is, how do you feed in extra information? And it's not always trivial. Say
00:54:38.640 | I have additional information that I want to inform, that I want the outputs to be informed
00:54:43.040 | by. Maybe I have other sensors, like radar. Maybe I have some map information, or a vehicle type,
00:54:47.520 | or some audio. And the question is, how do you feed information into a ComNet? Where do you feed
00:54:52.000 | it in? Do you concatenate it? Do you add it? At what stage? And so with a transformer, it's much
00:54:58.080 | easier, because you just take whatever you want, you chop it up into pieces, and you feed it in
00:55:01.760 | with a set of what you had before. And you let the self-attention figure out how everything should
00:55:05.200 | communicate. And that actually, frankly, works. So just chop up everything and throw it into the
00:55:10.000 | mix is kind of the way. And it frees neural nets from this burden of Euclidean space,
00:55:16.960 | where previously you had to arrange your computation to conform to the Euclidean
00:55:22.080 | space of three dimensions of how you're laying out the compute. The compute actually kind of
00:55:26.560 | happens in almost 3D space, if you think about it. But in attention, everything is just sets.
00:55:31.920 | So it's a very flexible framework, and you can just throw in stuff into your conditioning set,
00:55:35.680 | and everything just self-attended over. So it's quite beautiful from that perspective.
00:55:39.760 | OK. So now, what exactly makes transformers so effective? I think a good example of this
00:55:44.560 | comes from the GPT-3 paper, which I encourage people to read. Language models are two-shot
00:55:49.440 | learners. I would have probably renamed this a little bit. I would have said something like,
00:55:54.000 | transformers are capable of in-context learning, or like meta-learning. That's kind of what makes
00:55:58.960 | them really special. So basically, the setting that they're working with is, OK, I have some
00:56:03.040 | context, and I'm trying to, let's say, passage. This is just one example of many. I have a passage,
00:56:07.120 | and I'm asking questions about it. And then I'm giving, as part of the context, in the prompt,
00:56:12.960 | I'm giving the questions and the answers. So I'm giving one example of question-answer,
00:56:16.240 | another example of question-answer, another example of question-answer, and so on.
00:56:19.120 | And this becomes, oh yeah, people are going to have to leave soon now.
00:56:23.040 | OK. This is really important, let me think.
00:56:30.160 | OK, so what's really interesting is basically like, with more examples given in the context,
00:56:35.440 | the accuracy improves. And so what that hints at is that the transformer is able to somehow
00:56:40.000 | learn in the activations without doing any gradient descent in a typical fine-tuning fashion.
00:56:45.120 | So if you fine-tune, you have to give an example and the answer, and you do fine-tuning
00:56:50.160 | using gradient descent. But it looks like the transformer, internally in its weights,
00:56:53.760 | is doing something that looks like potential gradient descent, some kind of a meta-learning
00:56:56.720 | in the weights of the transformer as it is reading the prompt. And so in this paper,
00:57:00.640 | they go into, OK, distinguishing this outer loop with stochastic gradient descent and this inner
00:57:04.880 | loop of the in-context learning. So the inner loop is, as the transformer, sort of like reading the
00:57:09.040 | sequence almost, and the outer loop is the training by gradient descent. So basically,
00:57:14.560 | there's some training happening in the activations of the transformer as it is consuming a sequence
00:57:18.640 | that maybe very much looks like gradient descent. And so there's some recent papers that kind of
00:57:22.320 | hint at this and study it. And so as an example, in this paper here, they propose something called
00:57:27.440 | the raw operator. And they argue that the raw operator is implemented by a transformer,
00:57:32.880 | and then they show that you can implement things like ridge regression on top of a raw operator.
00:57:36.800 | And so this is kind of giving - their paper is hinting that maybe there is some thing that looks
00:57:41.680 | like gradient-based learning inside the activations of the transformer. And I think this is not
00:57:46.720 | impossible to think through, because what is gradient-based learning? Forward pass,
00:57:50.320 | backward pass, and then update. Well, that looks like a resonant, right, because you're just
00:57:54.800 | changing - you're adding to the weights. So you start with initial random set of weights, forward
00:58:00.160 | pass, backward pass, and update your weights, and then forward pass, backward pass, update weights.
00:58:03.840 | Looks like a resonant. Transformer is a resonant. So much more hand-wavy, but basically some
00:58:11.280 | papers trying to hint at why that could be potentially possible. And then I have a bunch
00:58:15.760 | of tweets. I just got them pasted here in the end. This was kind of meant for general consumption,
00:58:20.560 | so they're a bit more high-level and hype-y a little bit. But I'm talking about why this
00:58:24.960 | architecture is so interesting and why it potentially became so popular. And I think
00:58:28.720 | it simultaneously optimizes three properties that I think are very desirable. Number one,
00:58:32.320 | the transformer is very expressive in the forward pass. It's able to implement very
00:58:37.600 | interesting functions, potentially functions that can even do meta-learning. Number two,
00:58:42.480 | it is very optimizable, thanks to things like residual connections, layer knowns, and so on.
00:58:46.240 | And number three, it's extremely efficient. This is not always appreciated, but the transformer,
00:58:50.080 | if you look at the computational graph, is a shallow wide network, which is perfect to take
00:58:54.720 | advantage of the parallelism of GPUs. So I think the transformer was designed very deliberately
00:58:58.560 | to run efficiently on GPUs. There's previous work like neural GPU that I really enjoy as well,
00:59:04.960 | which is really just like how do we design neural nets that are efficient on GPUs, and thinking
00:59:09.040 | backwards from the constraints of the hardware, which I think is a very interesting way to think
00:59:12.000 | about it. Oh yeah, so here I'm saying I probably would have called the transformer a general
00:59:24.000 | purpose efficient optimizable computer instead of attention is all you need. That's what I would
00:59:28.960 | have maybe in hindsight called that paper. It's proposing a model that is very general purpose,
00:59:36.960 | so forward pass is expressive. It's very efficient in terms of GPU usage, and it's
00:59:42.000 | easily optimizable by gradient descent, and trains very nicely. Then I have some other hype tweets
00:59:48.160 | here. Anyway, so you can read them later, but I think this one is maybe interesting.
00:59:54.960 | So if previous neural nets are special purpose computers designed for a specific task,
01:00:00.160 | GPT is a general purpose computer reconfigurable at runtime to run natural language programs.
01:00:05.920 | So the programs are given as prompts, and then GPT runs the program by completing the document.
01:00:11.200 | So I really like these analogies personally to computer. It's just like a powerful computer,
01:00:18.080 | and it's optimizable by gradient descent.
01:00:30.000 | I don't know. Okay, you can read this later, but for now I'll just leave this up.
01:00:36.720 | So sorry, I just found this tweet. So it turns out that if you scale up the training set
01:00:49.520 | and use a powerful enough neural net like a transformer, the network becomes a kind of
01:00:52.880 | general purpose computer over text. So I think that's a kind of like nice way to look at it,
01:00:56.640 | and instead of performing a single text sequence, you can design the sequence in the prompt,
01:01:00.400 | and because the transformer is both powerful but also is trained on a large enough,
01:01:04.000 | very hard data set, it kind of becomes a general purpose text computer,
01:01:07.440 | and so I think that's kind of interesting way to look at it. Yeah?
01:01:12.320 | Um, you have three points to the vote. Yeah. Um, so I guess, like, for me, I learned about
01:01:26.800 | a number of things, but now we've seen,
01:01:56.720 | kind of, like, the idea that, like, you think there's really no harm from gradient descent,
01:02:01.280 | and I guess my question is, how much do you think it's, like,
01:02:05.040 | it's pretty, really, like, most of it, you know, like, do they really think that it's
01:02:10.800 | mostly more efficient, or do you think it's very, sort of, like, something that you have that, like,
01:02:17.120 | you need the equivalent value of specific [inaudible] or do you [inaudible]
01:02:24.240 | Yeah. So I think there's a bit of that, yeah. So I would say RNNs, like, in principle, yes,
01:02:29.680 | they can implement arbitrary programs. I think it's kind of, like, a useless statement to some
01:02:33.280 | extent, because they are not - they're probably - I'm not sure that they're probably expressive,
01:02:38.080 | because in a sense of, like, power, in that they can implement these arbitrary functions,
01:02:42.160 | but they're not optimizable, and they're certainly not efficient, because they are serial
01:02:46.960 | computing devices. So I think - so if you look at it as a compute graph, RNNs are very long,
01:02:53.680 | thin compute graph. Like, if you stretched out the neurons, and you look, like, take all the
01:03:01.200 | individual neurons in our connectivity, and stretch them out, and try to visualize them,
01:03:04.400 | RNNs would be, like, a very long graph, and it's bad, and it's bad also for optimizability, because
01:03:09.440 | I don't exactly know why, but just the rough intuition is when you're backpropagating,
01:03:13.680 | you don't want to make too many steps. And so transformers are a shallow, wide graph,
01:03:18.240 | and so from supervision to inputs is a very small number of hops, and it's along residual pathways,
01:03:25.760 | which make gradients flow very easily, and there's all these layer norms to control the
01:03:28.960 | scales of all of those activations. And so there's not too many hops, and you're going
01:03:35.760 | from supervision to input very quickly, and this flows through the graph. And it can all be done
01:03:41.920 | in parallel, so you don't need to do this encoder-decoder RNNs, you have to go from first word,
01:03:46.080 | then second word, then third word, but here in transformer, every single word was processed
01:03:50.400 | completely as sort of in parallel, which is kind of - so I think all these are really important,
01:03:56.080 | because all these are really important, and I think number three is less talked about,
01:04:00.080 | but extremely important, because in deep learning, scale matters, and so the size of the network that
01:04:04.800 | you can train gives you - is extremely important, and so if it's efficient on the current hardware,
01:04:10.320 | then you can make it bigger.
01:04:16.080 | No, so yeah, so you take your image, and you apparently chop them up into patches,
01:04:33.040 | so there's the first thousand tokens or whatever, and now I have a special - so radar could be
01:04:38.880 | also - but I don't actually know the native representation of radar, so - but you could - you
01:04:44.720 | just need to chop it up and enter it, and then you have to encode it somehow. Like, the transformer
01:04:48.160 | needs to know that they're coming from radar, so you create a special - you have some kind of a
01:04:52.800 | special token that you - like, these radar tokens are slightly different in representation, and it's
01:04:58.720 | learnable by gradient descent, and like, vehicle information would also come in with a special
01:05:04.320 | embedding token that can be learned. So have you learned those, like, orally?
01:05:11.280 | You don't, it's all just a set.
01:05:13.440 | Yeah, it's all just a set, but you can positionally encode these sets if you want,
01:05:23.120 | so - but positional encoding means you can hardwire, for example, the coordinates, like using
01:05:28.320 | sinusoids and cosines, you can hardwire that, but it's better if you don't hardwire the position,
01:05:33.120 | you just - it's just a vector that is always hanging out at this location,
01:05:36.240 | whatever content is there just adds on it, and this vector is trainable by background,
01:05:39.520 | that's how you do it. Yeah, go for it.
01:05:43.760 | I'm not sure if I understand the question. So I mean, the positional encoder is like,
01:06:12.800 | they're actually like, not - they have - okay, so they have very little inductive
01:06:16.560 | bias or something like that, they're just vectors hanging out in location always,
01:06:19.520 | and you're trying to help the network in some way, and I think the intuition is good, but
01:06:27.360 | like, if you have enough data, usually trying to mess with it is like a bad thing.
01:06:31.920 | Like, trying to enter knowledge when you have enough knowledge in the data set itself
01:06:36.800 | is not usually productive, so it really depends on what scale you are. If you have infinity data,
01:06:41.280 | then you actually want to encode less and less, that turns out to work better,
01:06:44.160 | and if you have very little data, then actually you do want to encode some biases,
01:06:47.520 | and maybe if you have a much smaller data set, then maybe convolutions are a good idea,
01:06:50.640 | because you actually have this bias coming from more filters. And so - but I think - so the
01:06:57.920 | transformer is extremely general, but there are ways to mess with the encodings to put in more
01:07:01.600 | structure, like you could, for example, encode sinuses and cosines and fix it, or you could
01:07:05.840 | actually go to the attention mechanism and say, okay, if my image is chopped up into patches,
01:07:11.040 | this patch can only communicate to this neighborhood, and you can - you just do that
01:07:14.080 | in the attention matrix, just mask out whatever you don't want to communicate. And so people
01:07:18.720 | really play with this, because the full attention is inefficient, so they will intersperse,
01:07:23.840 | for example, layers that only communicate in little patches, and then layers that communicate
01:07:27.920 | globally, and they will sort of do all kinds of tricks like that. So you can slowly bring in more
01:07:32.960 | inductive bias, you would do it - but the inductive biases are sort of like, they're factored out from
01:07:37.920 | the core transformer, and they are factored out in the connectivity of the nodes, and they are
01:07:42.960 | factored out in the positional encodings, and you can mess with this for computation.
01:08:02.480 | So there's probably about 200 papers on this now, if not more. They're kind of hard to track up,
01:08:07.920 | honestly, like my Safari browser, which is - oh, it's on my computer, like 200 open tabs. But
01:08:13.840 | yes, I'm not even sure if I want to pick my favorite, honestly.
01:08:22.000 | Yeah, I think it was a very interesting talk from you this year, and you can think of a
01:08:28.160 | transformer as like a CPU. I think the first test was to take five instructions out of like
01:08:32.400 | 4,000 programs, and then now, at the beginning of the CPU, what you have is like you store variables,
01:08:37.440 | you have memory, so it's like, if I want to do a debugger program of the CPU, I just do it
01:08:41.120 | multiple times. So maybe you can use a transformer like that. The other one that I actually like
01:08:46.640 | even more is potentially keep the context length fixed, but allow the network to somehow use a
01:08:51.600 | scratchpad. And so the way this works is you will teach the transformer somehow, via examples in
01:08:57.040 | the prompt, that hey, you actually have a scratchpad. Hey, basically, you can't remember too
01:09:01.440 | much. Your context line is finite. But you can use a scratchpad, and you do that by emitting a start
01:09:06.080 | scratchpad, and then writing whatever you want to remember, and then end scratchpad. And then
01:09:10.480 | you continue with whatever you want. And then later, when it's decoding, you actually have
01:09:15.200 | special logic that when you detect start scratchpad, you will sort of like save whatever
01:09:19.520 | it puts in there in like external thing, and allow it to attend over it. So basically, you can teach
01:09:23.920 | the transformer just dynamically, because it's so meta-learned. You can teach it dynamically to use
01:09:28.800 | other gizmos and gadgets, and allow it to expand its memory that way, if that makes sense. It's
01:09:32.800 | just like human learning to use a notepad, right? You don't have to keep it in your brain. So keeping
01:09:37.440 | things in your brain is kind of like the context length of the transformer. But maybe we can just
01:09:40.720 | give it a notebook. And then it can query the notebook, and read from it, and write to it.
01:09:45.600 | [INAUDIBLE]
01:10:08.960 | I don't know if I detected that. I kind of feel like-- did you feel like it was more than just
01:10:12.800 | a long prompt that's unfolding? [INAUDIBLE]
01:10:19.680 | I didn't try extensively, but I did see a forgetting event. And I kind of felt like
01:10:24.160 | the block size was just moved. Maybe I'm wrong. I don't actually know about the internals of
01:10:29.840 | [INAUDIBLE]
01:10:51.600 | I mean, so right now, I'm working on things like nano-GPT. Where's nano-GPT?
01:10:59.040 | I mean, I'm going basically slightly from computer vision and kind of computer vision-based
01:11:03.360 | products to a little bit in the language domain. Where's chat-GPT? OK, nano-GPT. So originally,
01:11:08.160 | I had min-GPT, which I rewrote to nano-GPT. And I'm working on this. I'm trying to reproduce
01:11:12.320 | GPTs. And I mean, I think something like chat-GPT, I think, incrementally improved in a product
01:11:17.840 | fashion would be extremely interesting. And I think a lot of people feel it. And that's why
01:11:23.440 | it went so wide. So I think there's something like a Google plus, plus, plus to build that
01:11:29.040 | I think is really interesting. So we did our speed around the clock.
01:11:34.800 | [END PLAYBACK]
01:11:35.380 | Thanks.
01:11:35.880 | [BLANK_AUDIO]