back to index

AugSBERT: Domain Transfer for Sentence Transformers


Chapters

0:0 Why Use Domain Transfer
4:8 Strategy Outline
6:5 Train Source Cross-Encoder
12:44 Cross-Encoder Outcome
15:12 Labeling Target Data
20:31 Training Bi-encoder
23:58 Evaluator Bi-encoder Performance
28:8 Final Points

Whisper Transcript | Transcript Only Page

00:00:00.000 | Today we're going to talk about how we can train language models across different domains
00:00:06.880 | Now to do this we're going to be using what is called the augmented expert training strategy
00:00:13.200 | And we're going to be using the domain transfer flavor of that
00:00:19.440 | the training strategy is used where
00:00:22.160 | You have some data
00:00:25.280 | But not enough or you in our case with domain transfer. We have some data in
00:00:31.360 | One domain so say maybe we have quora question pairs
00:00:35.600 | But we need sac overflow question pairs and all we have on the sac overflow side
00:00:42.080 | are unlabeled
00:00:44.720 | question pairs
00:00:47.360 | What we're essentially doing is using that source data set the for example quora question pairs
00:00:54.000 | to train a model that can then
00:00:56.080 | Transfer its knowledge and label the sac overflow question pairs
00:01:02.640 | Now we don't need to stick with question pairs and question pairs
00:01:05.600 | We can have and go from question pairs to semantic similarity pairs or anything else
00:01:12.400 | We do have to make sure that our source domain and target domain are at least similar
00:01:17.680 | and we can think of this
00:01:21.440 | required similarity as a bridge
00:01:24.160 | across a gap between the source domain and target main the
00:01:29.440 | Great that gap. So the the less overlap those two domains have
00:01:34.880 | the greater this gap here is and
00:01:38.080 | If this gap is bigger, we obviously need a bigger stronger bridge
00:01:43.200 | And we can think of that bridge as our model. We need a bigger stronger model to bridge these
00:01:50.320 | longer or greater gaps between the two domains
00:01:54.640 | So in this video, we're going to be going through one example, but I have tested these five different data sets
00:02:04.400 | What we can see is the n-gram similarity
00:02:08.100 | So this is how many n-grams overlap between the two data sets as a as a whole
00:02:15.840 | So this is not a perfect measure of similarity, of course
00:02:20.960 | because it's only capturing the
00:02:22.960 | Syntactical overlap and it's not really capturing whether
00:02:29.360 | two data sets
00:02:31.280 | Have the same language structure or talking about the same topics
00:02:37.520 | Much although it does help and this is just a very simple measure to help us
00:02:42.640 | figure out whether
00:02:45.280 | The two domains that we're trying to transfer across
00:02:50.000 | will be easy to
00:02:52.000 | bridge or not
00:02:54.080 | so we would expect from this that the
00:02:56.560 | Medical question pairs down here
00:02:59.600 | Because of this these lower similarity scores it would probably
00:03:03.860 | Not be the best domain transfer to these different domains
00:03:10.080 | Whereas these different domains stsb rte
00:03:15.100 | Mrpc and quora question pairs are all reasonably similar
00:03:19.500 | So we'd expect the performance or the transfer or ease of transfer between those domains
00:03:25.500 | to be much smoother
00:03:28.300 | But like I said, this isn't perfect and in fact one of the
00:03:33.820 | best domain transfer performances that I saw was from the medical question pairs
00:03:41.580 | to the quora question pairs, which
00:03:44.780 | Obviously makes sense
00:03:46.680 | Semantically because question pairs will probably have more of a similar structure
00:03:51.340 | than entailment pairs
00:03:53.900 | Which is what we have up here with rt or just semantic textual similarity pairs as we have up here
00:04:02.380 | That does make sense logically
00:04:04.380 | But it's not really reflected in this chart here
00:04:07.820 | so this is the
00:04:10.460 | overall
00:04:12.120 | Strategy that we're going to follow with domain transfer augmented experts
00:04:16.700 | so the first step
00:04:21.960 | Get our data. So we're going to have a labeled source data set which
00:04:25.640 | The earlier example would have been the quora question pairs and we're also going to have the unlabeled
00:04:31.020 | Target data set which would be our stack overflow question pairs. We're not actually going to use those two examples. We're going to use
00:04:37.800 | from here we're going to go with
00:04:42.740 | As our source data set and stsb as our target data set. So we'll go ahead and we'll get our
00:04:49.780 | labeled source data
00:04:52.420 | So that's going to be mrpc
00:04:54.420 | And we'll train a cross encoder model with it. Now. The reason i'll just
00:04:59.780 | Quickly mention is the reason we're training a cross encoder model rather than a by encoder model
00:05:05.620 | Which is a sentence transformer is because we need less data
00:05:09.300 | to train a good cross encoder
00:05:11.860 | and the cross encoder here
00:05:14.340 | Is only used to label the target data
00:05:17.940 | So we need less data to train a good cross encoder, but cross encoders are also slow
00:05:24.020 | Which is why we don't use a cross encoder over here as our final model
00:05:27.860 | so we're
00:05:29.780 | Basically using that greater performance with less data
00:05:35.140 | Build a data set
00:05:36.820 | over here
00:05:38.100 | That can train a fast sentence transformer
00:05:41.080 | Over here to a similar level or similar performance to that cross encoder
00:05:46.500 | That's basically what we're doing. It's knowledge distillation from our cross encoder
00:05:52.660 | All the way over to our sentence transformer. That's all we're doing
00:05:57.460 | So let's jump over to the code and we'll have a look at how we do those or form those first two steps
00:06:04.900 | Okay, so we're starting with the source domain mrpc as I said, and the target domain is scsb
00:06:12.820 | First thing we do is get that source data
00:06:15.060 | So we use hugging face data sets for this so you may need to pip install that so just pip install
00:06:23.060 | data sets
00:06:25.860 | And from data sets we import load data sets
00:06:28.580 | So this is just going to load the mrpc data set from hugging face
00:06:33.940 | which is stored in the
00:06:35.780 | Larger glue data set and mrpc is like a subset of that
00:06:40.740 | And then we're taking the train split from there and we can see the data set features here
00:06:46.500 | So we have sentence one and two which are sentence pairs that we're going to be
00:06:50.180 | learning from to in order to label the other
00:06:54.660 | Scsb sentence pairs
00:06:57.140 | And we also have the label as well
00:06:59.140 | And we only have three point six thousand
00:07:03.300 | Samples here, which is quite a small number
00:07:06.340 | Particularly for a sentence transformer. So
00:07:09.220 | That's why we're going with the cross encoder here so that can build something that has good performance
00:07:14.980 | Whereas a sentence transformer might need a little more data
00:07:19.860 | Although not necessarily
00:07:22.980 | Um, so we need to first format that data for training and
00:07:29.380 | to do that we we just go through each row in the data set that we just downloaded and
00:07:35.300 | Sentence transformers always uses this input example object
00:07:40.500 | So all we're doing here is creating a an empty list and then within that empty list
00:07:45.620 | We're populating it with these input examples, which contain
00:07:49.700 | sentence pairs here and also a label
00:07:52.980 | And then behind the scenes we're using pytorch. So
00:07:58.340 | We also need to initialize a pytorch data loader with those input examples
00:08:04.260 | And we're using batch size 16 and also shuffling that data as well
00:08:10.340 | That's our our data's ready for training the cross encoder and then we move on to
00:08:16.660 | if we'd like to we can
00:08:19.460 | Move on to adding this
00:08:21.120 | Validation set as well, which is useful because then we get a performance
00:08:26.660 | from our source validation set if possible, you really probably want to
00:08:32.260 | Switch this out for the target data set validation set
00:08:36.660 | So this would be in our case it would be stsb i'm not going to do that here
00:08:46.180 | to emulate in in
00:08:48.180 | Like an actual project you may not have this validation data
00:08:52.900 | From your target domain, but I would I would definitely recommend if you do have the time to just manually label something so that you can
00:08:59.540 | Check out performance you'll want you will want to do that later anyway
00:09:03.540 | But for now, we'll just use source data set
00:09:06.660 | We're using this cross encoder because training the cross encoder and it's coming from this cross encoder
00:09:13.240 | Section of sentence transformers library and it's a correlation evaluator
00:09:17.640 | okay, so that's going to
00:09:20.740 | check the correlation between the
00:09:22.740 | Labels that we have here. So the true labels in our validation set and the predicted labels from the cross encoder
00:09:30.580 | So cross encoder does not output
00:09:32.660 | Sentence vectors like a sentence transformer. It outputs a similarity score
00:09:37.700 | And then we just initialize
00:09:41.720 | That correlation value we're using from input examples. So we're using input examples up here. So we use that there
00:09:51.380 | yeah, that's our
00:09:53.220 | Training and validation data. I'm ready to go so we can go ahead and start training the cross encoder
00:09:59.620 | so for training the cross encoder
00:10:03.380 | Load this cross encoder object from sentence transformers again just using a vert base. Okay, so that's the
00:10:09.620 | initial model that we're using
00:10:14.180 | Initialize our cross encoder from a search coming from the hooking face models
00:10:20.740 | um and the number of labels
00:10:23.060 | So this you can change this if you are training a cross encoder that you'd like to predict nli
00:10:29.780 | Labels, for example, so you'd have zero one two, which would be your contradiction
00:10:34.840 | neutral and entailment classes
00:10:38.020 | You would put something like three because you you have three labels and you want to output three labels
00:10:43.860 | But for us we want to output a similarity score. So we just have one label
00:10:48.420 | Then we have here we have number of epochs now anything more than one you're probably going to over fit
00:10:53.460 | so I would
00:10:55.140 | I think always stick with one
00:10:57.140 | um, unless you really have reason to do otherwise
00:11:02.020 | And here we have the percentage of warm-up steps that we're going to use so this is the
00:11:08.980 | Number of steps where we're not going to be training at the full learning rate, but we're going to be gradually increasing up to that
00:11:17.140 | full learning rate
00:11:19.060 | So it's quite high 35, but it's what tends to work best. I found find with cross encoders
00:11:25.380 | a higher
00:11:27.380 | Percentage of warm-up steps tends to work
00:11:29.620 | reasonably well
00:11:32.420 | And then over here we have our
00:11:35.620 | optimizer parameters which is another
00:11:38.580 | Training parameter that I modified a little bit. So here we're using slightly higher than the default of two e to the minus five
00:11:45.540 | We're using five e to the minus five
00:11:47.540 | Now everything else is as you usually would with the crossing coder. So we're using the fit method
00:11:53.140 | We will do that with a sentence transformer as well
00:11:56.340 | the train data loader, this is slightly different to the
00:12:00.020 | sentence transformer
00:12:02.580 | Parameter, but otherwise everything is exactly the same way
00:12:06.260 | So passing in our source data, which is the training data loader that we created earlier
00:12:11.860 | We have our evaluator number of epochs warm-up steps
00:12:14.900 | Learning rate and where we are going to save the model so that a source is going to be mrpc
00:12:21.160 | cross encoder
00:12:23.780 | um, so
00:12:25.220 | That trains our model. It doesn't take long. You can see here. It's it took on this computer
00:12:30.420 | 14 seconds
00:12:33.140 | very quick
00:12:34.660 | And then we can go ahead and label the target data
00:12:40.100 | That's on to the next step. So let me actually go back to here. So we've just we're going to kind of cross these off
00:12:47.060 | so we've
00:12:48.180 | Labeled source data and we train our cross encoder. Okay, so this is our performance
00:12:54.200 | of the
00:12:56.180 | Cross encoder models. So that is the the source models over here on different target data sets over here
00:13:06.180 | Source model for mrpc
00:13:08.680 | on the
00:13:11.460 | Stsb target data over here. So this is a the validation set of the stsb target data is what we are aiming for. So this
00:13:18.420 | 0.63 now we should be able to get a better performance than this I think
00:13:22.900 | Uh, but we'll have a look and we can also see
00:13:25.940 | I I said the stsb is quite easy
00:13:28.820 | As a data set to get good results on
00:13:33.220 | And we can see that here because all of these
00:13:35.940 | uh do achieve better performance, but what we or what I really want to focus on here is the
00:13:42.660 | The bird based on case performance as a benchmark
00:13:45.800 | because this is what you would get if you didn't have
00:13:49.380 | your any any training data or which
00:13:52.980 | We don't we don't have any target domain training data
00:13:57.140 | Instead where we're kind of using this
00:14:00.660 | Or expert training strategy
00:14:02.820 | to train from the source
00:14:05.540 | data set
00:14:07.540 | So these are our
00:14:10.020 | benchmarks
00:14:12.500 | we can see that the from from here, so we have these red circles which represent that the
00:14:18.980 | The model performance has decreased from that benchmark and then the blue ones over here
00:14:26.740 | they represent that the model performance has increased from the benchmark and then we also have gray, which is
00:14:32.740 | To indicate that it was the same value or maybe a couple of points higher, but not that much higher
00:14:40.660 | So we can see
00:14:44.500 | straight away that the worse and we can see that it does perform worse so
00:14:50.660 | all of these
00:14:52.900 | Glue data sets don't really correlate so much with the medical question pairs and we can see a lot more red in this column
00:15:01.700 | that is
00:15:02.820 | To be expected in the other glue data sets. We see a lot more gray and blue so the results either similar or better
00:15:11.540 | So obviously the next bit is is this target domain part
00:15:16.180 | We're not going to be doing this optional augment data because it doesn't
00:15:22.660 | tend to
00:15:24.100 | At least for these data sets. I didn't find this to help much
00:15:28.500 | But if you if you do want to see how to do that
00:15:34.340 | last video that we did on
00:15:36.340 | augmented expert last article
00:15:38.880 | Does cover that?
00:15:41.460 | Augmentation of data using something called random sampling. So we're just taking random samples
00:15:47.380 | from our data set
00:15:50.260 | And creating new pairs with that and then labeling them with our cross encoder. That's
00:15:54.900 | All it all there is to it. So but for this I didn't find any
00:15:59.620 | Significant performance increase it did increase the performance for a couple of the models a little bit but not not a huge amount
00:16:06.260 | So I I think for the time it takes it's probably not worth covering in some cases
00:16:12.180 | But if you think it might help especially if you had a small data set
00:16:15.140 | I would check out the other video and an article so
00:16:20.900 | We need to
00:16:22.820 | Get out on label target data and then we're going to skip that step. So we'll just just cross it out and
00:16:29.060 | We are going to label it with our cross encoder. So
00:16:33.620 | From up here. We're going to take that and label this this target data
00:16:38.420 | so let's
00:16:40.260 | Have a look at how we do that. So
00:16:42.260 | Yeah the labeling that target data
00:16:46.180 | First we need target data. It's same again. We're just going to download that from the plugin face data sets library. So
00:16:53.460 | We're coming from the glue like larger data sets again
00:16:57.940 | Um, but this time we're using stsb, which is quite an easy
00:17:02.260 | Data set to get good results on to be honest, so I wouldn't expect this sort of forms for everything
00:17:08.420 | Um, but this is just the example we're going to use here
00:17:15.220 | And so we're obviously using the training split as well and we can see here
00:17:19.620 | Okay in that data set we have the features sentence one sentence two, and we also have the label
00:17:24.180 | so obviously we don't actually need to use a cross encoder to create our labels, uh, but we are going to because
00:17:31.220 | That's the whole point of this training strategy
00:17:35.140 | And I just couldn't find a reasonable reasonably good data set
00:17:40.580 | that didn't have the labels already, so
00:17:44.340 | We're just pretending they don't exist. So you didn't this here. It's not there
00:17:49.300 | Um if you if you're struggling to train across encoder or you can just download it from here, by the way
00:17:57.780 | Um, or if it's taking a long time, I mean it shouldn't take that long, but I know
00:18:01.700 | on on some machines it can take a while so
00:18:05.300 | Just if you're following along this in real time, just I'll do that
00:18:11.300 | having any issues
00:18:14.340 | So we have our
00:18:16.020 | Our target data and what we do is we zip those
00:18:19.940 | sentence one sentence two together
00:18:25.140 | I'll delete that. Let's say we have sentence one and we have the first
00:18:30.660 | I don't know what it could be. It could be I think it's something like a plane
00:18:35.220 | Our plane is taking off
00:18:38.020 | And we would have basically a tuple
00:18:44.440 | Where we have the first sentence and we also have the second related sentence, which is something like
00:18:50.600 | An airplane
00:18:54.200 | Lands I don't think it's lands. It is something to taking off as well, but I don't remember what exactly it is
00:19:02.520 | So you would get something like this and then the next one would be something else like um the dog
00:19:12.200 | runs and
00:19:14.200 | I don't know
00:19:17.160 | The dog warps. Okay, so something like that. So we're getting these like
00:19:21.880 | tuples of sentence pairs and we're feeding them into
00:19:25.800 | The predict method of our cross encoder and that will return a set of scores. Obviously these here would be
00:19:32.200 | Reasonably high similarity because they're talking about similar topic not necessarily the same got warps and runs and lands and taking off
00:19:41.160 | But they they would be
00:19:43.160 | so reasonably high
00:19:45.160 | and then what i'm going to do, so this just makes things easier in terms of
00:19:50.120 | Seeing what we actually what we're actually working with what we're looking at. So i'm just going to
00:19:55.480 | Pass those sentence ones
00:19:58.040 | Under sentence twos and then there's a new spores into a pattern state frame
00:20:02.600 | And then we can see everything here. So we can see
00:20:05.800 | A plane is taking off and an airplane is taking off
00:20:09.400 | So that's that was the the first two I was talking about
00:20:12.040 | And then we have all the labels that our cross encoder is predicting now. All these are pretty high
00:20:16.920 | because these are all
00:20:20.120 | similar semantically similar pairs
00:20:22.280 | So that's why that's what we would expect but later on there are other pairs which are less similar
00:20:29.160 | And then we we move on to training the by encoder
00:20:36.060 | So let's quickly switch back to our visual again
00:20:39.260 | okay, so now we have identified or we've downloaded our unlabeled target data cross off and we've also labeled it using our
00:20:48.700 | Cross encoder model. So the final step as we saw is training our by encoder with that labeled target data
00:20:56.780 | So let's have a look at
00:20:59.100 | How we would do that. So we have our
00:21:03.500 | labeled target data
00:21:05.500 | In this pandas data frame here. So we're just going to iterate through the rows in our data frame
00:21:11.100 | And i'm going to append all those input examples as we did before
00:21:17.180 | This data list here
00:21:19.180 | Okay, and we have our sentence pairs and we have the predicted labels in there
00:21:26.380 | So once we've created those that list of input examples again, we're just pushing all of that to a pytorch data loader
00:21:33.820 | and from there we can go on and initialize our
00:21:38.380 | our sentence transformer so
00:21:41.100 | To do that. We are using models and sentence transformer from the sentence transformers library and we
00:21:48.380 | So the sentence transformer takes the typical
00:21:53.020 | transformer like bert and
00:21:57.640 | Bert outputs is
00:21:59.640 | 512 word vectors and
00:22:02.340 | What we want is one single sentence vector for from our sentence transformer
00:22:09.020 | so to do that
00:22:11.480 | we need a way of
00:22:13.640 | Translating those 512 word vectors into a single vector
00:22:17.720 | And we do that by pooling
00:22:21.000 | all of those 512 tokens
00:22:24.760 | Into one token by taking the average value across each dimension
00:22:29.260 | And that that's all we that's all we do there. So that's why we have this bert layer
00:22:35.320 | Followed by a pooling layer. I was just taking the mean pooling
00:22:39.880 | And then we combine both of those into a single sentence transformer object
00:22:45.160 | And that's our that's our sentence transformer initialized, but obviously we need to train it. So we come down here
00:22:54.120 | To train it on the data. We have at the moment we have
00:22:56.760 | Continuous values in our labels from zero to one
00:23:00.920 | So what we can use is the cosine similarity loss
00:23:04.440 | So we initialize that and then again, we don't want to
00:23:09.640 | Overfit so epochs is set to one
00:23:13.000 | And this time we're using a lower number of warmup steps, which is 10 this time
00:23:20.200 | We train with that
00:23:24.760 | Once that is done
00:23:26.840 | We are completely ready to train and we just call model fit
00:23:31.080 | So train objectives is slightly different to what we saw the cross encoder
00:23:35.720 | Bit method, but everything else is the same. We're just using the default value here
00:23:41.000 | So two e to minus five, so you don't really need to remove that if you want
00:23:44.520 | but if you do want to change that I put that in so you can see where you would change that and
00:23:51.640 | And yeah, we're ready to train we go ahead and train and see how that performs
00:23:57.560 | so we evaluate it this time we're not using cross encoder, so the
00:24:02.920 | The evaluator is slightly different this time using a embedding similarity evaluator, which is going to
00:24:09.480 | Take two sentence vectors and it's going to calculate the similarity between them
00:24:15.560 | and then it's going to compare that to the
00:24:19.160 | That predicted similarity to the true similarity as per our validation set here
00:24:27.480 | In that validation set we we do have these labels
00:24:31.400 | But in stsb those labels are from the or in the range zero to five
00:24:36.120 | so we use this lambda function to
00:24:39.640 | Divide everything by five which brings us into a range of zero to one
00:24:44.040 | Which is what we need for this embedding similarity evaluator
00:24:48.840 | And then as we did before we're creating that list of input examples
00:24:55.400 | Then we can initialize the evaluator
00:24:57.880 | So we're using the embedding similarity evaluator from input examples again because we have input examples
00:25:03.660 | and we're passing in our input examples data and
00:25:08.040 | Right csv is false. That just means I will print the
00:25:12.200 | score to the
00:25:16.280 | To our notebook
00:25:17.960 | and then to actually evaluate all we do is pass the model to our evaluator and it will do everything of course and we get a
00:25:25.880 | similarity here or a correlation score of
00:25:31.720 | 0.76 which is is pretty good. Um, you would you would think of you can think of something like
00:25:42.600 | 0.5 is kind of like your moderate correlation
00:25:46.620 | And 0.8 is like high correlation so that's
00:25:54.760 | Pretty good
00:25:58.440 | And if we for the final quick part, I just want to have a look at the other performances that I found
00:26:05.000 | Because not everything is going to be as good as that what we just got there
00:26:10.120 | So we can see so this is what we just did
00:26:12.760 | Over here. We got 76 just now. So we got slightly higher than what I got before
00:26:19.560 | So we got 76 here
00:26:22.680 | um, but obviously
00:26:24.920 | like I said stsb is
00:26:27.080 | an easy
00:26:29.480 | Data set to get good results on
00:26:31.480 | And the others are more difficult. So the rest of them are you're more in this moderate
00:26:38.660 | um similarity range or correlation range, so
00:26:41.780 | The medical question pair states actually perform better with the bi encoder training than the cross encoder training
00:26:51.220 | Nonetheless, they're still within the same sort of range
00:26:53.620 | That I would expect it's not it's not a massive improvement. It's this is a benchmark down here and there are a few
00:27:01.300 | percentage points better which
00:27:04.500 | Is is probably reasonable particularly because they were not
00:27:08.340 | These blue data sets were not that similar to our medical question pairs data set
00:27:14.660 | So that that makes sense and then the other ones. Yeah, they're all sort of within that moderate range
00:27:22.020 | Um the one that that's which did surprise me
00:27:25.540 | Although it does make sense. I mean the question pairs
00:27:30.980 | is this one here, so
00:27:33.380 | The the transfer from from medical question pairs to quora question pairs was pretty good
00:27:41.060 | I suppose
00:27:42.340 | If you have a look at here we from core question pairs and medical question pairs. It's not as good and
00:27:47.380 | Maybe because the language in core question pairs is simpler than the medical question pairs. I'm not sure
00:27:53.700 | But from medical question pairs to core question pairs. It worked quite well
00:27:59.300 | So that I suppose points out where that n-gram similarity doesn't always
00:28:04.660 | Correlate exactly to what you would expect
00:28:07.860 | But anyway, I think the results from this are probably pretty typical from what you can expect
00:28:16.900 | I think this can be really useful if you
00:28:19.780 | If you really don't have any any label data within your target domain to at least
00:28:28.180 | Squeeze out a few percentage points of performance more than you would be able to without
00:28:33.700 | this training strategy, so
00:28:36.500 | For that reason, I think this can be quite useful. Um
00:28:40.420 | Whether or not it is the best approach to take will depend on your on your data. Um,
00:28:47.540 | So I think it's useful
00:28:50.180 | Not always the best option
00:28:52.500 | But definitely something useful to know about and be able to apply if you need it
00:28:58.420 | So yeah, that's it. Um, thank you very much for watching. I hope it's been useful
00:29:05.380 | And I will see you again in the next one