back to index

LoRA: Low-Rank Adaptation of Large Language Models - Explained visually + PyTorch code from scratch


Chapters

0:0 Introduction
0:47 How neural networks work
1:48 How fine tuning works
3:50 LoRA
8:58 Math intuition
10:25 Math explanation
14:5 PyTorch implementation from scratch

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello guys, welcome back to my channel. Today we will be exploring a very influential people
00:00:05.600 | called LORA. LORA stands for Low Rank Adaptation of Large Language Models
00:00:11.360 | and it's a very influential people. It came out I think two years ago from Microsoft
00:00:17.200 | and we will see in this video, we will see what is LORA, how does it work and we will also
00:00:22.720 | implement it in PyTorch from zero without using any external libraries except for Torch of course
00:00:29.280 | and let's go. So we are in the domain of language models but actually LORA can be applied to any
00:00:37.680 | kind of model and in fact in my demo that I will also show you later we will apply it to a very
00:00:43.920 | simple classification task and so before we study LORA we need to understand why we need LORA in
00:00:52.560 | the first place. So let's review some basics about neural networks. So imagine we have some input
00:00:57.680 | which could be one number or it could be a vector of numbers and then we have some hidden layer in
00:01:02.400 | a neural network which is usually represented by a matrix but here I show you the graphical
00:01:06.880 | representation and then we have another hidden layer and finally we have the output right.
00:01:12.400 | Usually when we train a network we also have a target and what we do is we compare the output
00:01:19.120 | and the target to produce a loss and finally we back propagate the loss to each of the weights
00:01:26.240 | of all the layers. So in this case for example we may have many weights in this layer we will
00:01:33.200 | have a weights matrix and a bias matrix and each of these weights will be modified by the loss
00:01:40.880 | function and also here we will have a weight and a bias matrix here. Now what is fine tuning? Fine
00:01:50.320 | tuning basically means that we have a pre-trained model and we want to fine tune it on some other
00:01:55.840 | data that the original model may have not seen. For example imagine we work for a company that
00:02:01.680 | has built its own database so this new database has its own sql language right and we have
00:02:08.560 | downloaded a pre-trained model let's say gpt that was trained on a lot of programming languages
00:02:14.960 | but we want to fine tune it on our own sql language so that it can answer so that the
00:02:23.280 | model can help our users build queries for our database and what we used to do is we train this
00:02:30.400 | model with this entire model here on new data and we alter all these weights using the new data
00:02:37.520 | however this creates some problem. The problem with full fine tuning is that we must train the
00:02:43.760 | full network which first of all is computationally expensive for the average user because you need
00:02:48.960 | to load all the language in the memory then you need to run back propagation on all the weights
00:02:54.880 | plus the storage requirements for the checkpoints are expensive because for every checkpoint for
00:03:00.400 | every epoch usually we save a checkpoint and we save it on the disk plus if we save also the
00:03:08.320 | optimizer state let's say we are using adam optimizer adam optimizer for each of the weights
00:03:13.520 | keeps also some statistics to better optimize the models so we are saving a lot of data
00:03:20.480 | and if we suppose we want to use the same base model but fine-tuned on two different data sets
00:03:28.080 | so we will have basically two different fine-tuned models if we need to switch between them it's very
00:03:35.680 | expensive because we need to unload the previous model and then load again all the weights of the
00:03:41.920 | other fine-tuned model so we need to replace the all the weights metrics of the model however we
00:03:48.720 | have a better solution to these problems with LoRa. In LoRa there is this difference so we
00:03:54.240 | start with an input and we have our pre-trained model so we want to fine-tune it right so we have
00:04:00.880 | our pre-trained model with its weights and we freeze them basically we tell PyTorch to never
00:04:07.600 | touch these weights just use them as read only never never run back propagation on these weights
00:04:14.400 | then we create two other matrices one for each of the metrics that we want to train
00:04:22.400 | so basically in LoRa we don't have to create matrices the matrices b and a for each of the
00:04:32.560 | layers of the original model we can just do it for some layers and we will see later how
00:04:38.800 | but in this case suppose we only have one layer and we introduce the matrix b and a so what's the
00:04:45.680 | difference between this matrix b and a and the original matrix w first of all let's look at the
00:04:51.440 | dimension the original matrix was d by k suppose d is let's say 1000 and k is equal to 5000 we want
00:05:03.680 | to create two new matrices that when multiplied together they produce the same dimension so d by
00:05:11.360 | k so in fact we can see it here d by r when it's multiplied by r by k will produce a new matrix
00:05:18.480 | that is d by k because the inner dimensions cancel out and we want r to be much smaller than d or k
00:05:28.400 | we may as well choose r equal to 1 so if we choose r equal to 1 basically we will have a matrix that
00:05:36.240 | is d by 1 so 1000 by 1 and another matrix that is 1 by 5000 and if we compare the numbers of
00:05:45.040 | parameters in this matrix in this part in the original matrix w we have the number of parameters
00:05:50.480 | let's call it p is equal to d multiplied by k which is equal to 5 million numbers in this matrix
00:06:00.640 | in this case however we have two matrices so if r is 1 we will have one matrix that is
00:06:08.960 | d by r so 1000 plus 5000 only 6000 numbers in this the combined matrix but with the advantage
00:06:23.360 | that when we multiply them together we will still produce a matrix of d by k of course you may think
00:06:29.200 | that this matrix will not capture the same information as the original matrix w because
00:06:35.280 | it's much smaller right even if they produce the same dimension they actually have the
00:06:39.920 | the it's a smaller representation of something so it should you lose some information but this is
00:06:48.160 | the whole idea behind LoRa actually we the whole idea behind LoRa is that the matrix w contains a
00:06:54.960 | lot of weights a lot of numbers that are actually not meaningful for our purpose they are actually
00:07:01.600 | not adding any information to the model they are just a combination of the other weights so they
00:07:07.440 | are kind of redundant so we don't need the whole matrix w we can create a lower representation of
00:07:13.600 | this w and fine-tune that one so let's continue with our journey of this model let me delete the
00:07:20.640 | link okay so we create these two matrix b and a what we do is we combine them because we can sum
00:07:29.840 | them right because they have the same dimension when we multiply b by a it will have the dimension
00:07:35.520 | uh d by k so we can sum it with the original w we produce the output and then we have our usual
00:07:42.400 | target we calculate the loss and we only back propagate the loss to the matrix that we want to
00:07:47.680 | train that is the b and a matrix so we never touch the w matrix so our original model which was the
00:07:57.440 | pre-trained model is frozen and we never touch its weights we only modify the b and a matrix
00:08:03.360 | so what are the benefits first of all as we saw before we have less parameters to train and store
00:08:08.960 | because in the case i showed before we have for example five million parameters when the w matrix
00:08:16.080 | in the original one and using r equal to five we only have thirty thousand parameters in total so
00:08:21.680 | less than one percent of the original less parameters also means that we have less storage
00:08:26.320 | requirements and faster back propagation because we don't need to evaluate the gradient for most
00:08:30.480 | of the parameters and we can easily switch between two fine-tuned models because for example imagine
00:08:36.160 | we have two different models one for sql and one for generating javascript code we only need to
00:08:43.120 | reload these two matrices if we want to switch between them we don't need to reload the w matrix
00:08:48.640 | because it was never touched so it's still the same as the original pre-trained model
00:08:56.720 | why does this work so the idea is that and it's written in the paper is that the pre-trained
00:09:07.360 | model have they saw the intuition is that they have an interesting dimension that is smaller
00:09:13.520 | than their actual dimension and inspired by this they hypothesize that the updates to the weights
00:09:19.920 | also have a low intrinsic rank during adaptation and the rank of a matrix basically means it's we
00:09:27.120 | will see it later with a practical example basically it means imagine we have a matrix
00:09:32.080 | made of many vectors column vectors and the rank of the matrix is the number of the vectors that
00:09:40.080 | are linearly independent from each other so you cannot combine linearly any of them to produce
00:09:46.160 | another one this also indicates kind of how many columns are redundant because they can be obtained
00:09:53.520 | by linearly combining the other ones and what they what they mean in this paper is that the
00:10:01.360 | w matrix actually is is a rank deficient it means that it does not have full rank so it has a
00:10:07.440 | dimension maybe 1000 by 1000 but maybe the actual rank is let's say 10 so actually we can use a 10
00:10:13.680 | by 10 matrix to capture most of the information and the idea between this rank reduction is used
00:10:20.080 | in a lot of scenarios also for example in compression algorithms so let's review some
00:10:26.720 | mathematics of ranking and metric decomposition and then we check the lora implementation in pytorch
00:10:36.160 | so let's switch here let's go here first so i will show you a very simple example of
00:10:43.200 | matrix decomposition and how a matrix can be rank deficient and how we can produce a smaller matrix
00:10:49.520 | that captures most of the information so let's start by importing the very simple libraries
00:10:55.760 | torch and numpy then i will create a 10 by 10 matrix here that is artificially rank deficient
00:11:03.760 | so i create it in such a way that it is rank deficient with the rank actual rank of 2 so even
00:11:08.880 | if this matrix is 10 by 10 we can see that it has 100 numbers we will this the rank of this matrix
00:11:16.400 | is actually 2 and we can evaluate that using a numpy so we will see that the rank of this matrix
00:11:22.640 | is actually 2 this means that we can decompose it using an algorithm called svd which means
00:11:30.480 | singular value decomposition which produces three matrices u s and v that when multiplied together
00:11:38.880 | they give us w but the dimension of this u s and v can be much smaller based on the rank so
00:11:45.760 | basically it produces three matrices that if we take only the first r columns of these matrices
00:11:52.800 | where r indicates the rank of the original matrix they will capture most of the information of the
00:11:58.560 | original matrix and we can visualize that in a simple way what we do is we calculate the b and
00:12:06.000 | the a matrix just like in the lora case using this decomposition and we can see that we created the
00:12:12.560 | lower representation of the w matrix which is originally was 10 by 10 but now we created two
00:12:18.400 | matrices one b and one a that is 10 by 2 and 2 by 10 and what we do is we take some input let's call
00:12:26.960 | it x and some bias and it's random we compute the output using the w original matrix which was the
00:12:34.800 | 10 by 10 matrix so we multiply it by x we add the bias and we also compute the output using the b
00:12:42.720 | and a matrix that is the result of the decomposition so we calculate y prime using b multiplied by a
00:12:50.080 | just like lora multiplied by x plus bias and we see that the output is the same even if b and a
00:12:59.760 | actually have much less elements so in this case i renamed it i forgot to change the names this is b
00:13:06.720 | and a okay b and a and what's okay so what i want to show and this is not a proof because i actually
00:13:18.960 | created artificially this w matrix and i made it rank deficient artificially i actually took this
00:13:26.000 | code from somewhere i don't remember where and so the the idea is that we can have a smaller matrix
00:13:34.160 | that can produce the same output for the same given input but by using much less numbers the
00:13:41.120 | much less parameters so as you can see the b and a elements combined the number of elements in the b
00:13:47.360 | matrix and a matrix combined are 40 while in the original matrix we had 100 elements
00:13:53.600 | and they still produce the same output for the same given input which means that
00:13:58.720 | b and a captured most of the information the most important information of w now let's go to lora
00:14:05.360 | so let's implement lora step by step what we will do is we will do a classification task so imagine
00:14:12.240 | we have a very simple neural network for classifying mnist digits and we want to fine tune
00:14:19.200 | it on a one specific digit because we see that the performance on one specific digit is not very good
00:14:25.200 | so we want to fine tune it on only one and we will use lora and show that we when we fine tune with
00:14:31.760 | lora we are actually modifying a very small number of parameters and we only need to save very small
00:14:37.520 | number of parameters compared to the pre-trained model let's start so we import the usual libraries
00:14:44.480 | so torch and matplotlib actually we will not need it and tqdm for visualizing the progress bar
00:14:51.680 | we make it deterministic so it always returns the same results and we load mnist the data set it's
00:15:01.760 | already integrated into torch vision so it's not a big deal and we create the loader we create a
00:15:08.960 | very unoptimized neural network for classifying these digits so basically this is a very big
00:15:15.600 | network for the task we don't need such a big network but i want to make it specific i made
00:15:20.320 | it on purpose such big because i want to show the the savings in parameters that we get so i call it
00:15:28.000 | rich boy net so because daddy got money so i don't care about efficiency right and it's a very simple
00:15:34.000 | network made of three linear layers and with the rule activation and the final layer is just
00:15:40.080 | basically the classification of the digit into one of its categories 0 1 or 2 or up to 9 so we create
00:15:48.480 | this network and we train it on mnist so we run for only one epoch and we train it just simple
00:15:57.120 | training of mnist for classification and then what we do is we keep a copy of the original weights
00:16:05.600 | because we will need it later to prove that the laura didn't modify the original weights so the
00:16:12.320 | weights of the original pretty pre-trained model will not be altered by laura we can also test the
00:16:20.560 | model the pre-trained model we can test it on and check what is the accuracy so if we test it we can
00:16:27.440 | see the accuracy is very high but we can see that for the digit number nine the accuracy is not as
00:16:33.280 | good as the other digits so maybe we want to fine-tune especially on the digit nine okay laura
00:16:39.280 | actually in the paper was fine-tuned on large language models which i cannot do because i don't
00:16:43.840 | have the computational resources so that's why i'm using mnist and this very simple example
00:16:49.520 | anyway so we have one digit that we want to fine-tune better right let's visualize before
00:16:55.760 | we do any fine-tuning how many parameters we have in this network that we created here
00:17:01.360 | this network here rich boy net so we have in the layer 1 we have this matrix weights and this bias
00:17:08.880 | this weights for the layer 2 and this bias this weights matrix for the layer 3 and this bias in
00:17:15.680 | total we have two million eight hundred seven thousand and ten parameters now let's introduce
00:17:22.400 | laura so as we saw uh before laura introduces two two matrices called a and b and the um the size of
00:17:34.400 | these matrices is if the original weights is d by k the b is d by r and a is r by k so i just call it
00:17:49.040 | features in and features out in the paper it's written that they initialize the b matrix with
00:17:54.960 | zero and a matrix with random gaussian initialization and this is what i do here as well
00:18:01.520 | then they also introduce a scale parameter this is from the section 4.1 of the paper
00:18:05.760 | that basically allows to change the rank without changing the the scale of the
00:18:11.120 | items and i just use alpha alpha is fixed uh you and because maybe you want to try the same model
00:18:20.560 | on different ranks so instead of the scale allow us to keep the scale of the numbers the same
00:18:28.080 | if laura is enabled we want the weights matrix so we will basically we will run laura only on
00:18:35.440 | the weights matrix not on the bias because also in the paper they don't do it for the
00:18:39.680 | bias matrix only on the weights so if laura is enabled the weights matrix will be x so the
00:18:48.960 | original weights plus b multiplied by a just like in the paper multiplied by the scale this is also
00:18:55.600 | introduced by the paper so basically instead of multiplying the this should be w instead of
00:19:04.000 | multiplying x by w just like in the original network we multiply it by w plus b multiplied by
00:19:12.080 | a and this is written in the paper we can see it here let's go down it's written here so instead
00:19:21.200 | of multiplying x only by w we multiply it by this delta w which is how much the weights have moved
00:19:29.920 | moved because of the fine tuning which is b by a and this is what we are doing here
00:19:36.800 | and we add this parametrization to our network so to add this parametrization i'm using a special
00:19:44.560 | function of pytorch called pytorch parametrization so if you want to have more information how it
00:19:50.000 | works this is the link but i will briefly introduce it parametrization basically means
00:19:56.000 | allow us to replace the weights matrix of the linear one layer in this case with this function
00:20:04.720 | so every time the neural network wants to access the weights layer the weights matrix it will not
00:20:10.880 | access directly the weights matrix it will access this function and when this function is what is
00:20:17.040 | basically our lora parametrization so when it will ask for the weights matrix it will call
00:20:23.200 | this function giving us the original weights and we just alter the original weights by introducing
00:20:29.360 | the b and a matrix so when it will multiply the the pytorch will keep doing its work so it will
00:20:37.040 | just multiply the w so the weights by x but actually the weights will be the original weights
00:20:42.560 | plus the b and a that we combined in this way according to the paper and we can easily enable
00:20:49.120 | or disable lora in each of the layers by modifying the enabled property we can see it here so if
00:20:56.240 | it's enabled we will use the b and a matrix if it's disabled we will only use the original weights
00:21:02.240 | if we enable basically it means that we enable also the fine-tuned weights if we disable it the
00:21:09.200 | model should behave just like the pre-trained model and we can also visualize the parameters
00:21:15.280 | added by lora so how many parameters were added well in the original layer 1 2 and 3 we only had
00:21:22.400 | the weights and the bias now we also have the lora a matrix and the lora b matrix and i chose a rank
00:21:29.840 | of 1 and this i defined it here rank of 1 and so the the the matrix b is 1000 by 1 because the
00:21:41.840 | weight matrix is 1000 by 784 so 1000 by 1 multiplied by 1 by 784 gives you the same dimension of the
00:21:51.680 | weights matrix and we do it for all the layers so in the original model without lora we had 2
00:21:58.240 | million 807 010 parameters by adding the lora matrices we have 2 million 813 804 parameters
00:22:08.240 | but the only 6 000 of them so the one introduced by lora will be actually trained all the others
00:22:17.120 | will not be trained and to do it we freeze the non-lora parameters so we can see here i created
00:22:24.960 | the code to freeze the parameters so we just set requires grad equal false for them
00:22:30.000 | and then what we do is we fine-tune the model only on the digit 9 because originally as i show you
00:22:38.960 | here we want to improve the accuracy of the digit 9 so we don't fine-tune it on any other thing so
00:22:46.880 | we have a pre-trained model that was trained on all the digits but now we will train it fine-tune
00:22:52.080 | it only on the digit 9 hoping that it will improve the accuracy of the digit 9 maybe
00:22:56.880 | decreasing the accuracy of the other digits so let's go back here i train it i fine-tune this
00:23:04.080 | model only on the digits 9 and i do it for only 100 batches because i don't want to
00:23:10.880 | alter the model too much so i do it with the training it is very fast and then basically i
00:23:19.680 | want to show you that the frozen parameter are still unchanged by the fine-tuning so the frozen
00:23:25.840 | parameters are this one and they are still the same as the original weights that we saved after
00:23:31.680 | pre-training our model here so here we save the original parameters we actually clone them so they
00:23:37.680 | don't get altered and we can see that they are still the same and then what we do is we enable
00:23:47.200 | lora and we see that the weights so when we access the weights pytorch will actually replace the
00:23:53.440 | weights by the original weights plus b multiplied by a multiplied by the scale according to the
00:24:01.360 | formula that we have defined here so every time pytorch tries to access the weight matrix it will
00:24:07.040 | actually run this function and this function will return the original weights plus b multiplied by a
00:24:12.480 | multiplied by the scale and this is what is happening here if we enable lora if we disable
00:24:18.560 | lora we are disabling the parameterization so it will just return the original weights
00:24:23.360 | and why does this happen because here we said that when lora is disabled just return the original
00:24:29.040 | weights and so what we can do now is that we can enable lora and test the model and we can see that
00:24:38.160 | now the digit 9 is performing much better but of course we lost some information about the other
00:24:43.760 | digits and if we disable lora the model will behave exactly the same as the pre-trained model
00:24:50.160 | so without any fine tuning and we can see these numbers are the same as the pre-trained model here
00:24:58.320 | so the number zero had a wrong count for 33 the wrong count for the digit 9 was 107 and it's the
00:25:06.080 | same as this one so when we disable lora the model will behave exactly the same as the pre-trained
00:25:11.680 | model when we enable lora we introduce the matrix b and a that make the model behave like the fine
00:25:19.040 | tuned one and the best the best thing about lora is that we didn't alter the original weights and
00:25:26.880 | the only weights that we altered are the b and a matrix and their dimension is much smaller compared
00:25:33.040 | to the w matrix so now if we want to save this fine-tuned model we only need to save this 6794
00:25:41.360 | numbers instead of 2 million etc we can fine-tune many versions of this model and by we can easily
00:25:50.160 | switch between them just by changing the b and the w matrix in this parameterization we don't need to
00:25:56.000 | reload again all the w matrix of the original pre-trained model and this is the power of lora
00:26:03.280 | uh i hope my video was clear because i try to make videos that are theoretical but also practical
00:26:09.040 | please let me know in the comments if there is something that you want to
00:26:12.080 | be explained a little better you can use my repository it's pytorch lora on my account
00:26:20.160 | and you can play with it and you can try to use different sizes of their ranking or you can
00:26:28.480 | different models it's very easy i suggest you also read the parameterization this
00:26:34.080 | parameterization function of pytorch because it's very easy to introduce a different kind
00:26:40.000 | of parameterization and also play with the parameterization of a neural network
00:26:43.680 | thank you again for listening and i hope you and i hope you enjoyed the video and
00:26:49.360 | please come back back to my channel for more videos about machine learning and deep learning