Back to Index

Mamba and S4 Explained: Architecture, Parallel Scan, Kernel Fusion, Recurrent, Convolution, Math


Chapters

0:0 Introduction
1:46 Sequence modeling
7:12 Differential equations (basics)
11:38 State Space Models
13:53 Discretization
23:8 Recurrent computation
26:32 Convolutional computation
34:18 Skip connection term
35:21 Multidimentional SSM
37:44 The HIPPO theory
43:30 The motivation behind Mamba
46:56 Selective Scan algorithm
51:34 The Scan operation
54:24 Parallel Scan
57:20 Innovations in Selective Scan
58:0 GPU Memory Hierarchy
61:23 Kernel Fusion
61:48 Activations recomputation
66:48 Mamba architecture
70:18 Performance considerations
72:54 Conclusion

Transcript

Hello guys, welcome back to my channel. Today we are going to talk about Mamba So Mamba is a new model for sequence modeling that came out just one month ago in a paper called Mamba linear time sequence modeling With selective state spaces. Let's review the topics of today. In the first part of the video I will be introducing what are sequence models and what kind of sequence modeling we can do The second part of the video I will be talking about state space models But to fully understand the state space models, we need to have a little background on differential equation I of course don't expect you to have this background because in some Bachelor degree or some master degree it is taught but in some most of the cases It's not taught So I will give you the necessary background to understand differential equations and later we will talk about state space models and we will derive the formula for the Magnetization and we will also derive the formula for the convolutional computation and the recurrent computation I will show you what do we mean by the hippo matrix and the importance of the A matrix in state space models In the second the third part of the video we will be talking about Mamba So what was the motivation that led to the to Mamba and what is the innovation of Mamba, which is the selective scan algorithm?

So first of all, what do we mean by scan operation and we will see what that this scan operation can be parallelized With the parallel scan we will see what is kernel fusion the recomputation of the activations and finally We will explore the architecture of Mamba and some performance consideration with respect to the transformer So as a prerequisite, I just hope that you have a basic basics of calculus I think a high school mathematics will be more than enough and You have a basic understanding of the transformer model and neural networks in general.

So let's start our journey The goal of a sequence model is to map an input sequence to an output sequence The input sequence can be a continuous signal in that case We want to map it to an output continuous signal or it can be a discrete input signal and we want to map it into a discrete output signal for example Continuous signal could be audio and a discrete signal could be text for example actually most of the time we work with the discrete signals even with the case of audio because we sample the audio file over time and if we talk about language modeling, we are talking about a discrete input because we have a finite number of tokens and We want to map it to an output sequence of tokens We can choose among many models to do sequence modeling let's review them The first model that comes to mind to do sequence modeling is the recurrent neural network Which is a network in which we have a hidden state and we compute the output as follows So we have for example our input sequence made up of x1 x2 and x3 What we do is the first time the hidden sequence we initialize the hidden state is initialized with zeros So we feed to the network the hidden state the first hidden state So zeros along with the first input and it will produce the first output Then we use the previously produced the hidden state.

So the output of the previous step It will produce a new hidden state and the new output token for a new Input token and this will be the y2 so the output number two So we use the previously generated hidden state Along with a new input token to produce a new output token and a new hidden state that will be used for the next Token as you can see This is sequential generation of output is not parallelizable because to generate the nth token We need the n-1th token So the training of this kind of model cannot be parallelizable.

And this is one of the reason the transformer has been so successful However, the inference time is a constant for each token Which means that the number of the effort from a computational point of view and also from a memory point of view that we need to put in order to produce one token of the output is the same doesn't matter if it's the first token or the 100th token the effort with that we are doing so the number of Operations and computations that we are doing is always the same.

We are taking the previous state and the current input to produce one output token Theoretically, this one has infinite context length because we can continue with this sequence sequence of forever But practically we cannot because it suffers from a vanishing and exploding gradient Another model that we can use to do sequence modeling, but it's not very used recent It's not very used.

It's the convolutional neural network, which is mostly used for computer vision tasks It has a finite context window and it needs to build a kernel that is run So this is the kernel that is run through the input to produce output features. So this is the output It is easily parallelizable because each output uses the same kernel So we can run the kernel in parallel on all the possible input windows to produce the output features The last model that we will see is the transformer The transformer is easily parallelizable when it comes to training because we have this self-attention mechanism that computes many dot products and since it's a matrix multiplication we can parallelize the operation and it has a finite context window defined by the sequence the input sequence and the attention mask and The inference of this kind of model, however is not constant for each token.

So in the transformer model if we are Producing what the first token of the output we will do one dot product By using the KVCache if we are using the we are producing the 10th output token We will need to do 10 dot products to produce it And if you are producing the 100th output token, we will need to do 100 dot products So the effort that we put to produce the first token is not the same as the effort that we need to produce the 10th token and this is not good because This doesn't allow us to scale to very long inputs also for the training because the training as you can see that scales Quadratically with the input sequence length, which means that if we double the sequence length We need to do four times more computation to train this model But it's easily parallelizable So and in an ideal world that we would like the model for which we can Parallelize the training just like the transformer because we can exploit our GPU very well and it can scale linearly to long sequences Just like the RNN because it scales linearly with the input sequence length And in an ideal world, we also would like to inference each token with a constant memory and the computation cost Which means that the effort we need to put to produce the first token should be the same as the effort that we put to Produce the 10th token or the 100th token Now let's explore state space models and how they can help us solve the problems of the recurrent neural networks and the transformer But first we need to review some maths Let me give you a very simple introduction to differential equation with a very simple example Imagine you have some bunnies and the population of these bunnies grows at a constant rate of lambda Proportional to the number of bunnies that you have which means that every bunny will give birth to lambda baby bunnies Now also suppose that the lambda is equal to 2.

