back to indexAugSBERT: Domain Transfer for Sentence Transformers
Chapters
0:0 Why Use Domain Transfer
4:8 Strategy Outline
6:5 Train Source Cross-Encoder
12:44 Cross-Encoder Outcome
15:12 Labeling Target Data
20:31 Training Bi-encoder
23:58 Evaluator Bi-encoder Performance
28:8 Final Points
00:00:00.000 |
Today we're going to talk about how we can train language models across different domains 00:00:06.880 |
Now to do this we're going to be using what is called the augmented expert training strategy 00:00:13.200 |
And we're going to be using the domain transfer flavor of that 00:00:25.280 |
But not enough or you in our case with domain transfer. We have some data in 00:00:31.360 |
One domain so say maybe we have quora question pairs 00:00:35.600 |
But we need sac overflow question pairs and all we have on the sac overflow side 00:00:47.360 |
What we're essentially doing is using that source data set the for example quora question pairs 00:00:56.080 |
Transfer its knowledge and label the sac overflow question pairs 00:01:02.640 |
Now we don't need to stick with question pairs and question pairs 00:01:05.600 |
We can have and go from question pairs to semantic similarity pairs or anything else 00:01:12.400 |
We do have to make sure that our source domain and target domain are at least similar 00:01:24.160 |
across a gap between the source domain and target main the 00:01:29.440 |
Great that gap. So the the less overlap those two domains have 00:01:38.080 |
If this gap is bigger, we obviously need a bigger stronger bridge 00:01:43.200 |
And we can think of that bridge as our model. We need a bigger stronger model to bridge these 00:01:50.320 |
longer or greater gaps between the two domains 00:01:54.640 |
So in this video, we're going to be going through one example, but I have tested these five different data sets 00:02:08.100 |
So this is how many n-grams overlap between the two data sets as a as a whole 00:02:15.840 |
So this is not a perfect measure of similarity, of course 00:02:22.960 |
Syntactical overlap and it's not really capturing whether 00:02:31.280 |
Have the same language structure or talking about the same topics 00:02:37.520 |
Much although it does help and this is just a very simple measure to help us 00:02:45.280 |
The two domains that we're trying to transfer across 00:02:59.600 |
Because of this these lower similarity scores it would probably 00:03:03.860 |
Not be the best domain transfer to these different domains 00:03:15.100 |
Mrpc and quora question pairs are all reasonably similar 00:03:19.500 |
So we'd expect the performance or the transfer or ease of transfer between those domains 00:03:28.300 |
But like I said, this isn't perfect and in fact one of the 00:03:33.820 |
best domain transfer performances that I saw was from the medical question pairs 00:03:46.680 |
Semantically because question pairs will probably have more of a similar structure 00:03:53.900 |
Which is what we have up here with rt or just semantic textual similarity pairs as we have up here 00:04:04.380 |
But it's not really reflected in this chart here 00:04:12.120 |
Strategy that we're going to follow with domain transfer augmented experts 00:04:21.960 |
Get our data. So we're going to have a labeled source data set which 00:04:25.640 |
The earlier example would have been the quora question pairs and we're also going to have the unlabeled 00:04:31.020 |
Target data set which would be our stack overflow question pairs. We're not actually going to use those two examples. We're going to use 00:04:42.740 |
As our source data set and stsb as our target data set. So we'll go ahead and we'll get our 00:04:54.420 |
And we'll train a cross encoder model with it. Now. The reason i'll just 00:04:59.780 |
Quickly mention is the reason we're training a cross encoder model rather than a by encoder model 00:05:05.620 |
Which is a sentence transformer is because we need less data 00:05:17.940 |
So we need less data to train a good cross encoder, but cross encoders are also slow 00:05:24.020 |
Which is why we don't use a cross encoder over here as our final model 00:05:29.780 |
Basically using that greater performance with less data 00:05:41.080 |
Over here to a similar level or similar performance to that cross encoder 00:05:46.500 |
That's basically what we're doing. It's knowledge distillation from our cross encoder 00:05:52.660 |
All the way over to our sentence transformer. That's all we're doing 00:05:57.460 |
So let's jump over to the code and we'll have a look at how we do those or form those first two steps 00:06:04.900 |
Okay, so we're starting with the source domain mrpc as I said, and the target domain is scsb 00:06:15.060 |
So we use hugging face data sets for this so you may need to pip install that so just pip install 00:06:28.580 |
So this is just going to load the mrpc data set from hugging face 00:06:35.780 |
Larger glue data set and mrpc is like a subset of that 00:06:40.740 |
And then we're taking the train split from there and we can see the data set features here 00:06:46.500 |
So we have sentence one and two which are sentence pairs that we're going to be 00:07:09.220 |
That's why we're going with the cross encoder here so that can build something that has good performance 00:07:14.980 |
Whereas a sentence transformer might need a little more data 00:07:22.980 |
Um, so we need to first format that data for training and 00:07:29.380 |
to do that we we just go through each row in the data set that we just downloaded and 00:07:35.300 |
Sentence transformers always uses this input example object 00:07:40.500 |
So all we're doing here is creating a an empty list and then within that empty list 00:07:45.620 |
We're populating it with these input examples, which contain 00:07:52.980 |
And then behind the scenes we're using pytorch. So 00:07:58.340 |
We also need to initialize a pytorch data loader with those input examples 00:08:04.260 |
And we're using batch size 16 and also shuffling that data as well 00:08:10.340 |
That's our our data's ready for training the cross encoder and then we move on to 00:08:21.120 |
Validation set as well, which is useful because then we get a performance 00:08:26.660 |
from our source validation set if possible, you really probably want to 00:08:32.260 |
Switch this out for the target data set validation set 00:08:36.660 |
So this would be in our case it would be stsb i'm not going to do that here 00:08:48.180 |
Like an actual project you may not have this validation data 00:08:52.900 |
From your target domain, but I would I would definitely recommend if you do have the time to just manually label something so that you can 00:08:59.540 |
Check out performance you'll want you will want to do that later anyway 00:09:06.660 |
We're using this cross encoder because training the cross encoder and it's coming from this cross encoder 00:09:13.240 |
Section of sentence transformers library and it's a correlation evaluator 00:09:22.740 |
Labels that we have here. So the true labels in our validation set and the predicted labels from the cross encoder 00:09:32.660 |
Sentence vectors like a sentence transformer. It outputs a similarity score 00:09:41.720 |
That correlation value we're using from input examples. So we're using input examples up here. So we use that there 00:09:53.220 |
Training and validation data. I'm ready to go so we can go ahead and start training the cross encoder 00:10:03.380 |
Load this cross encoder object from sentence transformers again just using a vert base. Okay, so that's the 00:10:14.180 |
Initialize our cross encoder from a search coming from the hooking face models 00:10:23.060 |
So this you can change this if you are training a cross encoder that you'd like to predict nli 00:10:29.780 |
Labels, for example, so you'd have zero one two, which would be your contradiction 00:10:38.020 |
You would put something like three because you you have three labels and you want to output three labels 00:10:43.860 |
But for us we want to output a similarity score. So we just have one label 00:10:48.420 |
Then we have here we have number of epochs now anything more than one you're probably going to over fit 00:10:57.140 |
um, unless you really have reason to do otherwise 00:11:02.020 |
And here we have the percentage of warm-up steps that we're going to use so this is the 00:11:08.980 |
Number of steps where we're not going to be training at the full learning rate, but we're going to be gradually increasing up to that 00:11:19.060 |
So it's quite high 35, but it's what tends to work best. I found find with cross encoders 00:11:38.580 |
Training parameter that I modified a little bit. So here we're using slightly higher than the default of two e to the minus five 00:11:47.540 |
Now everything else is as you usually would with the crossing coder. So we're using the fit method 00:11:53.140 |
We will do that with a sentence transformer as well 00:11:56.340 |
the train data loader, this is slightly different to the 00:12:02.580 |
Parameter, but otherwise everything is exactly the same way 00:12:06.260 |
So passing in our source data, which is the training data loader that we created earlier 00:12:11.860 |
We have our evaluator number of epochs warm-up steps 00:12:14.900 |
Learning rate and where we are going to save the model so that a source is going to be mrpc 00:12:25.220 |
That trains our model. It doesn't take long. You can see here. It's it took on this computer 00:12:34.660 |
And then we can go ahead and label the target data 00:12:40.100 |
That's on to the next step. So let me actually go back to here. So we've just we're going to kind of cross these off 00:12:48.180 |
Labeled source data and we train our cross encoder. Okay, so this is our performance 00:12:56.180 |
Cross encoder models. So that is the the source models over here on different target data sets over here 00:13:11.460 |
Stsb target data over here. So this is a the validation set of the stsb target data is what we are aiming for. So this 00:13:18.420 |
0.63 now we should be able to get a better performance than this I think 00:13:22.900 |
Uh, but we'll have a look and we can also see 00:13:33.220 |
And we can see that here because all of these 00:13:35.940 |
uh do achieve better performance, but what we or what I really want to focus on here is the 00:13:42.660 |
The bird based on case performance as a benchmark 00:13:45.800 |
because this is what you would get if you didn't have 00:13:52.980 |
We don't we don't have any target domain training data 00:14:12.500 |
we can see that the from from here, so we have these red circles which represent that the 00:14:18.980 |
The model performance has decreased from that benchmark and then the blue ones over here 00:14:26.740 |
they represent that the model performance has increased from the benchmark and then we also have gray, which is 00:14:32.740 |
To indicate that it was the same value or maybe a couple of points higher, but not that much higher 00:14:44.500 |
straight away that the worse and we can see that it does perform worse so 00:14:52.900 |
Glue data sets don't really correlate so much with the medical question pairs and we can see a lot more red in this column 00:15:02.820 |
To be expected in the other glue data sets. We see a lot more gray and blue so the results either similar or better 00:15:11.540 |
So obviously the next bit is is this target domain part 00:15:16.180 |
We're not going to be doing this optional augment data because it doesn't 00:15:24.100 |
At least for these data sets. I didn't find this to help much 00:15:28.500 |
But if you if you do want to see how to do that 00:15:41.460 |
Augmentation of data using something called random sampling. So we're just taking random samples 00:15:50.260 |
And creating new pairs with that and then labeling them with our cross encoder. That's 00:15:54.900 |
All it all there is to it. So but for this I didn't find any 00:15:59.620 |
Significant performance increase it did increase the performance for a couple of the models a little bit but not not a huge amount 00:16:06.260 |
So I I think for the time it takes it's probably not worth covering in some cases 00:16:12.180 |
But if you think it might help especially if you had a small data set 00:16:15.140 |
I would check out the other video and an article so 00:16:22.820 |
Get out on label target data and then we're going to skip that step. So we'll just just cross it out and 00:16:29.060 |
We are going to label it with our cross encoder. So 00:16:33.620 |
From up here. We're going to take that and label this this target data 00:16:46.180 |
First we need target data. It's same again. We're just going to download that from the plugin face data sets library. So 00:16:53.460 |
We're coming from the glue like larger data sets again 00:16:57.940 |
Um, but this time we're using stsb, which is quite an easy 00:17:02.260 |
Data set to get good results on to be honest, so I wouldn't expect this sort of forms for everything 00:17:08.420 |
Um, but this is just the example we're going to use here 00:17:15.220 |
And so we're obviously using the training split as well and we can see here 00:17:19.620 |
Okay in that data set we have the features sentence one sentence two, and we also have the label 00:17:24.180 |
so obviously we don't actually need to use a cross encoder to create our labels, uh, but we are going to because 00:17:31.220 |
That's the whole point of this training strategy 00:17:35.140 |
And I just couldn't find a reasonable reasonably good data set 00:17:44.340 |
We're just pretending they don't exist. So you didn't this here. It's not there 00:17:49.300 |
Um if you if you're struggling to train across encoder or you can just download it from here, by the way 00:17:57.780 |
Um, or if it's taking a long time, I mean it shouldn't take that long, but I know 00:18:05.300 |
Just if you're following along this in real time, just I'll do that 00:18:16.020 |
Our target data and what we do is we zip those 00:18:25.140 |
I'll delete that. Let's say we have sentence one and we have the first 00:18:30.660 |
I don't know what it could be. It could be I think it's something like a plane 00:18:44.440 |
Where we have the first sentence and we also have the second related sentence, which is something like 00:18:54.200 |
Lands I don't think it's lands. It is something to taking off as well, but I don't remember what exactly it is 00:19:02.520 |
So you would get something like this and then the next one would be something else like um the dog 00:19:17.160 |
The dog warps. Okay, so something like that. So we're getting these like 00:19:21.880 |
tuples of sentence pairs and we're feeding them into 00:19:25.800 |
The predict method of our cross encoder and that will return a set of scores. Obviously these here would be 00:19:32.200 |
Reasonably high similarity because they're talking about similar topic not necessarily the same got warps and runs and lands and taking off 00:19:45.160 |
and then what i'm going to do, so this just makes things easier in terms of 00:19:50.120 |
Seeing what we actually what we're actually working with what we're looking at. So i'm just going to 00:19:58.040 |
Under sentence twos and then there's a new spores into a pattern state frame 00:20:02.600 |
And then we can see everything here. So we can see 00:20:05.800 |
A plane is taking off and an airplane is taking off 00:20:09.400 |
So that's that was the the first two I was talking about 00:20:12.040 |
And then we have all the labels that our cross encoder is predicting now. All these are pretty high 00:20:22.280 |
So that's why that's what we would expect but later on there are other pairs which are less similar 00:20:29.160 |
And then we we move on to training the by encoder 00:20:36.060 |
So let's quickly switch back to our visual again 00:20:39.260 |
okay, so now we have identified or we've downloaded our unlabeled target data cross off and we've also labeled it using our 00:20:48.700 |
Cross encoder model. So the final step as we saw is training our by encoder with that labeled target data 00:21:05.500 |
In this pandas data frame here. So we're just going to iterate through the rows in our data frame 00:21:11.100 |
And i'm going to append all those input examples as we did before 00:21:19.180 |
Okay, and we have our sentence pairs and we have the predicted labels in there 00:21:26.380 |
So once we've created those that list of input examples again, we're just pushing all of that to a pytorch data loader 00:21:33.820 |
and from there we can go on and initialize our 00:21:41.100 |
To do that. We are using models and sentence transformer from the sentence transformers library and we 00:21:48.380 |
So the sentence transformer takes the typical 00:22:02.340 |
What we want is one single sentence vector for from our sentence transformer 00:22:13.640 |
Translating those 512 word vectors into a single vector 00:22:24.760 |
Into one token by taking the average value across each dimension 00:22:29.260 |
And that that's all we that's all we do there. So that's why we have this bert layer 00:22:35.320 |
Followed by a pooling layer. I was just taking the mean pooling 00:22:39.880 |
And then we combine both of those into a single sentence transformer object 00:22:45.160 |
And that's our that's our sentence transformer initialized, but obviously we need to train it. So we come down here 00:22:54.120 |
To train it on the data. We have at the moment we have 00:22:56.760 |
Continuous values in our labels from zero to one 00:23:00.920 |
So what we can use is the cosine similarity loss 00:23:04.440 |
So we initialize that and then again, we don't want to 00:23:13.000 |
And this time we're using a lower number of warmup steps, which is 10 this time 00:23:26.840 |
We are completely ready to train and we just call model fit 00:23:31.080 |
So train objectives is slightly different to what we saw the cross encoder 00:23:35.720 |
Bit method, but everything else is the same. We're just using the default value here 00:23:41.000 |
So two e to minus five, so you don't really need to remove that if you want 00:23:44.520 |
but if you do want to change that I put that in so you can see where you would change that and 00:23:51.640 |
And yeah, we're ready to train we go ahead and train and see how that performs 00:23:57.560 |
so we evaluate it this time we're not using cross encoder, so the 00:24:02.920 |
The evaluator is slightly different this time using a embedding similarity evaluator, which is going to 00:24:09.480 |
Take two sentence vectors and it's going to calculate the similarity between them 00:24:19.160 |
That predicted similarity to the true similarity as per our validation set here 00:24:27.480 |
In that validation set we we do have these labels 00:24:31.400 |
But in stsb those labels are from the or in the range zero to five 00:24:39.640 |
Divide everything by five which brings us into a range of zero to one 00:24:44.040 |
Which is what we need for this embedding similarity evaluator 00:24:48.840 |
And then as we did before we're creating that list of input examples 00:24:57.880 |
So we're using the embedding similarity evaluator from input examples again because we have input examples 00:25:03.660 |
and we're passing in our input examples data and 00:25:08.040 |
Right csv is false. That just means I will print the 00:25:17.960 |
and then to actually evaluate all we do is pass the model to our evaluator and it will do everything of course and we get a 00:25:31.720 |
0.76 which is is pretty good. Um, you would you would think of you can think of something like 00:25:42.600 |
0.5 is kind of like your moderate correlation 00:25:58.440 |
And if we for the final quick part, I just want to have a look at the other performances that I found 00:26:05.000 |
Because not everything is going to be as good as that what we just got there 00:26:12.760 |
Over here. We got 76 just now. So we got slightly higher than what I got before 00:26:31.480 |
And the others are more difficult. So the rest of them are you're more in this moderate 00:26:41.780 |
The medical question pair states actually perform better with the bi encoder training than the cross encoder training 00:26:51.220 |
Nonetheless, they're still within the same sort of range 00:26:53.620 |
That I would expect it's not it's not a massive improvement. It's this is a benchmark down here and there are a few 00:27:04.500 |
Is is probably reasonable particularly because they were not 00:27:08.340 |
These blue data sets were not that similar to our medical question pairs data set 00:27:14.660 |
So that that makes sense and then the other ones. Yeah, they're all sort of within that moderate range 00:27:25.540 |
Although it does make sense. I mean the question pairs 00:27:33.380 |
The the transfer from from medical question pairs to quora question pairs was pretty good 00:27:42.340 |
If you have a look at here we from core question pairs and medical question pairs. It's not as good and 00:27:47.380 |
Maybe because the language in core question pairs is simpler than the medical question pairs. I'm not sure 00:27:53.700 |
But from medical question pairs to core question pairs. It worked quite well 00:27:59.300 |
So that I suppose points out where that n-gram similarity doesn't always 00:28:07.860 |
But anyway, I think the results from this are probably pretty typical from what you can expect 00:28:19.780 |
If you really don't have any any label data within your target domain to at least 00:28:28.180 |
Squeeze out a few percentage points of performance more than you would be able to without 00:28:36.500 |
For that reason, I think this can be quite useful. Um 00:28:40.420 |
Whether or not it is the best approach to take will depend on your on your data. Um, 00:28:52.500 |
But definitely something useful to know about and be able to apply if you need it 00:28:58.420 |
So yeah, that's it. Um, thank you very much for watching. I hope it's been useful