Back to Index

AugSBERT: Domain Transfer for Sentence Transformers


Chapters

0:0 Why Use Domain Transfer
4:8 Strategy Outline
6:5 Train Source Cross-Encoder
12:44 Cross-Encoder Outcome
15:12 Labeling Target Data
20:31 Training Bi-encoder
23:58 Evaluator Bi-encoder Performance
28:8 Final Points

Transcript

Today we're going to talk about how we can train language models across different domains Now to do this we're going to be using what is called the augmented expert training strategy And we're going to be using the domain transfer flavor of that so the training strategy is used where You have some data But not enough or you in our case with domain transfer.

We have some data in One domain so say maybe we have quora question pairs But we need sac overflow question pairs and all we have on the sac overflow side are unlabeled question pairs so What we're essentially doing is using that source data set the for example quora question pairs to train a model that can then Transfer its knowledge and label the sac overflow question pairs Now we don't need to stick with question pairs and question pairs We can have and go from question pairs to semantic similarity pairs or anything else but We do have to make sure that our source domain and target domain are at least similar and we can think of this required similarity as a bridge across a gap between the source domain and target main the Great that gap.

So the the less overlap those two domains have the greater this gap here is and If this gap is bigger, we obviously need a bigger stronger bridge And we can think of that bridge as our model. We need a bigger stronger model to bridge these longer or greater gaps between the two domains So in this video, we're going to be going through one example, but I have tested these five different data sets so What we can see is the n-gram similarity So this is how many n-grams overlap between the two data sets as a as a whole So this is not a perfect measure of similarity, of course because it's only capturing the Syntactical overlap and it's not really capturing whether the two data sets Have the same language structure or talking about the same topics that Much although it does help and this is just a very simple measure to help us figure out whether The two domains that we're trying to transfer across will be easy to bridge or not so we would expect from this that the Medical question pairs down here Because of this these lower similarity scores it would probably Not be the best domain transfer to these different domains Whereas these different domains stsb rte Mrpc and quora question pairs are all reasonably similar So we'd expect the performance or the transfer or ease of transfer between those domains to be much smoother But like I said, this isn't perfect and in fact one of the best domain transfer performances that I saw was from the medical question pairs to the quora question pairs, which Obviously makes sense Semantically because question pairs will probably have more of a similar structure than entailment pairs Which is what we have up here with rt or just semantic textual similarity pairs as we have up here so That does make sense logically But it's not really reflected in this chart here so this is the overall Strategy that we're going to follow with domain transfer augmented experts so the first step is to Get our data.

So we're going to have a labeled source data set which The earlier example would have been the quora question pairs and we're also going to have the unlabeled Target data set which would be our stack overflow question pairs. We're not actually going to use those two examples. We're going to use from here we're going to go with mrpc As our source data set and stsb as our target data set.

So we'll go ahead and we'll get our labeled source data So that's going to be mrpc And we'll train a cross encoder model with it. Now. The reason i'll just Quickly mention is the reason we're training a cross encoder model rather than a by encoder model Which is a sentence transformer is because we need less data to train a good cross encoder and the cross encoder here Is only used to label the target data So we need less data to train a good cross encoder, but cross encoders are also slow Which is why we don't use a cross encoder over here as our final model so we're Basically using that greater performance with less data to Build a data set over here That can train a fast sentence transformer Over here to a similar level or similar performance to that cross encoder That's basically what we're doing.

It's knowledge distillation from our cross encoder All the way over to our sentence transformer. That's all we're doing So let's jump over to the code and we'll have a look at how we do those or form those first two steps Okay, so we're starting with the source domain mrpc as I said, and the target domain is scsb so First thing we do is get that source data So we use hugging face data sets for this so you may need to pip install that so just pip install data sets And from data sets we import load data sets So this is just going to load the mrpc data set from hugging face which is stored in the Larger glue data set and mrpc is like a subset of that And then we're taking the train split from there and we can see the data set features here So we have sentence one and two which are sentence pairs that we're going to be learning from to in order to label the other Scsb sentence pairs And we also have the label as well And we only have three point six thousand Samples here, which is quite a small number Particularly for a sentence transformer.