So let's write lambda equal to 2 So we can see that the rate of change of the population Which means the number of babies that are born is equal to lambda Multiplied by the number of bunnies that you have at a particular time step t So which means also that the rate of change of this population because the number of babies that these bunnies produce is also the rate Of change of this population So how much the population is growing is equal to lambda multiplied by the number of bunnies that you have But if you remember from high school, what is the rate of change of a particular function?

It is the variable is the derivative of this function so we can also say that the derivative of the function that describes the number of bunnies in time is Equal to lambda multiplied by the number of bunnies that you have at a particular time step t How can we find the population at time step 100 knowing that the population is made up of five bunnies at time equal to 0 We need to find a function B of t So we need to find a function B of t that describes the population of our bunnies over time So the evolution of our population over time Solving a differential equation means to find a function in this case B of t that makes the expression above so this expression here true For all t so we need to find a function t that when replaced in this expression makes the left hand side equal to the right hand side and This differential equation can be solved using a very simple method called the method of separation of variables So I will not show you here, but we can clearly see that a solution to this Differential equation is this function here B of t is equal to K multiplied by Exponential of lambda by t where K is equal to the initial population of bunnies How can we verify that this is actually the solution of this differential equation because if we replace this Function here inside this expression it will make the left hand side equal to the right hand side This is how we verify the solution of it equation So let's try to replace it first of all in the left hand side We have the derivative of this function with respect to t so let's calculate the derivative of this function with respect to t Which is equal to K Multiplied by lambda multiplied by A to the power of lambda t and This is equal to Lambda multiplied by the function itself, which is K multiplied by e to the power of lambda t You can see that these two expressions are the same so this function here So this function here is a solution to our differential equation So when you have a differential equation the solution of a differential equation is not a number When we like when we solve a quadratic equation so when you solve for example the Equation X to the power of 2 is equal to 4 you get some values of X that make the left-hand side Equal to the right-hand side in the case of a differential equation you find a function That makes the left-hand side equal to the right-hand side and usually we represent the differential equation by omitting the variable t And writing it as follows so B t so B dot is equal to lambda Multiplied by B.

This dot indicates that it's derivative of this function with respect to a certain variable, which is implied and basically the result of this differential equation will Will describe the evolution of our bunnies so this function here that we have found B of t Will describe how the population of our bunnies will grow and We usually use the differential equations to model the state of a system over time With the goal of finding a function that gives us the state of the system at any time step Given the initial state of the system, and this is what we do with state space models A state space model allow us to map an input signal X of t to an output signal Y of t By means of a state representation H of t as follows So H prime of t which means the derivative of this function H of t with respect to time is equal to a number A or a matrix A multiplied by H of t plus B Multiplied by X of t and then the output of the system is computed as follows This state space model is linear and time invariant It's linear because the relationships in the expressions above are linear and the time invariant because the parameter matrices A, B, C and D do not vary over time because you can see they do not depend on the time They are always fixed for each time step So for now for simplicity, we will consider that all these parameters So A, B, C and D, but also the input the output and the H of t are numbers not vectors Later, we will expand our analysis to vectors Now you may be wondering Okay, we have these expressions, but how can I compute the output of this model given the input X of t?

As you can see in the first expression, we have a differential equation and to compute the output Like this we need to have a function H of t that describes the state of our system at every time step t So to find the output Y of t of this This model we need to find first of all, we need to solve this differential equation here So it means to find a function H of t that describes the state of the system for all time step but solving this differential equation can be hard analytically and Usually also we do not work with the continuous signals So we do our input X of t usually is not continuous because when we do work with computer or any digital device We always work with discrete systems So one way is to actually discretize this system such that we can calculate an approximate Solution of this H of t in a discretized way.

So not in a continuous way Let's see how we can discretize our system to compute then the output of the system itself To solve a differential equation we need to find the function H of t so in our case this function here So let me change to the pointer so this function a function H of t that when replaced here will make the left-hand side equal to the right-hand side and But solving this finding this function H of t in the continuous domain It can be very hard analytically and also because we work with the discrete input We can discretize this system, which means that we can approximate the solution of this differential equation So instead of finding H of t for all the time step we can find H of t for certain time steps Instead so for example, we may want to find H of t for H of 0, H of 1, H of 2, H of 3, etc In this case, we have chosen a step size the step size in my case I have chosen 1 because I'm evaluating the function H of t only in certain time steps that are separated by a step size of 1 so I go from 0 to 1 from 1 to 2 from 2 to 3 and This is will be our Delta so Remember the bunnies problem.

Let's try to find the approximate solution of this bunny problem using the Euler's method So before I give you the result of this bunnies problem using the analytical solution, which was very easy to compute But let's try to build the same solution But using the Euler's method or actually not the same, but it's the approximated solution using the Euler's method So first of all, let's rewrite the our bunny population model Which is the derivative of the function that describes the bunny population With respect to the variable time is equal to lambda multiplied by the function itself If you remember from high school What is the definition of the derivative the derivative is equal to the limit of a step?

Going to 0 of the function evaluated at t plus the step Minus the function at time step t divided by the step size. This is the definition of derivative. So there is nothing new here now This this is the the left-hand side is equal to the right-hand side when the Delta so the step size that we are using It's very very very very very small.

So suppose that we choose a very very very small step size So then we can get rid of this equality and change it to an approximate approximation so we can see that Okay, what if I choose a very small step size, then the left-hand side will be more or less equal to the right-hand side Now by multiplying with the Delta on both sides and then bringing this term to the right-hand side We can write the following formula here Which allow us to compute the value of the function B at the time step t plus Delta.

So one step forward As the derivative of the function at time step t multiplied by Delta plus the function at time step t But what is the derivative of the function here? We know that the derivative of the function that we are trying to approximate is equal to the Lambda multiplied by the function itself So we can replace here in this Expression here we can replace our population model and we obtain the following Expression and this allow us to compute the bunnies population at time step t plus Delta So at the next time step given the bunnies population at the previous time step Multiplied by a Delta and the Lambda.

