back to indexLoRA: Low-Rank Adaptation of Large Language Models - Explained visually + PyTorch code from scratch
Chapters
0:0 Introduction
0:47 How neural networks work
1:48 How fine tuning works
3:50 LoRA
8:58 Math intuition
10:25 Math explanation
14:5 PyTorch implementation from scratch
00:00:00.000 |
Hello guys, welcome back to my channel. Today we will be exploring a very influential people 00:00:05.600 |
called LORA. LORA stands for Low Rank Adaptation of Large Language Models 00:00:11.360 |
and it's a very influential people. It came out I think two years ago from Microsoft 00:00:17.200 |
and we will see in this video, we will see what is LORA, how does it work and we will also 00:00:22.720 |
implement it in PyTorch from zero without using any external libraries except for Torch of course 00:00:29.280 |
and let's go. So we are in the domain of language models but actually LORA can be applied to any 00:00:37.680 |
kind of model and in fact in my demo that I will also show you later we will apply it to a very 00:00:43.920 |
simple classification task and so before we study LORA we need to understand why we need LORA in 00:00:52.560 |
the first place. So let's review some basics about neural networks. So imagine we have some input 00:00:57.680 |
which could be one number or it could be a vector of numbers and then we have some hidden layer in 00:01:02.400 |
a neural network which is usually represented by a matrix but here I show you the graphical 00:01:06.880 |
representation and then we have another hidden layer and finally we have the output right. 00:01:12.400 |
Usually when we train a network we also have a target and what we do is we compare the output 00:01:19.120 |
and the target to produce a loss and finally we back propagate the loss to each of the weights 00:01:26.240 |
of all the layers. So in this case for example we may have many weights in this layer we will 00:01:33.200 |
have a weights matrix and a bias matrix and each of these weights will be modified by the loss 00:01:40.880 |
function and also here we will have a weight and a bias matrix here. Now what is fine tuning? Fine 00:01:50.320 |
tuning basically means that we have a pre-trained model and we want to fine tune it on some other 00:01:55.840 |
data that the original model may have not seen. For example imagine we work for a company that 00:02:01.680 |
has built its own database so this new database has its own sql language right and we have 00:02:08.560 |
downloaded a pre-trained model let's say gpt that was trained on a lot of programming languages 00:02:14.960 |
but we want to fine tune it on our own sql language so that it can answer so that the 00:02:23.280 |
model can help our users build queries for our database and what we used to do is we train this 00:02:30.400 |
model with this entire model here on new data and we alter all these weights using the new data 00:02:37.520 |
however this creates some problem. The problem with full fine tuning is that we must train the 00:02:43.760 |
full network which first of all is computationally expensive for the average user because you need 00:02:48.960 |
to load all the language in the memory then you need to run back propagation on all the weights 00:02:54.880 |
plus the storage requirements for the checkpoints are expensive because for every checkpoint for 00:03:00.400 |
every epoch usually we save a checkpoint and we save it on the disk plus if we save also the 00:03:08.320 |
optimizer state let's say we are using adam optimizer adam optimizer for each of the weights 00:03:13.520 |
keeps also some statistics to better optimize the models so we are saving a lot of data 00:03:20.480 |
and if we suppose we want to use the same base model but fine-tuned on two different data sets 00:03:28.080 |
so we will have basically two different fine-tuned models if we need to switch between them it's very 00:03:35.680 |
expensive because we need to unload the previous model and then load again all the weights of the 00:03:41.920 |
other fine-tuned model so we need to replace the all the weights metrics of the model however we 00:03:48.720 |
have a better solution to these problems with LoRa. In LoRa there is this difference so we 00:03:54.240 |
start with an input and we have our pre-trained model so we want to fine-tune it right so we have 00:04:00.880 |
our pre-trained model with its weights and we freeze them basically we tell PyTorch to never 00:04:07.600 |
touch these weights just use them as read only never never run back propagation on these weights 00:04:14.400 |
then we create two other matrices one for each of the metrics that we want to train 00:04:22.400 |
so basically in LoRa we don't have to create matrices the matrices b and a for each of the 00:04:32.560 |
layers of the original model we can just do it for some layers and we will see later how 00:04:38.800 |
but in this case suppose we only have one layer and we introduce the matrix b and a so what's the 00:04:45.680 |
difference between this matrix b and a and the original matrix w first of all let's look at the 00:04:51.440 |
dimension the original matrix was d by k suppose d is let's say 1000 and k is equal to 5000 we want 00:05:03.680 |
to create two new matrices that when multiplied together they produce the same dimension so d by 00:05:11.360 |
k so in fact we can see it here d by r when it's multiplied by r by k will produce a new matrix 00:05:18.480 |
that is d by k because the inner dimensions cancel out and we want r to be much smaller than d or k 00:05:28.400 |
we may as well choose r equal to 1 so if we choose r equal to 1 basically we will have a matrix that 00:05:36.240 |
is d by 1 so 1000 by 1 and another matrix that is 1 by 5000 and if we compare the numbers of 00:05:45.040 |
parameters in this matrix in this part in the original matrix w we have the number of parameters 00:05:50.480 |
let's call it p is equal to d multiplied by k which is equal to 5 million numbers in this matrix 00:06:00.640 |
in this case however we have two matrices so if r is 1 we will have one matrix that is 00:06:08.960 |
d by r so 1000 plus 5000 only 6000 numbers in this the combined matrix but with the advantage 00:06:23.360 |
that when we multiply them together we will still produce a matrix of d by k of course you may think 00:06:29.200 |
that this matrix will not capture the same information as the original matrix w because 00:06:35.280 |
it's much smaller right even if they produce the same dimension they actually have the 00:06:39.920 |
the it's a smaller representation of something so it should you lose some information but this is 00:06:48.160 |
the whole idea behind LoRa actually we the whole idea behind LoRa is that the matrix w contains a 00:06:54.960 |
lot of weights a lot of numbers that are actually not meaningful for our purpose they are actually 00:07:01.600 |
not adding any information to the model they are just a combination of the other weights so they 00:07:07.440 |
are kind of redundant so we don't need the whole matrix w we can create a lower representation of 00:07:13.600 |
this w and fine-tune that one so let's continue with our journey of this model let me delete the 00:07:20.640 |
link okay so we create these two matrix b and a what we do is we combine them because we can sum 00:07:29.840 |
them right because they have the same dimension when we multiply b by a it will have the dimension 00:07:35.520 |
uh d by k so we can sum it with the original w we produce the output and then we have our usual 00:07:42.400 |
target we calculate the loss and we only back propagate the loss to the matrix that we want to 00:07:47.680 |
train that is the b and a matrix so we never touch the w matrix so our original model which was the 00:07:57.440 |
pre-trained model is frozen and we never touch its weights we only modify the b and a matrix 00:08:03.360 |
so what are the benefits first of all as we saw before we have less parameters to train and store 00:08:08.960 |
because in the case i showed before we have for example five million parameters when the w matrix 00:08:16.080 |
in the original one and using r equal to five we only have thirty thousand parameters in total so 00:08:21.680 |
less than one percent of the original less parameters also means that we have less storage 00:08:26.320 |
requirements and faster back propagation because we don't need to evaluate the gradient for most 00:08:30.480 |
of the parameters and we can easily switch between two fine-tuned models because for example imagine 00:08:36.160 |
we have two different models one for sql and one for generating javascript code we only need to 00:08:43.120 |
reload these two matrices if we want to switch between them we don't need to reload the w matrix 00:08:48.640 |
because it was never touched so it's still the same as the original pre-trained model 00:08:56.720 |
why does this work so the idea is that and it's written in the paper is that the pre-trained 00:09:07.360 |
model have they saw the intuition is that they have an interesting dimension that is smaller 00:09:13.520 |
than their actual dimension and inspired by this they hypothesize that the updates to the weights 00:09:19.920 |
also have a low intrinsic rank during adaptation and the rank of a matrix basically means it's we 00:09:27.120 |
will see it later with a practical example basically it means imagine we have a matrix 00:09:32.080 |
made of many vectors column vectors and the rank of the matrix is the number of the vectors that 00:09:40.080 |
are linearly independent from each other so you cannot combine linearly any of them to produce 00:09:46.160 |
another one this also indicates kind of how many columns are redundant because they can be obtained 00:09:53.520 |
by linearly combining the other ones and what they what they mean in this paper is that the 00:10:01.360 |
w matrix actually is is a rank deficient it means that it does not have full rank so it has a 00:10:07.440 |
dimension maybe 1000 by 1000 but maybe the actual rank is let's say 10 so actually we can use a 10 00:10:13.680 |
by 10 matrix to capture most of the information and the idea between this rank reduction is used 00:10:20.080 |
in a lot of scenarios also for example in compression algorithms so let's review some 00:10:26.720 |
mathematics of ranking and metric decomposition and then we check the lora implementation in pytorch 00:10:36.160 |
so let's switch here let's go here first so i will show you a very simple example of 00:10:43.200 |
matrix decomposition and how a matrix can be rank deficient and how we can produce a smaller matrix 00:10:49.520 |
that captures most of the information so let's start by importing the very simple libraries 00:10:55.760 |
torch and numpy then i will create a 10 by 10 matrix here that is artificially rank deficient 00:11:03.760 |
so i create it in such a way that it is rank deficient with the rank actual rank of 2 so even 00:11:08.880 |
if this matrix is 10 by 10 we can see that it has 100 numbers we will this the rank of this matrix 00:11:16.400 |
is actually 2 and we can evaluate that using a numpy so we will see that the rank of this matrix 00:11:22.640 |
is actually 2 this means that we can decompose it using an algorithm called svd which means 00:11:30.480 |
singular value decomposition which produces three matrices u s and v that when multiplied together 00:11:38.880 |
they give us w but the dimension of this u s and v can be much smaller based on the rank so 00:11:45.760 |
basically it produces three matrices that if we take only the first r columns of these matrices 00:11:52.800 |
where r indicates the rank of the original matrix they will capture most of the information of the 00:11:58.560 |
original matrix and we can visualize that in a simple way what we do is we calculate the b and 00:12:06.000 |
the a matrix just like in the lora case using this decomposition and we can see that we created the 00:12:12.560 |
lower representation of the w matrix which is originally was 10 by 10 but now we created two 00:12:18.400 |
matrices one b and one a that is 10 by 2 and 2 by 10 and what we do is we take some input let's call 00:12:26.960 |
it x and some bias and it's random we compute the output using the w original matrix which was the 00:12:34.800 |
10 by 10 matrix so we multiply it by x we add the bias and we also compute the output using the b 00:12:42.720 |
and a matrix that is the result of the decomposition so we calculate y prime using b multiplied by a 00:12:50.080 |
just like lora multiplied by x plus bias and we see that the output is the same even if b and a 00:12:59.760 |
actually have much less elements so in this case i renamed it i forgot to change the names this is b 00:13:06.720 |
and a okay b and a and what's okay so what i want to show and this is not a proof because i actually 00:13:18.960 |
created artificially this w matrix and i made it rank deficient artificially i actually took this 00:13:26.000 |
code from somewhere i don't remember where and so the the idea is that we can have a smaller matrix 00:13:34.160 |
that can produce the same output for the same given input but by using much less numbers the 00:13:41.120 |
much less parameters so as you can see the b and a elements combined the number of elements in the b 00:13:47.360 |
matrix and a matrix combined are 40 while in the original matrix we had 100 elements 00:13:53.600 |
and they still produce the same output for the same given input which means that 00:13:58.720 |
b and a captured most of the information the most important information of w now let's go to lora 00:14:05.360 |
so let's implement lora step by step what we will do is we will do a classification task so imagine 00:14:12.240 |
we have a very simple neural network for classifying mnist digits and we want to fine tune 00:14:19.200 |
it on a one specific digit because we see that the performance on one specific digit is not very good 00:14:25.200 |
so we want to fine tune it on only one and we will use lora and show that we when we fine tune with 00:14:31.760 |
lora we are actually modifying a very small number of parameters and we only need to save very small 00:14:37.520 |
number of parameters compared to the pre-trained model let's start so we import the usual libraries 00:14:44.480 |
so torch and matplotlib actually we will not need it and tqdm for visualizing the progress bar 00:14:51.680 |
we make it deterministic so it always returns the same results and we load mnist the data set it's 00:15:01.760 |
already integrated into torch vision so it's not a big deal and we create the loader we create a 00:15:08.960 |
very unoptimized neural network for classifying these digits so basically this is a very big 00:15:15.600 |
network for the task we don't need such a big network but i want to make it specific i made 00:15:20.320 |
it on purpose such big because i want to show the the savings in parameters that we get so i call it 00:15:28.000 |
rich boy net so because daddy got money so i don't care about efficiency right and it's a very simple 00:15:34.000 |
network made of three linear layers and with the rule activation and the final layer is just 00:15:40.080 |
basically the classification of the digit into one of its categories 0 1 or 2 or up to 9 so we create 00:15:48.480 |
this network and we train it on mnist so we run for only one epoch and we train it just simple 00:15:57.120 |
training of mnist for classification and then what we do is we keep a copy of the original weights 00:16:05.600 |
because we will need it later to prove that the laura didn't modify the original weights so the 00:16:12.320 |
weights of the original pretty pre-trained model will not be altered by laura we can also test the 00:16:20.560 |
model the pre-trained model we can test it on and check what is the accuracy so if we test it we can 00:16:27.440 |
see the accuracy is very high but we can see that for the digit number nine the accuracy is not as 00:16:33.280 |
good as the other digits so maybe we want to fine-tune especially on the digit nine okay laura 00:16:39.280 |
actually in the paper was fine-tuned on large language models which i cannot do because i don't 00:16:43.840 |
have the computational resources so that's why i'm using mnist and this very simple example 00:16:49.520 |
anyway so we have one digit that we want to fine-tune better right let's visualize before 00:16:55.760 |
we do any fine-tuning how many parameters we have in this network that we created here 00:17:01.360 |
this network here rich boy net so we have in the layer 1 we have this matrix weights and this bias 00:17:08.880 |
this weights for the layer 2 and this bias this weights matrix for the layer 3 and this bias in 00:17:15.680 |
total we have two million eight hundred seven thousand and ten parameters now let's introduce 00:17:22.400 |
laura so as we saw uh before laura introduces two two matrices called a and b and the um the size of 00:17:34.400 |
these matrices is if the original weights is d by k the b is d by r and a is r by k so i just call it 00:17:49.040 |
features in and features out in the paper it's written that they initialize the b matrix with 00:17:54.960 |
zero and a matrix with random gaussian initialization and this is what i do here as well 00:18:01.520 |
then they also introduce a scale parameter this is from the section 4.1 of the paper 00:18:05.760 |
that basically allows to change the rank without changing the the scale of the 00:18:11.120 |
items and i just use alpha alpha is fixed uh you and because maybe you want to try the same model 00:18:20.560 |
on different ranks so instead of the scale allow us to keep the scale of the numbers the same 00:18:28.080 |
if laura is enabled we want the weights matrix so we will basically we will run laura only on 00:18:35.440 |
the weights matrix not on the bias because also in the paper they don't do it for the 00:18:39.680 |
bias matrix only on the weights so if laura is enabled the weights matrix will be x so the 00:18:48.960 |
original weights plus b multiplied by a just like in the paper multiplied by the scale this is also 00:18:55.600 |
introduced by the paper so basically instead of multiplying the this should be w instead of 00:19:04.000 |
multiplying x by w just like in the original network we multiply it by w plus b multiplied by 00:19:12.080 |
a and this is written in the paper we can see it here let's go down it's written here so instead 00:19:21.200 |
of multiplying x only by w we multiply it by this delta w which is how much the weights have moved 00:19:29.920 |
moved because of the fine tuning which is b by a and this is what we are doing here 00:19:36.800 |
and we add this parametrization to our network so to add this parametrization i'm using a special 00:19:44.560 |
function of pytorch called pytorch parametrization so if you want to have more information how it 00:19:50.000 |
works this is the link but i will briefly introduce it parametrization basically means 00:19:56.000 |
allow us to replace the weights matrix of the linear one layer in this case with this function 00:20:04.720 |
so every time the neural network wants to access the weights layer the weights matrix it will not 00:20:10.880 |
access directly the weights matrix it will access this function and when this function is what is 00:20:17.040 |
basically our lora parametrization so when it will ask for the weights matrix it will call 00:20:23.200 |
this function giving us the original weights and we just alter the original weights by introducing 00:20:29.360 |
the b and a matrix so when it will multiply the the pytorch will keep doing its work so it will 00:20:37.040 |
just multiply the w so the weights by x but actually the weights will be the original weights 00:20:42.560 |
plus the b and a that we combined in this way according to the paper and we can easily enable 00:20:49.120 |
or disable lora in each of the layers by modifying the enabled property we can see it here so if 00:20:56.240 |
it's enabled we will use the b and a matrix if it's disabled we will only use the original weights 00:21:02.240 |
if we enable basically it means that we enable also the fine-tuned weights if we disable it the 00:21:09.200 |
model should behave just like the pre-trained model and we can also visualize the parameters 00:21:15.280 |
added by lora so how many parameters were added well in the original layer 1 2 and 3 we only had 00:21:22.400 |
the weights and the bias now we also have the lora a matrix and the lora b matrix and i chose a rank 00:21:29.840 |
of 1 and this i defined it here rank of 1 and so the the the matrix b is 1000 by 1 because the 00:21:41.840 |
weight matrix is 1000 by 784 so 1000 by 1 multiplied by 1 by 784 gives you the same dimension of the 00:21:51.680 |
weights matrix and we do it for all the layers so in the original model without lora we had 2 00:21:58.240 |
million 807 010 parameters by adding the lora matrices we have 2 million 813 804 parameters 00:22:08.240 |
but the only 6 000 of them so the one introduced by lora will be actually trained all the others 00:22:17.120 |
will not be trained and to do it we freeze the non-lora parameters so we can see here i created 00:22:24.960 |
the code to freeze the parameters so we just set requires grad equal false for them 00:22:30.000 |
and then what we do is we fine-tune the model only on the digit 9 because originally as i show you 00:22:38.960 |
here we want to improve the accuracy of the digit 9 so we don't fine-tune it on any other thing so 00:22:46.880 |
we have a pre-trained model that was trained on all the digits but now we will train it fine-tune 00:22:52.080 |
it only on the digit 9 hoping that it will improve the accuracy of the digit 9 maybe 00:22:56.880 |
decreasing the accuracy of the other digits so let's go back here i train it i fine-tune this 00:23:04.080 |
model only on the digits 9 and i do it for only 100 batches because i don't want to 00:23:10.880 |
alter the model too much so i do it with the training it is very fast and then basically i 00:23:19.680 |
want to show you that the frozen parameter are still unchanged by the fine-tuning so the frozen 00:23:25.840 |
parameters are this one and they are still the same as the original weights that we saved after 00:23:31.680 |
pre-training our model here so here we save the original parameters we actually clone them so they 00:23:37.680 |
don't get altered and we can see that they are still the same and then what we do is we enable 00:23:47.200 |
lora and we see that the weights so when we access the weights pytorch will actually replace the 00:23:53.440 |
weights by the original weights plus b multiplied by a multiplied by the scale according to the 00:24:01.360 |
formula that we have defined here so every time pytorch tries to access the weight matrix it will 00:24:07.040 |
actually run this function and this function will return the original weights plus b multiplied by a 00:24:12.480 |
multiplied by the scale and this is what is happening here if we enable lora if we disable 00:24:18.560 |
lora we are disabling the parameterization so it will just return the original weights 00:24:23.360 |
and why does this happen because here we said that when lora is disabled just return the original 00:24:29.040 |
weights and so what we can do now is that we can enable lora and test the model and we can see that 00:24:38.160 |
now the digit 9 is performing much better but of course we lost some information about the other 00:24:43.760 |
digits and if we disable lora the model will behave exactly the same as the pre-trained model 00:24:50.160 |
so without any fine tuning and we can see these numbers are the same as the pre-trained model here 00:24:58.320 |
so the number zero had a wrong count for 33 the wrong count for the digit 9 was 107 and it's the 00:25:06.080 |
same as this one so when we disable lora the model will behave exactly the same as the pre-trained 00:25:11.680 |
model when we enable lora we introduce the matrix b and a that make the model behave like the fine 00:25:19.040 |
tuned one and the best the best thing about lora is that we didn't alter the original weights and 00:25:26.880 |
the only weights that we altered are the b and a matrix and their dimension is much smaller compared 00:25:33.040 |
to the w matrix so now if we want to save this fine-tuned model we only need to save this 6794 00:25:41.360 |
numbers instead of 2 million etc we can fine-tune many versions of this model and by we can easily 00:25:50.160 |
switch between them just by changing the b and the w matrix in this parameterization we don't need to 00:25:56.000 |
reload again all the w matrix of the original pre-trained model and this is the power of lora 00:26:03.280 |
uh i hope my video was clear because i try to make videos that are theoretical but also practical 00:26:09.040 |
please let me know in the comments if there is something that you want to 00:26:12.080 |
be explained a little better you can use my repository it's pytorch lora on my account 00:26:20.160 |
and you can play with it and you can try to use different sizes of their ranking or you can 00:26:28.480 |
different models it's very easy i suggest you also read the parameterization this 00:26:34.080 |
parameterization function of pytorch because it's very easy to introduce a different kind 00:26:40.000 |
of parameterization and also play with the parameterization of a neural network 00:26:43.680 |
thank you again for listening and i hope you and i hope you enjoyed the video and 00:26:49.360 |
please come back back to my channel for more videos about machine learning and deep learning