So That's why we're going with the cross encoder here so that can build something that has good performance Whereas a sentence transformer might need a little more data Although not necessarily Um, so we need to first format that data for training and to do that we we just go through each row in the data set that we just downloaded and Sentence transformers always uses this input example object So all we're doing here is creating a an empty list and then within that empty list We're populating it with these input examples, which contain sentence pairs here and also a label And then behind the scenes we're using pytorch.

So We also need to initialize a pytorch data loader with those input examples And we're using batch size 16 and also shuffling that data as well so That's our our data's ready for training the cross encoder and then we move on to if we'd like to we can Move on to adding this Validation set as well, which is useful because then we get a performance from our source validation set if possible, you really probably want to Switch this out for the target data set validation set So this would be in our case it would be stsb i'm not going to do that here just to emulate in in Like an actual project you may not have this validation data From your target domain, but I would I would definitely recommend if you do have the time to just manually label something so that you can Check out performance you'll want you will want to do that later anyway But for now, we'll just use source data set We're using this cross encoder because training the cross encoder and it's coming from this cross encoder Section of sentence transformers library and it's a correlation evaluator okay, so that's going to check the correlation between the Labels that we have here.

So the true labels in our validation set and the predicted labels from the cross encoder So cross encoder does not output Sentence vectors like a sentence transformer. It outputs a similarity score And then we just initialize That correlation value we're using from input examples. So we're using input examples up here.

So we use that there and yeah, that's our Training and validation data. I'm ready to go so we can go ahead and start training the cross encoder so for training the cross encoder we Load this cross encoder object from sentence transformers again just using a vert base. Okay, so that's the initial model that we're using to Initialize our cross encoder from a search coming from the hooking face models hope um and the number of labels So this you can change this if you are training a cross encoder that you'd like to predict nli Labels, for example, so you'd have zero one two, which would be your contradiction neutral and entailment classes You would put something like three because you you have three labels and you want to output three labels But for us we want to output a similarity score.

So we just have one label Then we have here we have number of epochs now anything more than one you're probably going to over fit so I would I think always stick with one um, unless you really have reason to do otherwise And here we have the percentage of warm-up steps that we're going to use so this is the Number of steps where we're not going to be training at the full learning rate, but we're going to be gradually increasing up to that full learning rate So it's quite high 35, but it's what tends to work best.

I found find with cross encoders a higher Percentage of warm-up steps tends to work reasonably well And then over here we have our optimizer parameters which is another Training parameter that I modified a little bit. So here we're using slightly higher than the default of two e to the minus five We're using five e to the minus five Now everything else is as you usually would with the crossing coder.

So we're using the fit method We will do that with a sentence transformer as well the train data loader, this is slightly different to the sentence transformer Parameter, but otherwise everything is exactly the same way So passing in our source data, which is the training data loader that we created earlier We have our evaluator number of epochs warm-up steps Learning rate and where we are going to save the model so that a source is going to be mrpc cross encoder um, so That trains our model.

It doesn't take long. You can see here. It's it took on this computer 14 seconds very quick And then we can go ahead and label the target data so That's on to the next step. So let me actually go back to here. So we've just we're going to kind of cross these off so we've Labeled source data and we train our cross encoder.

Okay, so this is our performance of the Cross encoder models. So that is the the source models over here on different target data sets over here so the Source model for mrpc on the Stsb target data over here. So this is a the validation set of the stsb target data is what we are aiming for.