So we obtained a recurrent formulation Wonderful now that we have our recurrent formulation to approximate the state of the bunny population over time Suppose we set Lambda equal to 2 which means that every bunny will make two children and Delta is equal to 1 So we want to go forward by one by one step So t is equal to 0 then t equal to 1, t equal to 2, etc, etc For example, if we started with a population of 5 bunnies at time step 0 We can calculate the evolution of the population as follows Knowing that the population at time step 0 we can calculate the population at time step 1 using the formula that we have So B of 1 is equal to Lambda Delta multiplied by Lambda multiplied by the population at the previous time step.

So just like this formula here Plus the population at the previous time step So this gives us a 15 and it makes sense why because we started with 5 bunnies. Each bunny has 2 children So they have 10 babies and the initial population, which is 5 plus 10 babies gives us a population of 15 Now that we have the population at time step 1 we can calculate the population at the time step 2 which is 45 because we had 15 bunnies each one makes 2 children.

So we have 30 children So 15 plus 30 is equal to 45 Now that we have the population at time step 2 we can calculate the population at time step 3 and it's 135 because we have a 45 bunnies each one of them makes 2 children. So we have 90 children plus 45, which is 135 and If you compare the solution that we are obtaining for this time step with the analytical solution that we found before That I didn't show you how to find it, but it can be easily found using a method called the separation of variables If you compare the plot of these two functions You will see that the analytical function grows very fast while the approximated solution grows very slow and this is actually and the two are similar in how they Change but they will not overlap Because the step size that we have chosen is very big I mean we chose lambda equal to delta equal to 1 to make a better approximation We need to use a delta very very small and actually the Euler method does not give us very good Results very very good approximations of the analytical solution and we use a similar reasoning that we used for the bunny population to Discretize the system of the state space model.

So how to discretize that system? we just as you remember we found using the The definition of the derivative we found out that a function evaluated at the time step t plus delta is more or less so approximately equal to Delta multiplied by the derivative of the function at time step t plus the function at time step t But what is the derivative of the function at time step t?

We know from our state space model that the derivative of this function h prime is Equal to a multiplied by h of t plus b multiplied by x of t. So we just replace this expression here Into this term here and we obtain the following derivation So we replace this stuff here with h prime of t We multiply it by delta we collect this term h of t and we obtain the following Discretized the parameters of the discretized model.

So if we set a bar equal to the identity matrix plus delta a and b bar equal to Delta B. We will obtain a Recurrent formula just like with the bunny Problem the bunny population problem that allow us to compute the next state of the model given the previous state and Actually in the paper this So in the paper they show first of all the continuous Formulation of the state space model, which is this one and then they show you the discrete Discretized model that you can see here and they also show you how to obtain this discretized parameter So the parameters of the discretized model, which is this a bar and b bar actually in the paper They do not use the Euler method.

So the one that I used here to derive the discretized version They use a method called the zero-order hold and I believe they use this one because it results in a better approximation But the idea of the discretization is that instead of calculating the analytical solution of this differential equation here We calculate approximately what is the H state of the system at discrete time steps And then we can plug this this approximated State value into this second relationship to get the output of the system But in practice as you saw before we had to choose the delta parameter to discretize the system in practice We do not choose the discretization step delta, but we let make it a parameter of the model that we will train with gradient descent So that the model will learn this parameter delta based on the input that it will see Okay Now that we have our recurrent formulation to calculate the output of the system sequentially How can we use it to calculate the output of the system for various time steps?

Let's do it Suppose for simplicity that the initial state of the system is zero So the first state of the system, so suppose we have an input made up of x0, x1, x2, x3, x4 So it's a discrete input. That's why we are using the discretized state space model we can compute the first state of the system which is equal to b multiplied by x0 because We are multiplying a by the previous state, but the previous state we say it's zero So this term will not be present and by using this state we can calculate the first output of the system Which is just c multiplied by h0 Now we can calculate the next state of the system by using the previous state and the next input Just using the formula here So a multiplied a bar multiplied by the previous state of the system Which we already computed plus b multiplied by the next input of the system Having the next state we can compute the next output which is y1 is equal to c multiplied by h1 Now using the previous state we can compute the next state So h2, h2 is equal to a multiplied by the previous state plus b multiplied by the next input And using h2 we can compute y2.

So the output at the next step And this is exactly what we used to do with the recurrent neural networks So that's why they are called the recurrent neural networks because they have this recurrent formulation to go from the previous step to the next step And as you can see, this is exactly what we did So our initial state is zero And we use our first input and the previous state to calculate the first output This will not only produce the first output, but also the next state of the system Then we use the next state of the system plus the next input to produce the next output and the next state of the system Then we use the next state of the system plus the next output Input to produce the next output and the next state of the system, etc, etc.

Just like we saw before The recurrent formulation that we have just seen is great for inference Because we can compute one token at a time with a constant memory and computation requirement because the effort that we are Making to compute this output here is exactly the same as the effort that we take to Compute this output here.

So the second output token And this makes it suitable for inference during a large language model because in large language model We generate one token at a time using the previous tokens and the prompt However, the recurrent formulation is not good for training because during training we already have all the tokens We do not need to generate the output of the model one token at a time We already have our sequence of input.

We already have a sequence that is the target We want to compute the output of the model in parallel without doing this sequential Computation to calculate the loss and train the model and this is exactly what we do with the transformer But we cannot do it with this recurrent computation Thankfully the state space model also provide a convolutional mode as well, which can be parallelized.

Let's see how it works So using just the recurrent formulation that we have seen before let's expand the output at every step So as you saw before to compute the first state of the model We need to take the initial input multiplied by b because we suppose that the previous state is zero.

So there is no a This term here disappears in the first state We can use to use the first state to compute the first output, which is c multiplied by h zero and this gives us uh We can replace this the h zero with this with with the what what is the with?

