back to indexMamba and S4 Explained: Architecture, Parallel Scan, Kernel Fusion, Recurrent, Convolution, Math
Chapters
0:0 Introduction
1:46 Sequence modeling
7:12 Differential equations (basics)
11:38 State Space Models
13:53 Discretization
23:8 Recurrent computation
26:32 Convolutional computation
34:18 Skip connection term
35:21 Multidimentional SSM
37:44 The HIPPO theory
43:30 The motivation behind Mamba
46:56 Selective Scan algorithm
51:34 The Scan operation
54:24 Parallel Scan
57:20 Innovations in Selective Scan
58:0 GPU Memory Hierarchy
61:23 Kernel Fusion
61:48 Activations recomputation
66:48 Mamba architecture
70:18 Performance considerations
72:54 Conclusion
00:00:00.000 |
Hello guys, welcome back to my channel. Today we are going to talk about Mamba 00:00:03.840 |
So Mamba is a new model for sequence modeling that came out just one month ago in a paper called Mamba linear time sequence modeling 00:00:11.040 |
With selective state spaces. Let's review the topics of today. In the first part of the video 00:00:16.140 |
I will be introducing what are sequence models and what kind of sequence modeling we can do 00:00:20.560 |
The second part of the video I will be talking about state space models 00:00:25.820 |
But to fully understand the state space models, we need to have a little background on differential equation 00:00:31.340 |
I of course don't expect you to have this background because in some 00:00:35.200 |
Bachelor degree or some master degree it is taught but in some most of the cases 00:00:42.360 |
So I will give you the necessary background to understand differential equations 00:00:46.040 |
and later we will talk about state space models and we will derive the formula for the 00:00:51.040 |
Magnetization and we will also derive the formula for the convolutional computation and the recurrent computation 00:00:55.980 |
I will show you what do we mean by the hippo matrix and the importance of the A matrix in state space models 00:01:02.280 |
In the second the third part of the video we will be talking about Mamba 00:01:06.220 |
So what was the motivation that led to the to Mamba and what is the innovation of Mamba, which is the selective scan algorithm? 00:01:12.860 |
So first of all, what do we mean by scan operation and we will see what that this scan operation can be parallelized 00:01:20.080 |
With the parallel scan we will see what is kernel fusion the recomputation of the activations and finally 00:01:26.080 |
We will explore the architecture of Mamba and some performance consideration with respect to the transformer 00:01:31.860 |
So as a prerequisite, I just hope that you have a basic basics of calculus 00:01:36.700 |
I think a high school mathematics will be more than enough and 00:01:39.640 |
You have a basic understanding of the transformer model and neural networks in general. So let's start our journey 00:01:47.460 |
The goal of a sequence model is to map an input sequence to an output sequence 00:01:51.940 |
The input sequence can be a continuous signal in that case 00:01:55.100 |
We want to map it to an output continuous signal or it can be a discrete input signal and we want to map it 00:02:04.040 |
Continuous signal could be audio and a discrete signal could be text for example 00:02:08.600 |
actually most of the time we work with the discrete signals even with the case of audio because we sample the audio file over time and 00:02:16.700 |
if we talk about language modeling, we are talking about a discrete input because we have a finite number of tokens and 00:02:22.620 |
We want to map it to an output sequence of tokens 00:02:26.740 |
We can choose among many models to do sequence modeling let's review them 00:02:32.900 |
The first model that comes to mind to do sequence modeling is the recurrent neural network 00:02:39.140 |
Which is a network in which we have a hidden state and we compute the output as follows 00:02:45.220 |
So we have for example our input sequence made up of x1 x2 and x3 00:02:51.140 |
What we do is the first time the hidden sequence we initialize the hidden state is initialized with zeros 00:02:57.720 |
So we feed to the network the hidden state the first hidden state 00:03:01.760 |
So zeros along with the first input and it will produce the first output 00:03:05.940 |
Then we use the previously produced the hidden state. So the output of the previous step 00:03:14.540 |
It will produce a new hidden state and the new output token for a new 00:03:19.520 |
Input token and this will be the y2 so the output number two 00:03:23.720 |
So we use the previously generated hidden state 00:03:26.500 |
Along with a new input token to produce a new output token and a new hidden state that will be used for the next 00:03:35.100 |
This is sequential generation of output is not parallelizable because to generate the nth token 00:03:44.020 |
So the training of this kind of model cannot be parallelizable. And this is one of the reason the transformer has been so successful 00:03:50.100 |
However, the inference time is a constant for each token 00:03:56.680 |
from a computational point of view and also from a memory point of view that we need to put in order to produce one 00:04:02.180 |
token of the output is the same doesn't matter if it's the first token or the 00:04:06.980 |
100th token the effort with that we are doing so the number of 00:04:10.740 |
Operations and computations that we are doing is always the same. We are taking the previous state and the current input to produce one output token 00:04:18.420 |
Theoretically, this one has infinite context length because we can continue with this sequence sequence of forever 00:04:25.200 |
But practically we cannot because it suffers from a vanishing and exploding gradient 00:04:30.000 |
Another model that we can use to do sequence modeling, but it's not very used recent 00:04:36.620 |
It's not very used. It's the convolutional neural network, which is mostly used for computer vision tasks 00:04:42.340 |
It has a finite context window and it needs to build a kernel that is run 00:04:48.040 |
So this is the kernel that is run through the input to produce output features. So this is the output 00:04:53.380 |
It is easily parallelizable because each output uses the same kernel 00:04:58.780 |
So we can run the kernel in parallel on all the possible input windows to produce the output features 00:05:05.860 |
The last model that we will see is the transformer 00:05:08.480 |
The transformer is easily parallelizable when it comes to training because we have this self-attention mechanism 00:05:15.060 |
that computes many dot products and since it's a matrix multiplication we can parallelize the operation and 00:05:21.140 |
it has a finite context window defined by the sequence the input sequence and the attention mask and 00:05:30.220 |
The inference of this kind of model, however is not constant for each token. So in the transformer model if we are 00:05:36.540 |
Producing what the first token of the output we will do one dot product 00:05:40.740 |
By using the KVCache if we are using the we are producing the 10th output token 00:05:47.620 |
We will need to do 10 dot products to produce it 00:05:50.540 |
And if you are producing the 100th output token, we will need to do 100 dot products 00:05:56.740 |
So the effort that we put to produce the first token is not the same as the effort that we need to produce the 10th 00:06:05.780 |
This doesn't allow us to scale to very long inputs also for the training because the training as you can see that scales 00:06:14.720 |
Quadratically with the input sequence length, which means that if we double the sequence length 00:06:19.860 |
We need to do four times more computation to train this model 00:06:26.340 |
So and in an ideal world that we would like the model for which we can 00:06:31.140 |
Parallelize the training just like the transformer because we can exploit our GPU very well and it can scale linearly to long sequences 00:06:39.300 |
Just like the RNN because it scales linearly with the input sequence length 00:06:44.460 |
And in an ideal world, we also would like to inference each token with a constant memory and the computation cost 00:06:51.940 |
Which means that the effort we need to put to produce the first token should be the same as the effort that we put to 00:07:00.580 |
Now let's explore state space models and how they can help us solve the problems of the recurrent neural networks and the transformer 00:07:11.340 |
Let me give you a very simple introduction to differential equation with a very simple example 00:07:18.580 |
Imagine you have some bunnies and the population of these bunnies grows at a constant rate of lambda 00:07:23.900 |
Proportional to the number of bunnies that you have which means that every bunny will give birth to lambda baby bunnies 00:07:30.220 |
Now also suppose that the lambda is equal to 2. So let's write lambda equal to 2 00:07:35.300 |
So we can see that the rate of change of the population 00:07:39.500 |
Which means the number of babies that are born is equal to lambda 00:07:44.140 |
Multiplied by the number of bunnies that you have at a particular time step t 00:07:48.340 |
So which means also that the rate of change of this population because the number of babies that these bunnies produce is also the rate 00:07:56.540 |
So how much the population is growing is equal to lambda multiplied by the number of bunnies that you have 00:08:00.860 |
But if you remember from high school, what is the rate of change of a particular function? 00:08:05.900 |
It is the variable is the derivative of this function 00:08:12.460 |
derivative of the function that describes the number of bunnies in time is 00:08:17.060 |
Equal to lambda multiplied by the number of bunnies that you have at a particular time step t 00:08:28.140 |
100 knowing that the population is made up of five bunnies at time equal to 0 00:08:37.460 |
So we need to find a function B of t that describes the population of our bunnies over time 00:08:46.220 |
Solving a differential equation means to find a function in this case B of t that makes the expression 00:08:58.700 |
For all t so we need to find a function t that when replaced in this expression makes the left hand side equal to the right 00:09:07.980 |
This differential equation can be solved using a very simple method called the method of separation of variables 00:09:14.160 |
So I will not show you here, but we can clearly see that a solution to this 00:09:18.300 |
Differential equation is this function here B of t is equal to K multiplied by 00:09:23.580 |
Exponential of lambda by t where K is equal to the initial population of bunnies 00:09:30.540 |
How can we verify that this is actually the solution of this differential equation because if we replace this 00:09:36.260 |
Function here inside this expression it will make the left hand side equal to the right hand side 00:09:41.380 |
This is how we verify the solution of it equation 00:09:43.900 |
So let's try to replace it first of all in the left hand side 00:09:47.580 |
We have the derivative of this function with respect to t so let's calculate the derivative of this function with respect to t 00:09:56.060 |
Multiplied by lambda multiplied by A to the power of lambda t and 00:10:04.380 |
Lambda multiplied by the function itself, which is K 00:10:11.860 |
You can see that these two expressions are the same so this function here 00:10:17.820 |
So this function here is a solution to our differential equation 00:10:22.820 |
So when you have a differential equation the solution of a differential equation is not a number 00:10:27.620 |
When we like when we solve a quadratic equation so when you solve for example the 00:10:31.900 |
Equation X to the power of 2 is equal to 4 you get some values of X that make the left-hand side 00:10:38.660 |
Equal to the right-hand side in the case of a differential equation you find a function 00:10:43.900 |
That makes the left-hand side equal to the right-hand side and usually we represent the differential equation by omitting the variable t 00:10:52.180 |
And writing it as follows so B t so B dot is equal to lambda 00:10:57.940 |
Multiplied by B. This dot indicates that it's derivative of this function with respect to a certain variable, which is implied 00:11:08.180 |
basically the result of this differential equation will 00:11:11.420 |
Will describe the evolution of our bunnies so this function here that we have found B of t 00:11:17.980 |
Will describe how the population of our bunnies will grow 00:11:23.260 |
We usually use the differential equations to model the state of a system over time 00:11:28.100 |
With the goal of finding a function that gives us the state of the system at any time step 00:11:34.340 |
Given the initial state of the system, and this is what we do with state space models 00:11:39.380 |
A state space model allow us to map an input signal X of t to an output signal Y of t 00:11:45.980 |
By means of a state representation H of t as follows 00:11:50.740 |
So H prime of t which means the derivative of this function H of t with respect to time is equal to a 00:11:58.140 |
number A or a matrix A multiplied by H of t plus B 00:12:02.820 |
Multiplied by X of t and then the output of the system is computed as follows 00:12:08.260 |
This state space model is linear and time invariant 00:12:12.860 |
It's linear because the relationships in the expressions above are linear and the time invariant because the parameter matrices 00:12:20.020 |
A, B, C and D do not vary over time because you can see they do not depend on the time 00:12:28.100 |
So for now for simplicity, we will consider that all these parameters 00:12:33.980 |
So A, B, C and D, but also the input the output and the H of t are numbers not vectors 00:12:39.500 |
Later, we will expand our analysis to vectors 00:12:44.500 |
Okay, we have these expressions, but how can I compute the output of this model given the input X of t? 00:12:51.940 |
As you can see in the first expression, we have a differential equation and to compute the output 00:12:57.780 |
Like this we need to have a function H of t that describes the state of our system at every time step t 00:13:09.740 |
This model we need to find first of all, we need to solve this differential equation here 00:13:14.300 |
So it means to find a function H of t that describes the state of the system for all time step 00:13:20.260 |
but solving this differential equation can be hard analytically and 00:13:24.860 |
Usually also we do not work with the continuous signals 00:13:28.740 |
So we do our input X of t usually is not continuous because when we do work with computer or any digital device 00:13:37.740 |
So one way is to actually discretize this system such that we can calculate an approximate 00:13:43.020 |
Solution of this H of t in a discretized way. So not in a continuous way 00:13:48.220 |
Let's see how we can discretize our system to compute then the output of the system itself 00:13:53.460 |
To solve a differential equation we need to find the function H of t so in our case this function here 00:14:03.940 |
so this function a function H of t that when replaced here will make the left-hand side equal to the right-hand side and 00:14:10.740 |
But solving this finding this function H of t in the continuous domain 00:14:16.460 |
It can be very hard analytically and also because we work with the discrete input 00:14:21.460 |
We can discretize this system, which means that we can approximate the solution of this differential equation 00:14:27.420 |
So instead of finding H of t for all the time step we can find H of t for certain time steps 00:14:33.880 |
Instead so for example, we may want to find H of t for H of 0, H of 1, H of 2, H of 3, etc 00:14:41.140 |
In this case, we have chosen a step size the step size in my case 00:14:46.040 |
I have chosen 1 because I'm evaluating the function H of t only in certain time steps 00:14:51.300 |
that are separated by a step size of 1 so I go from 0 to 1 from 1 to 2 from 2 to 3 and 00:15:01.620 |
Remember the bunnies problem. Let's try to find the approximate solution of this bunny problem using the Euler's method 00:15:07.900 |
So before I give you the result of this bunnies problem using the analytical solution, which was very easy to compute 00:15:18.100 |
But using the Euler's method or actually not the same, but it's the approximated solution using the Euler's method 00:15:24.100 |
So first of all, let's rewrite the our bunny population model 00:15:28.180 |
Which is the derivative of the function that describes the bunny population 00:15:33.220 |
With respect to the variable time is equal to lambda multiplied by the function itself 00:15:42.340 |
What is the definition of the derivative the derivative is equal to the limit of a step? 00:15:47.860 |
Going to 0 of the function evaluated at t plus the step 00:15:53.540 |
Minus the function at time step t divided by the step size. This is the definition of derivative. So there is nothing new here 00:16:03.060 |
This this is the the left-hand side is equal to the right-hand side when the Delta so the step size that we are using 00:16:10.260 |
It's very very very very very small. So suppose that we choose a very very very small step size 00:16:15.980 |
So then we can get rid of this equality and change it to an approximate approximation so we can see that 00:16:22.980 |
Okay, what if I choose a very small step size, then the left-hand side will be more or less equal to the right-hand side 00:16:30.700 |
Now by multiplying with the Delta on both sides and then bringing this term to the right-hand side 00:16:41.160 |
Which allow us to compute the value of the function B at the time step t plus Delta. So one step forward 00:16:48.600 |
As the derivative of the function at time step t multiplied by Delta plus the function at time step t 00:16:56.540 |
But what is the derivative of the function here? 00:17:02.700 |
We know that the derivative of the function that we are trying to approximate is equal to the Lambda multiplied by the function itself 00:17:15.100 |
Expression here we can replace our population model and we obtain the following 00:17:20.060 |
Expression and this allow us to compute the bunnies population at time step t plus Delta 00:17:26.260 |
So at the next time step given the bunnies population at the previous time step 00:17:31.540 |
Multiplied by a Delta and the Lambda. So we obtained a recurrent formulation 00:17:37.540 |
Wonderful now that we have our recurrent formulation to approximate the state of the bunny population over time 00:17:45.640 |
Suppose we set Lambda equal to 2 which means that every bunny will make two children and Delta is equal to 1 00:17:55.980 |
So t is equal to 0 then t equal to 1, t equal to 2, etc, etc 00:17:59.380 |
For example, if we started with a population of 5 bunnies at time step 0 00:18:04.380 |
We can calculate the evolution of the population as follows 00:18:08.040 |
Knowing that the population at time step 0 we can calculate the population at time step 1 using the formula that we have 00:18:17.660 |
Delta multiplied by Lambda multiplied by the population at the previous time step. So just like this formula here 00:18:24.660 |
Plus the population at the previous time step 00:18:26.540 |
So this gives us a 15 and it makes sense why because we started with 5 bunnies. Each bunny has 2 children 00:18:32.900 |
So they have 10 babies and the initial population, which is 5 plus 10 babies gives us a population of 15 00:18:40.140 |
Now that we have the population at time step 1 we can calculate the population at the time step 2 which is 00:18:46.520 |
45 because we had 15 bunnies each one makes 2 children. So we have 30 children 00:18:56.500 |
Now that we have the population at time step 2 we can calculate the population at time step 3 and it's 00:19:01.860 |
135 because we have a 45 bunnies each one of them makes 2 children. So we have 90 children plus 45, which is 00:19:12.500 |
If you compare the solution that we are obtaining for this time step with the analytical solution that we found before 00:19:19.820 |
That I didn't show you how to find it, but it can be easily found using a method called the separation of variables 00:19:25.880 |
If you compare the plot of these two functions 00:19:29.920 |
You will see that the analytical function grows very fast while the approximated solution grows very slow 00:19:35.340 |
and this is actually and the two are similar in how they 00:19:43.320 |
Because the step size that we have chosen is very big 00:19:46.520 |
I mean we chose lambda equal to delta equal to 1 to make a better approximation 00:19:51.520 |
We need to use a delta very very small and actually the Euler method does not give us very good 00:19:57.440 |
Results very very good approximations of the analytical solution 00:20:01.980 |
and we use a similar reasoning that we used for the bunny population to 00:20:08.380 |
Discretize the system of the state space model. So how to discretize that system? 00:20:18.320 |
The definition of the derivative we found out that a function evaluated at the time step t plus delta is 00:20:29.240 |
Delta multiplied by the derivative of the function at time step t plus the function at time step t 00:20:36.560 |
But what is the derivative of the function at time step t? 00:20:39.900 |
We know from our state space model that the derivative of this function h prime is 00:20:44.760 |
Equal to a multiplied by h of t plus b multiplied by x of t. So we just replace this 00:20:54.080 |
Into this term here and we obtain the following derivation 00:20:58.920 |
So we replace this stuff here with h prime of t 00:21:02.680 |
We multiply it by delta we collect this term h of t and we obtain the following 00:21:09.560 |
Discretized the parameters of the discretized model. So if we set a 00:21:17.160 |
the identity matrix plus delta a and b bar equal to 00:21:32.480 |
Problem the bunny population problem that allow us to compute the next state of the model given the previous state 00:21:45.280 |
So in the paper they show first of all the continuous 00:21:48.880 |
Formulation of the state space model, which is this one and then they show you the discrete 00:21:53.880 |
Discretized model that you can see here and they also show you how to obtain this discretized parameter 00:22:01.360 |
So the parameters of the discretized model, which is this a bar and b bar actually in the paper 00:22:07.200 |
They do not use the Euler method. So the one that I used here to derive the discretized version 00:22:12.900 |
They use a method called the zero-order hold and I believe they use this one because it results in a better approximation 00:22:18.960 |
But the idea of the discretization is that instead of calculating the analytical solution of this differential equation here 00:22:27.360 |
We calculate approximately what is the H state of the system at discrete time steps 00:22:38.480 |
State value into this second relationship to get the output of the system 00:22:44.920 |
But in practice as you saw before we had to choose the delta parameter to discretize the system in practice 00:22:54.320 |
We do not choose the discretization step delta, but we let make it a parameter of the model that we will train with gradient descent 00:23:01.120 |
So that the model will learn this parameter delta based on the input that it will see 00:23:10.080 |
Now that we have our recurrent formulation to calculate the output of the system sequentially 00:23:15.780 |
How can we use it to calculate the output of the system for various time steps? Let's do it 00:23:21.600 |
Suppose for simplicity that the initial state of the system is zero 00:23:25.600 |
So the first state of the system, so suppose we have an input made up of x0, x1, x2, x3, x4 00:23:33.280 |
So it's a discrete input. That's why we are using the discretized state space model 00:23:37.360 |
we can compute the first state of the system which is equal to b multiplied by x0 because 00:23:43.040 |
We are multiplying a by the previous state, but the previous state we say it's zero 00:23:47.920 |
So this term will not be present and by using this state we can calculate the first output of the system 00:23:55.920 |
Now we can calculate the next state of the system by using the previous state and the next input 00:24:04.160 |
So a multiplied a bar multiplied by the previous state of the system 00:24:08.400 |
Which we already computed plus b multiplied by the next input of the system 00:24:13.040 |
Having the next state we can compute the next output which is y1 is equal to c multiplied by h1 00:24:19.280 |
Now using the previous state we can compute the next state 00:24:23.120 |
So h2, h2 is equal to a multiplied by the previous state plus b multiplied by the next input 00:24:28.880 |
And using h2 we can compute y2. So the output at the next step 00:24:34.320 |
And this is exactly what we used to do with the recurrent neural networks 00:24:38.880 |
So that's why they are called the recurrent neural networks because they have this recurrent formulation to go from the previous step 00:24:47.280 |
And as you can see, this is exactly what we did 00:24:53.600 |
And we use our first input and the previous state to calculate the first output 00:24:58.960 |
This will not only produce the first output, but also the next state of the system 00:25:04.240 |
Then we use the next state of the system plus the next input to produce the next output and the next state of the system 00:25:10.640 |
Then we use the next state of the system plus the next output 00:25:13.840 |
Input to produce the next output and the next state of the system, etc, etc. Just like we saw before 00:25:19.600 |
The recurrent formulation that we have just seen is great for inference 00:25:25.200 |
Because we can compute one token at a time with a constant memory and computation requirement because the effort that we are 00:25:33.440 |
Making to compute this output here is exactly the same as the effort that we take to 00:25:38.480 |
Compute this output here. So the second output token 00:25:42.720 |
And this makes it suitable for inference during a large language model because in large language model 00:25:48.880 |
We generate one token at a time using the previous tokens and the prompt 00:25:53.120 |
However, the recurrent formulation is not good for training because during training we already have all the tokens 00:26:00.080 |
We do not need to generate the output of the model one token at a time 00:26:04.560 |
We already have our sequence of input. We already have a sequence that is the target 00:26:10.000 |
We want to compute the output of the model in parallel without doing this sequential 00:26:14.660 |
Computation to calculate the loss and train the model and this is exactly what we do with the transformer 00:26:20.960 |
But we cannot do it with this recurrent computation 00:26:25.180 |
Thankfully the state space model also provide a convolutional mode as well, which can be parallelized. Let's see how it works 00:26:32.060 |
So using just the recurrent formulation that we have seen before let's expand the output at every step 00:26:39.900 |
So as you saw before to compute the first state of the model 00:26:43.500 |
We need to take the initial input multiplied by b because we suppose that the previous state is zero. So there is no a 00:26:54.140 |
We can use to use the first state to compute the first output, which is c multiplied by h zero 00:27:03.900 |
We can replace this the h zero with this with with the what what is the with? 00:27:17.340 |
So h one is equal to a multiplied by h zero plus b multiplied by x one 00:27:22.860 |
But what is h zero h zero is exactly this stuff here. So b multiplied by x zero 00:27:28.780 |
so we can compute the output the state of the system at the 00:27:35.420 |
This computation here and the output of the system at time steps one using h1 00:27:42.060 |
And then we expand the expression of h1 to obtain this expression here 00:27:48.540 |
We can use the similar reasoning to expand also the representation of the h2 space 00:27:54.780 |
So that it only depends on the input and the parameters of the model and in the same way 00:27:59.980 |
Also the computation of the output of the model at time step two using only the input and the parameters of the model 00:28:10.700 |
So that to compute the kth token of the output we need to do the following 00:28:16.920 |
Summation. So for example to compute y2 as you can see, we are multiplying x2 multiplied by cb 00:28:23.400 |
So as you can see yk we are multiplying xk multiplied by cb 00:28:28.120 |
Then the previous token so in here x1. So here xk minus one multiplied by cab 00:28:39.880 |
Ca with an additional with the by increasing the power by which a is elevated multiplied by b 00:28:47.000 |
and this is exactly what we say here, so the previous token with a 00:28:50.360 |
An increasing power of a and etc, etc until we reach the kth power of a 00:28:55.880 |
By using the formula we derived we note something interesting 00:29:01.660 |
The output of the system can be calculated using a convolution of a kernel k with the input x 00:29:09.160 |
So this formulation this formula that we derived here, which basically represents the pattern of all the expanded 00:29:22.040 |
We can build a kernel a kernel just like the kernel that we use in convolutional neural networks 00:29:27.480 |
With the first term of the kernel being cb the second being cab the third being ca to the power of kb 00:29:34.440 |
Etc, etc until the last item of the kernel and we can convolve this kernel with the input to compute directly the output 00:29:42.120 |
And as you remember the convolution is something that we can parallelize 00:29:46.700 |
So that's why this is very powerful because if we build this kernel 00:29:51.240 |
We can convolve it on the input to produce the output 00:29:55.080 |
And I want to prove you that the output computed using the kernel like this is exactly the one we just found 00:30:02.280 |
By expanding these formulas before so let's do it 00:30:09.480 |
Cb cab etc, etc up to the ca to the power of kb 00:30:14.440 |
Depending on what is the length of the sequence that we are producing 00:30:20.760 |
Inverting the kernel. So the first term is not a cb 00:30:23.560 |
But it's the last term and the second this one here becomes the second last etc because it's easier to visualize 00:30:34.840 |
But I add some padding later. I will show you why and then we have our output that will be produced by this convolution 00:30:43.480 |
The first output of the convolution is just the kernel 00:30:48.760 |
Slided over the first four tokens of the input 00:30:54.280 |
And we multiply this term with this one this term of the kernel with the this term of the input this 00:31:01.240 |
Term of the kernel with this term of the input this current term of the kernel with this term of the input 00:31:05.880 |
And then we sum all these results and the sum you can see is equal to y0 x 00:31:11.080 |
X0 multiplied by cb and this is exactly result from this formula here 00:31:17.000 |
So as you can see if we only have the first token 00:31:20.200 |
It will be only cb multiplied by the token itself because we don't have any preceding token 00:31:25.320 |
Let's go forward then we slide the kernel forward by one step and it will produce the following output 00:31:33.240 |
So we have this item of the kernel multiplied by x1 00:31:36.760 |
This item of the kernel multiplied by x0 and all the other multiplied by 0 so they will not contribute to the sum 00:31:45.560 |
So as you can see when y is equal the k is equal to one we have x1 multiplied by cb 00:31:51.560 |
And then we have the previous term. So x0 multiplied by cab 00:32:01.480 |
This term is multiplied by this this term is multiplied by this and this term is multiplied by this and the last one is multiplied 00:32:10.760 |
So y2 is computed like this and as you can see if you compare it with this formula, you will see that 00:32:17.240 |
yk so when k is equal to 2 it will be cb multiplied by x2 then cab multiplied by x1 and then 00:32:28.120 |
Without a actually ca to the power of 2 multiplied by b multiplied by x0 which is the previous token compared to x minus 1 00:32:37.080 |
And etc for the step number 4 and this one you can verify by yourself 00:32:40.700 |
The best thing about the convolution calculation is that it can be parallelized 00:32:46.220 |
Because the output of yk does not depend on the previous output. So what I mean is 00:32:54.200 |
product that we are doing here can be done on one thread for example or one core of the 00:32:59.960 |
GPU and this one can be done also simultaneously because they do not depend on each other 00:33:05.800 |
So the convolutional calculation can be parallelized 00:33:12.440 |
The the kernel we need to build the kernel for the convolution and building this kernel can be a little 00:33:18.680 |
Expensive from a computational point of view, but also from a memory point of view 00:33:23.640 |
We can still use the convolutional computation to perform the training 00:33:27.160 |
Because we can parallelize the training because we already have the input token 00:33:31.240 |
We already have the target so we can compute the output of the model in parallel for all the input 00:33:36.280 |
tokens, even if it's expensive to do it computationally 00:33:39.740 |
And then we can use the recurrent formulation to inference from this model one token at a time 00:33:45.800 |
And we also know that by just doing this the computation cost for producing one token will always be the same 00:33:52.840 |
No matter how which token we are producing if it's the first one or the 100th token or the 1000th token 00:34:00.120 |
And this is different from the transformer in which 00:34:02.680 |
If we want to generate tokens to produce the first token is less expensive than to produce the 100th token 00:34:09.320 |
Because we need to do much more dot products when we generate the 100th token 00:34:14.680 |
So with the KVCache we would do 100 dot products 00:34:17.800 |
One thing that is worth mentioning is that in the paper they do not use the term d 00:34:23.720 |
Here when they compute the output and the reason is that this can be thought of as a skip connection 00:34:31.740 |
So imagine we have our input x which is the input of the system 00:34:36.540 |
So suppose we have this x this is sent to some black box that we will call state space model 00:34:42.860 |
That will do this calculation of the state recurrently. So continuously 00:34:52.220 |
But then we can see that this dx of t basically means that we take our input 00:34:57.900 |
We skip the state space model. We multiply it with some number d and we send it directly to the output 00:35:04.940 |
So we do not need to model this d to model the state space 00:35:09.980 |
Because it does not depend on the state of the system at any time step 00:35:14.140 |
And this is why it can be represented as a skip connection. And this is why in the paper they do not mention it 00:35:20.460 |
The analysis that we have done so far on this state space model as I told you before we will consider 00:35:27.100 |
We for simplicity we are considering that all the parameters are just the single numbers 00:35:31.500 |
But usually and the input is just a single number and the output is just a single number 00:35:37.500 |
But usually when we work with language models, especially our input is not a single number 00:35:42.700 |
but it's a vector because suppose we have a sequence of tokens each token is represented by a vector of 00:35:48.780 |
512 dimensions in the case of the vanilla transformer 00:35:53.020 |
And maybe 4096 dimensions in the case of llama, for example 00:35:57.500 |
So how can we work with the state space models, but with a vector input and a vector output? 00:36:04.620 |
The idea is that each dimension of the input vector will be managed by an independent state space model 00:36:12.940 |
So imagine we have an input made up of a vector of 512 tokens 00:36:19.740 |
Which needs to produce an output token of 512 dimensions. We will create one state space model for each dimension 00:36:26.620 |
So one for the dimension zero one for the dimension one dimension two, etc, etc for all the dimensions 00:36:32.000 |
And all these state space muscles are independent from each other 00:36:39.820 |
But it's not so strange if you think about the multi-head attention of the transformer because in the transformer 00:36:45.820 |
We also have an input made up of vectors each one representing an embedding of a token 00:36:51.740 |
Suppose we have 512 dimensions for each embedding 00:36:55.680 |
Then the multi-head attention one basically groups some dimensions 00:37:00.320 |
Suppose in the vanilla transformer we have eight heads 00:37:04.540 |
So it means that the 512 is divided by eight which means that every 00:37:09.500 |
Head is managing 64 dimensions and each head is independent from the others 00:37:15.820 |
And this is so it's if it works for the transformer. It can also work for the state space model and actually it does 00:37:28.060 |
The input and the y t have now become vectors and they will have the sequence dimension and also the 00:37:35.260 |
D model dimension, which means that we have an input that is made up of suppose 512 dimensions 00:37:45.260 |
Now in the recurrent formulation of the state space model we have this matrix a which is quite important 00:37:53.740 |
The a matrix in the state space model can be intuitively thought of as a matrix that captures the information of the past state 00:38:06.860 |
It also decides how this information is copied forward in time 00:38:12.060 |
So when we produce the output of a state space model at a particular time step for which means for a particular token 00:38:18.860 |
It depends on the state at that time step, but the state at time step t depends on all the previous states 00:38:26.540 |
So basically this a matrix tells us how the information of the model is conveyed forward in time 00:38:35.420 |
Basically captures all the history of the input of the model so far to build the new token 00:38:42.540 |
This means that we need to be very careful about the structure of this a matrix 00:38:47.740 |
Otherwise, it may not very well capture the history of all the inputs seen so far 00:38:55.100 |
And this is very important for language models because the next token generated by the model 00:38:59.660 |
should depend on all the previous tokens, which are the prompt because when you 00:39:04.220 |
Send a prompt to a language model, which is a list of tokens. The model should attend should 00:39:10.460 |
Watch all the previous tokens to produce the new one 00:39:12.940 |
And in the case of the state space model, this is taken care of by the matrix a so the structure of this matrix is very 00:39:21.260 |
And if you have some background in controls engineering, you may remember that in a state space model the a matrix 00:39:30.940 |
Which means that the output of the system may diverge if the system is not stable 00:39:36.220 |
But if you don't have a background in controls engineering, it doesn't matter 00:39:39.740 |
So to make the a matrix behave well, the authors chose to use the hippo theory. Let's see how it works 00:39:45.980 |
To give you an intuition on what the hippo theory does. I will use borrow some the Fourier transformation 00:39:53.420 |
So as you remember the Fourier transformation allow us to decompose a signal 00:39:58.220 |
So suppose this is our initial signal into a series of sinusoidal functions 00:40:03.180 |
So this one this one and this one such that when we sum all these functions 00:40:09.500 |
Each one with its respective amplitude value you can see here 00:40:17.100 |
And we do something similar also in the state space model with the hippo theory 00:40:22.640 |
But instead of using sinusoidal functions, we use the Legendre polynomials 00:40:28.540 |
With the hippo theory we build the a matrix in such a way that it approximates all the input signals so far 00:40:35.980 |
So let me show you here. So imagine this is our input signal and it evolves over time 00:40:44.220 |
The goal of the a matrix is to capture information from this input signal and convey it forward for the next 00:40:52.780 |
State so it captures the history of all these signals 00:40:56.940 |
Suppose this is our state if we use the state to rebuild the original signal using the hippo theory 00:41:03.340 |
We will have an output. So the rebuild the reconstructed signal that is very good 00:41:09.580 |
Approximate well approximated for the more recent time steps and not so good 00:41:15.580 |
Well approximated for the previous time step. So as you can see this approximation here is not very 00:41:21.180 |
Good, if you compare it with the original signal here, but the more recent is very well reconstructed 00:41:27.420 |
You can see here and this is what the hippo matrix does it reconstructs build a state space representation 00:41:33.120 |
that captures well the information of recent tokens and 00:41:37.900 |
Decays the information of the past token using some reasoning very similar to the exponentially moving average 00:41:45.280 |
And this is very important for language modeling because in language models 00:41:50.300 |
We have a prompt made up of tokens and we need to produce the next token 00:41:54.460 |
So we need to build a hidden state because the hidden state will determine the output of the system 00:41:59.660 |
So the next token so we need to build a hidden state that captures very well 00:42:03.500 |
Information about the local context and it we can afford to lose some information about the global context. So information about 00:42:10.940 |
Tokens that are very far from the one we are producing 00:42:16.540 |
And in the in a paper called efficiently modeling long sequences with structured state spaces 00:42:25.020 |
The author says that just initializing the A matrix with the hippo matrix we can see here 00:42:31.980 |
so this is how to build the hippo matrix, which is 00:42:37.340 |
All the values above the principal diagonal here are zero and other values are computed as follows 00:42:44.380 |
So for the diagonal they are computed as n plus one which n is the row and the k is the column 00:42:55.900 |
It will result in a very big increase in performance of the model 00:43:00.940 |
Why? because the A matrix is the one that is responsible for capturing information about the previous state and carrying it over to the new state 00:43:08.860 |
So we want to carry the information well in the new state 00:43:13.340 |
So that the h of t which will become a vector 00:43:18.540 |
If we have a multi-dimensional state space model 00:43:22.540 |
Then this h of t captures very well the information of all the previous inputs so that we can produce the next output 00:43:30.300 |
Okay, now it's time to talk about Mamba and let's talk about the motivation that led to the building of Mamba 00:43:37.820 |
And in the paper the the authors they describe two tasks on which the vanilla state space model 00:43:43.340 |
So the state space model that we have described up to now or even the S4 model 00:43:49.740 |
Which is basically just the state space model with a very rigid structure on the A matrix like the one we saw before 00:43:55.820 |
So these two models they do not perform well on two specific tasks 00:43:59.900 |
One is the selective copying and one is the induction heads 00:44:03.580 |
So let's introduce these tasks the copying task basically means that we have some input tokens. So this one's here blue 00:44:13.740 |
Green and the model has to produce the same outputs but time shifted 00:44:18.480 |
And this actually can be done by the vanilla state space model 00:44:22.460 |
Because it can be actually done with a simple convolution and the convolution can learn the time shifting that we are doing 00:44:31.420 |
However, the selective copying which means that we have the some input tokens, for example, the blue white white 00:44:41.180 |
and then green cannot and the model needs to produce only the 00:44:48.780 |
This one cannot be done by the vanilla state space model because the vanilla spaces model 00:44:54.460 |
That is not content cannot do content aware reasoning because the parameters of the model are the same for every step 00:45:00.940 |
So the same for every token, so it will treat each token equally 00:45:07.500 |
The blue token and the white token and just ignore the white one and keep the blue the the colored one and not the 00:45:16.540 |
So to give you an intuition of how what this could mean, for example, imagine we are given a twitter 00:45:22.620 |
Comment on twitter and we want to rewrite this comment by removing all the bad words 00:45:28.220 |
So all the white tokens you can see here and this one cannot be done by the state space model 00:45:32.780 |
Because the parameters a b c and d are the same for each input and they cannot do they cannot change for each particular input 00:45:40.380 |
The second task on which the state space model have a difficulty is the induction head which means that the model needs to recall 00:45:51.640 |
Inputs to build the current input and they show for example this example 00:45:56.300 |
so for example, every time the model sees the 00:45:58.380 |
The black token it should output the blue token by recalling what it has seen before 00:46:04.540 |
But because the model cannot use a content aware reasoning it cannot do perform. Well this task 00:46:13.740 |
This is the motivation that led to the creation of mamba 00:46:16.780 |
And this is the reference in the paper in which they talk about these two tasks 00:46:21.020 |
So from the recurrent view the recursive dynamics 00:46:23.840 |
So the transition in the because of the a and the b matrix cannot let them select the correct information from the current context 00:46:31.340 |
Because the a and b metrics are the same for each input token and also from the convolutional view 00:46:37.340 |
It it is known that the model is able to solve the copying task because it can learn the time shift 00:46:43.660 |
but they have a difficulty with the selective 00:46:47.020 |
Copying task because the parameters of the model are the same for each input 00:46:51.740 |
So they do not know how to treat differently a particular input 00:46:55.660 |
Now let's talk about the innovation of mamba and how it differs from the state space model 00:47:03.020 |
Let's first review. What is the algorithm of the state space model? 00:47:06.620 |
So the state space model basically has an input and it has to produce an output 00:47:11.180 |
The input is this one. So it's a tensor of b l and d dimension 00:47:16.460 |
which is a batch dimension sequence dimension and the d model and this is the 00:47:20.620 |
Exactly the same as the transformer because we have a batch of prompts 00:47:27.340 |
Which is our sequence length and each token is made up of a vector of d model dimension 00:47:31.980 |
Which is our d here and it has to produce an output of the same dimension 00:47:38.620 |
We have an a matrix which is a matrix of parameters that indicates how to copy the previous state into the new state 00:47:46.060 |
And we we model the state as a vector of n tokens. So this will be a vector 00:47:55.180 |
So the a matrix is d by n which is basically because we have a 00:48:01.100 |
when we have an input vector made up of d dimensions, we have 00:48:07.880 |
State space models independent from each other one for each dimension just like I saw before 00:48:16.120 |
This b matrix is also d by n. The c matrix is also d by n 00:48:21.720 |
And the delta is the step size of the discretization which is learned by the model as I showed before 00:48:28.280 |
So we don't decide it. We just let the model learn this 00:48:31.880 |
Step size and because we have a d state space model because our input vector is d dimensional 00:48:40.760 |
And to discretize we just apply the formula that we saw before so depending which 00:48:47.160 |
Discretization rule we are using if we use the Euler method or we use the zero order hold in the case of mamba 00:48:58.040 |
So a bar and b bar and we can run the ssm as a recurrence or a convolution depending if we are 00:49:08.280 |
We run it as a convolution using building the kernel like I showed you before and running it on the input like I showed before 00:49:14.120 |
Or as a recurrence using the formula we can see here. So this formulation here 00:49:19.000 |
So we compute the next step the next state using the previous state and then we use each state to build the output 00:49:25.400 |
In the case of mamba they make the state space model selective which means that the parameters of the model 00:49:33.720 |
Are changed for each input token. Let's see. The input is still bld. So we have a sequence of 00:49:42.680 |
Batch of prompts each prompt is made of l tokens. Each token is made up of d dimensions and the output has the same shape 00:49:56.680 |
By a linear layer this sb and this sb will project the d dimension into the n dimension. You can see here 00:50:04.600 |
This means that basically now the b matrix is not the same for all input tokens because now we have this l dimension here 00:50:12.360 |
Which means that for every input token of the each batch 00:50:18.520 |
And the same goes for the c matrix you can see here and the same goes for the delta you can see here 00:50:25.240 |
For each input token, we will have a different delta 00:50:38.920 |
However, because now the model is not a time invariant anymore because the parameters of the model 00:50:45.800 |
Change for each input. So for each token for each step for each time step 00:50:54.120 |
Recurrence so we can only apply this formulation here 00:50:57.240 |
We cannot compute it with the convolution anymore because the kernel will not be fixed 00:51:02.200 |
Because before because the parameters of the model are fixed for all the inputs 00:51:06.680 |
We can just build a kernel and run it for all the inputs 00:51:09.720 |
but now for every input we should use a different kernel so we cannot compute it as a 00:51:15.000 |
Convolution we are only forced to compute it as a recurrence 00:51:23.240 |
Have you noticed that the authors talk about this scan operation you can see here 00:51:27.960 |
So what is this scan operation to when evaluating the model as a recurrence? Let's talk about it 00:51:36.440 |
So if you have ever done competitive programming you are familiar with the prefix sum array 00:51:41.640 |
Which is an array calculated sequentially such that the value at each position indicates the sum of all the previous values 00:51:49.240 |
We can easily compute it with the for loop in linear time 00:51:55.720 |
We can calculate the prefix sum as like this the first value is equal to the first value 00:52:04.040 |
Previous value of this array. So this one plus the current value of the initial array. So 00:52:09.480 |
Wait, this one is equal to this one plus this one 00:52:13.160 |
And this one is computed using the previous value plus the current value of the initial array 00:52:18.600 |
And this one is computed using the previous value plus the current value of the initial array 00:52:23.400 |
And this one is computed using the previous value plus the current 00:52:29.880 |
Such that each item of this prefix sum indicates the sum of all the items of the initial array up to that 00:52:36.840 |
element. So the number 32 is the result of 10 plus 7 plus 6 plus 9 00:52:45.160 |
The scan operation refers to computing an array like the prefix sum in which each value can be computed using the 00:52:54.120 |
previously computed value and the current input value 00:52:58.840 |
And the recurrent formula of the state space model can also be thought as a scan operation in which each state 00:53:05.800 |
is the sum of the previous state multiplied by an A matrix plus 00:53:13.980 |
So if the model input is X0, X1, X2, X3, X4 and X5 00:53:23.640 |
And then we can compute H1 using the previously computed value plus the current input each one multiplied by the A matrix and the B matrix 00:53:35.400 |
multiplied by A plus X2 multiplied by B and H3 can be computed as H2 00:53:42.680 |
multiplied by A plus X3 multiplied by B, etc, etc, etc 00:53:48.680 |
So to generate the input we just multiply it H_k with the C matrix to generate the output token K 00:53:55.800 |
So if we have this scan output, so if we build a 00:53:59.720 |
Array like this one, so this array you can see here 00:54:03.560 |
We can easily compute the output of the model for each time step 00:54:08.200 |
By just multiplying each of this value with the C matrix 00:54:11.880 |
So we multiply it by C multiplied by C and this one here multiplied by C to compute 00:54:24.120 |
Now what if I told you that the scan operation that I have shown you can be parallelized 00:54:31.580 |
Of course, you will not believe me because the scan operation is one of those operations that naturally looks like a sequence 00:54:37.560 |
So to compute the current value, I need to have the previous value plus the current input 00:54:42.760 |
So how can I parallelize an operation like this? Actually, it can be parallelized 00:54:47.100 |
As long as the operations that we are doing are associative 00:54:50.780 |
Means that they benefit from the associative property 00:54:53.800 |
Now if you remember from elementary school or middle school when they teach you about the properties of the addition and the multiplication 00:55:00.460 |
You may recall that the associative property means that if you have 00:55:04.120 |
The operation done on three operands, for example A multiplied B multiplied C 00:55:09.400 |
It does not matter the order in which you do these operations 00:55:12.780 |
So it does not matter where you put the parentheses, the result will be the same 00:55:16.520 |
So you can do A multiplied by B and then the result multiplied by C 00:55:19.960 |
Or you can do A multiplied by the result of B multiplied by C 00:55:25.800 |
So as long as the operations that we are doing have this property, then we can parallelize the scan operation 00:55:31.960 |
I want to show you how to actually do it practically 00:55:36.040 |
So imagine the initial array is this we can create multiple threads each one computing 00:55:42.120 |
In parallel a sum. So for example, we have a we can have for example this 00:55:47.800 |
Okay, this is actually the picture I took from wikipedia and it's made for 16 input array 00:55:56.280 |
The first thread will compute the summation of the first two elements 00:55:59.400 |
the second thread will compute the summation of the third and the fourth the third of the 00:56:04.680 |
Fourth the fifth and the sixth etc, etc. And then we use the result of this 00:56:11.640 |
And then in parallel and then we use it to compute the next step in parallel and then to compute the next step in parallel 00:56:20.440 |
This is called the sweep down if I remember correctly and then we have a sweep up operation to rebuild all the 00:56:27.240 |
Values that we didn't compute. So for example, this value here is not computed until the last step 00:56:36.680 |
Basically, we can decrease the time complexity of the scan operation from a sequence or so o of n 00:56:43.080 |
It's reduced to o of n divided by t where t is the number of parallel threads that are computing this operation 00:56:50.280 |
And in my github repository, I have also put a excel file that shows you how to 00:56:56.440 |
How this scan operation is computed step by step. So I actually show you all the intermediate steps. So if you want to 00:57:05.480 |
Now that we know that this parallel scan can be done in parallel, this is actually what the authors do. They also do the 00:57:15.080 |
Recurrence to calculate the output of the model in parallel to reduce its time complexity 00:57:20.220 |
Basically since Mamba cannot be evaluated using a convolution because it's time varying 00:57:25.960 |
So it means that the parameters of the model are different for each time step 00:57:30.120 |
Our only way of computing the output is to use the recurrent formulation 00:57:33.800 |
But thanks to the parallel scan algorithm, this can be parallelized to reduce its time complexity 00:57:38.620 |
The authors also indicate some techniques that they have used to make this algorithm faster 00:57:43.480 |
The first technique that they show you is the kernel fusion 00:57:47.000 |
The second is the parallel scan which I have already shown and then we have the 00:57:50.840 |
Circumputation of the activations that we will show later. So let's see all these techniques one by one 00:57:56.280 |
But first let's see how the memory hierarchy of the GPU works 00:57:59.960 |
The GPU basically it's a very fast calculator. So it's a very very very big 00:58:06.200 |
computational unit that can do a lot of operations in parallel 00:58:13.880 |
The one that you are mostly familiar with, the one that you actually check when you buy a GPU is called the DRAM 00:58:20.600 |
So it's the high bandwidth memory and it's in the order of gigabytes 00:58:24.520 |
And then the GPU also has a smaller memory, a local memory that is called the SRAM 00:58:29.560 |
The difference between the two is that first of all the SRAM is much much much much much smaller 00:58:36.680 |
And however, this is where the GPU will do the computation 00:58:40.600 |
So when the GPU needs to do some matrix multiplication 00:58:42.920 |
It will first of all copy the information from the high bandwidth memory to the SRAM 00:58:47.020 |
Then the core of the GPU will access the information in the SRAM to do the computation 00:58:53.720 |
And then the result will be saved back to this high bandwidth memory 00:58:59.400 |
Actually, if we check the data sheet of a GPU, in this case, this is the NVIDIA A100 00:59:04.300 |
You will see that the GPU is very fast at computing operations 00:59:09.320 |
But the copying of information from the SRAM to the DRAM is not very fast, not as fast as computing operations 00:59:19.560 |
The copy speed of the copying speed is much slower 00:59:25.160 |
So this is like two terabytes per second compared to the number of operations that the GPU can do 00:59:30.120 |
In this case, it can do 20 tera floating point operations per second of 32 bit 00:59:36.440 |
So this parameter here is basically 40 times faster than this parameter here 00:59:42.360 |
This also means that when we create an algorithm that runs on the GPU, so ECUDA kernel 00:59:48.740 |
Sometimes the kernel may run slowly not because we are doing a lot of operations 00:59:53.940 |
But maybe because we are copying a lot of stuff around which results in a slow overall computation 00:59:59.800 |
And if this happens, we see that our operation is I/O bound because it's bounded by the I/O speed, by the copying speed 01:00:09.460 |
Now in the authors, they exploit this different hierarchy of the GPU to make their algorithm 01:00:16.900 |
So the selective scan faster. So the main idea is to leverage the properties of modern accelerators 01:00:22.740 |
So the GPU to maximize the to materialize the state edge. So the hidden state only in more 01:00:29.380 |
Efficient levels of the memory hierarchy. So only in the SRAM, so the smaller memory 01:00:34.420 |
Concretely instead of preparing the scan input because they compute the recurrence as a scan operation 01:00:40.340 |
Just like I showed you before. So the scan input is what is 01:00:51.540 |
Size, the D model, so the size of the input vector and N 01:00:54.660 |
In the GPU high bandwidth memory, so in the DRAM, they load all the parameter of the state space model 01:01:01.860 |
So the delta, the A matrix, the B matrix, and the C matrix directly from the highest 01:01:06.100 |
Bandwidth memory, so the DRAM, into the fast SRAM. They perform the discretization in the SRAM 01:01:13.140 |
The recurrence also, so the scan operation is also done in this SRAM and finally the result of this scan is 01:01:22.900 |
And they also make use of what is known as a kernel fusion 01:01:28.660 |
So what is kernel fusion? When we perform a tensor operation, our deep learning framework 01:01:34.420 |
So PyTorch, it will load the tensor. Suppose we are doing a matrix multiplication. It will load the tensor from the fast memory 01:01:42.900 |
From the slow memory to the fast memory. So from the DRAM to the SRAM of the GPU 01:01:47.460 |
It will perform the operation, for example the matrix multiplication, and then it will save back the result 01:01:52.580 |
From the SRAM to the DRAM. So from the SRAM to the high bandwidth memory of the GPU 01:01:57.860 |
However, what if we do three operations on the same tensor in sequence? 01:02:03.620 |
The deep learning framework will do something like this. So it will load first of all the tensor from the 01:02:10.260 |
High bandwidth memory to the SRAM. It will compute the first operation 01:02:13.700 |
Which means calling the CUDA kernel associated with the first operation and then save the result back to the high bandwidth memory 01:02:20.580 |
Then it will load again the result of the previous computation from the high bandwidth memory into the fast memory 01:02:28.900 |
Launching the second CUDA kernel associated with this operation and then it will save the result of this operation back to the high bandwidth memory 01:02:38.740 |
Load the result of the previous computation again from the high bandwidth memory into the fast memory 01:02:43.700 |
compute the third operation and then save back the result into the high bandwidth memory. As you can see 01:02:49.620 |
the total time is occupied in this case when we have three operations in sequence is 01:02:55.460 |
occupied by the copying operations that we are performing from the 01:02:58.900 |
high bandwidth memory to the fast memory and from the fast memory back to the high bandwidth memory because we know that 01:03:10.760 |
So kernel fusion means that to make a sequence of operations faster 01:03:17.700 |
So the three operations that we are doing in sequence into one custom CUDA kernel such that 01:03:23.220 |
We don't copy the intermediate results to the high bandwidth memory 01:03:27.460 |
But we keep doing these operations in the fast memory until we have done all these three computations 01:03:32.420 |
And then only the last result is saved into the high bandwidth memory 01:03:36.180 |
This speeds up the overall computation because we don't have the intermediate copy operations 01:03:47.920 |
Okay, the last innovation of this selective scan algorithm is the recomputation of the activations, let's see what is it 01:03:55.060 |
So when we train a deep learning model, this model gets converted into a computation graph 01:04:00.580 |
When we perform a back propagation in order to calculate the gradients at each node of this computation graph 01:04:06.340 |
We need to cache the output values of each node that we have done during the forward step 01:04:11.700 |
So imagine we have a very simple model like the one I show you here 01:04:15.380 |
Let me show with a pointer. So this model basically computes the output using just a linear operation 01:04:23.940 |
Parameter w1 plus x2 multiplied by this parameter w2 plus a bias 01:04:30.260 |
Suppose we have done the forward process and it has produced at each node its own value 01:04:34.900 |
During the back propagation our goal is to calculate the gradient of the loss function with respect to each parameter 01:04:42.840 |
Here so with respect to w1 with respect to w2 and with respect to the bias 01:04:50.500 |
The gradient of the loss function with respect to the w1 for example, I show you the step here 01:04:57.060 |
And to compute this gradient we need to also compute the gradient of all the intermediate nodes 01:05:02.900 |
And to compute the gradient of the intermediate nodes 01:05:06.100 |
We need to have the values of the activations of each node that we had during the forward step 01:05:11.940 |
So for example to compute the gradient of the loss function with respect to this node y_pred 01:05:16.900 |
Which result in the expression 2 multiplied by y_pred minus 2 multiplied by target. We need to cache the 01:05:24.340 |
Value y_pred that we had during the forward step 01:05:28.020 |
And these activations actually can occupy a lot of memory in a very big network. And this is why in the 01:05:37.780 |
Recomputing them. So since caching the activations and then reusing them during back propagation means that we need to save them 01:05:46.340 |
To the high bandwidth memory. So the slow memory and then copy back them from the slow memory during back propagation 01:05:52.900 |
It may be faster to just recompute them during back propagation because maybe the recomputation 01:05:57.560 |
Speed because the gpu is very fast at computing operations than it is at copying 01:06:02.820 |
Maybe just recomputing them is faster than copying them 01:06:06.100 |
So this is the reference in the paper in which they describe this technique 01:06:09.780 |
So they say finally we must also avoid saving the intermediate states which are necessary for back propagation 01:06:15.380 |
So the intermediate states are all the activations of all the nodes of this computation graph 01:06:20.420 |
So we carefully apply the classic technique of recomputation to reduce the memory requirements. So the intermediate states are not stored 01:06:27.140 |
They are not stored in the high bandwidth memory 01:06:29.780 |
But recomputed during the backward pass when the inputs are loaded from the high bandwidth memory to the fast RAM 01:06:35.380 |
So basically it's faster to just redo the calculations again instead of copying this information to the high band memory and then 01:06:42.340 |
Reloading it from high bandwidth memory to the fast memory 01:06:48.820 |
Now let's look at the block that makes up Mamba in the Mamba architecture 01:06:53.860 |
So first I introduce what is the Mamba block and then we will show you all the Mamba architecture 01:06:58.360 |
So Mamba is built by stacking multiple layers of this Mamba block that we can see here 01:07:03.620 |
And this is very similar to the stacked layer of the transformer 01:07:08.660 |
We have the encoder and the decoder and the encoder side and the decoder sides are made by stacking these blocks 01:07:16.420 |
with the self-attention and the feed forward the network 01:07:19.220 |
multiple times on top of the other such that the 01:07:22.500 |
Output of the one block is sent as input to the next block and the output of the last block is sent to the output 01:07:29.380 |
Of the model and this is exactly what we do with Mamba in which we create 01:07:33.620 |
First of all, we have our input which will be sent to this block and this block will be repeated many times such that the 01:07:41.060 |
Output of this block here will become the input of the next block 01:07:46.420 |
At the beginning of this block we have a linear layer two linear layers that convert the size 01:07:52.180 |
D model into d inner as least this is how they call it in the code 01:07:56.740 |
So d model is the size of the vector of our embedding. So imagine we have an embedding size of 512 01:08:03.620 |
This is d model and the d inner can be chosen. You can choose for example double the size of the d model 01:08:13.460 |
Then they have a convolution here on this branch and this is actually 01:08:16.820 |
To mix up kind of the tokens with each other because otherwise the state space model they will be running independently 01:08:24.040 |
For each dimension, but this convolution makes up all these dimensions 01:08:28.520 |
And then we have these two silo activations. We have the state space 01:08:33.060 |
Model here that runs the recurrence using the parallel scan algorithm that I have just shown you before 01:08:40.660 |
and then we will multiply element wise product of this branch and the output of the state space model then the 01:08:47.140 |
Linear, there is another linear layer that will project back the d inner 01:08:50.980 |
So the inner dimension of this block to the outer dimension, which is the d model 01:08:55.140 |
So we go back to the 512 dimensions of the embedding size initially if we have chosen the model equal to 512 01:09:07.380 |
And I drew this architecture by myself by analyzing the code so I didn't have time to make it very beautiful 01:09:13.380 |
But okay, it's very similar to what we do with the transformer. So we have our input it gets converted into embeddings 01:09:20.360 |
So it becomes a sequence of tokens each token made up of an embedding of size. Let's say 512 01:09:26.680 |
And then we have many blocks like this one after another 01:09:31.780 |
We have n of them such that the output of one block is sent as input to the next one 01:09:36.660 |
And the mama block that I show you in the previous slide is basically just this one 01:09:41.940 |
But they also include a rms norm at the beginning and then a skip connection 01:09:46.900 |
You can see here and this is repeated n times 01:09:53.060 |
Lama and just like mistral because we have this rms norm and then we have the linear layer that will project the output embedding 01:10:00.580 |
back into our vocabulary and then we have a softmax to choose the 01:10:04.260 |
Which will indicate which token from our vocabulary 01:10:08.600 |
We need to choose as the next token if we are modeling a language model 01:10:20.500 |
So as you remember at the initial when we started talking about mamba mamba was introduced to solve the problems of the selecting 01:10:27.300 |
selective copying task and the induction task because 01:10:33.060 |
State space models were not performing very well in 01:10:38.180 |
So they wanted to solve this problem with mamba by using they introduced that's why they introduced the selective state space model 01:10:47.720 |
And we can see that the state space model. So the s4 model. So the structural space space model 01:10:55.780 |
Performs quite poorly on this selective copying but mamba performs very well 01:11:02.900 |
Of accuracy and the mamba basically this layer here is called the s6 layer 01:11:09.540 |
And the s4 layer is the one described in the previous paper. So structure state space model 01:11:14.340 |
While on the induction heads we can see that also mamba is performing very well 01:11:19.780 |
So the accuracy of mamba you can see here is always 01:11:23.700 |
Nearly 100% actually it's 100% for sequence length that can reach 10 to the power of 6 01:11:29.620 |
So very very very very very long sequence length in comparison 01:11:33.380 |
For example, the transformer model with the absolute positional encoding or also rotary positional encodings 01:11:38.360 |
Start degrading in accuracy when the sequence length reaches a certain size 01:11:43.060 |
And so they are very quite good up to a few hundred tokens 01:11:49.860 |
But they start degrading as soon as they reach the thousands of tokens, but mamba maintains a very 01:11:55.540 |
Consistent performance over even very long sequence length and this is very important for language modeling because 01:12:02.180 |
The prompts especially with the retrieval augmented generation, but also with chat applications, etc 01:12:06.980 |
They are becoming very long. So we want models that can perform well on very very long sequence 01:12:17.280 |
Performance so the number of operations that we need to do to train a model to reach a certain perplexity 01:12:26.800 |
So mamba actually performs as good as the best transformer model that we have now 01:12:32.080 |
So the transformer model like lama and mistral. This is the transformer plus plus you can see here 01:12:38.560 |
And it performs very similarly to the best model that we have here. So it's a very good 01:12:45.420 |
Concurrent to the transform, but as we saw in the previous slide it can scale much better for longer sequences 01:12:50.780 |
And this is why it became quite popular recently 01:12:53.660 |
Thank you guys for watching my video. I hope you learned a lot in this video 01:12:57.900 |
I wanted to make a video that was very descriptive and also very 01:13:02.140 |
Technically in detail because I wanted to derive all the formulations of mamba. I just don't like to throw formulas at people 01:13:08.620 |
And mamba I think will be a very popular model in the future 01:13:12.460 |
Even if I think it has its own limitations, for example 01:13:15.180 |
It's still a recurrent neural network because it's still run like a recurrence 01:13:19.120 |
And it may have its own limitations for for example 01:13:22.460 |
We still don't know how well it performs on massive amounts of data like data that has been used for lama or for 01:13:30.780 |
So but I think people are looking for alternatives for the transformer because the transformer has shown its limitations, especially for 01:13:38.300 |
Scaling to very long sequence length which are very much needed for language modeling 01:13:43.260 |
But also with recent models for image generation movie generation and audio generation 01:13:52.700 |
Transformer is massive because the the scaling power is quadratic 01:13:57.840 |
so it results in a really high memory consumption and that's why 01:14:02.860 |
People normal people cannot even inference a model like mistral on their computer unless they use the model sharding 01:14:14.380 |
More research is done in this area. So thank you for watching my video 01:14:18.060 |
I hope you like this video and you will subscribe to my channel 01:14:21.900 |
Please share this video with your friends and share it on your social media. This is the best way to support me