back to index

Fine-tune High Performance Sentence Transformers (with Multiple Negatives Ranking)


Chapters

0:0 Intro
1:2 NLI Training Data
2:56 Preprocessing
10:11 SBERT Finetuning Visuals
14:14 MNR Loss Visual
16:37 MNR in PyTorch
23:4 MNR in Sentence Transformers
34:20 Results
36:14 Outro

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi, welcome to the video.
00:00:01.760 | We're going to be having a look today
00:00:03.560 | at something called a multiple negatives ranking
00:00:07.000 | loss, or MNL loss, for training sentence transformers.
00:00:11.880 | Now, if you're new to sentence transformers,
00:00:15.080 | they're essentially NLP models using
00:00:18.440 | transformers that have been trained specifically
00:00:20.480 | for building dense vector representations of text.
00:00:26.040 | And the highest performing sentence transformers
00:00:30.040 | at the moment are all trained using this method.
00:00:32.600 | So they're trained on a natural language inference data set,
00:00:36.600 | and they are trained with MNL loss.
00:00:40.640 | So in this video, we are going to learn about MNL loss,
00:00:44.560 | how it works.
00:00:46.120 | We're going to work through a PyTorch example very quickly,
00:00:48.720 | and then we're going to look at how we can implement it
00:00:51.080 | ourselves using the sentence transformers library.
00:00:54.400 | So let's jump straight into it.
00:00:56.240 | So if you saw our previous article and video,
00:01:07.440 | you'll remember that we had a NLI data set, which
00:01:12.080 | is built from the Stanford NLI data set and the multi-genre NLI
00:01:17.360 | data set.
00:01:17.880 | We're using the same data set here.
00:01:20.280 | So in essence, all it is is a load of sentence pairs,
00:01:25.000 | and there is a label 0, 1, or 2, which
00:01:28.680 | indicates whether those sentence pairs either infer each other,
00:01:34.240 | are not really related, like they could both be true,
00:01:38.160 | but not necessarily because of each other,
00:01:41.160 | or if they contradict each other.
00:01:43.800 | So there are the three labels.
00:01:45.640 | And what we covered in the previous video and article
00:01:49.440 | was something called softmax loss.
00:01:51.080 | With softmax loss, we use those labels, the 0, 1, or 2,
00:01:57.400 | to produce a classification.
00:02:00.760 | We optimize on that label.
00:02:04.760 | Now, with MNR loss, we don't actually use those labels.
00:02:09.280 | We just use the sentence pairs.
00:02:10.880 | And what I'm going to show you is
00:02:14.160 | where we use something called a Siamese network, which
00:02:17.680 | means we only have an anchor and a positive sentence.
00:02:22.400 | Now, what that means is an anchor is a--
00:02:27.320 | you can think of it as like a base sentence.
00:02:29.720 | And a positive to that anchor sentence
00:02:34.200 | is just a sentence that indicates
00:02:37.520 | that the anchor is true.
00:02:40.440 | Now, we could also have negatives.
00:02:42.320 | And a negative would indicate that that anchor is not true.
00:02:46.840 | And we can extract that from the NMI data set.
00:02:49.960 | So let's start pre-processing our data
00:02:54.080 | and have a look at what we need to do.
00:02:57.480 | So this code here, we already wrote it
00:03:01.160 | in the previous video and article.
00:03:03.000 | But we're going to go through it.
00:03:04.400 | So if you've never seen any of this before, it's fine.
00:03:06.960 | We're going to go through it.
00:03:08.120 | I'm going to explain everything.
00:03:09.460 | It's not a problem.
00:03:10.760 | So basically, up here, all I'm saying
00:03:13.680 | is what I just told you.
00:03:15.520 | So these are our labels in our data.
00:03:18.280 | We have this 0, entailment, 1, neutral, and 2, contradiction.
00:03:23.760 | And then I'm just saying, if M and R,
00:03:25.520 | we don't actually need those labels.
00:03:27.760 | All we need are anchor-positive pairs,
00:03:30.040 | e.g. a pair where the premise suggests the hypothesis,
00:03:34.800 | so where the positive indicates the anchor or the other way
00:03:42.160 | around.
00:03:45.000 | So essentially, what we need are just rows
00:03:47.960 | where we have the label 0.
00:03:50.400 | So we're going to do that.
00:03:51.600 | But first, we need to actually get our data.
00:03:53.440 | So we're using the HugInface dataset library here,
00:03:57.000 | which is very good.
00:03:59.000 | And we're getting the Stanford Natural Language Inference
00:04:02.640 | dataset here.
00:04:04.200 | Now, down here, this format you can see
00:04:08.640 | is the format of the dataset.
00:04:10.840 | So we have these three features, premise, hypothesis,
00:04:13.960 | and label.
00:04:15.160 | And if we come down here, we can see what one of those
00:04:17.440 | looks like.
00:04:17.940 | So we have premise, a person on a horse
00:04:20.640 | jumped over a broken-down airplane,
00:04:22.680 | and this hypothesis, a person is training
00:04:24.840 | his horse for a competition.
00:04:26.480 | And the label for that is 1.
00:04:28.440 | Come up here, 1 means neutral.
00:04:29.880 | So basically, this here, a person
00:04:33.760 | is training his horse for a competition,
00:04:35.440 | does not necessarily mean the person on a horse
00:04:38.200 | is jumping over a broken-down airplane.
00:04:41.480 | And then we come down here, and I
00:04:43.120 | think this one's entailment.
00:04:45.160 | So this one is a pair that we want.
00:04:47.240 | This is an anchor positive pair.
00:04:50.280 | And we have a person's outdoors on a horse,
00:04:53.200 | and a person on a horse jumps over a broken-down airplane.
00:04:55.880 | If they're jumping over a broken-down airplane,
00:04:57.640 | they're probably outside.
00:04:58.680 | In fact, they-- well, almost definitely outside.
00:05:02.440 | So this indicates it.
00:05:06.560 | So this is an entailment, means it's an anchor positive pair.
00:05:11.280 | So I said before, we have two data sets, not just one.
00:05:16.320 | So we have the Stanford Natural Language Inference Dataset,
00:05:20.360 | and we also have the Multi-Genre Natural Language
00:05:23.400 | Inference Dataset.
00:05:24.160 | And that's what we're getting here, MLI.
00:05:26.800 | So I'm just loading that with the Load Dataset here.
00:05:30.800 | We're loading it from the Glue Dataset.
00:05:33.480 | Within Glue, there is this MLI, which is the data that we want.
00:05:37.440 | And we're just getting the training data from that.
00:05:39.820 | We don't-- because I think there's also--
00:05:42.040 | maybe there's validation data, and there's also test data.
00:05:45.080 | We don't want that.
00:05:46.080 | We just want the training data.
00:05:48.520 | So we can see the format for this data set,
00:05:53.120 | almost exactly the same.
00:05:54.760 | We just have this extra ID.
00:05:59.000 | So what we can do is we just remove those columns.
00:06:02.080 | So we MLI.RemoveColumns, and we specify
00:06:05.800 | that we don't want the ID column,
00:06:07.640 | because we're going to merge these two data sets.
00:06:09.680 | But to merge them, they both need
00:06:11.700 | to have the exact same format.
00:06:13.020 | So after that, what we do is we perform this Cast function.
00:06:20.660 | Now, if we didn't perform this Cast function--
00:06:23.940 | let me show you.
00:06:24.620 | So I need to run these anyway, so start running them now.
00:06:29.420 | Come down here.
00:06:31.380 | I'm going to Load Dataset, Remove Columns.
00:06:34.700 | And OK, let's say I'm not going to do this Cast.
00:06:38.180 | I'm just going to concatenate the data sets.
00:06:40.060 | I'm going to merge those two data sets together.
00:06:43.500 | I do that, and we get this arrow, InvalidError.
00:06:47.440 | And the reason for that is because although those two
00:06:50.220 | data sets look very similar, they're not actually
00:06:53.180 | exactly the same.
00:06:54.780 | The data sets used in one of them
00:06:57.740 | includes this NotNull specification,
00:07:01.700 | whereas the other one does not include that NotNull
00:07:04.500 | specification.
00:07:05.180 | So I assume that means that this other data set
00:07:09.020 | can include null values, whereas this one cannot.
00:07:12.020 | So we can't merge them both because they're not
00:07:15.460 | the same data type.
00:07:17.180 | So what we do, we come up here, and we have to do this Cast.
00:07:20.340 | So we're Casting the features of the SNLY dataset
00:07:25.260 | to be the same as the MNLY dataset features.
00:07:28.220 | So we do that, run it.
00:07:29.900 | We come down here, we can run it again, and it will work.
00:07:35.740 | And now in Dataset, we have the full data sets.
00:07:40.940 | That's both MNLY and SNLY.
00:07:44.660 | And you can see that because we have 943,000 rows.
00:07:49.940 | If you come up here, we only have 392 in the MNLY dataset.
00:07:54.260 | And up here, we have 550 in the SNLY dataset.
00:07:59.220 | So we have all the data.
00:08:03.780 | And now in the previous video, an article I mentioned,
00:08:09.900 | we have these negative 1 labels.
00:08:14.940 | Now, this is an error in the data.
00:08:18.620 | Or it's not an error, but it's where whoever labeled the data,
00:08:23.180 | they couldn't decide whether this was--
00:08:26.980 | they couldn't decide on the nature of the relationship
00:08:30.060 | between the pair of sentences.
00:08:32.100 | So they just put minus 1.
00:08:33.380 | There's not very few.
00:08:34.500 | I think it's 700 or so sentences in there or pairs in there
00:08:39.740 | that are labeled with this minus 1.
00:08:41.620 | But that's not a label.
00:08:43.980 | We can't do anything with that in our data.
00:08:48.060 | Now, what we do is we use this filter method to remove.
00:08:53.500 | So we say false for the row.
00:08:55.700 | So the row is false if the label is
00:08:58.500 | equal to minus 1, which is saying that row is false,
00:09:02.020 | e.g. we do not keep it if its label value is equal to minus 1.
00:09:07.620 | Otherwise, we do keep it.
00:09:09.940 | Now, things are a bit different now
00:09:12.100 | because we actually only want anchor positive pairs, which
00:09:17.220 | means we want to remove--
00:09:19.660 | or we only want to keep the rows which have a 0 label.
00:09:24.820 | We want to remove everything else.
00:09:27.220 | So we need to modify this to just keep the 0 values.
00:09:35.060 | So we'll say, OK, false if x label is not equal to 0.
00:09:44.660 | Else it's true.
00:09:45.660 | So this is going to keep only 0 values
00:09:48.420 | and remove everything else.
00:09:50.220 | So that removes those error-less rows
00:09:52.980 | and also removes the neutral and contradiction rows as well.
00:09:57.780 | Now, we remove that and let that run.
00:10:05.020 | Now, while that's running, let me
00:10:07.020 | show you some visuals of how this will work.
00:10:11.900 | So what we have here is we have that anchor positive pair
00:10:18.580 | from before.
00:10:19.420 | So here, our anchor would be, I think, from i's.
00:10:25.580 | And the positive would be our hypothesis.
00:10:29.540 | But it would obviously only be rows
00:10:31.300 | where the label is equal to 0, which
00:10:38.380 | is the entailment label.
00:10:41.700 | Now, we have our anchor and positive.
00:10:44.980 | And as we usually would with a transform model,
00:10:48.740 | we tokenize them.
00:10:50.300 | So we tokenize them.
00:10:52.740 | We're going to do that with just a tokenizer method,
00:10:55.700 | a pre-trained tokenizer from the base transformers library.
00:11:01.100 | Or we do that if we're using PyTorch.
00:11:04.780 | If we're using the sentence transformers library,
00:11:07.700 | it's a lot easier, and we don't actually need to do that.
00:11:10.060 | It will deal with that for us, which is quite nice.
00:11:13.060 | And what that produces is just a tokenized version
00:11:19.700 | of the anchor and the positive.
00:11:22.140 | So here, we have our--
00:11:23.700 | so I'm just going to put A for anchor, and over here,
00:11:27.340 | P for positive.
00:11:29.140 | And then what happens next is we have a single BERT model.
00:11:36.140 | We actually visualize it as two BERT models,
00:11:39.100 | because we're processing the anchor.
00:11:42.020 | And then after we've processed the anchor data,
00:11:45.420 | we move on to processing the positive data.
00:11:49.060 | So it's like we're using-- for every single training step,
00:11:51.940 | we're using the same BERT model twice.
00:11:55.380 | So we process both of those through our BERT model,
00:11:59.380 | and that produces token embeddings.
00:12:01.460 | So token embeddings are 512 dense vectors,
00:12:06.900 | which contains 168 dimensions.
00:12:10.060 | And then we use something called a mean pooling function.
00:12:12.940 | So mean pooling function is, say we have some vectors here--
00:12:19.420 | 1, 2, and 3.
00:12:23.300 | What we're going to do is take the mean value
00:12:30.220 | across each dimension.
00:12:32.180 | So let's say we have three dimensions here in our--
00:12:36.260 | no, that's a bad idea, because we have three vectors.
00:12:38.980 | Let's say we have five dimensions here.
00:12:43.780 | And we take the average across each of those dimensions.
00:12:46.660 | And what we produce from that mean pooling operation
00:12:51.500 | is a single five-dimensional vector.
00:12:54.220 | So we have the mean of each of those--
00:13:00.940 | across each of those dimensions.
00:13:03.060 | That's the mean pooling operation.
00:13:05.380 | And obviously, from that, we produce
00:13:07.300 | what is our sentence embedding, which
00:13:08.980 | is a single one-dimensional dense vector.
00:13:13.700 | And we produce that sentence embedding both for A, our anchor,
00:13:17.940 | and for P, our positive.
00:13:20.260 | And what we have here--
00:13:27.140 | so I don't know if I mentioned it, but we--
00:13:30.860 | so this is a Siamese network, this double network.
00:13:35.860 | And this triple network is a triplet network.
00:13:39.900 | And it works in the exact same way,
00:13:41.860 | but we also have a negative.
00:13:44.660 | So where before we had the contradiction label,
00:13:50.740 | I think the label is 2.
00:13:52.500 | But I could be wrong.
00:13:53.900 | I think it-- no, I think it is 2.
00:13:57.340 | That would be a negative sentence
00:14:01.260 | because it contradicts the anchor.
00:14:05.060 | And what we do is we process that as well.
00:14:08.060 | And then we would also get a negative vector--
00:14:11.900 | negative sentence embedding at the end there.
00:14:14.460 | And what we do with that during training--
00:14:17.300 | so the whole MNR loss thing, multiple negative ranking
00:14:21.860 | thing--
00:14:23.500 | is we take all of those vectors that we produce,
00:14:27.540 | the A, which is the anchors, the P, and the N
00:14:32.260 | for the positive and negatives, if we're using negatives.
00:14:36.260 | If not, we just basically blank out the dark blue part
00:14:40.980 | of this visual.
00:14:43.340 | All we do is we calculate the cosine similarity
00:14:48.100 | between our anchor and our positive,
00:14:51.260 | and maybe our neutral negative as well.
00:14:53.820 | But I'm just going to say I'm going
00:14:55.540 | to go through the anchor and positive version of it for now.
00:15:00.220 | So we calculate the cosine similarity
00:15:02.340 | between the anchor and positive.
00:15:04.540 | And we do that for every anchor.
00:15:07.980 | So anchor 0 for the first one, we
00:15:11.620 | would calculate the cosine similarity
00:15:13.180 | between anchor 0 and positive 0, 1, 2, 3, and 4, and so on,
00:15:20.180 | up until the batch size.
00:15:22.260 | So what we have there is the actual--
00:15:24.460 | so the number of sentence embeddings that we have there
00:15:29.020 | is equal to the batch size that we are using.
00:15:32.380 | And obviously, what we would expect
00:15:34.700 | with our anchor and positive pair
00:15:36.660 | is that we would expect the cosine similarity between A0
00:15:40.780 | and P0 to be greater than the cosine similarity between,
00:15:45.100 | let's say, A0 and P1, or P2, or P3, or P4.
00:15:50.740 | And likewise, if we had A3, we would
00:15:54.500 | expect the cosine similarity between A3 and P3
00:15:58.420 | to be greater than that between A3 and P0, or P1, or P2, or P4.
00:16:08.060 | So that's how we optimize.
00:16:09.620 | We say, OK, the target label for A0
00:16:14.900 | is just going to have the greatest argmax value
00:16:19.740 | with the value of P0, or with the pair P0.
00:16:24.340 | And for A3, it's going to be with P3.
00:16:28.540 | So the labels for this are actually just 0, 1, 2, 3, 4,
00:16:33.100 | up until the batch size, which we will see.
00:16:35.980 | I'm going to show you that.
00:16:38.620 | OK, so here we have what is our involved MNR loss training
00:16:46.220 | notebook.
00:16:47.180 | So I'm just going to take you through this.
00:16:48.980 | We're just using PyTorch rather than the Sentence Transformers
00:16:51.740 | library here.
00:16:53.020 | Now, I'm going to go through, and we're
00:16:55.740 | going to come to where we actually start training.
00:17:01.940 | So we have a mean pooling operation here.
00:17:05.660 | So I mentioned before we had that mean pooling function
00:17:08.140 | where we're getting the average across dimensions.
00:17:10.380 | That's what this is dealing with.
00:17:11.940 | The reason it looks more complicated
00:17:13.380 | than just taking the average is because we
00:17:15.100 | need to consider where we have padding values, where
00:17:19.140 | the mask value is 0, because we don't
00:17:24.340 | want to consider those in our average function,
00:17:26.700 | because then it's going to obviously bring down
00:17:28.780 | our average a lot just for having more padding tokens,
00:17:32.620 | which doesn't make sense, obviously.
00:17:34.820 | So that's why it looks more complicated than you
00:17:37.860 | probably expect for an averaging function.
00:17:42.140 | And then so we're using PyTorch here.
00:17:44.300 | So we're moving some set model.
00:17:46.620 | It's just a BERT model that we're using here,
00:17:49.020 | plain, straight BERT model, nothing special about it.
00:17:53.580 | I think it's BERT based on case.
00:17:56.700 | And just moving it to a CUDA GPU if we have one.
00:18:01.020 | So it is available, checking if it's available.
00:18:04.940 | And then here we're defining some layers
00:18:07.740 | that we're going to be using in MNR loss.
00:18:10.740 | So we have the cosine similarity.
00:18:12.420 | I said before we're doing that.
00:18:14.020 | We're checking and calculating similarity between pairs.
00:18:17.420 | That's what we're doing here, just initializing that function.
00:18:21.020 | And we also have a loss function, of course.
00:18:23.220 | So here we're using a categorical cross-entropy loss.
00:18:27.540 | And we'll see how that works.
00:18:29.820 | We can also use a scale function.
00:18:32.460 | So I think I thought I did use it a bit later on, I think.
00:18:37.460 | But it's fine.
00:18:40.020 | It's not really too important.
00:18:42.860 | But we just use that.
00:18:44.180 | We multiply our similarity score value by the scale value
00:18:51.040 | later on.
00:18:52.980 | Down here, so we're using Transforms Optimization.
00:18:56.060 | So what we're doing here is we're
00:18:59.340 | saying, for the first 10% of the training data,
00:19:03.940 | I want to warm up to this learning rate of 2e to minus 5.
00:19:08.780 | So we're not going to start at that learning rate initially.
00:19:12.060 | We're going to slowly build up to it.
00:19:13.900 | And yeah, that's all we're doing there.
00:19:19.480 | And we can see our batch here.
00:19:24.540 | So we have attention mask.
00:19:28.100 | So sorry, we have the anchor.
00:19:29.540 | And in the anchor, we have our attention mask
00:19:31.420 | and we have our input IDs.
00:19:33.780 | And then in positive, we also have the same.
00:19:36.220 | We have attention mask and input IDs.
00:19:38.740 | So input IDs and attention mask, if you
00:19:40.340 | use the Transformers library before,
00:19:42.140 | you probably recognize these.
00:19:43.780 | It's just the input tensors that we
00:19:46.140 | use when we're feeding text into a model.
00:19:49.820 | So that's how our Transformer understands our text.
00:19:57.540 | Yeah, I'll just mention here what is in there.
00:20:01.300 | And let's come down.
00:20:02.740 | Is there anything important?
00:20:04.020 | I don't think so.
00:20:05.820 | OK, let's come down to the training loop.
00:20:07.860 | So yeah, using a scale value of 20 there.
00:20:12.980 | And we come to here.
00:20:14.580 | So here, we have our anchor embeddings.
00:20:19.540 | So these are sentence embeddings.
00:20:21.420 | So we have the anchor sentence embeddings
00:20:23.500 | and we have positive sentence embeddings, which
00:20:25.900 | we've output from our BERT model.
00:20:28.380 | And what we do, we do the mean pooling.
00:20:31.340 | So when I said sentence embeddings here,
00:20:33.700 | they are the token embeddings, sorry.
00:20:35.340 | So we have 512 token embeddings.
00:20:37.860 | And then we do the mean pooling to get the sentence embeddings.
00:20:42.500 | And then what we do is we calculate
00:20:44.940 | the cosine similarity between each of our anchors
00:20:49.660 | and all of the positive values.
00:20:52.860 | So we create that array of values of cosine similarities.
00:21:00.260 | And on each row, obviously, we would
00:21:01.900 | expect the true pair to have the highest cosine similarity.
00:21:05.940 | So that is what we will do a little further down here.
00:21:10.940 | So labels, what this value does here
00:21:14.940 | is just outputs a tensor, which is 0, 1, 2, 3, 4,
00:21:20.220 | up until the batch size of the data.
00:21:25.340 | So that's where you'd expect the argmax value to be.
00:21:31.100 | So for A3, you'd expect the maximum cosine similarity
00:21:36.500 | to be in the third index, where P3, where
00:21:39.980 | it's been compared to P3.
00:21:41.100 | And then here, we are calculating the loss.
00:21:48.620 | So we're calculating between the scores, which
00:21:52.820 | we have up here, and the labels.
00:21:55.340 | So this is taking the--
00:21:58.540 | we're basically looking for this value
00:22:00.700 | here to be the maximum value in a specific row at the index
00:22:07.920 | equal to the current label.
00:22:10.360 | So that's the A3, P3 pair that I'm talking about.
00:22:14.940 | Now we're just optimizing on that.
00:22:16.460 | And that's it.
00:22:17.620 | That's the order is to mnr loss in PyTorch.
00:22:24.140 | But when I say that's order, it's actually quite a lot.
00:22:26.780 | I mean, there's a lot of code going into this.
00:22:28.860 | And it's very confusing.
00:22:32.140 | Oh, here is the--
00:22:33.620 | that's the labels tensor that I just mentioned, by the way.
00:22:36.220 | So you can see it's just counting up
00:22:37.780 | to our batch size, which is 32.
00:22:40.460 | So that's the PyTorch implementation.
00:22:47.260 | But it's complicated.
00:22:49.780 | And to be honest, if you do the same with sentence
00:22:53.760 | transformers, you're probably going to get better results.
00:22:56.420 | So I'm going to show you how to do that with sentence
00:23:00.620 | transformers.
00:23:01.740 | It's a lot easier and a lot more effective.
00:23:05.540 | So in sentence transformers, so we're back
00:23:08.620 | in the first notebook now.
00:23:11.380 | So we have our data set, 314 rows.
00:23:14.940 | And it is just anchor positive pairs.
00:23:18.740 | So with sentence transformers, we
00:23:22.140 | use something called an input example.
00:23:24.660 | So we have a big list of input examples,
00:23:27.220 | which is the data format that our sentence transformers
00:23:33.460 | library training methods would expect.
00:23:37.300 | So we want to do from sentence transformers import input
00:23:48.060 | example.
00:23:49.020 | And what we'll do is we'll just initialize the list here,
00:23:52.180 | so samples.
00:23:53.260 | And this is very simple.
00:23:55.580 | We're just going to go for sample in--
00:24:00.660 | or for row, let's say, for row in data set.
00:24:06.380 | We want to say samples.append.
00:24:11.580 | And in here, we have our input example.
00:24:14.860 | So the thing we just import, the object,
00:24:18.820 | the special sentence transformers object.
00:24:21.500 | And in the previous video and article,
00:24:24.100 | this accepts two different parameters.
00:24:27.620 | It accepts the text and label.
00:24:30.180 | Now, like I said, we don't have labels.
00:24:32.460 | We just generate them as we're performing the training.
00:24:36.580 | So we don't need the label here, so we just write text.
00:24:41.340 | And in here, we-- so list.
00:24:44.780 | And in there, we want our row.
00:24:47.620 | We want the premise, which is our anchor.
00:24:52.780 | And we also want the hypothesis, which is our positive.
00:25:00.180 | OK, and that's all we need, OK?
00:25:04.700 | So that will take a while.
00:25:10.900 | So what you can do is if I just go from tqdm.auto,
00:25:18.260 | import tqdm.
00:25:20.580 | So we can add this to our loop here
00:25:23.060 | so that we have a nice progress bar,
00:25:25.140 | so we can see how far along we are,
00:25:27.700 | how long it's going to take.
00:25:30.860 | And I mean, it's very quick anyway,
00:25:33.060 | but I think it's nice to be able to see that, especially
00:25:35.460 | for the longer data sets.
00:25:37.580 | Now, here we need to initialize what is a data loader.
00:25:42.620 | Usually, we use the PyTorch data loader.
00:25:44.820 | This time, we're not going to.
00:25:46.060 | This time, we're going to use a special data
00:25:48.180 | loader from the sentence transformers library.
00:25:51.620 | And I'll explain why in just a moment.
00:25:54.340 | So from sentence transformers, I want to import data sets.
00:26:02.300 | We set the batch size, as we usually do.
00:26:04.700 | We're going to set that equal to 32.
00:26:08.540 | And we're creating a data loader,
00:26:12.180 | so we're just going to call that loader again, as we usually do.
00:26:19.300 | And in here, we want to write data sets,
00:26:22.060 | and we want the no duplicates data loader.
00:26:26.300 | So you can see that there.
00:26:28.020 | And what this is going to do is, unlike a normal data
00:26:31.420 | loader in PyTorch, which would just feed you--
00:26:35.460 | it would just feed you 32, in this case, samples at once,
00:26:41.980 | it wouldn't check the data in there.
00:26:43.620 | Or no duplicates data loader checks
00:26:47.340 | that you don't have any duplicates
00:26:50.260 | within the same batch.
00:26:52.060 | Now, realistically, with this data set,
00:26:53.700 | probably not going to get that anyway.
00:26:56.060 | But if you do have data sets which might have duplicates,
00:27:00.500 | you can use this to avoid that issue.
00:27:03.020 | Because if you think, OK, if you have a duplicate,
00:27:06.020 | and our labels are saying that pair A1 and P1
00:27:11.660 | should be the same, but in reality, over here,
00:27:15.060 | we have A7 and P7, which are exactly the same,
00:27:19.140 | it's going to confuse the model.
00:27:20.420 | And it's going to say that A1 and P7 should be matching.
00:27:25.860 | But in reality, it should just be A1 and P1.
00:27:29.820 | So that's why we use this no duplicates data loader
00:27:32.860 | to remove any possibility of that happening.
00:27:36.900 | Now, as well, if that happens occasionally,
00:27:40.420 | it's not really an issue anyway.
00:27:42.380 | But it's nice to just be careful with it.
00:27:46.780 | So we have our samples, and we want
00:27:48.540 | the batch size, which is equal to our batch size.
00:27:54.180 | So that's our data loader.
00:27:56.980 | Now, what we need to do is initialize a model.
00:28:00.180 | So in Sentence Transformers, we can do the same thing
00:28:03.620 | as we do with Hug and Face Transformers, where
00:28:05.540 | we load a pre-trained model.
00:28:08.300 | But we can also initialize a new model
00:28:11.580 | using what are called modules.
00:28:13.820 | So in this case, we're going to use
00:28:15.980 | two modules, which is going to be the Transformer module,
00:28:20.060 | so the actual bear itself.
00:28:21.780 | And that's going to be followed by a pooling module
00:28:24.380 | for the mean pooling that we do.
00:28:28.340 | So we're going to write, from Sentence Transformers,
00:28:33.380 | import models, and also Sentence Transformer.
00:28:40.820 | And we say, OK, BERT is equal to models,
00:28:45.140 | and it's a Transformer model.
00:28:48.180 | So we initialize that, and we're using a BERT base on case
00:28:54.220 | model.
00:28:55.660 | And we also have our pooling module,
00:29:01.420 | and that is models pooling.
00:29:07.100 | And in here, the pooling approach
00:29:09.060 | needs to know the dimensionality of the vectors
00:29:12.900 | that it's going to be dealing with.
00:29:14.540 | And we can get that from BERT.
00:29:15.860 | We can say, get word embedding dimension, which is the 768.
00:29:24.220 | And as well as that, we also need
00:29:26.420 | to know which type of pooling we want to use.
00:29:30.660 | And for that, we can write pooling.
00:29:33.940 | And we can see we have all these different methods in here.
00:29:37.620 | So we have the pooling mode.
00:29:38.940 | We can use the CLS token.
00:29:41.740 | We can take the maximum value.
00:29:43.140 | We can take the mean and consider the square root
00:29:48.980 | of the length of the tokens.
00:29:50.780 | I don't know this one, so I could be completely wrong.
00:29:54.500 | And we also have the mean.
00:29:55.900 | So this is a mean pooling method that I mentioned before.
00:29:59.140 | We're going to be using that one, so we say true.
00:30:02.780 | And then to actually initialize the model using
00:30:06.740 | those two parts or two modules, we
00:30:11.140 | are going to write model equals sentence transformer.
00:30:15.460 | And like I said before, this is how you would usually
00:30:18.380 | load a pre-trained model.
00:30:19.660 | So if you wanted to load a pre-trained model,
00:30:21.580 | you'd be like BERT base on case, like in here.
00:30:24.620 | We're not doing that.
00:30:25.620 | We are initializing a new model using these two modules.
00:30:29.620 | So we write modules equals.
00:30:33.260 | And we have BERT followed by the pooling function or module.
00:30:41.900 | And then we can have a look at what we have there.
00:30:44.020 | So we have the model.
00:30:45.540 | So hopefully, it doesn't take too long.
00:30:49.660 | We get our list.
00:30:50.380 | This is fine.
00:30:51.060 | It's just coming from, I think, Honey Face.
00:30:53.820 | And then here, we have the structure of our model.
00:30:57.220 | So we can see with transformer, we're using the BERT model.
00:31:01.940 | And in here, the pooling--
00:31:04.980 | we have the embedding dimension, 7, 6, 8.
00:31:07.820 | And we see that the only one of these values that is true
00:31:11.300 | is the pooling mode mean tokens.
00:31:13.860 | And the rest of them are false because we're not
00:31:15.860 | using those methods.
00:31:19.340 | Now, all we need to do is do the--
00:31:24.100 | we need to initialize the loss function.
00:31:26.300 | So from Sentence Transformers, import losses.
00:31:32.940 | And our loss function is going to be equal to losses.
00:31:36.660 | And we have the multiple negatives ranking loss,
00:31:40.580 | which we can see there.
00:31:42.860 | And all we need to pass to that is a model.
00:31:45.740 | So it knows what model it is--
00:31:50.020 | it knows the model parameters that it's dealing with.
00:31:53.220 | So let's run that.
00:31:54.580 | And with that, we're ready to actually start training
00:31:58.700 | or fine-tuning the model.
00:32:01.220 | So we'll say, OK, we want to use one epoch.
00:32:06.900 | And the number of warm-up steps--
00:32:09.220 | so like I said before, we have that 10% of warm-up steps
00:32:13.100 | that we want to use.
00:32:16.220 | But it isn't as complicated to set that this time.
00:32:19.900 | We just want to say, OK, 10% 0.1 multiplied
00:32:23.220 | by the total number of steps.
00:32:25.500 | What total number of steps?
00:32:27.620 | Well, it would be the length of the loader, OK?
00:32:35.300 | And we write int there, OK?
00:32:39.180 | And that was our warm-up steps.
00:32:43.900 | And then we can call model.fit, so like tennis flow.
00:32:48.640 | And then in here, we just pass a few--
00:32:50.820 | now, we pass our model configuration and setup
00:32:55.420 | and so on.
00:32:57.780 | So the first thing we want to do is set our train objectives,
00:33:03.820 | which is just a list containing train objective pairs,
00:33:08.660 | you could say.
00:33:09.860 | For us, we just have one of those.
00:33:11.340 | And that is loader followed by loss, so our objectives.
00:33:17.420 | And then we have the epochs.
00:33:19.900 | So epochs, just put one there.
00:33:22.340 | It's probably easier.
00:33:24.500 | We have number of warm-up steps.
00:33:26.940 | So we write warm-up steps is equal to warm-up.
00:33:31.220 | And then we have the output path,
00:33:32.500 | so where we're going to save our model.
00:33:36.780 | So I think I use that MNR2 or something for the final one
00:33:42.500 | that I put together.
00:33:45.940 | And then after that, final thing, if you want.
00:33:50.660 | So this will come up with a progress bar
00:33:53.780 | like we saw before.
00:33:54.620 | But the progress bar, for me, it will just
00:33:56.940 | print every single step update.
00:33:59.440 | So it's quite annoying.
00:34:00.940 | So you can set show progress bar equal to false, if you want.
00:34:06.260 | And then you just run this, OK?
00:34:08.180 | We run this, and this will fine-tune our model.
00:34:11.380 | Now, I've already run it, so I'm not going to run it again.
00:34:14.220 | It takes a while.
00:34:17.140 | So what I'm going to show you are just the results from that.
00:34:20.300 | OK, so this is my other notebook where I already
00:34:25.900 | ran all of what you just saw.
00:34:28.980 | We have our sentences here.
00:34:30.700 | So it's a load of random sentences, but a couple of them
00:34:32.980 | do match up, right?
00:34:34.860 | So we have this sushi one here, and there's
00:34:38.260 | another sushi one here.
00:34:39.380 | But what I've done is not use the same words
00:34:42.060 | in both sentences.
00:34:44.740 | We've got knit noodles and weaving spaghetti.
00:34:48.180 | And we also have dental specialists
00:34:50.340 | with construction materials and dentists with trim bricks.
00:34:54.700 | So similar in concept, but they don't
00:34:56.980 | share any of the same words.
00:34:58.180 | And it's quite abstract as well.
00:35:01.740 | So down here, so we have our model.
00:35:05.220 | We're just encoding those sentences.
00:35:08.020 | And then what I'm doing is I'm just
00:35:10.580 | creating a similarity matrix between all
00:35:14.220 | of these different sentences, or the encodings produced
00:35:18.700 | by our model, the sentence embeddings produced
00:35:21.220 | by our model, and all of the equivalent embeddings
00:35:26.180 | from this list of sentences.
00:35:29.100 | And we just use matplotlib and seaborn to actually plot that.
00:35:35.060 | And we see that we get this nice visualization.
00:35:38.820 | And we see that we have 4 and 3 that align,
00:35:42.460 | and also 9 and 1, and 7 and 5.
00:35:45.460 | And those are the three pairs I mentioned before
00:35:50.140 | that are very similar.
00:35:52.060 | And the rest of these values are all very low,
00:35:56.060 | which is obviously very good because those pairs are not
00:36:01.300 | similar.
00:36:02.100 | Maybe they have some sort of similarity,
00:36:04.460 | but they're not similar.
00:36:05.700 | So we can obviously see straight away
00:36:09.020 | where the true pairs are, or the true semantic pairs are there.
00:36:15.020 | OK, so I think that's pretty much it for this video.
00:36:19.980 | We've kind of gone through multiple negatives, ranking
00:36:24.500 | loss, and how to implement it.
00:36:27.460 | And like I said before, this is really
00:36:29.460 | if you're going to train a sentence transformer,
00:36:32.340 | or fine-tune a sentence transformer,
00:36:34.940 | this is the approach I would probably
00:36:38.060 | go with, depending on the data that you have.
00:36:41.580 | Now, I mean, that's everything for the video.
00:36:44.020 | So thank you very much for watching.
00:36:47.940 | I hope you've enjoyed it.
00:36:49.220 | And I will see you in the next one.