expanded expression b multiplied by x zero We can then use h zero to compute h one So h one is equal to a multiplied by h zero plus b multiplied by x one But what is h zero h zero is exactly this stuff here. So b multiplied by x zero so we can compute the output the state of the system at the time step one using This computation here and the output of the system at time steps one using h1 And then we expand the expression of h1 to obtain this expression here We can use the similar reasoning to expand also the representation of the h2 space So that it only depends on the input and the parameters of the model and in the same way Also the computation of the output of the model at time step two using only the input and the parameters of the model As you can see we see a pattern So that to compute the kth token of the output we need to do the following Summation.

So for example to compute y2 as you can see, we are multiplying x2 multiplied by cb So as you can see yk we are multiplying xk multiplied by cb Then the previous token so in here x1. So here xk minus one multiplied by cab cab, etc and then the previous previous token with Ca with an additional with the by increasing the power by which a is elevated multiplied by b and this is exactly what we say here, so the previous token with a An increasing power of a and etc, etc until we reach the kth power of a By using the formula we derived we note something interesting The output of the system can be calculated using a convolution of a kernel k with the input x So this formulation this formula that we derived here, which basically represents the pattern of all the expanded formulas that we saw before Using this this formula here We can build a kernel a kernel just like the kernel that we use in convolutional neural networks With the first term of the kernel being cb the second being cab the third being ca to the power of kb Etc, etc until the last item of the kernel and we can convolve this kernel with the input to compute directly the output And as you remember the convolution is something that we can parallelize So that's why this is very powerful because if we build this kernel We can convolve it on the input to produce the output And I want to prove you that the output computed using the kernel like this is exactly the one we just found By expanding these formulas before so let's do it We build a kernel like this Cb cab etc, etc up to the ca to the power of kb Depending on what is the length of the sequence that we are producing I am just Inverting the kernel.

So the first term is not a cb But it's the last term and the second this one here becomes the second last etc because it's easier to visualize Then we have our input which is x0 x1 x2 x3 But I add some padding later. I will show you why and then we have our output that will be produced by this convolution So let's run this convolution step by step The first output of the convolution is just the kernel Slided over the first four tokens of the input And we multiply this term with this one this term of the kernel with the this term of the input this Term of the kernel with this term of the input this current term of the kernel with this term of the input And then we sum all these results and the sum you can see is equal to y0 x X0 multiplied by cb and this is exactly result from this formula here So as you can see if we only have the first token It will be only cb multiplied by the token itself because we don't have any preceding token Let's go forward then we slide the kernel forward by one step and it will produce the following output So we have this item of the kernel multiplied by x1 This item of the kernel multiplied by x0 and all the other multiplied by 0 so they will not contribute to the sum So y1 is equal to this here So as you can see when y is equal the k is equal to one we have x1 multiplied by cb And then we have the previous term.

So x0 multiplied by cab Then we slide our kernel forward And we can see that This term is multiplied by this this term is multiplied by this and this term is multiplied by this and the last one is multiplied By zero so it doesn't contribute to the sum So y2 is computed like this and as you can see if you compare it with this formula, you will see that yk so when k is equal to 2 it will be cb multiplied by x2 then cab multiplied by x1 and then ca Without a actually ca to the power of 2 multiplied by b multiplied by x0 which is the previous token compared to x minus 1 And etc for the step number 4 and this one you can verify by yourself The best thing about the convolution calculation is that it can be parallelized Because the output of yk does not depend on the previous output.

So what I mean is this product that we are doing here can be done on one thread for example or one core of the GPU and this one can be done also simultaneously because they do not depend on each other So the convolutional calculation can be parallelized but the problem is that The the kernel we need to build the kernel for the convolution and building this kernel can be a little Expensive from a computational point of view, but also from a memory point of view however We can still use the convolutional computation to perform the training Because we can parallelize the training because we already have the input token We already have the target so we can compute the output of the model in parallel for all the input tokens, even if it's expensive to do it computationally And then we can use the recurrent formulation to inference from this model one token at a time And we also know that by just doing this the computation cost for producing one token will always be the same No matter how which token we are producing if it's the first one or the 100th token or the 1000th token And this is different from the transformer in which If we want to generate tokens to produce the first token is less expensive than to produce the 100th token Because we need to do much more dot products when we generate the 100th token So with the KVCache we would do 100 dot products One thing that is worth mentioning is that in the paper they do not use the term d Here when they compute the output and the reason is that this can be thought of as a skip connection Let me show you why So imagine we have our input x which is the input of the system So suppose we have this x this is sent to some black box that we will call state space model That will do this calculation of the state recurrently.

So continuously And it will produce an output y But then we can see that this dx of t basically means that we take our input We skip the state space model. We multiply it with some number d and we send it directly to the output So we do not need to model this d to model the state space Because it does not depend on the state of the system at any time step And this is why it can be represented as a skip connection.

And this is why in the paper they do not mention it The analysis that we have done so far on this state space model as I told you before we will consider We for simplicity we are considering that all the parameters are just the single numbers But usually and the input is just a single number and the output is just a single number But usually when we work with language models, especially our input is not a single number but it's a vector because suppose we have a sequence of tokens each token is represented by a vector of 512 dimensions in the case of the vanilla transformer And maybe 4096 dimensions in the case of llama, for example So how can we work with the state space models, but with a vector input and a vector output?

The idea is that each dimension of the input vector will be managed by an independent state space model So imagine we have an input made up of a vector of 512 tokens 512 dimensions Which needs to produce an output token of 512 dimensions. We will create one state space model for each dimension So one for the dimension zero one for the dimension one dimension two, etc, etc for all the dimensions And all these state space muscles are independent from each other now this idea may look very strange and But it's not so strange if you think about the multi-head attention of the transformer because in the transformer We also have an input made up of vectors each one representing an embedding of a token Suppose we have 512 dimensions for each embedding Then the multi-head attention one basically groups some dimensions Suppose in the vanilla transformer we have eight heads So it means that the 512 is divided by eight which means that every Head is managing 64 dimensions and each head is independent from the others And this is so it's if it works for the transformer.