So this 0.63 now we should be able to get a better performance than this I think Uh, but we'll have a look and we can also see I I said the stsb is quite easy As a data set to get good results on And we can see that here because all of these uh do achieve better performance, but what we or what I really want to focus on here is the The bird based on case performance as a benchmark because this is what you would get if you didn't have your any any training data or which We don't we don't have any target domain training data Instead where we're kind of using this Or expert training strategy to train from the source data set So these are our benchmarks and we can see that the from from here, so we have these red circles which represent that the The model performance has decreased from that benchmark and then the blue ones over here they represent that the model performance has increased from the benchmark and then we also have gray, which is To indicate that it was the same value or maybe a couple of points higher, but not that much higher So we can see straight away that the worse and we can see that it does perform worse so all of these Glue data sets don't really correlate so much with the medical question pairs and we can see a lot more red in this column so that is To be expected in the other glue data sets.

We see a lot more gray and blue so the results either similar or better So obviously the next bit is is this target domain part We're not going to be doing this optional augment data because it doesn't tend to At least for these data sets. I didn't find this to help much But if you if you do want to see how to do that the last video that we did on augmented expert last article Does cover that?

Augmentation of data using something called random sampling. So we're just taking random samples from our data set And creating new pairs with that and then labeling them with our cross encoder. That's All it all there is to it. So but for this I didn't find any Significant performance increase it did increase the performance for a couple of the models a little bit but not not a huge amount So I I think for the time it takes it's probably not worth covering in some cases But if you think it might help especially if you had a small data set I would check out the other video and an article so We need to Get out on label target data and then we're going to skip that step.

So we'll just just cross it out and We are going to label it with our cross encoder. So From up here. We're going to take that and label this this target data so let's Have a look at how we do that. So Yeah the labeling that target data First we need target data.

It's same again. We're just going to download that from the plugin face data sets library. So We're coming from the glue like larger data sets again Um, but this time we're using stsb, which is quite an easy Data set to get good results on to be honest, so I wouldn't expect this sort of forms for everything Um, but this is just the example we're going to use here And so we're obviously using the training split as well and we can see here Okay in that data set we have the features sentence one sentence two, and we also have the label so obviously we don't actually need to use a cross encoder to create our labels, uh, but we are going to because That's the whole point of this training strategy And I just couldn't find a reasonable reasonably good data set that didn't have the labels already, so We're just pretending they don't exist.

So you didn't this here. It's not there Um if you if you're struggling to train across encoder or you can just download it from here, by the way Um, or if it's taking a long time, I mean it shouldn't take that long, but I know on on some machines it can take a while so Just if you're following along this in real time, just I'll do that having any issues So we have our Our target data and what we do is we zip those sentence one sentence two together so I'll delete that.

Let's say we have sentence one and we have the first I don't know what it could be. It could be I think it's something like a plane Our plane is taking off And we would have basically a tuple Where we have the first sentence and we also have the second related sentence, which is something like An airplane Lands I don't think it's lands.

It is something to taking off as well, but I don't remember what exactly it is So you would get something like this and then the next one would be something else like um the dog runs and I don't know the The dog warps. Okay, so something like that. So we're getting these like tuples of sentence pairs and we're feeding them into The predict method of our cross encoder and that will return a set of scores.

Obviously these here would be Reasonably high similarity because they're talking about similar topic not necessarily the same got warps and runs and lands and taking off But they they would be so reasonably high and then what i'm going to do, so this just makes things easier in terms of Seeing what we actually what we're actually working with what we're looking at.

So i'm just going to Pass those sentence ones Under sentence twos and then there's a new spores into a pattern state frame And then we can see everything here. So we can see A plane is taking off and an airplane is taking off So that's that was the the first two I was talking about And then we have all the labels that our cross encoder is predicting now.

All these are pretty high because these are all very similar semantically similar pairs So that's why that's what we would expect but later on there are other pairs which are less similar And then we we move on to training the by encoder So let's quickly switch back to our visual again okay, so now we have identified or we've downloaded our unlabeled target data cross off and we've also labeled it using our Cross encoder model.

So the final step as we saw is training our by encoder with that labeled target data So let's have a look at How we would do that. So we have our labeled target data In this pandas data frame here. So we're just going to iterate through the rows in our data frame And i'm going to append all those input examples as we did before to This data list here Okay, and we have our sentence pairs and we have the predicted labels in there So once we've created those that list of input examples again, we're just pushing all of that to a pytorch data loader and from there we can go on and initialize our our sentence transformer so To do that.

