back to index

Mamba and S4 Explained: Architecture, Parallel Scan, Kernel Fusion, Recurrent, Convolution, Math


Chapters

0:0 Introduction
1:46 Sequence modeling
7:12 Differential equations (basics)
11:38 State Space Models
13:53 Discretization
23:8 Recurrent computation
26:32 Convolutional computation
34:18 Skip connection term
35:21 Multidimentional SSM
37:44 The HIPPO theory
43:30 The motivation behind Mamba
46:56 Selective Scan algorithm
51:34 The Scan operation
54:24 Parallel Scan
57:20 Innovations in Selective Scan
58:0 GPU Memory Hierarchy
61:23 Kernel Fusion
61:48 Activations recomputation
66:48 Mamba architecture
70:18 Performance considerations
72:54 Conclusion

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello guys, welcome back to my channel. Today we are going to talk about Mamba
00:00:03.840 | So Mamba is a new model for sequence modeling that came out just one month ago in a paper called Mamba linear time sequence modeling
00:00:11.040 | With selective state spaces. Let's review the topics of today. In the first part of the video
00:00:16.140 | I will be introducing what are sequence models and what kind of sequence modeling we can do
00:00:20.560 | The second part of the video I will be talking about state space models
00:00:25.820 | But to fully understand the state space models, we need to have a little background on differential equation
00:00:31.340 | I of course don't expect you to have this background because in some
00:00:35.200 | Bachelor degree or some master degree it is taught but in some most of the cases
00:00:40.220 | It's not taught
00:00:42.360 | So I will give you the necessary background to understand differential equations
00:00:46.040 | and later we will talk about state space models and we will derive the formula for the
00:00:51.040 | Magnetization and we will also derive the formula for the convolutional computation and the recurrent computation
00:00:55.980 | I will show you what do we mean by the hippo matrix and the importance of the A matrix in state space models
00:01:02.280 | In the second the third part of the video we will be talking about Mamba
00:01:06.220 | So what was the motivation that led to the to Mamba and what is the innovation of Mamba, which is the selective scan algorithm?
00:01:12.860 | So first of all, what do we mean by scan operation and we will see what that this scan operation can be parallelized
00:01:20.080 | With the parallel scan we will see what is kernel fusion the recomputation of the activations and finally
00:01:26.080 | We will explore the architecture of Mamba and some performance consideration with respect to the transformer
00:01:31.860 | So as a prerequisite, I just hope that you have a basic basics of calculus
00:01:36.700 | I think a high school mathematics will be more than enough and
00:01:39.640 | You have a basic understanding of the transformer model and neural networks in general. So let's start our journey
00:01:47.460 | The goal of a sequence model is to map an input sequence to an output sequence
00:01:51.940 | The input sequence can be a continuous signal in that case
00:01:55.100 | We want to map it to an output continuous signal or it can be a discrete input signal and we want to map it
00:02:01.360 | into a discrete output signal for example
00:02:04.040 | Continuous signal could be audio and a discrete signal could be text for example
00:02:08.600 | actually most of the time we work with the discrete signals even with the case of audio because we sample the audio file over time and
00:02:16.700 | if we talk about language modeling, we are talking about a discrete input because we have a finite number of tokens and
00:02:22.620 | We want to map it to an output sequence of tokens
00:02:26.740 | We can choose among many models to do sequence modeling let's review them
00:02:32.900 | The first model that comes to mind to do sequence modeling is the recurrent neural network
00:02:39.140 | Which is a network in which we have a hidden state and we compute the output as follows
00:02:45.220 | So we have for example our input sequence made up of x1 x2 and x3
00:02:51.140 | What we do is the first time the hidden sequence we initialize the hidden state is initialized with zeros
00:02:57.720 | So we feed to the network the hidden state the first hidden state
00:03:01.760 | So zeros along with the first input and it will produce the first output
00:03:05.940 | Then we use the previously produced the hidden state. So the output of the previous step
00:03:14.540 | It will produce a new hidden state and the new output token for a new
00:03:19.520 | Input token and this will be the y2 so the output number two
00:03:23.720 | So we use the previously generated hidden state
00:03:26.500 | Along with a new input token to produce a new output token and a new hidden state that will be used for the next
00:03:32.860 | Token as you can see
00:03:35.100 | This is sequential generation of output is not parallelizable because to generate the nth token
00:03:40.380 | We need the n-1th token
00:03:44.020 | So the training of this kind of model cannot be parallelizable. And this is one of the reason the transformer has been so successful
00:03:50.100 | However, the inference time is a constant for each token
00:03:54.880 | Which means that the number of the effort
00:03:56.680 | from a computational point of view and also from a memory point of view that we need to put in order to produce one
00:04:02.180 | token of the output is the same doesn't matter if it's the first token or the
00:04:06.980 | 100th token the effort with that we are doing so the number of
00:04:10.740 | Operations and computations that we are doing is always the same. We are taking the previous state and the current input to produce one output token
00:04:18.420 | Theoretically, this one has infinite context length because we can continue with this sequence sequence of forever
00:04:25.200 | But practically we cannot because it suffers from a vanishing and exploding gradient
00:04:30.000 | Another model that we can use to do sequence modeling, but it's not very used recent
00:04:36.620 | It's not very used. It's the convolutional neural network, which is mostly used for computer vision tasks
00:04:42.340 | It has a finite context window and it needs to build a kernel that is run
00:04:48.040 | So this is the kernel that is run through the input to produce output features. So this is the output
00:04:53.380 | It is easily parallelizable because each output uses the same kernel
00:04:58.780 | So we can run the kernel in parallel on all the possible input windows to produce the output features
00:05:05.860 | The last model that we will see is the transformer
00:05:08.480 | The transformer is easily parallelizable when it comes to training because we have this self-attention mechanism
00:05:15.060 | that computes many dot products and since it's a matrix multiplication we can parallelize the operation and
00:05:21.140 | it has a finite context window defined by the sequence the input sequence and the attention mask and
00:05:30.220 | The inference of this kind of model, however is not constant for each token. So in the transformer model if we are
00:05:36.540 | Producing what the first token of the output we will do one dot product
00:05:40.740 | By using the KVCache if we are using the we are producing the 10th output token
00:05:47.620 | We will need to do 10 dot products to produce it
00:05:50.540 | And if you are producing the 100th output token, we will need to do 100 dot products
00:05:56.740 | So the effort that we put to produce the first token is not the same as the effort that we need to produce the 10th
00:06:03.340 | token and this is not good because
00:06:05.780 | This doesn't allow us to scale to very long inputs also for the training because the training as you can see that scales
00:06:14.720 | Quadratically with the input sequence length, which means that if we double the sequence length
00:06:19.860 | We need to do four times more computation to train this model
00:06:23.300 | But it's easily parallelizable
00:06:26.340 | So and in an ideal world that we would like the model for which we can
00:06:31.140 | Parallelize the training just like the transformer because we can exploit our GPU very well and it can scale linearly to long sequences
00:06:39.300 | Just like the RNN because it scales linearly with the input sequence length
00:06:44.460 | And in an ideal world, we also would like to inference each token with a constant memory and the computation cost
00:06:51.940 | Which means that the effort we need to put to produce the first token should be the same as the effort that we put to
00:06:57.900 | Produce the 10th token or the 100th token
00:07:00.580 | Now let's explore state space models and how they can help us solve the problems of the recurrent neural networks and the transformer
00:07:09.060 | But first we need to review some maths
00:07:11.340 | Let me give you a very simple introduction to differential equation with a very simple example
00:07:18.580 | Imagine you have some bunnies and the population of these bunnies grows at a constant rate of lambda
00:07:23.900 | Proportional to the number of bunnies that you have which means that every bunny will give birth to lambda baby bunnies
00:07:30.220 | Now also suppose that the lambda is equal to 2. So let's write lambda equal to 2
00:07:35.300 | So we can see that the rate of change of the population
00:07:39.500 | Which means the number of babies that are born is equal to lambda
00:07:44.140 | Multiplied by the number of bunnies that you have at a particular time step t
00:07:48.340 | So which means also that the rate of change of this population because the number of babies that these bunnies produce is also the rate
00:07:55.060 | Of change of this population
00:07:56.540 | So how much the population is growing is equal to lambda multiplied by the number of bunnies that you have
00:08:00.860 | But if you remember from high school, what is the rate of change of a particular function?
00:08:05.900 | It is the variable is the derivative of this function
00:08:09.380 | so we can also say that the
00:08:12.460 | derivative of the function that describes the number of bunnies in time is
00:08:17.060 | Equal to lambda multiplied by the number of bunnies that you have at a particular time step t
00:08:21.980 | How can we find the population at time step
00:08:28.140 | 100 knowing that the population is made up of five bunnies at time equal to 0
00:08:34.180 | We need to find a function B of t
00:08:37.460 | So we need to find a function B of t that describes the population of our bunnies over time
00:08:43.820 | So the evolution of our population over time
00:08:46.220 | Solving a differential equation means to find a function in this case B of t that makes the expression
00:08:53.700 | above so this expression here
00:08:58.700 | For all t so we need to find a function t that when replaced in this expression makes the left hand side equal to the right
00:09:05.980 | hand side and
00:09:07.980 | This differential equation can be solved using a very simple method called the method of separation of variables
00:09:14.160 | So I will not show you here, but we can clearly see that a solution to this
00:09:18.300 | Differential equation is this function here B of t is equal to K multiplied by
00:09:23.580 | Exponential of lambda by t where K is equal to the initial population of bunnies
00:09:30.540 | How can we verify that this is actually the solution of this differential equation because if we replace this
00:09:36.260 | Function here inside this expression it will make the left hand side equal to the right hand side
00:09:41.380 | This is how we verify the solution of it equation
00:09:43.900 | So let's try to replace it first of all in the left hand side
00:09:47.580 | We have the derivative of this function with respect to t so let's calculate the derivative of this function with respect to t
00:09:53.260 | Which is equal to K
00:09:56.060 | Multiplied by lambda multiplied by A to the power of lambda t and
00:10:01.780 | This is equal to
00:10:04.380 | Lambda multiplied by the function itself, which is K
00:10:08.220 | multiplied by e to the power of lambda t
00:10:11.860 | You can see that these two expressions are the same so this function here
00:10:17.820 | So this function here is a solution to our differential equation
00:10:22.820 | So when you have a differential equation the solution of a differential equation is not a number
00:10:27.620 | When we like when we solve a quadratic equation so when you solve for example the
00:10:31.900 | Equation X to the power of 2 is equal to 4 you get some values of X that make the left-hand side
00:10:38.660 | Equal to the right-hand side in the case of a differential equation you find a function
00:10:43.900 | That makes the left-hand side equal to the right-hand side and usually we represent the differential equation by omitting the variable t
00:10:52.180 | And writing it as follows so B t so B dot is equal to lambda
00:10:57.940 | Multiplied by B. This dot indicates that it's derivative of this function with respect to a certain variable, which is implied
00:11:08.180 | basically the result of this differential equation will
00:11:11.420 | Will describe the evolution of our bunnies so this function here that we have found B of t
00:11:17.980 | Will describe how the population of our bunnies will grow
00:11:23.260 | We usually use the differential equations to model the state of a system over time
00:11:28.100 | With the goal of finding a function that gives us the state of the system at any time step
00:11:34.340 | Given the initial state of the system, and this is what we do with state space models
00:11:39.380 | A state space model allow us to map an input signal X of t to an output signal Y of t
00:11:45.980 | By means of a state representation H of t as follows
00:11:50.740 | So H prime of t which means the derivative of this function H of t with respect to time is equal to a
00:11:58.140 | number A or a matrix A multiplied by H of t plus B
00:12:02.820 | Multiplied by X of t and then the output of the system is computed as follows
00:12:08.260 | This state space model is linear and time invariant
00:12:12.860 | It's linear because the relationships in the expressions above are linear and the time invariant because the parameter matrices
00:12:20.020 | A, B, C and D do not vary over time because you can see they do not depend on the time
00:12:26.100 | They are always fixed for each time step
00:12:28.100 | So for now for simplicity, we will consider that all these parameters
00:12:33.980 | So A, B, C and D, but also the input the output and the H of t are numbers not vectors
00:12:39.500 | Later, we will expand our analysis to vectors
00:12:41.980 | Now you may be wondering
00:12:44.500 | Okay, we have these expressions, but how can I compute the output of this model given the input X of t?
00:12:51.940 | As you can see in the first expression, we have a differential equation and to compute the output
00:12:57.780 | Like this we need to have a function H of t that describes the state of our system at every time step t
00:13:05.700 | So to find the output Y of t of this
00:13:09.740 | This model we need to find first of all, we need to solve this differential equation here
00:13:14.300 | So it means to find a function H of t that describes the state of the system for all time step
00:13:20.260 | but solving this differential equation can be hard analytically and
00:13:24.860 | Usually also we do not work with the continuous signals
00:13:28.740 | So we do our input X of t usually is not continuous because when we do work with computer or any digital device
00:13:35.060 | We always work with discrete systems
00:13:37.740 | So one way is to actually discretize this system such that we can calculate an approximate
00:13:43.020 | Solution of this H of t in a discretized way. So not in a continuous way
00:13:48.220 | Let's see how we can discretize our system to compute then the output of the system itself
00:13:53.460 | To solve a differential equation we need to find the function H of t so in our case this function here
00:14:01.460 | So let me change to the pointer
00:14:03.940 | so this function a function H of t that when replaced here will make the left-hand side equal to the right-hand side and
00:14:10.740 | But solving this finding this function H of t in the continuous domain
00:14:16.460 | It can be very hard analytically and also because we work with the discrete input
00:14:21.460 | We can discretize this system, which means that we can approximate the solution of this differential equation
00:14:27.420 | So instead of finding H of t for all the time step we can find H of t for certain time steps
00:14:33.880 | Instead so for example, we may want to find H of t for H of 0, H of 1, H of 2, H of 3, etc
00:14:41.140 | In this case, we have chosen a step size the step size in my case
00:14:46.040 | I have chosen 1 because I'm evaluating the function H of t only in certain time steps
00:14:51.300 | that are separated by a step size of 1 so I go from 0 to 1 from 1 to 2 from 2 to 3 and
00:14:57.700 | This is will be our Delta
00:15:01.620 | Remember the bunnies problem. Let's try to find the approximate solution of this bunny problem using the Euler's method
00:15:07.900 | So before I give you the result of this bunnies problem using the analytical solution, which was very easy to compute
00:15:14.860 | But let's try to build the same solution
00:15:18.100 | But using the Euler's method or actually not the same, but it's the approximated solution using the Euler's method
00:15:24.100 | So first of all, let's rewrite the our bunny population model
00:15:28.180 | Which is the derivative of the function that describes the bunny population
00:15:33.220 | With respect to the variable time is equal to lambda multiplied by the function itself
00:15:39.180 | If you remember from high school
00:15:42.340 | What is the definition of the derivative the derivative is equal to the limit of a step?
00:15:47.860 | Going to 0 of the function evaluated at t plus the step
00:15:53.540 | Minus the function at time step t divided by the step size. This is the definition of derivative. So there is nothing new here
00:16:03.060 | This this is the the left-hand side is equal to the right-hand side when the Delta so the step size that we are using
00:16:10.260 | It's very very very very very small. So suppose that we choose a very very very small step size
00:16:15.980 | So then we can get rid of this equality and change it to an approximate approximation so we can see that
00:16:22.980 | Okay, what if I choose a very small step size, then the left-hand side will be more or less equal to the right-hand side
00:16:30.700 | Now by multiplying with the Delta on both sides and then bringing this term to the right-hand side
00:16:38.260 | We can write the following formula here
00:16:41.160 | Which allow us to compute the value of the function B at the time step t plus Delta. So one step forward
00:16:48.600 | As the derivative of the function at time step t multiplied by Delta plus the function at time step t
00:16:56.540 | But what is the derivative of the function here?
00:17:02.700 | We know that the derivative of the function that we are trying to approximate is equal to the Lambda multiplied by the function itself
00:17:10.380 | So we can replace
00:17:12.820 | here in this
00:17:15.100 | Expression here we can replace our population model and we obtain the following
00:17:20.060 | Expression and this allow us to compute the bunnies population at time step t plus Delta
00:17:26.260 | So at the next time step given the bunnies population at the previous time step
00:17:31.540 | Multiplied by a Delta and the Lambda. So we obtained a recurrent formulation
00:17:37.540 | Wonderful now that we have our recurrent formulation to approximate the state of the bunny population over time
00:17:45.640 | Suppose we set Lambda equal to 2 which means that every bunny will make two children and Delta is equal to 1
00:17:52.980 | So we want to go forward by one by one step
00:17:55.980 | So t is equal to 0 then t equal to 1, t equal to 2, etc, etc
00:17:59.380 | For example, if we started with a population of 5 bunnies at time step 0
00:18:04.380 | We can calculate the evolution of the population as follows
00:18:08.040 | Knowing that the population at time step 0 we can calculate the population at time step 1 using the formula that we have
00:18:14.620 | So B of 1 is equal to Lambda
00:18:17.660 | Delta multiplied by Lambda multiplied by the population at the previous time step. So just like this formula here
00:18:24.660 | Plus the population at the previous time step
00:18:26.540 | So this gives us a 15 and it makes sense why because we started with 5 bunnies. Each bunny has 2 children
00:18:32.900 | So they have 10 babies and the initial population, which is 5 plus 10 babies gives us a population of 15
00:18:40.140 | Now that we have the population at time step 1 we can calculate the population at the time step 2 which is
00:18:46.520 | 45 because we had 15 bunnies each one makes 2 children. So we have 30 children
00:18:53.360 | So 15 plus 30 is equal to 45
00:18:56.500 | Now that we have the population at time step 2 we can calculate the population at time step 3 and it's
00:19:01.860 | 135 because we have a 45 bunnies each one of them makes 2 children. So we have 90 children plus 45, which is
00:19:09.900 | 135 and
00:19:12.500 | If you compare the solution that we are obtaining for this time step with the analytical solution that we found before
00:19:19.820 | That I didn't show you how to find it, but it can be easily found using a method called the separation of variables
00:19:25.880 | If you compare the plot of these two functions
00:19:29.920 | You will see that the analytical function grows very fast while the approximated solution grows very slow
00:19:35.340 | and this is actually and the two are similar in how they
00:19:39.440 | Change but they will not overlap
00:19:43.320 | Because the step size that we have chosen is very big
00:19:46.520 | I mean we chose lambda equal to delta equal to 1 to make a better approximation
00:19:51.520 | We need to use a delta very very small and actually the Euler method does not give us very good
00:19:57.440 | Results very very good approximations of the analytical solution
00:20:01.980 | and we use a similar reasoning that we used for the bunny population to
00:20:08.380 | Discretize the system of the state space model. So how to discretize that system?
00:20:14.160 | we just as you remember we found using the
00:20:18.320 | The definition of the derivative we found out that a function evaluated at the time step t plus delta is
00:20:25.760 | more or less so approximately equal to
00:20:29.240 | Delta multiplied by the derivative of the function at time step t plus the function at time step t
00:20:36.560 | But what is the derivative of the function at time step t?
00:20:39.900 | We know from our state space model that the derivative of this function h prime is
00:20:44.760 | Equal to a multiplied by h of t plus b multiplied by x of t. So we just replace this
00:20:51.600 | expression here
00:20:54.080 | Into this term here and we obtain the following derivation
00:20:58.920 | So we replace this stuff here with h prime of t
00:21:02.680 | We multiply it by delta we collect this term h of t and we obtain the following
00:21:09.560 | Discretized the parameters of the discretized model. So if we set a
00:21:14.460 | bar equal to
00:21:17.160 | the identity matrix plus delta a and b bar equal to
00:21:24.640 | Delta B. We will obtain a
00:21:28.260 | Recurrent formula just like with the bunny
00:21:32.480 | Problem the bunny population problem that allow us to compute the next state of the model given the previous state
00:21:42.440 | Actually in the paper this
00:21:45.280 | So in the paper they show first of all the continuous
00:21:48.880 | Formulation of the state space model, which is this one and then they show you the discrete
00:21:53.880 | Discretized model that you can see here and they also show you how to obtain this discretized parameter
00:22:01.360 | So the parameters of the discretized model, which is this a bar and b bar actually in the paper
00:22:07.200 | They do not use the Euler method. So the one that I used here to derive the discretized version
00:22:12.900 | They use a method called the zero-order hold and I believe they use this one because it results in a better approximation
00:22:18.960 | But the idea of the discretization is that instead of calculating the analytical solution of this differential equation here
00:22:27.360 | We calculate approximately what is the H state of the system at discrete time steps
00:22:33.680 | And then we can plug this
00:22:35.680 | this approximated
00:22:38.480 | State value into this second relationship to get the output of the system
00:22:44.920 | But in practice as you saw before we had to choose the delta parameter to discretize the system in practice
00:22:54.320 | We do not choose the discretization step delta, but we let make it a parameter of the model that we will train with gradient descent
00:23:01.120 | So that the model will learn this parameter delta based on the input that it will see
00:23:10.080 | Now that we have our recurrent formulation to calculate the output of the system sequentially
00:23:15.780 | How can we use it to calculate the output of the system for various time steps? Let's do it
00:23:21.600 | Suppose for simplicity that the initial state of the system is zero
00:23:25.600 | So the first state of the system, so suppose we have an input made up of x0, x1, x2, x3, x4
00:23:33.280 | So it's a discrete input. That's why we are using the discretized state space model
00:23:37.360 | we can compute the first state of the system which is equal to b multiplied by x0 because
00:23:43.040 | We are multiplying a by the previous state, but the previous state we say it's zero
00:23:47.920 | So this term will not be present and by using this state we can calculate the first output of the system
00:23:53.920 | Which is just c multiplied by h0
00:23:55.920 | Now we can calculate the next state of the system by using the previous state and the next input
00:24:01.840 | Just using the formula here
00:24:04.160 | So a multiplied a bar multiplied by the previous state of the system
00:24:08.400 | Which we already computed plus b multiplied by the next input of the system
00:24:13.040 | Having the next state we can compute the next output which is y1 is equal to c multiplied by h1
00:24:19.280 | Now using the previous state we can compute the next state
00:24:23.120 | So h2, h2 is equal to a multiplied by the previous state plus b multiplied by the next input
00:24:28.880 | And using h2 we can compute y2. So the output at the next step
00:24:34.320 | And this is exactly what we used to do with the recurrent neural networks
00:24:38.880 | So that's why they are called the recurrent neural networks because they have this recurrent formulation to go from the previous step
00:24:44.960 | to the next step
00:24:47.280 | And as you can see, this is exactly what we did
00:24:51.040 | So our initial state is zero
00:24:53.600 | And we use our first input and the previous state to calculate the first output
00:24:58.960 | This will not only produce the first output, but also the next state of the system
00:25:04.240 | Then we use the next state of the system plus the next input to produce the next output and the next state of the system
00:25:10.640 | Then we use the next state of the system plus the next output
00:25:13.840 | Input to produce the next output and the next state of the system, etc, etc. Just like we saw before
00:25:19.600 | The recurrent formulation that we have just seen is great for inference
00:25:25.200 | Because we can compute one token at a time with a constant memory and computation requirement because the effort that we are
00:25:33.440 | Making to compute this output here is exactly the same as the effort that we take to
00:25:38.480 | Compute this output here. So the second output token
00:25:42.720 | And this makes it suitable for inference during a large language model because in large language model
00:25:48.880 | We generate one token at a time using the previous tokens and the prompt
00:25:53.120 | However, the recurrent formulation is not good for training because during training we already have all the tokens
00:26:00.080 | We do not need to generate the output of the model one token at a time
00:26:04.560 | We already have our sequence of input. We already have a sequence that is the target
00:26:10.000 | We want to compute the output of the model in parallel without doing this sequential
00:26:14.660 | Computation to calculate the loss and train the model and this is exactly what we do with the transformer
00:26:20.960 | But we cannot do it with this recurrent computation
00:26:25.180 | Thankfully the state space model also provide a convolutional mode as well, which can be parallelized. Let's see how it works
00:26:32.060 | So using just the recurrent formulation that we have seen before let's expand the output at every step
00:26:39.900 | So as you saw before to compute the first state of the model
00:26:43.500 | We need to take the initial input multiplied by b because we suppose that the previous state is zero. So there is no a
00:26:49.900 | This term here disappears in the first state
00:26:54.140 | We can use to use the first state to compute the first output, which is c multiplied by h zero
00:26:59.900 | and this gives us
00:27:03.900 | We can replace this the h zero with this with with the what what is the with?
00:27:09.980 | expanded expression b multiplied by x zero
00:27:13.420 | We can then use h zero to compute h one
00:27:17.340 | So h one is equal to a multiplied by h zero plus b multiplied by x one
00:27:22.860 | But what is h zero h zero is exactly this stuff here. So b multiplied by x zero
00:27:28.780 | so we can compute the output the state of the system at the
00:27:32.620 | time step one using
00:27:35.420 | This computation here and the output of the system at time steps one using h1
00:27:42.060 | And then we expand the expression of h1 to obtain this expression here
00:27:48.540 | We can use the similar reasoning to expand also the representation of the h2 space
00:27:54.780 | So that it only depends on the input and the parameters of the model and in the same way
00:27:59.980 | Also the computation of the output of the model at time step two using only the input and the parameters of the model
00:28:07.020 | As you can see we see a pattern
00:28:10.700 | So that to compute the kth token of the output we need to do the following
00:28:16.920 | Summation. So for example to compute y2 as you can see, we are multiplying x2 multiplied by cb
00:28:23.400 | So as you can see yk we are multiplying xk multiplied by cb
00:28:28.120 | Then the previous token so in here x1. So here xk minus one multiplied by cab
00:28:35.000 | cab, etc
00:28:36.920 | and then the previous previous token with
00:28:39.880 | Ca with an additional with the by increasing the power by which a is elevated multiplied by b
00:28:47.000 | and this is exactly what we say here, so the previous token with a
00:28:50.360 | An increasing power of a and etc, etc until we reach the kth power of a
00:28:55.880 | By using the formula we derived we note something interesting
00:29:01.660 | The output of the system can be calculated using a convolution of a kernel k with the input x
00:29:09.160 | So this formulation this formula that we derived here, which basically represents the pattern of all the expanded
00:29:16.460 | formulas that we saw before
00:29:19.880 | Using this this formula here
00:29:22.040 | We can build a kernel a kernel just like the kernel that we use in convolutional neural networks
00:29:27.480 | With the first term of the kernel being cb the second being cab the third being ca to the power of kb
00:29:34.440 | Etc, etc until the last item of the kernel and we can convolve this kernel with the input to compute directly the output
00:29:42.120 | And as you remember the convolution is something that we can parallelize
00:29:46.700 | So that's why this is very powerful because if we build this kernel
00:29:51.240 | We can convolve it on the input to produce the output
00:29:55.080 | And I want to prove you that the output computed using the kernel like this is exactly the one we just found
00:30:02.280 | By expanding these formulas before so let's do it
00:30:05.720 | We build a kernel like this
00:30:09.480 | Cb cab etc, etc up to the ca to the power of kb
00:30:14.440 | Depending on what is the length of the sequence that we are producing
00:30:18.300 | I am just
00:30:20.760 | Inverting the kernel. So the first term is not a cb
00:30:23.560 | But it's the last term and the second this one here becomes the second last etc because it's easier to visualize
00:30:30.680 | Then we have our input which is x0 x1 x2 x3
00:30:34.840 | But I add some padding later. I will show you why and then we have our output that will be produced by this convolution
00:30:40.700 | So let's run this convolution step by step
00:30:43.480 | The first output of the convolution is just the kernel
00:30:48.760 | Slided over the first four tokens of the input
00:30:54.280 | And we multiply this term with this one this term of the kernel with the this term of the input this
00:31:01.240 | Term of the kernel with this term of the input this current term of the kernel with this term of the input
00:31:05.880 | And then we sum all these results and the sum you can see is equal to y0 x
00:31:11.080 | X0 multiplied by cb and this is exactly result from this formula here
00:31:17.000 | So as you can see if we only have the first token
00:31:20.200 | It will be only cb multiplied by the token itself because we don't have any preceding token
00:31:25.320 | Let's go forward then we slide the kernel forward by one step and it will produce the following output
00:31:33.240 | So we have this item of the kernel multiplied by x1
00:31:36.760 | This item of the kernel multiplied by x0 and all the other multiplied by 0 so they will not contribute to the sum
00:31:42.440 | So y1 is equal to this here
00:31:45.560 | So as you can see when y is equal the k is equal to one we have x1 multiplied by cb
00:31:51.560 | And then we have the previous term. So x0 multiplied by cab
00:31:55.900 | Then we slide our kernel forward
00:31:58.920 | And we can see that
00:32:01.480 | This term is multiplied by this this term is multiplied by this and this term is multiplied by this and the last one is multiplied
00:32:08.280 | By zero so it doesn't contribute to the sum
00:32:10.760 | So y2 is computed like this and as you can see if you compare it with this formula, you will see that
00:32:17.240 | yk so when k is equal to 2 it will be cb multiplied by x2 then cab multiplied by x1 and then
00:32:28.120 | Without a actually ca to the power of 2 multiplied by b multiplied by x0 which is the previous token compared to x minus 1
00:32:37.080 | And etc for the step number 4 and this one you can verify by yourself
00:32:40.700 | The best thing about the convolution calculation is that it can be parallelized
00:32:46.220 | Because the output of yk does not depend on the previous output. So what I mean is
00:32:54.200 | product that we are doing here can be done on one thread for example or one core of the
00:32:59.960 | GPU and this one can be done also simultaneously because they do not depend on each other
00:33:05.800 | So the convolutional calculation can be parallelized
00:33:09.580 | but the problem is that
00:33:12.440 | The the kernel we need to build the kernel for the convolution and building this kernel can be a little
00:33:18.680 | Expensive from a computational point of view, but also from a memory point of view
00:33:22.200 | however
00:33:23.640 | We can still use the convolutional computation to perform the training
00:33:27.160 | Because we can parallelize the training because we already have the input token
00:33:31.240 | We already have the target so we can compute the output of the model in parallel for all the input
00:33:36.280 | tokens, even if it's expensive to do it computationally
00:33:39.740 | And then we can use the recurrent formulation to inference from this model one token at a time
00:33:45.800 | And we also know that by just doing this the computation cost for producing one token will always be the same
00:33:52.840 | No matter how which token we are producing if it's the first one or the 100th token or the 1000th token
00:34:00.120 | And this is different from the transformer in which
00:34:02.680 | If we want to generate tokens to produce the first token is less expensive than to produce the 100th token
00:34:09.320 | Because we need to do much more dot products when we generate the 100th token
00:34:14.680 | So with the KVCache we would do 100 dot products
00:34:17.800 | One thing that is worth mentioning is that in the paper they do not use the term d
00:34:23.720 | Here when they compute the output and the reason is that this can be thought of as a skip connection
00:34:30.060 | Let me show you why
00:34:31.740 | So imagine we have our input x which is the input of the system
00:34:36.540 | So suppose we have this x this is sent to some black box that we will call state space model
00:34:42.860 | That will do this calculation of the state recurrently. So continuously
00:34:48.880 | And it will produce an output y
00:34:52.220 | But then we can see that this dx of t basically means that we take our input
00:34:57.900 | We skip the state space model. We multiply it with some number d and we send it directly to the output
00:35:04.940 | So we do not need to model this d to model the state space
00:35:09.980 | Because it does not depend on the state of the system at any time step
00:35:14.140 | And this is why it can be represented as a skip connection. And this is why in the paper they do not mention it
00:35:20.460 | The analysis that we have done so far on this state space model as I told you before we will consider
00:35:27.100 | We for simplicity we are considering that all the parameters are just the single numbers
00:35:31.500 | But usually and the input is just a single number and the output is just a single number
00:35:37.500 | But usually when we work with language models, especially our input is not a single number
00:35:42.700 | but it's a vector because suppose we have a sequence of tokens each token is represented by a vector of
00:35:48.780 | 512 dimensions in the case of the vanilla transformer
00:35:53.020 | And maybe 4096 dimensions in the case of llama, for example
00:35:57.500 | So how can we work with the state space models, but with a vector input and a vector output?
00:36:04.620 | The idea is that each dimension of the input vector will be managed by an independent state space model
00:36:12.940 | So imagine we have an input made up of a vector of 512 tokens
00:36:17.100 | 512 dimensions
00:36:19.740 | Which needs to produce an output token of 512 dimensions. We will create one state space model for each dimension
00:36:26.620 | So one for the dimension zero one for the dimension one dimension two, etc, etc for all the dimensions
00:36:32.000 | And all these state space muscles are independent from each other
00:36:36.220 | now this idea may look very strange and
00:36:39.820 | But it's not so strange if you think about the multi-head attention of the transformer because in the transformer
00:36:45.820 | We also have an input made up of vectors each one representing an embedding of a token
00:36:51.740 | Suppose we have 512 dimensions for each embedding
00:36:55.680 | Then the multi-head attention one basically groups some dimensions
00:37:00.320 | Suppose in the vanilla transformer we have eight heads
00:37:04.540 | So it means that the 512 is divided by eight which means that every
00:37:09.500 | Head is managing 64 dimensions and each head is independent from the others
00:37:15.820 | And this is so it's if it works for the transformer. It can also work for the state space model and actually it does
00:37:24.700 | the parameters
00:37:26.060 | a b c d
00:37:28.060 | The input and the y t have now become vectors and they will have the sequence dimension and also the
00:37:35.260 | D model dimension, which means that we have an input that is made up of suppose 512 dimensions
00:37:42.000 | Just like in the transformer model
00:37:45.260 | Now in the recurrent formulation of the state space model we have this matrix a which is quite important
00:37:51.900 | because
00:37:53.740 | The a matrix in the state space model can be intuitively thought of as a matrix that captures the information of the past state
00:38:01.660 | So the h t minus one
00:38:03.820 | To build the new state the h of t
00:38:06.860 | It also decides how this information is copied forward in time
00:38:12.060 | So when we produce the output of a state space model at a particular time step for which means for a particular token
00:38:18.860 | It depends on the state at that time step, but the state at time step t depends on all the previous states
00:38:26.540 | So basically this a matrix tells us how the information of the model is conveyed forward in time
00:38:33.420 | So this h of t
00:38:35.420 | Basically captures all the history of the input of the model so far to build the new token
00:38:42.540 | This means that we need to be very careful about the structure of this a matrix
00:38:47.740 | Otherwise, it may not very well capture the history of all the inputs seen so far
00:38:52.780 | Which is needed to produce the next token
00:38:55.100 | And this is very important for language models because the next token generated by the model
00:38:59.660 | should depend on all the previous tokens, which are the prompt because when you
00:39:04.220 | Send a prompt to a language model, which is a list of tokens. The model should attend should
00:39:10.460 | Watch all the previous tokens to produce the new one
00:39:12.940 | And in the case of the state space model, this is taken care of by the matrix a so the structure of this matrix is very
00:39:20.140 | important
00:39:21.260 | And if you have some background in controls engineering, you may remember that in a state space model the a matrix
00:39:28.060 | Also determines the stability of the system
00:39:30.940 | Which means that the output of the system may diverge if the system is not stable
00:39:36.220 | But if you don't have a background in controls engineering, it doesn't matter
00:39:39.740 | So to make the a matrix behave well, the authors chose to use the hippo theory. Let's see how it works
00:39:45.980 | To give you an intuition on what the hippo theory does. I will use borrow some the Fourier transformation
00:39:53.420 | So as you remember the Fourier transformation allow us to decompose a signal
00:39:58.220 | So suppose this is our initial signal into a series of sinusoidal functions
00:40:03.180 | So this one this one and this one such that when we sum all these functions
00:40:09.500 | Each one with its respective amplitude value you can see here
00:40:14.380 | They will rebuild the original signal
00:40:17.100 | And we do something similar also in the state space model with the hippo theory
00:40:22.640 | But instead of using sinusoidal functions, we use the Legendre polynomials
00:40:28.540 | With the hippo theory we build the a matrix in such a way that it approximates all the input signals so far
00:40:35.980 | So let me show you here. So imagine this is our input signal and it evolves over time
00:40:42.140 | So x of t you can see here
00:40:44.220 | The goal of the a matrix is to capture information from this input signal and convey it forward for the next
00:40:52.780 | State so it captures the history of all these signals
00:40:56.940 | Suppose this is our state if we use the state to rebuild the original signal using the hippo theory
00:41:03.340 | We will have an output. So the rebuild the reconstructed signal that is very good
00:41:09.580 | Approximate well approximated for the more recent time steps and not so good
00:41:15.580 | Well approximated for the previous time step. So as you can see this approximation here is not very
00:41:21.180 | Good, if you compare it with the original signal here, but the more recent is very well reconstructed
00:41:27.420 | You can see here and this is what the hippo matrix does it reconstructs build a state space representation
00:41:33.120 | that captures well the information of recent tokens and
00:41:37.900 | Decays the information of the past token using some reasoning very similar to the exponentially moving average
00:41:45.280 | And this is very important for language modeling because in language models
00:41:50.300 | We have a prompt made up of tokens and we need to produce the next token
00:41:54.460 | So we need to build a hidden state because the hidden state will determine the output of the system
00:41:59.660 | So the next token so we need to build a hidden state that captures very well
00:42:03.500 | Information about the local context and it we can afford to lose some information about the global context. So information about
00:42:10.940 | Tokens that are very far from the one we are producing
00:42:16.540 | And in the in a paper called efficiently modeling long sequences with structured state spaces
00:42:22.460 | Which is basically the paper of the S4 model
00:42:25.020 | The author says that just initializing the A matrix with the hippo matrix we can see here
00:42:31.980 | so this is how to build the hippo matrix, which is
00:42:34.540 | basically an n by n matrix in which
00:42:37.340 | All the values above the principal diagonal here are zero and other values are computed as follows
00:42:44.380 | So for the diagonal they are computed as n plus one which n is the row and the k is the column
00:42:50.860 | If we just build the A matrix like this
00:42:55.900 | It will result in a very big increase in performance of the model
00:43:00.940 | Why? because the A matrix is the one that is responsible for capturing information about the previous state and carrying it over to the new state
00:43:08.860 | So we want to carry the information well in the new state
00:43:13.340 | So that the h of t which will become a vector
00:43:18.540 | If we have a multi-dimensional state space model
00:43:22.540 | Then this h of t captures very well the information of all the previous inputs so that we can produce the next output
00:43:30.300 | Okay, now it's time to talk about Mamba and let's talk about the motivation that led to the building of Mamba
00:43:37.820 | And in the paper the the authors they describe two tasks on which the vanilla state space model
00:43:43.340 | So the state space model that we have described up to now or even the S4 model
00:43:47.820 | Which is the structured state space model
00:43:49.740 | Which is basically just the state space model with a very rigid structure on the A matrix like the one we saw before
00:43:55.820 | So these two models they do not perform well on two specific tasks
00:43:59.900 | One is the selective copying and one is the induction heads
00:44:03.580 | So let's introduce these tasks the copying task basically means that we have some input tokens. So this one's here blue
00:44:11.180 | orange red and
00:44:13.740 | Green and the model has to produce the same outputs but time shifted
00:44:18.480 | And this actually can be done by the vanilla state space model
00:44:22.460 | Because it can be actually done with a simple convolution and the convolution can learn the time shifting that we are doing
00:44:31.420 | However, the selective copying which means that we have the some input tokens, for example, the blue white white
00:44:38.300 | orange red
00:44:41.180 | and then green cannot and the model needs to produce only the
00:44:45.820 | Colored tokens so not the white tokens
00:44:48.780 | This one cannot be done by the vanilla state space model because the vanilla spaces model
00:44:54.460 | That is not content cannot do content aware reasoning because the parameters of the model are the same for every step
00:45:00.940 | So the same for every token, so it will treat each token equally
00:45:04.860 | so it cannot say distinguish between
00:45:07.500 | The blue token and the white token and just ignore the white one and keep the blue the the colored one and not the
00:45:14.460 | the white ones
00:45:16.540 | So to give you an intuition of how what this could mean, for example, imagine we are given a twitter
00:45:22.620 | Comment on twitter and we want to rewrite this comment by removing all the bad words
00:45:28.220 | So all the white tokens you can see here and this one cannot be done by the state space model
00:45:32.780 | Because the parameters a b c and d are the same for each input and they cannot do they cannot change for each particular input
00:45:40.380 | The second task on which the state space model have a difficulty is the induction head which means that the model needs to recall
00:45:49.340 | information from the previous
00:45:51.640 | Inputs to build the current input and they show for example this example
00:45:56.300 | so for example, every time the model sees the
00:45:58.380 | The black token it should output the blue token by recalling what it has seen before
00:46:04.540 | But because the model cannot use a content aware reasoning it cannot do perform. Well this task
00:46:13.740 | This is the motivation that led to the creation of mamba
00:46:16.780 | And this is the reference in the paper in which they talk about these two tasks
00:46:21.020 | So from the recurrent view the recursive dynamics
00:46:23.840 | So the transition in the because of the a and the b matrix cannot let them select the correct information from the current context
00:46:31.340 | Because the a and b metrics are the same for each input token and also from the convolutional view
00:46:37.340 | It it is known that the model is able to solve the copying task because it can learn the time shift
00:46:43.660 | but they have a difficulty with the selective
00:46:47.020 | Copying task because the parameters of the model are the same for each input
00:46:51.740 | So they do not know how to treat differently a particular input
00:46:55.660 | Now let's talk about the innovation of mamba and how it differs from the state space model
00:47:03.020 | Let's first review. What is the algorithm of the state space model?
00:47:06.620 | So the state space model basically has an input and it has to produce an output
00:47:11.180 | The input is this one. So it's a tensor of b l and d dimension
00:47:16.460 | which is a batch dimension sequence dimension and the d model and this is the
00:47:20.620 | Exactly the same as the transformer because we have a batch of prompts
00:47:24.780 | Each prompt is made up of l tokens
00:47:27.340 | Which is our sequence length and each token is made up of a vector of d model dimension
00:47:31.980 | Which is our d here and it has to produce an output of the same dimension
00:47:36.940 | Just like the transformer
00:47:38.620 | We have an a matrix which is a matrix of parameters that indicates how to copy the previous state into the new state
00:47:46.060 | And we we model the state as a vector of n tokens. So this will be a vector
00:47:52.700 | of n tokens
00:47:55.180 | So the a matrix is d by n which is basically because we have a
00:48:01.100 | when we have an input vector made up of d dimensions, we have
00:48:07.880 | State space models independent from each other one for each dimension just like I saw before
00:48:13.240 | So we have this parameter matrix is a d by n
00:48:16.120 | This b matrix is also d by n. The c matrix is also d by n
00:48:21.720 | And the delta is the step size of the discretization which is learned by the model as I showed before
00:48:28.280 | So we don't decide it. We just let the model learn this
00:48:31.880 | Step size and because we have a d state space model because our input vector is d dimensional
00:48:38.200 | So we have a d number of delta
00:48:40.760 | And to discretize we just apply the formula that we saw before so depending which
00:48:47.160 | Discretization rule we are using if we use the Euler method or we use the zero order hold in the case of mamba
00:48:53.400 | They use the zero order hold method
00:48:55.800 | And then we have this discretized parameter
00:48:58.040 | So a bar and b bar and we can run the ssm as a recurrence or a convolution depending if we are
00:49:04.840 | Training it or inferencing
00:49:07.240 | So when we are training it
00:49:08.280 | We run it as a convolution using building the kernel like I showed you before and running it on the input like I showed before
00:49:14.120 | Or as a recurrence using the formula we can see here. So this formulation here
00:49:19.000 | So we compute the next step the next state using the previous state and then we use each state to build the output
00:49:25.400 | In the case of mamba they make the state space model selective which means that the parameters of the model
00:49:33.720 | Are changed for each input token. Let's see. The input is still bld. So we have a sequence of
00:49:42.680 | Batch of prompts each prompt is made of l tokens. Each token is made up of d dimensions and the output has the same shape
00:49:49.960 | Then we have the a matrix which is d by n
00:49:53.560 | The b matrix basically is modified
00:49:56.680 | By a linear layer this sb and this sb will project the d dimension into the n dimension. You can see here
00:50:04.600 | This means that basically now the b matrix is not the same for all input tokens because now we have this l dimension here
00:50:12.360 | Which means that for every input token of the each batch
00:50:16.040 | We will have a different b matrix
00:50:18.520 | And the same goes for the c matrix you can see here and the same goes for the delta you can see here
00:50:25.240 | For each input token, we will have a different delta
00:50:28.920 | And the discretized
00:50:32.280 | Matrices have this dimension. So bld n
00:50:35.240 | And we can run the state space model
00:50:38.920 | However, because now the model is not a time invariant anymore because the parameters of the model
00:50:45.800 | Change for each input. So for each token for each step for each time step
00:50:51.640 | We can only run it as a
00:50:54.120 | Recurrence so we can only apply this formulation here
00:50:57.240 | We cannot compute it with the convolution anymore because the kernel will not be fixed
00:51:02.200 | Because before because the parameters of the model are fixed for all the inputs
00:51:06.680 | We can just build a kernel and run it for all the inputs
00:51:09.720 | but now for every input we should use a different kernel so we cannot compute it as a
00:51:15.000 | Convolution we are only forced to compute it as a recurrence
00:51:23.240 | Have you noticed that the authors talk about this scan operation you can see here
00:51:27.960 | So what is this scan operation to when evaluating the model as a recurrence? Let's talk about it
00:51:36.440 | So if you have ever done competitive programming you are familiar with the prefix sum array
00:51:41.640 | Which is an array calculated sequentially such that the value at each position indicates the sum of all the previous values
00:51:49.240 | We can easily compute it with the for loop in linear time
00:51:53.320 | So imagine we have this initial array
00:51:55.720 | We can calculate the prefix sum as like this the first value is equal to the first value
00:52:01.000 | the second value is computed as the
00:52:04.040 | Previous value of this array. So this one plus the current value of the initial array. So
00:52:09.480 | Wait, this one is equal to this one plus this one
00:52:13.160 | And this one is computed using the previous value plus the current value of the initial array
00:52:18.600 | And this one is computed using the previous value plus the current value of the initial array
00:52:23.400 | And this one is computed using the previous value plus the current
00:52:27.320 | element of the array
00:52:29.880 | Such that each item of this prefix sum indicates the sum of all the items of the initial array up to that
00:52:36.840 | element. So the number 32 is the result of 10 plus 7 plus 6 plus 9
00:52:45.160 | The scan operation refers to computing an array like the prefix sum in which each value can be computed using the
00:52:54.120 | previously computed value and the current input value
00:52:58.840 | And the recurrent formula of the state space model can also be thought as a scan operation in which each state
00:53:05.800 | is the sum of the previous state multiplied by an A matrix plus
00:53:10.600 | The current input multiplied by the B matrix
00:53:13.980 | So if the model input is X0, X1, X2, X3, X4 and X5
00:53:19.400 | We can compute for example H0 using only X0
00:53:23.640 | And then we can compute H1 using the previously computed value plus the current input each one multiplied by the A matrix and the B matrix
00:53:31.960 | Then we can compute H2 using H1
00:53:35.400 | multiplied by A plus X2 multiplied by B and H3 can be computed as H2
00:53:42.680 | multiplied by A plus X3 multiplied by B, etc, etc, etc
00:53:48.680 | So to generate the input we just multiply it H_k with the C matrix to generate the output token K
00:53:55.800 | So if we have this scan output, so if we build a
00:53:59.720 | Array like this one, so this array you can see here
00:54:03.560 | We can easily compute the output of the model for each time step
00:54:08.200 | By just multiplying each of this value with the C matrix
00:54:11.880 | So we multiply it by C multiplied by C and this one here multiplied by C to compute
00:54:19.400 | This is Y1
00:54:21.080 | This is Y2, etc, etc, etc
00:54:24.120 | Now what if I told you that the scan operation that I have shown you can be parallelized
00:54:31.580 | Of course, you will not believe me because the scan operation is one of those operations that naturally looks like a sequence
00:54:37.560 | So to compute the current value, I need to have the previous value plus the current input
00:54:42.760 | So how can I parallelize an operation like this? Actually, it can be parallelized
00:54:47.100 | As long as the operations that we are doing are associative
00:54:50.780 | Means that they benefit from the associative property
00:54:53.800 | Now if you remember from elementary school or middle school when they teach you about the properties of the addition and the multiplication
00:55:00.460 | You may recall that the associative property means that if you have
00:55:04.120 | The operation done on three operands, for example A multiplied B multiplied C
00:55:09.400 | It does not matter the order in which you do these operations
00:55:12.780 | So it does not matter where you put the parentheses, the result will be the same
00:55:16.520 | So you can do A multiplied by B and then the result multiplied by C
00:55:19.960 | Or you can do A multiplied by the result of B multiplied by C
00:55:23.560 | So this is the associative property
00:55:25.800 | So as long as the operations that we are doing have this property, then we can parallelize the scan operation
00:55:31.960 | I want to show you how to actually do it practically
00:55:36.040 | So imagine the initial array is this we can create multiple threads each one computing
00:55:42.120 | In parallel a sum. So for example, we have a we can have for example this
00:55:47.800 | Okay, this is actually the picture I took from wikipedia and it's made for 16 input array
00:55:54.120 | But imagine we have eight threads
00:55:56.280 | The first thread will compute the summation of the first two elements
00:55:59.400 | the second thread will compute the summation of the third and the fourth the third of the
00:56:04.680 | Fourth the fifth and the sixth etc, etc. And then we use the result of this
00:56:08.760 | summation to compute the next step
00:56:11.640 | And then in parallel and then we use it to compute the next step in parallel and then to compute the next step in parallel
00:56:20.440 | This is called the sweep down if I remember correctly and then we have a sweep up operation to rebuild all the
00:56:27.240 | Values that we didn't compute. So for example, this value here is not computed until the last step
00:56:33.400 | now by
00:56:35.160 | By doing a parallel scan
00:56:36.680 | Basically, we can decrease the time complexity of the scan operation from a sequence or so o of n
00:56:43.080 | It's reduced to o of n divided by t where t is the number of parallel threads that are computing this operation
00:56:50.280 | And in my github repository, I have also put a excel file that shows you how to
00:56:56.440 | How this scan operation is computed step by step. So I actually show you all the intermediate steps. So if you want to
00:57:03.480 | Understand how this works
00:57:05.480 | Now that we know that this parallel scan can be done in parallel, this is actually what the authors do. They also do the
00:57:12.360 | computation of the
00:57:15.080 | Recurrence to calculate the output of the model in parallel to reduce its time complexity
00:57:20.220 | Basically since Mamba cannot be evaluated using a convolution because it's time varying
00:57:25.960 | So it means that the parameters of the model are different for each time step
00:57:30.120 | Our only way of computing the output is to use the recurrent formulation
00:57:33.800 | But thanks to the parallel scan algorithm, this can be parallelized to reduce its time complexity
00:57:38.620 | The authors also indicate some techniques that they have used to make this algorithm faster
00:57:43.480 | The first technique that they show you is the kernel fusion
00:57:47.000 | The second is the parallel scan which I have already shown and then we have the
00:57:50.840 | Circumputation of the activations that we will show later. So let's see all these techniques one by one
00:57:56.280 | But first let's see how the memory hierarchy of the GPU works
00:57:59.960 | The GPU basically it's a very fast calculator. So it's a very very very big
00:58:06.200 | computational unit that can do a lot of operations in parallel
00:58:11.080 | And it has two main memories
00:58:13.880 | The one that you are mostly familiar with, the one that you actually check when you buy a GPU is called the DRAM
00:58:20.600 | So it's the high bandwidth memory and it's in the order of gigabytes
00:58:24.520 | And then the GPU also has a smaller memory, a local memory that is called the SRAM
00:58:29.560 | The difference between the two is that first of all the SRAM is much much much much much smaller
00:58:34.600 | It's in the order of megabytes
00:58:36.680 | And however, this is where the GPU will do the computation
00:58:40.600 | So when the GPU needs to do some matrix multiplication
00:58:42.920 | It will first of all copy the information from the high bandwidth memory to the SRAM
00:58:47.020 | Then the core of the GPU will access the information in the SRAM to do the computation
00:58:53.720 | And then the result will be saved back to this high bandwidth memory
00:58:59.400 | Actually, if we check the data sheet of a GPU, in this case, this is the NVIDIA A100
00:59:04.300 | You will see that the GPU is very fast at computing operations
00:59:09.320 | But the copying of information from the SRAM to the DRAM is not very fast, not as fast as computing operations
00:59:16.940 | So as you can see here, for example
00:59:19.560 | The copy speed of the copying speed is much slower
00:59:25.160 | So this is like two terabytes per second compared to the number of operations that the GPU can do
00:59:30.120 | In this case, it can do 20 tera floating point operations per second of 32 bit
00:59:36.440 | So this parameter here is basically 40 times faster than this parameter here
00:59:42.360 | This also means that when we create an algorithm that runs on the GPU, so ECUDA kernel
00:59:48.740 | Sometimes the kernel may run slowly not because we are doing a lot of operations
00:59:53.940 | But maybe because we are copying a lot of stuff around which results in a slow overall computation
00:59:59.800 | And if this happens, we see that our operation is I/O bound because it's bounded by the I/O speed, by the copying speed
01:00:07.300 | Not by the computation speed of the GPU
01:00:09.460 | Now in the authors, they exploit this different hierarchy of the GPU to make their algorithm
01:00:16.900 | So the selective scan faster. So the main idea is to leverage the properties of modern accelerators
01:00:22.740 | So the GPU to maximize the to materialize the state edge. So the hidden state only in more
01:00:29.380 | Efficient levels of the memory hierarchy. So only in the SRAM, so the smaller memory
01:00:34.420 | Concretely instead of preparing the scan input because they compute the recurrence as a scan operation
01:00:40.340 | Just like I showed you before. So the scan input is what is
01:00:43.700 | batch
01:00:46.340 | the sequence length is
01:00:51.540 | Size, the D model, so the size of the input vector and N
01:00:54.660 | In the GPU high bandwidth memory, so in the DRAM, they load all the parameter of the state space model
01:01:01.860 | So the delta, the A matrix, the B matrix, and the C matrix directly from the highest
01:01:06.100 | Bandwidth memory, so the DRAM, into the fast SRAM. They perform the discretization in the SRAM
01:01:13.140 | The recurrence also, so the scan operation is also done in this SRAM and finally the result of this scan is
01:01:19.700 | computed back to the high bandwidth memory
01:01:22.900 | And they also make use of what is known as a kernel fusion
01:01:28.660 | So what is kernel fusion? When we perform a tensor operation, our deep learning framework
01:01:34.420 | So PyTorch, it will load the tensor. Suppose we are doing a matrix multiplication. It will load the tensor from the fast memory
01:01:42.900 | From the slow memory to the fast memory. So from the DRAM to the SRAM of the GPU
01:01:47.460 | It will perform the operation, for example the matrix multiplication, and then it will save back the result
01:01:52.580 | From the SRAM to the DRAM. So from the SRAM to the high bandwidth memory of the GPU
01:01:57.860 | However, what if we do three operations on the same tensor in sequence?
01:02:03.620 | The deep learning framework will do something like this. So it will load first of all the tensor from the
01:02:10.260 | High bandwidth memory to the SRAM. It will compute the first operation
01:02:13.700 | Which means calling the CUDA kernel associated with the first operation and then save the result back to the high bandwidth memory
01:02:20.580 | Then it will load again the result of the previous computation from the high bandwidth memory into the fast memory
01:02:26.500 | compute the second operation which means
01:02:28.900 | Launching the second CUDA kernel associated with this operation and then it will save the result of this operation back to the high bandwidth memory
01:02:36.100 | then it will
01:02:38.740 | Load the result of the previous computation again from the high bandwidth memory into the fast memory
01:02:43.700 | compute the third operation and then save back the result into the high bandwidth memory. As you can see
01:02:49.620 | the total time is occupied in this case when we have three operations in sequence is
01:02:55.460 | occupied by the copying operations that we are performing from the
01:02:58.900 | high bandwidth memory to the fast memory and from the fast memory back to the high bandwidth memory because we know that
01:03:05.300 | The GPUs are relatively slow at copying
01:03:07.620 | data instead of computing operations
01:03:10.760 | So kernel fusion means that to make a sequence of operations faster
01:03:15.300 | We can fuse all these CUDA kernels
01:03:17.700 | So the three operations that we are doing in sequence into one custom CUDA kernel such that
01:03:23.220 | We don't copy the intermediate results to the high bandwidth memory
01:03:27.460 | But we keep doing these operations in the fast memory until we have done all these three computations
01:03:32.420 | And then only the last result is saved into the high bandwidth memory
01:03:36.180 | This speeds up the overall computation because we don't have the intermediate copy operations
01:03:42.120 | Because they would result in an I/O bound
01:03:44.820 | algorithm
01:03:47.920 | Okay, the last innovation of this selective scan algorithm is the recomputation of the activations, let's see what is it
01:03:55.060 | So when we train a deep learning model, this model gets converted into a computation graph
01:04:00.580 | When we perform a back propagation in order to calculate the gradients at each node of this computation graph
01:04:06.340 | We need to cache the output values of each node that we have done during the forward step
01:04:11.700 | So imagine we have a very simple model like the one I show you here
01:04:15.380 | Let me show with a pointer. So this model basically computes the output using just a linear operation
01:04:21.940 | So x1 is multiplied by this
01:04:23.940 | Parameter w1 plus x2 multiplied by this parameter w2 plus a bias
01:04:30.260 | Suppose we have done the forward process and it has produced at each node its own value
01:04:34.900 | During the back propagation our goal is to calculate the gradient of the loss function with respect to each parameter
01:04:42.840 | Here so with respect to w1 with respect to w2 and with respect to the bias
01:04:48.020 | to compute the
01:04:50.500 | The gradient of the loss function with respect to the w1 for example, I show you the step here
01:04:57.060 | And to compute this gradient we need to also compute the gradient of all the intermediate nodes
01:05:02.900 | And to compute the gradient of the intermediate nodes
01:05:06.100 | We need to have the values of the activations of each node that we had during the forward step
01:05:11.940 | So for example to compute the gradient of the loss function with respect to this node y_pred
01:05:16.900 | Which result in the expression 2 multiplied by y_pred minus 2 multiplied by target. We need to cache the
01:05:24.340 | Value y_pred that we had during the forward step
01:05:28.020 | And these activations actually can occupy a lot of memory in a very big network. And this is why in the
01:05:34.660 | in the paper they talk about
01:05:37.780 | Recomputing them. So since caching the activations and then reusing them during back propagation means that we need to save them
01:05:46.340 | To the high bandwidth memory. So the slow memory and then copy back them from the slow memory during back propagation
01:05:52.900 | It may be faster to just recompute them during back propagation because maybe the recomputation
01:05:57.560 | Speed because the gpu is very fast at computing operations than it is at copying
01:06:02.820 | Maybe just recomputing them is faster than copying them
01:06:06.100 | So this is the reference in the paper in which they describe this technique
01:06:09.780 | So they say finally we must also avoid saving the intermediate states which are necessary for back propagation
01:06:15.380 | So the intermediate states are all the activations of all the nodes of this computation graph
01:06:20.420 | So we carefully apply the classic technique of recomputation to reduce the memory requirements. So the intermediate states are not stored
01:06:27.140 | They are not stored in the high bandwidth memory
01:06:29.780 | But recomputed during the backward pass when the inputs are loaded from the high bandwidth memory to the fast RAM
01:06:35.380 | So basically it's faster to just redo the calculations again instead of copying this information to the high band memory and then
01:06:42.340 | Reloading it from high bandwidth memory to the fast memory
01:06:48.820 | Now let's look at the block that makes up Mamba in the Mamba architecture
01:06:53.860 | So first I introduce what is the Mamba block and then we will show you all the Mamba architecture
01:06:58.360 | So Mamba is built by stacking multiple layers of this Mamba block that we can see here
01:07:03.620 | And this is very similar to the stacked layer of the transformer
01:07:07.140 | So if you remember the transformer model
01:07:08.660 | We have the encoder and the decoder and the encoder side and the decoder sides are made by stacking these blocks
01:07:16.420 | with the self-attention and the feed forward the network
01:07:19.220 | multiple times on top of the other such that the
01:07:22.500 | Output of the one block is sent as input to the next block and the output of the last block is sent to the output
01:07:29.380 | Of the model and this is exactly what we do with Mamba in which we create
01:07:33.620 | First of all, we have our input which will be sent to this block and this block will be repeated many times such that the
01:07:41.060 | Output of this block here will become the input of the next block
01:07:46.420 | At the beginning of this block we have a linear layer two linear layers that convert the size
01:07:52.180 | D model into d inner as least this is how they call it in the code
01:07:56.740 | So d model is the size of the vector of our embedding. So imagine we have an embedding size of 512
01:08:03.620 | This is d model and the d inner can be chosen. You can choose for example double the size of the d model
01:08:10.180 | And this is just a linear projection
01:08:13.460 | Then they have a convolution here on this branch and this is actually
01:08:16.820 | To mix up kind of the tokens with each other because otherwise the state space model they will be running independently
01:08:24.040 | For each dimension, but this convolution makes up all these dimensions
01:08:28.520 | And then we have these two silo activations. We have the state space
01:08:33.060 | Model here that runs the recurrence using the parallel scan algorithm that I have just shown you before
01:08:40.660 | and then we will multiply element wise product of this branch and the output of the state space model then the
01:08:47.140 | Linear, there is another linear layer that will project back the d inner
01:08:50.980 | So the inner dimension of this block to the outer dimension, which is the d model
01:08:55.140 | So we go back to the 512 dimensions of the embedding size initially if we have chosen the model equal to 512
01:09:02.120 | Now let's see the entire architecture of
01:09:07.380 | And I drew this architecture by myself by analyzing the code so I didn't have time to make it very beautiful
01:09:13.380 | But okay, it's very similar to what we do with the transformer. So we have our input it gets converted into embeddings
01:09:20.360 | So it becomes a sequence of tokens each token made up of an embedding of size. Let's say 512
01:09:26.680 | And then we have many blocks like this one after another
01:09:31.780 | We have n of them such that the output of one block is sent as input to the next one
01:09:36.660 | And the mama block that I show you in the previous slide is basically just this one
01:09:41.940 | But they also include a rms norm at the beginning and then a skip connection
01:09:46.900 | You can see here and this is repeated n times
01:09:50.420 | finally, there is an rms norm just like
01:09:53.060 | Lama and just like mistral because we have this rms norm and then we have the linear layer that will project the output embedding
01:10:00.580 | back into our vocabulary and then we have a softmax to choose the
01:10:04.260 | Which will indicate which token from our vocabulary
01:10:08.600 | We need to choose as the next token if we are modeling a language model
01:10:13.380 | and this is the architecture of
01:10:16.420 | of mamba guys
01:10:18.020 | So let's look also at the performance
01:10:20.500 | So as you remember at the initial when we started talking about mamba mamba was introduced to solve the problems of the selecting
01:10:27.300 | selective copying task and the induction task because
01:10:30.260 | we saw that the
01:10:33.060 | State space models were not performing very well in
01:10:35.780 | context aware reasoning
01:10:38.180 | So they wanted to solve this problem with mamba by using they introduced that's why they introduced the selective state space model
01:10:44.740 | with their selective scan algorithm
01:10:47.720 | And we can see that the state space model. So the s4 model. So the structural space space model
01:10:55.780 | Performs quite poorly on this selective copying but mamba performs very well
01:11:00.580 | So it has a 99.8%
01:11:02.900 | Of accuracy and the mamba basically this layer here is called the s6 layer
01:11:09.540 | And the s4 layer is the one described in the previous paper. So structure state space model
01:11:14.340 | While on the induction heads we can see that also mamba is performing very well
01:11:19.780 | So the accuracy of mamba you can see here is always
01:11:23.700 | Nearly 100% actually it's 100% for sequence length that can reach 10 to the power of 6
01:11:29.620 | So very very very very very long sequence length in comparison
01:11:33.380 | For example, the transformer model with the absolute positional encoding or also rotary positional encodings
01:11:38.360 | Start degrading in accuracy when the sequence length reaches a certain size
01:11:43.060 | And so they are very quite good up to a few hundred tokens
01:11:49.860 | But they start degrading as soon as they reach the thousands of tokens, but mamba maintains a very
01:11:55.540 | Consistent performance over even very long sequence length and this is very important for language modeling because
01:12:02.180 | The prompts especially with the retrieval augmented generation, but also with chat applications, etc
01:12:06.980 | They are becoming very long. So we want models that can perform well on very very long sequence
01:12:12.420 | And we can also see here that the model
01:12:17.280 | Performance so the number of operations that we need to do to train a model to reach a certain perplexity
01:12:23.460 | Is very comparable with the transformer
01:12:26.800 | So mamba actually performs as good as the best transformer model that we have now
01:12:32.080 | So the transformer model like lama and mistral. This is the transformer plus plus you can see here
01:12:38.560 | And it performs very similarly to the best model that we have here. So it's a very good
01:12:45.420 | Concurrent to the transform, but as we saw in the previous slide it can scale much better for longer sequences
01:12:50.780 | And this is why it became quite popular recently
01:12:53.660 | Thank you guys for watching my video. I hope you learned a lot in this video
01:12:57.900 | I wanted to make a video that was very descriptive and also very
01:13:02.140 | Technically in detail because I wanted to derive all the formulations of mamba. I just don't like to throw formulas at people
01:13:08.620 | And mamba I think will be a very popular model in the future
01:13:12.460 | Even if I think it has its own limitations, for example
01:13:15.180 | It's still a recurrent neural network because it's still run like a recurrence
01:13:19.120 | And it may have its own limitations for for example
01:13:22.460 | We still don't know how well it performs on massive amounts of data like data that has been used for lama or for
01:13:28.940 | mistral
01:13:30.780 | So but I think people are looking for alternatives for the transformer because the transformer has shown its limitations, especially for
01:13:38.300 | Scaling to very long sequence length which are very much needed for language modeling
01:13:43.260 | But also with recent models for image generation movie generation and audio generation
01:13:47.920 | and so
01:13:50.140 | also, the computational complexity of the
01:13:52.700 | Transformer is massive because the the scaling power is quadratic
01:13:57.840 | so it results in a really high memory consumption and that's why
01:14:02.860 | People normal people cannot even inference a model like mistral on their computer unless they use the model sharding
01:14:09.120 | and very advanced techniques
01:14:11.580 | so I hope that
01:14:14.380 | More research is done in this area. So thank you for watching my video
01:14:18.060 | I hope you like this video and you will subscribe to my channel
01:14:21.900 | Please share this video with your friends and share it on your social media. This is the best way to support me
01:14:27.020 | Thank you and have a nice day