It can also work for the state space model and actually it does so the parameters a b c d The input and the y t have now become vectors and they will have the sequence dimension and also the D model dimension, which means that we have an input that is made up of suppose 512 dimensions Just like in the transformer model Now in the recurrent formulation of the state space model we have this matrix a which is quite important because The a matrix in the state space model can be intuitively thought of as a matrix that captures the information of the past state So the h t minus one To build the new state the h of t It also decides how this information is copied forward in time So when we produce the output of a state space model at a particular time step for which means for a particular token It depends on the state at that time step, but the state at time step t depends on all the previous states So basically this a matrix tells us how the information of the model is conveyed forward in time So this h of t Basically captures all the history of the input of the model so far to build the new token This means that we need to be very careful about the structure of this a matrix Otherwise, it may not very well capture the history of all the inputs seen so far Which is needed to produce the next token And this is very important for language models because the next token generated by the model should depend on all the previous tokens, which are the prompt because when you Send a prompt to a language model, which is a list of tokens.

The model should attend should Watch all the previous tokens to produce the new one And in the case of the state space model, this is taken care of by the matrix a so the structure of this matrix is very important And if you have some background in controls engineering, you may remember that in a state space model the a matrix Also determines the stability of the system Which means that the output of the system may diverge if the system is not stable But if you don't have a background in controls engineering, it doesn't matter So to make the a matrix behave well, the authors chose to use the hippo theory.

Let's see how it works To give you an intuition on what the hippo theory does. I will use borrow some the Fourier transformation So as you remember the Fourier transformation allow us to decompose a signal So suppose this is our initial signal into a series of sinusoidal functions So this one this one and this one such that when we sum all these functions Each one with its respective amplitude value you can see here They will rebuild the original signal And we do something similar also in the state space model with the hippo theory But instead of using sinusoidal functions, we use the Legendre polynomials With the hippo theory we build the a matrix in such a way that it approximates all the input signals so far So let me show you here.

So imagine this is our input signal and it evolves over time So x of t you can see here The goal of the a matrix is to capture information from this input signal and convey it forward for the next State so it captures the history of all these signals Suppose this is our state if we use the state to rebuild the original signal using the hippo theory We will have an output.

So the rebuild the reconstructed signal that is very good Approximate well approximated for the more recent time steps and not so good Well approximated for the previous time step. So as you can see this approximation here is not very Good, if you compare it with the original signal here, but the more recent is very well reconstructed You can see here and this is what the hippo matrix does it reconstructs build a state space representation that captures well the information of recent tokens and Decays the information of the past token using some reasoning very similar to the exponentially moving average And this is very important for language modeling because in language models We have a prompt made up of tokens and we need to produce the next token So we need to build a hidden state because the hidden state will determine the output of the system So the next token so we need to build a hidden state that captures very well Information about the local context and it we can afford to lose some information about the global context.

So information about Tokens that are very far from the one we are producing And in the in a paper called efficiently modeling long sequences with structured state spaces Which is basically the paper of the S4 model The author says that just initializing the A matrix with the hippo matrix we can see here so this is how to build the hippo matrix, which is basically an n by n matrix in which All the values above the principal diagonal here are zero and other values are computed as follows So for the diagonal they are computed as n plus one which n is the row and the k is the column If we just build the A matrix like this It will result in a very big increase in performance of the model Why?

because the A matrix is the one that is responsible for capturing information about the previous state and carrying it over to the new state So we want to carry the information well in the new state So that the h of t which will become a vector If we have a multi-dimensional state space model Then this h of t captures very well the information of all the previous inputs so that we can produce the next output Okay, now it's time to talk about Mamba and let's talk about the motivation that led to the building of Mamba And in the paper the the authors they describe two tasks on which the vanilla state space model So the state space model that we have described up to now or even the S4 model Which is the structured state space model Which is basically just the state space model with a very rigid structure on the A matrix like the one we saw before So these two models they do not perform well on two specific tasks One is the selective copying and one is the induction heads So let's introduce these tasks the copying task basically means that we have some input tokens.

So this one's here blue orange red and Green and the model has to produce the same outputs but time shifted And this actually can be done by the vanilla state space model Because it can be actually done with a simple convolution and the convolution can learn the time shifting that we are doing However, the selective copying which means that we have the some input tokens, for example, the blue white white orange red and then green cannot and the model needs to produce only the Colored tokens so not the white tokens This one cannot be done by the vanilla state space model because the vanilla spaces model That is not content cannot do content aware reasoning because the parameters of the model are the same for every step So the same for every token, so it will treat each token equally so it cannot say distinguish between The blue token and the white token and just ignore the white one and keep the blue the the colored one and not the the white ones So to give you an intuition of how what this could mean, for example, imagine we are given a twitter Comment on twitter and we want to rewrite this comment by removing all the bad words So all the white tokens you can see here and this one cannot be done by the state space model Because the parameters a b c and d are the same for each input and they cannot do they cannot change for each particular input The second task on which the state space model have a difficulty is the induction head which means that the model needs to recall information from the previous Inputs to build the current input and they show for example this example so for example, every time the model sees the The black token it should output the blue token by recalling what it has seen before But because the model cannot use a content aware reasoning it cannot do perform.

Well this task and This is the motivation that led to the creation of mamba And this is the reference in the paper in which they talk about these two tasks So from the recurrent view the recursive dynamics So the transition in the because of the a and the b matrix cannot let them select the correct information from the current context Because the a and b metrics are the same for each input token and also from the convolutional view It it is known that the model is able to solve the copying task because it can learn the time shift but they have a difficulty with the selective Copying task because the parameters of the model are the same for each input So they do not know how to treat differently a particular input Now let's talk about the innovation of mamba and how it differs from the state space model Let's first review.