We are using models and sentence transformer from the sentence transformers library and we So the sentence transformer takes the typical transformer like bert and What Bert outputs is 512 word vectors and What we want is one single sentence vector for from our sentence transformer so to do that we need a way of Translating those 512 word vectors into a single vector And we do that by pooling all of those 512 tokens Into one token by taking the average value across each dimension And that that's all we that's all we do there.

So that's why we have this bert layer Followed by a pooling layer. I was just taking the mean pooling And then we combine both of those into a single sentence transformer object And that's our that's our sentence transformer initialized, but obviously we need to train it. So we come down here To train it on the data.

We have at the moment we have Continuous values in our labels from zero to one So what we can use is the cosine similarity loss So we initialize that and then again, we don't want to Overfit so epochs is set to one And this time we're using a lower number of warmup steps, which is 10 this time and We train with that So Once that is done We are completely ready to train and we just call model fit So train objectives is slightly different to what we saw the cross encoder Bit method, but everything else is the same.

We're just using the default value here So two e to minus five, so you don't really need to remove that if you want but if you do want to change that I put that in so you can see where you would change that and And yeah, we're ready to train we go ahead and train and see how that performs so we evaluate it this time we're not using cross encoder, so the The evaluator is slightly different this time using a embedding similarity evaluator, which is going to Take two sentence vectors and it's going to calculate the similarity between them and then it's going to compare that to the That predicted similarity to the true similarity as per our validation set here so In that validation set we we do have these labels But in stsb those labels are from the or in the range zero to five so we use this lambda function to Divide everything by five which brings us into a range of zero to one Which is what we need for this embedding similarity evaluator And then as we did before we're creating that list of input examples and Then we can initialize the evaluator So we're using the embedding similarity evaluator from input examples again because we have input examples and we're passing in our input examples data and Right csv is false.

That just means I will print the score to the To our notebook and then to actually evaluate all we do is pass the model to our evaluator and it will do everything of course and we get a similarity here or a correlation score of 0.76 which is is pretty good.

Um, you would you would think of you can think of something like 0.5 is kind of like your moderate correlation And 0.8 is like high correlation so that's Pretty good Okay And if we for the final quick part, I just want to have a look at the other performances that I found Because not everything is going to be as good as that what we just got there So we can see so this is what we just did Over here.

We got 76 just now. So we got slightly higher than what I got before So we got 76 here um, but obviously like I said stsb is an easy Data set to get good results on And the others are more difficult. So the rest of them are you're more in this moderate um similarity range or correlation range, so The medical question pair states actually perform better with the bi encoder training than the cross encoder training but Nonetheless, they're still within the same sort of range That I would expect it's not it's not a massive improvement.

It's this is a benchmark down here and there are a few percentage points better which Is is probably reasonable particularly because they were not These blue data sets were not that similar to our medical question pairs data set So that that makes sense and then the other ones. Yeah, they're all sort of within that moderate range Um the one that that's which did surprise me Although it does make sense.

I mean the question pairs um is this one here, so The the transfer from from medical question pairs to quora question pairs was pretty good and I suppose If you have a look at here we from core question pairs and medical question pairs. It's not as good and Maybe because the language in core question pairs is simpler than the medical question pairs.

I'm not sure But from medical question pairs to core question pairs. It worked quite well So that I suppose points out where that n-gram similarity doesn't always Correlate exactly to what you would expect But anyway, I think the results from this are probably pretty typical from what you can expect now I think this can be really useful if you If you really don't have any any label data within your target domain to at least Squeeze out a few percentage points of performance more than you would be able to without this training strategy, so For that reason, I think this can be quite useful.

Um Whether or not it is the best approach to take will depend on your on your data. Um, So I think it's useful Not always the best option But definitely something useful to know about and be able to apply if you need it So yeah, that's it. Um, thank you very much for watching.

I hope it's been useful And I will see you again in the next one