back to index

Fine-tune Sentence Transformers the OG Way (with NLI Softmax loss)


Chapters

0:0 Intro
0:42 NLI Fine-tuning
1:44 Softmax Loss Training Overview
5:47 Preprocessing NLI Data
12:48 PyTorch Process
19:48 Using Sentence-Transformers
30:45 Results
35:49 Outro

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi, welcome to the video.
00:00:02.160 | We're going to be covering how we
00:00:04.200 | can train a SBERT model, or a Sentence Transformer,
00:00:07.720 | or SentenceBERT model, using what
00:00:12.240 | is kind of like the original way of training these models
00:00:15.320 | or fine-tuning these models, which is using Softmax Loss.
00:00:19.940 | So let's start with just a quick overview
00:00:22.520 | of the training approach.
00:00:24.840 | [MUSIC PLAYING]
00:00:28.320 | [MUSIC PLAYING]
00:00:31.760 | Now, using the Softmax training approach
00:00:47.320 | is part of what we could call the natural language inference
00:00:52.440 | approach to fine-tuning these models.
00:00:55.280 | And within that sort of category of training,
00:00:59.560 | we have two approaches.
00:01:01.520 | We have Softmax Loss or Softmax Classification Loss,
00:01:05.360 | which we're going to cover.
00:01:06.600 | And then we also have something called a Multiple Negatives
00:01:10.600 | Ranking Loss.
00:01:12.040 | Now, in reality, you probably wouldn't use Softmax Loss,
00:01:17.840 | because it's just nowhere near as good as using
00:01:22.120 | the other form of Loss, the Multiple Negatives Ranking.
00:01:25.760 | I'm going to call it MNR from now on.
00:01:28.480 | So MNR is more effective, but Softmax Loss
00:01:32.240 | is sort of the original, and that's
00:01:33.840 | why we're covering it here.
00:01:35.560 | So we're just going to go through it.
00:01:37.160 | We're not going to go into too much depth.
00:01:42.360 | I'm going to just kind of go through it very quickly.
00:01:45.960 | So when we're training these models,
00:01:47.520 | we can either use what is called a Siamese network
00:01:51.480 | or a triplet network.
00:01:52.880 | Now, what you can see right now is a Siamese network.
00:01:55.760 | So we have almost like two copies of the same BERT,
00:01:59.640 | so they're like Siamese twins.
00:02:03.360 | And the idea is we would have two sentences, sentence A
00:02:07.480 | and sentence B, and we would feed both of those
00:02:10.680 | through our BERT model and produce the token embeddings
00:02:14.880 | that we usually get from BERT.
00:02:16.400 | And then we use a pooling operation,
00:02:18.440 | so like a mean average pooling, and then from that,
00:02:22.640 | we get our sentence embedding in U and V.
00:02:26.080 | And what we would be doing is optimizing
00:02:29.920 | to try and get those sentence embeddings as close as
00:02:34.320 | possible for similar sentences, and then
00:02:38.280 | for dissimilar sentences, we want
00:02:39.720 | them to be as far away from each other as possible.
00:02:46.560 | So that's kind of like the start of the model,
00:02:50.880 | but it's not the full model.
00:02:52.640 | We continue, and what we're going to do
00:02:57.040 | is concatenate those two together, so U plus V here.
00:03:03.480 | And then we're also going to do this other operation here.
00:03:06.800 | So we're going to take U and V, and we're
00:03:08.520 | going to get difference between them.
00:03:10.920 | So this is just a positive number here.
00:03:13.080 | We're taking the magnitude here, these bars.
00:03:17.480 | So we're just getting a positive number, which
00:03:19.880 | is a difference between the two vectors,
00:03:22.680 | and we also concatenate that.
00:03:24.080 | And we create this big vector, which is U, V,
00:03:29.560 | and then we have U minus V at the end.
00:03:31.240 | And we take that vector, so that's
00:03:38.120 | what you can see over here, and we feed it
00:03:40.680 | into a very simple feed-forward neural network.
00:03:44.880 | So to feed-forward neural network,
00:03:47.000 | each one of these sentence embeddings,
00:03:49.520 | we're going to have the dimensionality of 768.
00:03:55.040 | So obviously, the dimensionality or input dimension
00:03:58.000 | of our feed-forward neural network
00:03:59.880 | is 768 with pipeline 3.
00:04:05.120 | And then the output are our three output activations here.
00:04:11.240 | Now, if you watched the last video
00:04:13.520 | or you read the last article, you
00:04:15.800 | will remember that we had three labels in our training data.
00:04:22.960 | So in our NLI training data, we had its entailment, neutral,
00:04:32.520 | and contradiction.
00:04:34.400 | So those sentence pairs, and we're trying to classify,
00:04:36.960 | are they relevant to each other or not?
00:04:40.440 | And then we have the true label over here,
00:04:43.920 | and then we just optimize using cross-entropy loss, which
00:04:47.280 | is what you can see over here.
00:04:49.480 | The reason why this is called softmax loss
00:04:51.880 | is because there is a softmax function
00:04:54.760 | within the cross-entropy loss function.
00:05:00.240 | So that's the process at high level.
00:05:06.640 | So let's jump in.
00:05:08.520 | We're going to first--
00:05:11.640 | we're mainly going to focus on how we form our data to put it
00:05:15.280 | into here, and then we're going to move on
00:05:17.400 | to how we actually train all of this using the sentence
00:05:23.360 | transformers library.
00:05:24.560 | That's going to be our main focus.
00:05:26.400 | But we'll just very quickly run through the code
00:05:29.680 | in PyTorch, so you can just see how it works.
00:05:34.800 | And if you really want to dive into it,
00:05:36.720 | you can obviously just take a look at the code
00:05:39.000 | and figure out how it is.
00:05:40.720 | It's not hard to read.
00:05:42.320 | So let's jump into it.
00:05:44.720 | OK, so the first thing that we want to do--
00:05:50.680 | I've added a little note up here.
00:05:52.060 | This is just an information of what
00:05:54.720 | is in our data or the labels of our data.
00:05:58.120 | We're going to have a look at our data anyway now.
00:06:00.560 | So I'm going to use the HuggingFace data sets library.
00:06:04.680 | So we're just going to import data sets.
00:06:07.280 | And we're actually using two different data sets.
00:06:10.160 | We're using the Stanford Natural Language Inference,
00:06:12.520 | or SNLI data set, and also the multi-genre NLI data set
00:06:16.840 | as well.
00:06:19.960 | So SNLI, we just put that in.
00:06:22.920 | So it's data sets load data set.
00:06:29.520 | And then it's called SNLI.
00:06:31.160 | And we want the training subset of that, so train.
00:06:38.080 | It's also-- sorry, not subset--
00:06:40.040 | split.
00:06:40.540 | And then we can have a look.
00:06:45.520 | What do we have inside the data?
00:06:49.080 | OK, so we have these three features.
00:06:52.120 | So these are columns, or you can call them columns if you want.
00:06:56.280 | So we have the premise, hypothesis, and label.
00:06:59.000 | Now, in those previous diagrams we saw,
00:07:02.040 | we saw sentence A, which is the premise,
00:07:05.320 | and sentence B, which is the hypothesis.
00:07:08.920 | And then we saw labels at the end, so it's just the same.
00:07:14.360 | Now, if you want to have a quick look at one of those,
00:07:17.480 | we just have this.
00:07:18.280 | So we just get what you can see here.
00:07:22.680 | So what label do we have here?
00:07:25.680 | We have label 1, so we can appear.
00:07:28.120 | That's neutral, so the premise and the hypothesis
00:07:30.280 | could both be true, but they are not necessarily related.
00:07:33.720 | And then here, we see a person's jumping
00:07:35.840 | over a broken down airplane.
00:07:37.120 | The person is training his horse for the competition.
00:07:39.360 | So they could both be about the same topic,
00:07:42.600 | but it's not necessarily about the same.
00:07:44.680 | They don't infer each other.
00:07:46.720 | So if we maybe try and find one that
00:07:50.040 | is a contradiction or something else, why did I spawn again?
00:07:56.840 | OK, so this one is a contradiction.
00:07:59.240 | So a person on a horse jumps over a broken down airplane.
00:08:03.680 | A person's at a diner ordering an omelet.
00:08:05.440 | So those two things aren't about the same topic,
00:08:07.560 | so they're a contradiction.
00:08:09.520 | And then the other one, we have just, I think, if I do,
00:08:14.200 | we should find one.
00:08:16.680 | We have this one.
00:08:17.360 | So a person on a horse jumps over a broken down airplane.
00:08:20.840 | A person is outdoors on a horse.
00:08:23.920 | So this here would infer--
00:08:28.440 | sorry, this here, this premise infers this hypothesis.
00:08:34.400 | So that's the data.
00:08:36.120 | And like I said, we have two of those data sets.
00:08:38.480 | We have SMLI and MNLI.
00:08:41.960 | So MNLI, we load it the same way.
00:08:45.520 | So data sets, load data sets.
00:08:49.080 | It's from the glue data set.
00:08:54.280 | And then the subset is MNLI.
00:08:58.520 | And again, we want to split to be equal to train.
00:09:03.080 | OK, and if we just have a look, we'll
00:09:06.120 | see a very similar format, but not exactly the same.
00:09:10.760 | So see, we have premise, hypothesis, label.
00:09:13.480 | Then we also have this index.
00:09:15.120 | We need to merge these two data sets.
00:09:17.280 | We need to reformat our MNLI data set a little bit.
00:09:21.200 | So first thing we do, we write MNLI.
00:09:23.760 | And we want to remove that column.
00:09:25.840 | So MNLI, remove columns.
00:09:30.560 | And we are doing IDX.
00:09:34.240 | OK, and let's make sure it works.
00:09:36.400 | And we see now we don't have that.
00:09:38.520 | And if we try to merge these, we still
00:09:40.240 | get an error, which is annoying, but it's fine.
00:09:43.480 | So I'm going to call it data set,
00:09:46.080 | because data sets adopt to concatenate data sets.
00:09:52.560 | And we just pass them both as a list.
00:09:55.440 | So SMI, MNI, and we're going to get this error.
00:09:59.240 | OK, so the schema, so the format of the data set is different.
00:10:04.480 | Even though they both contain the same columns,
00:10:07.720 | I think one of them has a slightly different format.
00:10:11.120 | Like one of them allows you to have nulls.
00:10:13.400 | In fact, it does say here, right?
00:10:15.840 | So they both have slightly different formats.
00:10:20.440 | So to fix that, we just want to change the schema
00:10:24.840 | of one of those data sets.
00:10:27.120 | And all we do for that is we're going to change the SMI data
00:10:32.120 | set and say SMI cast features, just cast maybe,
00:10:39.240 | and MNI.features here, right?
00:10:44.920 | Yeah, and then we can actually do this now.
00:10:51.360 | OK, so now we have our data set.
00:10:53.400 | We can look and see, OK, we now have 943, basically,
00:11:01.280 | 1,000 rows there.
00:11:03.280 | Now, inside this data, we actually
00:11:06.080 | have some rows that we don't want.
00:11:08.000 | So we should have the labels 0, 1, and 2,
00:11:12.320 | which we have up at the top here, 0, 1, and 2.
00:11:15.840 | But there's actually some rows that have the label minus 1.
00:11:19.760 | And all these are just erroneous rows.
00:11:22.680 | We don't actually want those in there.
00:11:24.280 | It's where someone couldn't figure out what to actually
00:11:27.480 | rate that sentence there.
00:11:29.760 | So what we're going to do is just remove those.
00:11:31.720 | So we write data set equals data set.
00:11:35.920 | And we're going to use filter.
00:11:39.000 | And then we just write lambda function.
00:11:41.240 | This lambda function is going to select rows where the label
00:11:47.960 | value is not minus 1.
00:11:49.800 | So we're going to say false if the label value, so label,
00:11:59.040 | is equal to minus 1.
00:12:00.960 | So we're going to filter those out.
00:12:02.720 | And then obviously, we want to put else true
00:12:04.960 | to keep the other columns.
00:12:08.720 | So let me just print out.
00:12:12.760 | And we can see.
00:12:13.800 | So we have 942.8 there.
00:12:17.560 | And here, we have 942.0.
00:12:20.240 | So removed, I think it's like 700 or so rows.
00:12:24.600 | So if we're using the sentence transformers way of training
00:12:32.080 | the models, this is pretty much all we have to do.
00:12:34.520 | There's one more step that we have to take,
00:12:38.240 | which is to convert the data into input examples or a list
00:12:44.440 | of input examples, which we'll move on to in a moment.
00:12:47.920 | We won't cover it now.
00:12:48.880 | I'm going to quickly just cover the other training
00:12:52.440 | approach using PyTorch.
00:12:55.200 | But I mean, it's quite complicated.
00:12:58.000 | And at least when I was training,
00:13:00.280 | using that approach, the model was nowhere near as good
00:13:03.680 | as when I trained it using sentence transformers.
00:13:07.680 | So I wouldn't recommend it.
00:13:09.480 | But if you're interested, this is how we do it.
00:13:13.480 | So let me switch over to that notebook.
00:13:17.200 | So if I come over here.
00:13:19.240 | And OK, we're going to see it's basically doing the same thing.
00:13:26.000 | We're loading the data set.
00:13:27.240 | And we come down.
00:13:29.320 | Now, is there a difference?
00:13:31.360 | So the difference here, so we're importing mainly
00:13:35.320 | the BERT tokenizer is what we're focusing on here.
00:13:38.000 | We come down, and then here, so we're
00:13:42.480 | filtering, which is what we did before.
00:13:44.320 | Nothing new there.
00:13:45.800 | But here, we're doing something different.
00:13:47.600 | So here, we're actually tokenizing our text.
00:13:53.120 | So we're using this map function here.
00:13:58.520 | We're tokenizing both the premise sentences and also
00:14:02.040 | the hypothesis sentences.
00:14:04.080 | And we get the input IDs and the attention mask out of those.
00:14:07.880 | And if we have a look down here, we'll
00:14:11.840 | see that this is what we end up with at the end there.
00:14:14.640 | So I've removed all the other features.
00:14:17.040 | And all we have are the labels.
00:14:19.120 | And then we also have the input IDs and attention
00:14:21.200 | mask both our premise, or sentence A,
00:14:24.120 | and also our hypothesis, or sentence B.
00:14:28.960 | And then after that, we need to do this as well.
00:14:31.200 | It's dataset.setformat.
00:14:33.760 | And we use, because we're using PyTorch,
00:14:35.680 | we set it into a Torch format, OK?
00:14:40.680 | From there, typical PyTorch stuff here.
00:14:43.680 | So we're setting up a data loader using batch size 16.
00:14:47.520 | That's what we use in the SBERT paper.
00:14:50.600 | And then if we come down, this is all just examples.
00:14:57.440 | So I'm actually going to go a little further down.
00:15:00.680 | So to here, OK?
00:15:08.120 | So here, I'm defining the-- you remember before in that graph,
00:15:12.080 | we had the--
00:15:14.800 | we just passed sentence A, sentence B.
00:15:17.160 | They both went into the BERT, or the Siamese BERT.
00:15:20.720 | And then there was this pooling method,
00:15:24.120 | which took our token embeddings, which
00:15:26.320 | are 512, 768 dimensional vectors,
00:15:31.840 | and compressed them into just a single 768 dimensional vector.
00:15:38.240 | That's what this function here does.
00:15:40.920 | When we're using Sentence Transformer,
00:15:42.720 | we don't need to worry about this.
00:15:45.000 | Sentence Transformers, the library, by the way,
00:15:47.400 | the framework, that's probably a bit confusing.
00:15:51.680 | But I mean, when I say Sentence Transformers, or using
00:15:54.920 | Sentence Transformers, I mean the framework or library,
00:15:58.120 | which we're going to cover soon.
00:16:00.400 | But obviously, they're also the models.
00:16:02.760 | That's the name of the models.
00:16:04.360 | So here, I'm taking the mean pooling.
00:16:09.040 | It's taking the average of those values
00:16:11.160 | and excluding values that are padded,
00:16:14.000 | which is why we're not just taking the average straight.
00:16:18.440 | We are removing those attention mass values.
00:16:24.080 | We go down.
00:16:24.720 | We move our device.
00:16:27.040 | So we check if we have a CUDA-enabled GPU
00:16:29.880 | and move our model to it, if we can.
00:16:32.680 | And then here, these are the layers we use.
00:16:34.800 | So I told you before, we had that--
00:16:38.480 | well, we concatenate our U and V vectors,
00:16:41.200 | the sentence embeddings.
00:16:42.880 | And then we pass them to a Feedforward Neural Network.
00:16:45.760 | And that Feedforward Neural Network
00:16:47.160 | is the size of our sentence embeddings multiplied by 3.
00:16:53.080 | And it outputs three labels or classes.
00:16:58.040 | And then we also use a cross-entropy loss function
00:17:02.760 | between what the Feedforward Neural Network outputs
00:17:06.760 | and our actual labels.
00:17:10.160 | So after that, this is what I mean.
00:17:14.560 | There's quite a bit of code when you
00:17:16.880 | go to the very manual PyTorch approach,
00:17:20.280 | rather than using the Sentence Transformers library.
00:17:22.960 | So here, we're getting this get linear schedule with warmup.
00:17:30.360 | So that's just saying, for the first 10% of our sets,
00:17:34.280 | we're going to warm up the learning rate.
00:17:37.000 | So I'm not going to go full-on training at 1e to the minus 5.
00:17:42.480 | So we're going to warm it up a little bit.
00:17:45.040 | Now, in the SPET paper, they used 2e to the minus 5.
00:17:50.000 | For me, it just bounced around a lot.
00:17:51.720 | So I halved it.
00:17:54.960 | But if you can get it working with 2e to the minus 5,
00:18:00.360 | that's what they use in the paper.
00:18:01.780 | So it's probably better.
00:18:05.400 | And then they only train for 1e park as well.
00:18:09.360 | And then also here, I'm using the add and move weighted k.
00:18:13.800 | And then this is the training loop.
00:18:15.620 | So TQDM is just a progress bar.
00:18:19.440 | We do 1e park.
00:18:20.840 | We make sure our model is on training mode.
00:18:24.320 | We initialize a loop, which is going
00:18:26.120 | to get all the batches from our data loader.
00:18:31.600 | And then we're just getting all the data out.
00:18:33.840 | This is just PyTorch stuff.
00:18:36.400 | Getting our U and V, sentence embeddings.
00:18:41.640 | Then here, we're getting the--
00:18:43.680 | actually, sorry, so U and V here are actually token embeddings.
00:18:47.600 | Here, we're converting them into sentence embeddings.
00:18:51.360 | And then we're getting the U, the absolute value
00:18:54.320 | or the difference vector.
00:18:57.800 | Here, concatenate it all together.
00:18:59.640 | So we're creating that concatenated vector
00:19:04.160 | that we then feed into the feedforward neural network.
00:19:07.160 | And then we optimize based on the loss here.
00:19:12.520 | And that's pretty much it.
00:19:15.160 | And then we're saving the model down here.
00:19:18.600 | So yeah, that's how we train it in PyTorch.
00:19:27.800 | You can see here I was messing around,
00:19:30.760 | seeing if I could see what happened if I did two EPUBs.
00:19:34.200 | It's better to just stick with one.
00:19:37.680 | Even though the loss was lower, in the end,
00:19:41.760 | the performance wasn't any different.
00:19:44.320 | So I would train for one EPUB.
00:19:49.480 | OK, so let's go back to the code and we'll
00:19:51.680 | work through the actual training with sentence transformers,
00:19:54.600 | which is what I would recommend doing.
00:19:57.040 | OK, so I said before we had the list of input examples.
00:20:02.480 | So input example is just a data format
00:20:05.560 | that sentence transformers library uses.
00:20:08.160 | So we just want to write from sentence transformers,
00:20:12.520 | I'm going to import input example.
00:20:14.720 | And then all I'm going to do here is write from tqdm or--
00:20:22.000 | from tqdm.auto, I want to import tqdm.
00:20:26.800 | So this is just for our progress bar,
00:20:28.360 | so we can see what is happening.
00:20:31.920 | And then in here, we just want to--
00:20:34.960 | actually, we want to create our training examples first,
00:20:37.840 | or training samples, whichever you want to--
00:20:40.720 | whatever you want to call it.
00:20:43.000 | It's going to be empty list.
00:20:44.600 | And then we literally just for loop
00:20:46.280 | through all of our training data, through our data set,
00:20:49.160 | and extracting what we need from it,
00:20:51.560 | which is just sentence A, sentence B, and the label.
00:20:55.200 | So write for row in--
00:21:00.520 | I'm going to put tqdm training samples.
00:21:03.200 | So just adding tqdm in there so we have a progress bar,
00:21:06.280 | so we can see where we are.
00:21:09.880 | All we need to do is write train samples, append input example.
00:21:17.320 | And before you get confused, this
00:21:19.720 | should be data set, not train samples.
00:21:23.600 | So data set is where we're looping
00:21:27.360 | through our data set, not the empty list.
00:21:32.200 | And then inside our input example,
00:21:35.320 | we have two variables, text and labels.
00:21:39.280 | So you have to pass your text, which is the input text
00:21:43.800 | that you're going to process into your model.
00:21:46.080 | So we go row, premise, and also row hypothesis.
00:21:54.720 | So they're just our two text features from our data set.
00:22:01.000 | And then here, we also want a label.
00:22:03.200 | So label is just row label.
00:22:07.160 | It's just the feature names from our data set,
00:22:09.840 | which you can still see up here.
00:22:12.200 | Now, we process that.
00:22:14.640 | It can take a little bit of time.
00:22:16.400 | So I won't take too long, unfortunately.
00:22:18.960 | And then from there, we need to-- you remember before,
00:22:22.520 | or when we very quickly went through,
00:22:24.840 | we had the PyTorch data loader.
00:22:27.040 | We also need a data loader here as well.
00:22:30.160 | Sometimes you can use special data loaders
00:22:32.600 | from the Sentence Transformers library, which are quite good.
00:22:35.560 | But for this, we're just using a normal PyTorch data loader.
00:22:38.520 | So we need to import torch for that.
00:22:41.760 | So we can just write from torch utils data, import data loader.
00:22:48.920 | Same as the paper, we're using batch size of 16.
00:22:57.880 | And the data loader or loader is just data loader.
00:23:06.000 | We pass in those train samples, specify our batch size.
00:23:13.200 | And if you'd like to shuffle, which in this case we will,
00:23:17.120 | you also put shuffle.
00:23:18.720 | So shuffle equals true.
00:23:22.280 | And that should work.
00:23:24.760 | So now we have our data loader.
00:23:26.400 | And what we do now is initialize our model
00:23:31.320 | using Sentence Transformers.
00:23:32.800 | Now, Sentence Transformers uses modules
00:23:36.400 | to set up the model.
00:23:39.520 | So we're going to have a transformer module, which
00:23:42.400 | I'll just leave out a bit.
00:23:44.080 | And then we're also going to have a pooling module, which
00:23:48.440 | is for our mean pooling layer.
00:23:51.880 | So from Sentence Transformers, again,
00:23:55.160 | we're going to import models.
00:23:58.040 | And what is this one?
00:24:00.880 | Sentence Transformer, yeah.
00:24:05.280 | We initialize those two modules.
00:24:08.600 | So we have a BERT module.
00:24:12.400 | So models, transformer.
00:24:15.760 | And then here, it's using the Hugging Face models.
00:24:20.400 | So we can put anything from Hugging Face on here.
00:24:23.920 | I'm going to use BERT base on case.
00:24:26.960 | And then we also want our pooler.
00:24:29.320 | So our pooler models again.
00:24:33.320 | Pooler, pooling.
00:24:36.160 | And then we have BERT.
00:24:37.640 | And we want to get the word embedding dimension.
00:24:42.440 | So get word embedding dimension, which
00:24:46.600 | is the 768 of our token embeddings.
00:24:50.920 | And then, of course, of our sentence embedding as well.
00:24:55.560 | And then we also want to set the type of pooling
00:25:00.280 | that we are going to do.
00:25:01.360 | So pooling mode, you can see that we have
00:25:05.040 | these different ways of pooling.
00:25:08.000 | So we have CLS token, maximum.
00:25:10.440 | This one, I've never actually seen use.
00:25:13.920 | Square root, the length, it's interesting.
00:25:17.440 | And then we also have this one.
00:25:19.000 | This is the mean pooling, and we're going to use that.
00:25:21.560 | So that is-- there are two modules.
00:25:34.560 | And then we just want to initialize our model.
00:25:36.760 | So we write sentence transformer.
00:25:40.880 | And what you can do, by the way, is
00:25:43.600 | this is how you would actually--
00:25:45.800 | say you have a sentence transform model
00:25:47.560 | that you want to load.
00:25:48.560 | You'd write the sentence transformer name in here.
00:25:51.640 | So like all MP net, whatever it's called.
00:25:56.600 | You'd do that as well.
00:25:58.800 | But you can also load or initialize
00:26:01.680 | the model using the modules that we just initialized.
00:26:05.680 | So write BERT followed by the pooler.
00:26:09.120 | And then keep details of that model in there.
00:26:14.360 | OK, so ignore that.
00:26:17.120 | So this bit here.
00:26:18.920 | So this is our sentence transformer structure.
00:26:25.560 | You can think of it as a structure.
00:26:27.680 | So our transformer, we're using BERT model,
00:26:29.920 | the maximum sequence length.
00:26:32.640 | And then in here, we have a pooler.
00:26:34.120 | And we have the word embedding dimension that we'll expect,
00:26:38.080 | And then you see here, we have those different pooling modes.
00:26:40.580 | And we are using pooling mode mean tokens, which is true.
00:26:45.280 | The rest of them are false.
00:26:47.800 | And then from there, so we also need
00:26:50.320 | to initialize our loss function, which is pretty straightforward
00:26:55.720 | as well.
00:26:56.600 | So again, from sentence transformers,
00:26:59.800 | I want to import losses.
00:27:02.320 | And there are plenty of different losses
00:27:05.620 | that you can use.
00:27:06.520 | You can just look on that documentation.
00:27:09.200 | But we're using softmax loss.
00:27:11.040 | So what we want to do is write loss equals losses.
00:27:15.060 | And we write softmax loss.
00:27:20.360 | And then in here, so you think, OK, our loss function,
00:27:23.360 | what does it need?
00:27:24.560 | So we pass in the model.
00:27:27.120 | So it can get the model parameters from that.
00:27:31.760 | So firstly, model equals model.
00:27:35.720 | And then it also needs the embedding dimension.
00:27:38.920 | So it's 768 again.
00:27:41.360 | So it needs on this sentence embedding dimension.
00:27:47.240 | And there, we just want--
00:27:49.400 | what do we-- what is it?
00:27:53.120 | So the model get embedding dimension again,
00:27:56.600 | get sentence embedding dimension.
00:28:00.800 | And then we also need to pass--
00:28:02.800 | OK, how many labels are we going to have in our model?
00:28:06.200 | We already know it's the number of labels.
00:28:08.800 | We already know it's three.
00:28:10.960 | So I'm sure you can get that dynamically from the data set
00:28:14.240 | if you want as well.
00:28:15.960 | But I'm just going to put three.
00:28:18.320 | And I think that's it.
00:28:20.720 | So we have our loss model and our data.
00:28:25.200 | So I think we should be OK to start training.
00:28:30.600 | So I'm going to say we go for, OK, one epoch.
00:28:34.640 | We want to say how many warm-up steps do we want.
00:28:37.400 | So again, it's the 10% warming up that we use.
00:28:43.120 | So we just want 0.1 multiplied by the length of our data set.
00:28:50.240 | So length of the data set.
00:28:54.680 | Yeah, and I'm just going to--
00:28:56.000 | we need that to be an integer value.
00:28:59.840 | So I'm just-- it's quite rough, rounding very roughly there.
00:29:04.920 | But that's fine.
00:29:06.960 | And then we want to just start training our model.
00:29:11.320 | So we write model.fit, it's like TensorFlow.
00:29:14.680 | And we use our train objectives.
00:29:17.360 | So in here, we need to pass a list which
00:29:20.920 | contains a single tuple, which is our loader.
00:29:25.920 | So the data loader and the loss.
00:29:28.560 | So I think with this, you can, if you have multiple train
00:29:31.920 | objects, you can put another loader, another loss,
00:29:35.040 | and keep going through that.
00:29:36.240 | So that's why we have a tuple within a list.
00:29:41.760 | Then you have your epochs, the number
00:29:46.120 | of warm-up steps, which is just warm-up steps again.
00:29:51.080 | So warm-up steps.
00:29:54.200 | We also need the--
00:29:56.480 | what do we need?
00:29:57.320 | Output path.
00:29:58.720 | So where are we going to save the model?
00:30:02.360 | So I'd just put like SBIRT.
00:30:06.120 | Now, what did I call it?
00:30:07.120 | I think testB is what I've called it later on.
00:30:12.640 | And oh, last one is show progress bar.
00:30:18.240 | Now, this is automatically true.
00:30:21.800 | But when I zoom that, it just printed out
00:30:24.440 | loads of lines that it's printing to a new line
00:30:27.600 | every single update.
00:30:29.880 | So I'll just set to false.
00:30:32.120 | So I wouldn't do that because that's obviously
00:30:34.080 | quite annoying.
00:30:36.000 | So yeah, that's how you train the model.
00:30:38.960 | I'm not going to do it again.
00:30:40.160 | I've already done it.
00:30:41.120 | I've already trained this SBIRT testB model.
00:30:46.600 | So what I'm going to do is switch over
00:30:48.600 | to that notebook where I trained it
00:30:51.040 | and show you those results.
00:30:53.520 | So this is a notebook, pretty much just covered.
00:30:57.520 | I'm going through that again.
00:31:00.080 | And then here, so we have the training, the training time
00:31:05.800 | as well, something I didn't mention just now,
00:31:09.280 | is one hour and 15 minutes for me on an RTX 3090.
00:31:13.680 | So reasonably fast.
00:31:16.400 | It depends on what you're training on.
00:31:20.400 | Yeah, so yeah, it's quick.
00:31:22.840 | So I define these sentences just below random sentences,
00:31:28.760 | like complete nonsense.
00:31:30.600 | But some of them do align.
00:31:33.040 | So see this one, one thinks she saw her raw fish and rice
00:31:38.520 | change position.
00:31:41.120 | And this one, seeing her sushi move, weaving with spaghetti,
00:31:47.240 | and where is the other, knit with noodles,
00:31:50.080 | and dental specialist with construction materials,
00:31:54.280 | and same again, dentist with chewing bricks.
00:31:58.600 | So there's some that are kind of similar,
00:32:00.320 | but they don't share any of the same descriptive words.
00:32:04.720 | But they kind of mean the same thing, roundabouts.
00:32:14.280 | So with our model, so we have loaded the model here,
00:32:23.120 | which we can just do.
00:32:24.640 | So if you've saved the model, which it does automatically
00:32:28.480 | here, you just take this, you take that, you come down here,
00:32:34.440 | and you write sentence transformer.
00:32:37.800 | And then in here, you do that.
00:32:41.040 | And then you would put that in the model variable.
00:32:46.160 | So that's all I've done there.
00:32:49.280 | So remove that.
00:32:50.840 | So loading the model, I'm going to use
00:32:54.040 | it to encode those sentences, which is just in the list,
00:32:56.560 | and create our embedding.
00:32:57.640 | So the sentence embeddings.
00:32:59.720 | And then from there, I'm getting the cosine similarity,
00:33:03.240 | again, using sentence transformers for that.
00:33:05.120 | This makes it a bit easier.
00:33:06.920 | And I'm just comparing all those sentences.
00:33:09.880 | So I'm just a very qualitative view
00:33:13.840 | on how these embeddings are doing.
00:33:17.600 | And you see that we get these results here.
00:33:22.120 | So it's pretty good, actually.
00:33:28.440 | It's getting the right ones.
00:33:29.600 | So this 7 and 5, 9 and 1, and I think 4 and 3
00:33:37.240 | are the ones that we wanted to get.
00:33:39.680 | And they are, in fact, the highest three-rated scores.
00:33:44.040 | But a lot of these other non-pairs
00:33:46.560 | are still rated kind of high.
00:33:48.640 | And like I said before, softmax loss
00:33:51.800 | is not the best way of training your model anymore,
00:33:54.960 | or fine-tuning your model anymore.
00:33:56.920 | There's other ways, like MNR loss.
00:34:00.520 | So let me show you some of the charts from MNR loss,
00:34:05.960 | and you'll see the difference.
00:34:08.600 | So we have-- this one is just for BERT.
00:34:11.640 | You see they're all very flat.
00:34:13.160 | It does actually get--
00:34:14.760 | almost, it gets two of the correct answers
00:34:17.440 | within its top three, but not all of them.
00:34:20.400 | But it's very flat.
00:34:21.560 | Like, all the values are very near the same value, which
00:34:24.320 | makes it hard to differentiate between similar and not
00:34:27.800 | similar, which is not what we want.
00:34:29.800 | But obviously, BERT hasn't been trained for this,
00:34:31.840 | so you can't expect it.
00:34:34.440 | And this is my PyTorch model.
00:34:37.680 | It's getting better than the BERT model,
00:34:39.720 | but the performance still isn't there
00:34:41.160 | compared to the Sentence Transformer's trained model,
00:34:46.040 | which is not here.
00:34:48.840 | This is the actual-- the one that they trained,
00:34:52.400 | Sentence Transformer's themselves.
00:34:54.640 | You can see there's better differentiation
00:34:57.360 | between general values here.
00:35:00.640 | But then if we compare that to--
00:35:02.360 | so this is an MNR model that I have trained using
00:35:06.720 | the same Sentence Transformer's.
00:35:08.480 | We use Sentence Transformer's here.
00:35:10.080 | We come down, we see a big difference.
00:35:12.200 | So these are our similar pairs.
00:35:14.560 | And they stick out so much more than they
00:35:18.000 | did with the other models.
00:35:19.040 | Everything else is rated--
00:35:20.360 | is scored very lowly.
00:35:22.840 | But these stick out a lot.
00:35:25.760 | And that's really the difference between models.
00:35:31.960 | This is much better because it separates
00:35:34.520 | those similar and dissimilar pairs very well.
00:35:38.120 | It's just a lot more accurate.
00:35:40.720 | So that is-- I mean, that's my MNR model as well.
00:35:46.160 | The actual Sentence Transformer's MNR model
00:35:48.040 | is much better than this.
00:35:50.400 | But yeah, I think that's pretty much it.
00:35:52.640 | That's it for this video.
00:35:54.680 | So in the next one, we are going to--
00:35:57.960 | as you probably have guessed, we're
00:36:00.560 | going to have a look at how we can use MNR loss
00:36:03.040 | or multiple negative ranking loss to build a model, which
00:36:07.800 | I think, personally, is a lot more interesting.
00:36:11.400 | Sentence-- sorry, Softmax loss is pretty interesting.
00:36:15.880 | But it's not particularly intuitive.
00:36:18.880 | And even the Sentence Transformer's authors,
00:36:23.240 | the expert authors, said the same thing.
00:36:25.720 | That's actually where I got it from.
00:36:27.400 | It isn't very intuitive when you think about it.
00:36:29.640 | It's kind of hard to understand why it works.
00:36:33.680 | Because we have that weird concatenation at the end.
00:36:35.880 | It's classifying, and it seems strange.
00:36:38.360 | MNR loss is much more intuitive, and it makes a lot more sense.
00:36:41.560 | And I think it's more interesting,
00:36:43.520 | and you get way better results with it.
00:36:45.200 | So we're going to cover that in the next video.
00:36:48.920 | So I think that should be pretty interesting.
00:36:51.040 | But for now, that's it on Softmax loss.
00:36:55.280 | Thank you very much for watching.
00:36:57.120 | I hope it's been useful.
00:36:58.880 | And I will see you in the next one.