What is the algorithm of the state space model? So the state space model basically has an input and it has to produce an output The input is this one. So it's a tensor of b l and d dimension which is a batch dimension sequence dimension and the d model and this is the Exactly the same as the transformer because we have a batch of prompts Each prompt is made up of l tokens Which is our sequence length and each token is made up of a vector of d model dimension Which is our d here and it has to produce an output of the same dimension Just like the transformer We have an a matrix which is a matrix of parameters that indicates how to copy the previous state into the new state And we we model the state as a vector of n tokens.

So this will be a vector of n tokens So the a matrix is d by n which is basically because we have a when we have an input vector made up of d dimensions, we have d State space models independent from each other one for each dimension just like I saw before So we have this parameter matrix is a d by n This b matrix is also d by n.

The c matrix is also d by n And the delta is the step size of the discretization which is learned by the model as I showed before So we don't decide it. We just let the model learn this Step size and because we have a d state space model because our input vector is d dimensional So we have a d number of delta And to discretize we just apply the formula that we saw before so depending which Discretization rule we are using if we use the Euler method or we use the zero order hold in the case of mamba They use the zero order hold method And then we have this discretized parameter So a bar and b bar and we can run the ssm as a recurrence or a convolution depending if we are Training it or inferencing So when we are training it We run it as a convolution using building the kernel like I showed you before and running it on the input like I showed before Or as a recurrence using the formula we can see here.

So this formulation here So we compute the next step the next state using the previous state and then we use each state to build the output In the case of mamba they make the state space model selective which means that the parameters of the model Are changed for each input token.

Let's see. The input is still bld. So we have a sequence of Batch of prompts each prompt is made of l tokens. Each token is made up of d dimensions and the output has the same shape Then we have the a matrix which is d by n The b matrix basically is modified By a linear layer this sb and this sb will project the d dimension into the n dimension.

You can see here This means that basically now the b matrix is not the same for all input tokens because now we have this l dimension here Which means that for every input token of the each batch We will have a different b matrix And the same goes for the c matrix you can see here and the same goes for the delta you can see here For each input token, we will have a different delta And the discretized Matrices have this dimension.

So bld n And we can run the state space model However, because now the model is not a time invariant anymore because the parameters of the model Change for each input. So for each token for each step for each time step We can only run it as a Recurrence so we can only apply this formulation here We cannot compute it with the convolution anymore because the kernel will not be fixed Because before because the parameters of the model are fixed for all the inputs We can just build a kernel and run it for all the inputs but now for every input we should use a different kernel so we cannot compute it as a Convolution we are only forced to compute it as a recurrence And Have you noticed that the authors talk about this scan operation you can see here So what is this scan operation to when evaluating the model as a recurrence?

Let's talk about it So if you have ever done competitive programming you are familiar with the prefix sum array Which is an array calculated sequentially such that the value at each position indicates the sum of all the previous values We can easily compute it with the for loop in linear time So imagine we have this initial array We can calculate the prefix sum as like this the first value is equal to the first value the second value is computed as the Previous value of this array.

So this one plus the current value of the initial array. So Wait, this one is equal to this one plus this one And this one is computed using the previous value plus the current value of the initial array And this one is computed using the previous value plus the current value of the initial array And this one is computed using the previous value plus the current element of the array Such that each item of this prefix sum indicates the sum of all the items of the initial array up to that element.

So the number 32 is the result of 10 plus 7 plus 6 plus 9 The scan operation refers to computing an array like the prefix sum in which each value can be computed using the previously computed value and the current input value And the recurrent formula of the state space model can also be thought as a scan operation in which each state is the sum of the previous state multiplied by an A matrix plus The current input multiplied by the B matrix So if the model input is X0, X1, X2, X3, X4 and X5 We can compute for example H0 using only X0 And then we can compute H1 using the previously computed value plus the current input each one multiplied by the A matrix and the B matrix Then we can compute H2 using H1 multiplied by A plus X2 multiplied by B and H3 can be computed as H2 multiplied by A plus X3 multiplied by B, etc, etc, etc So to generate the input we just multiply it H_k with the C matrix to generate the output token K So if we have this scan output, so if we build a Array like this one, so this array you can see here We can easily compute the output of the model for each time step By just multiplying each of this value with the C matrix So we multiply it by C multiplied by C and this one here multiplied by C to compute Y0 This is Y1 This is Y2, etc, etc, etc Now what if I told you that the scan operation that I have shown you can be parallelized Of course, you will not believe me because the scan operation is one of those operations that naturally looks like a sequence So to compute the current value, I need to have the previous value plus the current input So how can I parallelize an operation like this?

Actually, it can be parallelized As long as the operations that we are doing are associative Means that they benefit from the associative property Now if you remember from elementary school or middle school when they teach you about the properties of the addition and the multiplication You may recall that the associative property means that if you have The operation done on three operands, for example A multiplied B multiplied C It does not matter the order in which you do these operations So it does not matter where you put the parentheses, the result will be the same So you can do A multiplied by B and then the result multiplied by C Or you can do A multiplied by the result of B multiplied by C So this is the associative property So as long as the operations that we are doing have this property, then we can parallelize the scan operation I want to show you how to actually do it practically So imagine the initial array is this we can create multiple threads each one computing In parallel a sum.

So for example, we have a we can have for example this Okay, this is actually the picture I took from wikipedia and it's made for 16 input array But imagine we have eight threads The first thread will compute the summation of the first two elements the second thread will compute the summation of the third and the fourth the third of the Fourth the fifth and the sixth etc, etc.

And then we use the result of this summation to compute the next step And then in parallel and then we use it to compute the next step in parallel and then to compute the next step in parallel This is called the sweep down if I remember correctly and then we have a sweep up operation to rebuild all the Values that we didn't compute.

So for example, this value here is not computed until the last step now by By doing a parallel scan Basically, we can decrease the time complexity of the scan operation from a sequence or so o of n It's reduced to o of n divided by t where t is the number of parallel threads that are computing this operation And in my github repository, I have also put a excel file that shows you how to How this scan operation is computed step by step.

So I actually show you all the intermediate steps. So if you want to Understand how this works Now that we know that this parallel scan can be done in parallel, this is actually what the authors do. They also do the computation of the Recurrence to calculate the output of the model in parallel to reduce its time complexity Basically since Mamba cannot be evaluated using a convolution because it's time varying So it means that the parameters of the model are different for each time step Our only way of computing the output is to use the recurrent formulation But thanks to the parallel scan algorithm, this can be parallelized to reduce its time complexity The authors also indicate some techniques that they have used to make this algorithm faster The first technique that they show you is the kernel fusion The second is the parallel scan which I have already shown and then we have the Circumputation of the activations that we will show later.

So let's see all these techniques one by one But first let's see how the memory hierarchy of the GPU works The GPU basically it's a very fast calculator. So it's a very very very big computational unit that can do a lot of operations in parallel And it has two main memories The one that you are mostly familiar with, the one that you actually check when you buy a GPU is called the DRAM So it's the high bandwidth memory and it's in the order of gigabytes And then the GPU also has a smaller memory, a local memory that is called the SRAM The difference between the two is that first of all the SRAM is much much much much much smaller It's in the order of megabytes And however, this is where the GPU will do the computation So when the GPU needs to do some matrix multiplication It will first of all copy the information from the high bandwidth memory to the SRAM Then the core of the GPU will access the information in the SRAM to do the computation And then the result will be saved back to this high bandwidth memory and Actually, if we check the data sheet of a GPU, in this case, this is the NVIDIA A100 You will see that the GPU is very fast at computing operations But the copying of information from the SRAM to the DRAM is not very fast, not as fast as computing operations So as you can see here, for example The copy speed of the copying speed is much slower So this is like two terabytes per second compared to the number of operations that the GPU can do In this case, it can do 20 tera floating point operations per second of 32 bit So this parameter here is basically 40 times faster than this parameter here This also means that when we create an algorithm that runs on the GPU, so ECUDA kernel Sometimes the kernel may run slowly not because we are doing a lot of operations But maybe because we are copying a lot of stuff around which results in a slow overall computation And if this happens, we see that our operation is I/O bound because it's bounded by the I/O speed, by the copying speed Not by the computation speed of the GPU Now in the authors, they exploit this different hierarchy of the GPU to make their algorithm So the selective scan faster.

So the main idea is to leverage the properties of modern accelerators So the GPU to maximize the to materialize the state edge. So the hidden state only in more Efficient levels of the memory hierarchy. So only in the SRAM, so the smaller memory Concretely instead of preparing the scan input because they compute the recurrence as a scan operation Just like I showed you before.

So the scan input is what is batch the sequence length is The Size, the D model, so the size of the input vector and N In the GPU high bandwidth memory, so in the DRAM, they load all the parameter of the state space model So the delta, the A matrix, the B matrix, and the C matrix directly from the highest Bandwidth memory, so the DRAM, into the fast SRAM.

They perform the discretization in the SRAM The recurrence also, so the scan operation is also done in this SRAM and finally the result of this scan is computed back to the high bandwidth memory And they also make use of what is known as a kernel fusion So what is kernel fusion?

When we perform a tensor operation, our deep learning framework So PyTorch, it will load the tensor. Suppose we are doing a matrix multiplication. It will load the tensor from the fast memory From the slow memory to the fast memory. So from the DRAM to the SRAM of the GPU It will perform the operation, for example the matrix multiplication, and then it will save back the result From the SRAM to the DRAM.

So from the SRAM to the high bandwidth memory of the GPU However, what if we do three operations on the same tensor in sequence? The deep learning framework will do something like this. So it will load first of all the tensor from the High bandwidth memory to the SRAM.

It will compute the first operation Which means calling the CUDA kernel associated with the first operation and then save the result back to the high bandwidth memory Then it will load again the result of the previous computation from the high bandwidth memory into the fast memory compute the second operation which means Launching the second CUDA kernel associated with this operation and then it will save the result of this operation back to the high bandwidth memory then it will Load the result of the previous computation again from the high bandwidth memory into the fast memory compute the third operation and then save back the result into the high bandwidth memory.

As you can see the total time is occupied in this case when we have three operations in sequence is occupied by the copying operations that we are performing from the high bandwidth memory to the fast memory and from the fast memory back to the high bandwidth memory because we know that The GPUs are relatively slow at copying data instead of computing operations So kernel fusion means that to make a sequence of operations faster We can fuse all these CUDA kernels So the three operations that we are doing in sequence into one custom CUDA kernel such that We don't copy the intermediate results to the high bandwidth memory But we keep doing these operations in the fast memory until we have done all these three computations And then only the last result is saved into the high bandwidth memory This speeds up the overall computation because we don't have the intermediate copy operations Because they would result in an I/O bound algorithm Okay, the last innovation of this selective scan algorithm is the recomputation of the activations, let's see what is it So when we train a deep learning model, this model gets converted into a computation graph When we perform a back propagation in order to calculate the gradients at each node of this computation graph We need to cache the output values of each node that we have done during the forward step So imagine we have a very simple model like the one I show you here Let me show with a pointer.

So this model basically computes the output using just a linear operation So x1 is multiplied by this Parameter w1 plus x2 multiplied by this parameter w2 plus a bias Suppose we have done the forward process and it has produced at each node its own value During the back propagation our goal is to calculate the gradient of the loss function with respect to each parameter Here so with respect to w1 with respect to w2 and with respect to the bias to compute the The gradient of the loss function with respect to the w1 for example, I show you the step here And to compute this gradient we need to also compute the gradient of all the intermediate nodes And to compute the gradient of the intermediate nodes We need to have the values of the activations of each node that we had during the forward step So for example to compute the gradient of the loss function with respect to this node y_pred Which result in the expression 2 multiplied by y_pred minus 2 multiplied by target.

We need to cache the Value y_pred that we had during the forward step And these activations actually can occupy a lot of memory in a very big network. And this is why in the in the paper they talk about Recomputing them. So since caching the activations and then reusing them during back propagation means that we need to save them To the high bandwidth memory.

So the slow memory and then copy back them from the slow memory during back propagation It may be faster to just recompute them during back propagation because maybe the recomputation Speed because the gpu is very fast at computing operations than it is at copying Maybe just recomputing them is faster than copying them So this is the reference in the paper in which they describe this technique So they say finally we must also avoid saving the intermediate states which are necessary for back propagation So the intermediate states are all the activations of all the nodes of this computation graph So we carefully apply the classic technique of recomputation to reduce the memory requirements.

So the intermediate states are not stored They are not stored in the high bandwidth memory But recomputed during the backward pass when the inputs are loaded from the high bandwidth memory to the fast RAM So basically it's faster to just redo the calculations again instead of copying this information to the high band memory and then Reloading it from high bandwidth memory to the fast memory Now let's look at the block that makes up Mamba in the Mamba architecture So first I introduce what is the Mamba block and then we will show you all the Mamba architecture So Mamba is built by stacking multiple layers of this Mamba block that we can see here And this is very similar to the stacked layer of the transformer So if you remember the transformer model We have the encoder and the decoder and the encoder side and the decoder sides are made by stacking these blocks with the self-attention and the feed forward the network multiple times on top of the other such that the Output of the one block is sent as input to the next block and the output of the last block is sent to the output Of the model and this is exactly what we do with Mamba in which we create First of all, we have our input which will be sent to this block and this block will be repeated many times such that the Output of this block here will become the input of the next block At the beginning of this block we have a linear layer two linear layers that convert the size D model into d inner as least this is how they call it in the code So d model is the size of the vector of our embedding.

So imagine we have an embedding size of 512 This is d model and the d inner can be chosen. You can choose for example double the size of the d model And this is just a linear projection Then they have a convolution here on this branch and this is actually To mix up kind of the tokens with each other because otherwise the state space model they will be running independently For each dimension, but this convolution makes up all these dimensions And then we have these two silo activations.

We have the state space Model here that runs the recurrence using the parallel scan algorithm that I have just shown you before and then we will multiply element wise product of this branch and the output of the state space model then the Linear, there is another linear layer that will project back the d inner So the inner dimension of this block to the outer dimension, which is the d model So we go back to the 512 dimensions of the embedding size initially if we have chosen the model equal to 512 Now let's see the entire architecture of And I drew this architecture by myself by analyzing the code so I didn't have time to make it very beautiful But okay, it's very similar to what we do with the transformer.

So we have our input it gets converted into embeddings So it becomes a sequence of tokens each token made up of an embedding of size. Let's say 512 And then we have many blocks like this one after another We have n of them such that the output of one block is sent as input to the next one And the mama block that I show you in the previous slide is basically just this one But they also include a rms norm at the beginning and then a skip connection You can see here and this is repeated n times finally, there is an rms norm just like Lama and just like mistral because we have this rms norm and then we have the linear layer that will project the output embedding back into our vocabulary and then we have a softmax to choose the Which will indicate which token from our vocabulary We need to choose as the next token if we are modeling a language model and this is the architecture of of mamba guys So let's look also at the performance So as you remember at the initial when we started talking about mamba mamba was introduced to solve the problems of the selecting selective copying task and the induction task because we saw that the State space models were not performing very well in context aware reasoning So they wanted to solve this problem with mamba by using they introduced that's why they introduced the selective state space model with their selective scan algorithm And we can see that the state space model.

So the s4 model. So the structural space space model Performs quite poorly on this selective copying but mamba performs very well So it has a 99.8% Of accuracy and the mamba basically this layer here is called the s6 layer And the s4 layer is the one described in the previous paper.

So structure state space model While on the induction heads we can see that also mamba is performing very well So the accuracy of mamba you can see here is always Nearly 100% actually it's 100% for sequence length that can reach 10 to the power of 6 So very very very very very long sequence length in comparison For example, the transformer model with the absolute positional encoding or also rotary positional encodings Start degrading in accuracy when the sequence length reaches a certain size And so they are very quite good up to a few hundred tokens But they start degrading as soon as they reach the thousands of tokens, but mamba maintains a very Consistent performance over even very long sequence length and this is very important for language modeling because The prompts especially with the retrieval augmented generation, but also with chat applications, etc They are becoming very long.

So we want models that can perform well on very very long sequence And we can also see here that the model Performance so the number of operations that we need to do to train a model to reach a certain perplexity Is very comparable with the transformer So mamba actually performs as good as the best transformer model that we have now So the transformer model like lama and mistral.

This is the transformer plus plus you can see here And it performs very similarly to the best model that we have here. So it's a very good Concurrent to the transform, but as we saw in the previous slide it can scale much better for longer sequences And this is why it became quite popular recently Thank you guys for watching my video.

I hope you learned a lot in this video I wanted to make a video that was very descriptive and also very Technically in detail because I wanted to derive all the formulations of mamba. I just don't like to throw formulas at people And mamba I think will be a very popular model in the future Even if I think it has its own limitations, for example It's still a recurrent neural network because it's still run like a recurrence And it may have its own limitations for for example We still don't know how well it performs on massive amounts of data like data that has been used for lama or for mistral So but I think people are looking for alternatives for the transformer because the transformer has shown its limitations, especially for Scaling to very long sequence length which are very much needed for language modeling But also with recent models for image generation movie generation and audio generation and so also, the computational complexity of the Transformer is massive because the the scaling power is quadratic so it results in a really high memory consumption and that's why People normal people cannot even inference a model like mistral on their computer unless they use the model sharding and very advanced techniques so I hope that More research is done in this area.

So thank you for watching my video I hope you like this video and you will subscribe to my channel Please share this video with your friends and share it on your social media. This is the best way to support me Thank you and have a nice day