back to indexTrain Sentence Transformers by Generating Queries (GenQ)
Chapters
0:0 Intro
0:32 Why GenQ?
2:23 GenQ Overview
4:28 Training Data
6:48 Asymmetric Semantic Search
7:54 T5 Query Generation
13:52 Finetuning Bi-encoders
16:2 GenQ Code Walkthrough
21:40 Finetuning Bi-encoder Walkthrough
26:48 Final Points
00:00:00.000 |
Today we're going to have a look at how to fine-tune or build a sentence 00:00:06.640 |
transformer or bi-encoder model using a very limited data set. So when I say 00:00:14.040 |
limited data set all we actually need to fine-tune this bi-encoder model is 00:00:21.120 |
unstructured text data from our target domain. Okay so the domain or the area 00:00:28.280 |
that we want to apply the model to. So the reason that we might want to do this 00:00:34.720 |
is I think quite clear for anyone who has trained sentence transformers, 00:00:42.200 |
bi-encoders before. We in most cases will find that we don't have that much data 00:00:48.040 |
and the data that we do have is usually not in the correct format for actually 00:00:53.800 |
training a model. So we're going to look at one of the most recent methods for 00:00:59.920 |
actually doing this and that is called Gen-Q. At a very high level Gen-Q relies 00:01:08.440 |
or the main benefits of Gen-Q are these three points here or reasons for 00:01:14.400 |
using Gen-Q. Our training data is simply unstructured text. So when I say 00:01:20.560 |
unstructured text it's if you have for example PDF documents or web pages that 00:01:26.080 |
you're scraping data from you can use that as your training data. It's 00:01:32.160 |
specifically for asymmetric semantic search e.g. where your queries that you're 00:01:37.920 |
searching with are generally smaller than the passages or context that you're 00:01:44.320 |
searching for. We'll explain that a little bit more pretty soon. One of the 00:01:51.640 |
very impressive things with Gen-Q is that the performance can approach the 00:01:57.680 |
performance of models trained using supervised methods. Now that's not going 00:02:02.680 |
to be the case for every single use case and I'd say you're probably going to be 00:02:07.020 |
relatively lucky if you start getting towards supervised performance but it 00:02:12.720 |
does happen so that's I think a pretty good indication that this is a good 00:02:19.120 |
technique to use or at least try. So at a high level let's go through what Gen-Q 00:02:27.220 |
actually does. So we start with our unstructured text over here so our 00:02:32.360 |
passages are text. We take a T5 model so a T5 is a model from Google and the 00:02:42.840 |
general philosophy of T5 is that every problem in NLP is a text-to-text problem. 00:02:49.800 |
We'll explain that a little more in a moment but what we're doing with 00:02:56.240 |
this T5 model is we take our long passages, unstructured text, and use T5 to 00:03:03.440 |
generate queries which are these short chunks of text over here. Now these 00:03:10.560 |
queries are basically questions that the passage should answer. So what we 00:03:20.720 |
end up getting are these query passage pairs. Now if you have trained a lot of 00:03:27.200 |
buying coder models you probably recognize that a query passage pair like 00:03:31.040 |
this is basically the format we need to actually train a model using multiple 00:03:39.920 |
negatives ranking loss which is is actually what we do here. So we're 00:03:43.760 |
synthetically generating these queries and then we're actually using a 00:03:48.480 |
supervised training method to train on those queries. So it's almost unsupervised 00:03:55.360 |
because we're not needing to label data or anything here but in 00:04:03.000 |
reality we are actually labeling data it's just automated. So it's a bit of a 00:04:09.640 |
unsupervised/supervised learning technique and from that we get our 00:04:17.440 |
buying coder or sentence transform model at the bottom here. Now don't worry if 00:04:23.240 |
none of that makes sense yet we are going to go through everything in a lot 00:04:26.640 |
more detail. So let's start at the very start with our unlabeled data. We can 00:04:33.680 |
describe our data as either being in domain or out of domain. So what I mean 00:04:38.640 |
by that is for example we want to train a model that is going to allow us to 00:04:45.480 |
query or search through German financial documents. So in domain for that 00:04:54.080 |
particular use case might be what we have over here. So we have German finance 00:04:59.080 |
news articles probably in domain and then we also have German finance 00:05:04.400 |
regulation documents probably in domain. Out of domain but not too far out of the 00:05:12.000 |
domain is English financial documents over here. So in some cases it can be 00:05:18.800 |
okay to train your model on some things that are slightly out of domain like in 00:05:24.140 |
this sort of area here if that's all you have but ideally you want things to be 00:05:29.200 |
very specifically in your domain. And then other things like PyTorch 00:05:34.520 |
documentation, PyTorch whether it's in German or English it's out of domain 00:05:39.120 |
because it's not finance documents. Doombugs are definitely getting out of 00:05:45.120 |
domain here and even if we had for example something like German pirate 00:05:49.680 |
metal even though it's in German it's still out of domain because there's 00:05:55.080 |
probably very little crossover between that and financial documents. So ideally 00:06:04.160 |
the data that we do need for training this it can be unstructured it can just 00:06:10.760 |
be documents or web pages but it should really be within the topic of if we're 00:06:16.960 |
doing a German financial model we would ideally want 00:06:23.400 |
something like the finance news articles or the finance regulations in German. 00:06:28.280 |
Okay so that's in domain and out of domain and GenQ is specifically built 00:06:36.800 |
for where we have a lot of in domain data but it's unstructured and like 00:06:42.720 |
documents or web pages as I said before. So another thing that GenQ is very 00:06:50.720 |
specifically built for is the what I called asymmetric semantic search. Now 00:06:57.600 |
asymmetric semantic search is basically where you have asymmetry between the 00:07:04.080 |
sizes or the length of the text between what you are querying so your search 00:07:08.920 |
query and what you're searching for e.g. trying to pull from your vector database. 00:07:14.600 |
Okay so your queries might be very small let's say on average they are 32 words 00:07:21.320 |
long that's probably actually very long your queries are probably even shorter. 00:07:24.840 |
Imagine when you type into Google so maybe it's more like five to ten words 00:07:28.520 |
at most and then what you return from Google are pretty usually little chunks of 00:07:34.040 |
text and web pages and so on so your passages are kind of like those chunks 00:07:39.320 |
of text and they're bigger. Right so there's asymmetry in the sizes between 00:07:43.960 |
what you are searching with and searching for. That's where you get the 00:07:49.720 |
asymmetric in asymmetric semantic search. Now the next point is that 00:07:57.520 |
we're synthetically generating queries so we have that unstructured text it's 00:08:03.040 |
just passages but to train the model we're actually going to use a supervised 00:08:09.560 |
training method and to train with an unsupervised training method we need to 00:08:16.160 |
have queries that match up to those passages. So to generate those queries we 00:08:24.680 |
use a t5 model that has been fine-tuned specifically for query generation. So as 00:08:33.360 |
I mentioned before the philosophy of t5 is every problem is a text-to-text 00:08:38.320 |
problem. So t5 is not just used for query generation we can use it we say okay 00:08:45.040 |
translate English-German that's good and get the German translation for that. You 00:08:50.520 |
can use it for scoring two sentences on how similar they are and you get like 00:08:56.280 |
3.8 for these two because they're relatively similar and the range there is 00:09:00.360 |
from 0 to 5. And then you might get a cola sentence so that is where you are 00:09:08.320 |
saying is this sentence correct does it make sense and the course is jumping 00:09:14.040 |
well it doesn't make sense so it's saying okay this is not acceptable. All 00:09:18.120 |
of these although okay for example in the semantic similarity example we do 00:09:23.680 |
output 3.8 it's still actually text so all of this is just text-to-text and 00:09:29.760 |
that's a core philosophy behind t5 just an encoder decoder model everything is 00:09:36.640 |
handled every problem is handled text-to-text. The same applies to our 00:09:41.440 |
problem we have a passage and we want to generate a query for that passage okay 00:09:47.000 |
it's just text-to-text and with t5 it might look something like this so the 00:09:52.840 |
passage up at the top here we have this really long you know this from Wikipedia 00:09:58.360 |
and the Esper example was on GenQ just a paragraph saying the Python is 00:10:05.360 |
interpreted high level etc etc programming language so it's just a 00:10:10.040 |
little intro to Python. T5 takes this and it will say okay let's generate 00:10:19.200 |
three queries okay and they're randomized so it can take this one 00:10:23.100 |
passage and apply a little bit of randomness and generate different 00:10:26.200 |
queries so we get this define Python program what is Python program and what 00:10:33.960 |
is Python useful like they're all queries that we would probably search in 00:10:39.160 |
Google for example and return a passive that looks like this okay now for this 00:10:46.640 |
we do need a t5 model that has been trained to produce queries and I'll show 00:10:53.760 |
you later that we are going to use a specifically trained model for that and 00:10:58.600 |
it is worth noting that your queries or the queries that are generated can be a 00:11:05.840 |
little bit noisy and you might maybe we'll see an example later of this what 00:11:09.520 |
I mean by noisy is that sometimes they can be nonsensical or and they're just 00:11:15.480 |
weird queries and like not good quality queries and that is probably the main 00:11:22.240 |
issue with GenQ that you need to be aware of is that because we're generating 00:11:28.560 |
queries using a language generation model and that's pretty much the only 00:11:32.320 |
step in in generating the the training data if there's any noise that will be 00:11:40.520 |
obviously translated into the performance of our model later on that 00:11:44.240 |
we are training and will obviously make it make the performance less than it 00:11:51.340 |
could be if the generate queries were all perfect right but that's to be 00:11:56.000 |
expected nonetheless it's in most cases I think works pretty well so let's have 00:12:02.200 |
a look at a an example of this in Python so here we have the generation code for 00:12:09.480 |
just one example so this paragraph here okay so I'm taking this example from the 00:12:14.920 |
spy.net web page on GenQ so all we're doing here is we're downloading a model 00:12:22.880 |
we're using a be a model over here that has been trained specifically for this 00:12:27.360 |
although as far as I know it's been trained for GPL and not GenQ but 00:12:36.140 |
nonetheless this works for this we initialize tokenizer and a model using 00:12:43.720 |
the transformers library here and we're using t5 for conditional generation 00:12:49.920 |
basically to generate text this is our import text and it's just the same 00:12:57.720 |
paragraph we saw before and all we're going to do is we're going to create our input IDs by 00:13:03.120 |
tokenizing text and then we're going to generate three sequences here and down 00:13:13.600 |
here all I'm doing is decoding the sentences or sequences that we 00:13:19.320 |
generated into text and I think this is the same example this is where I 00:13:27.560 |
actually pulled the example that you saw before from so we generate these queries 00:13:31.320 |
we get find Python program what is Python used for what is Python program okay 00:13:36.240 |
that that's all there is to the query generation step okay so not really 00:13:44.640 |
complicated we just need to do it for a lot of queries or a lot of passages okay 00:13:49.120 |
so that's the query generation step and what we get after query generation 00:13:54.200 |
obviously is a set of a query passage pairs then we can use those to fine-tune 00:14:01.200 |
our model now fine-tuning our model we use multiple native ranking loss now 00:14:07.280 |
there's sort of relatively abstract animation going through what MNR is 00:14:13.360 |
actually doing so MNR at a high level works by placing all those pairs that we 00:14:19.660 |
created the query passage pairs into batches okay and to each one of these 00:14:26.200 |
batches you imagine your query at position 0 should have the high 00:14:31.720 |
similarity score with your passage at position 0 because they're the pair 00:14:36.960 |
right but then you have a batch so you're actually going to compare the 00:14:41.040 |
similarity between the query of session 0 and the passage of position 1 and 2 00:14:48.880 |
and 3 and 4 and so on and for each batch the model weights are going to be 00:14:55.520 |
optimized so that the pair in position 0 have the highest similarity and for 00:15:03.920 |
example query 0 and query 3 or 4 or 5 and so on have a lowest similarity 00:15:10.840 |
right and we're not going to go super in-depth on MNR here because we've covered it before 00:15:19.000 |
and I'll make sure there's a link to that actual article and video in the 00:15:23.400 |
description but at a high level it's basically just a ranking optimization we 00:15:31.760 |
are optimizing so that the pairs that should be together those blocks that you 00:15:37.760 |
see that are more obvious in the you know the diagonal line in that batch 00:15:42.640 |
those should have the highest similarity scores okay so with that we will have a 00:15:53.280 |
buying coder that has been fine-tuned to our specific use case now let's have a 00:15:58.440 |
look at the actual code that we can use to do that so the very first step that 00:16:04.960 |
we'll need to go through is actually getting our data so to get our data 00:16:10.480 |
we're just going to use the squad data set the reason using this is because the 00:16:15.360 |
test set of this has both the questions and context pairs so we can look or we 00:16:22.960 |
can assess the quality of our model at the end of this I'm just checking here 00:16:30.880 |
okay a lot of these contexts are duplicated because we have a sample of 00:16:36.760 |
one context we might have 30 questions so all we're doing here is removing 00:16:41.480 |
those duplicates so that we just have the passages so we can see that the the 00:16:47.480 |
final length or the final number of passages that we have is just under 19,000 00:16:53.160 |
here which it's not that many for I think for a lot of use cases if you're 00:16:57.760 |
using unstructured data you can probably get more but in this case it does work 00:17:03.040 |
very well with just a small amount but it's a very generic data set so for more 00:17:11.560 |
specific use cases you maybe you might want to have more data maybe not you'll 00:17:17.280 |
just have to sort of play around with it and see what works so to generate the 00:17:26.080 |
queries you already saw before what we were doing so all we're doing is loading 00:17:32.200 |
this model again tokenize it and the model I was setting the model to 00:17:37.080 |
evaluation so for example it's only normalization layers in there they are 00:17:43.560 |
set to that evaluation mode ready for inference so I'm gonna shoot that and we 00:17:54.240 |
move on to the inference set where we are going to generate the queries for 00:18:01.480 |
our pairs for our passages so here I'm TQDM we're just using this as a progress 00:18:08.920 |
bar nothing fancy going on there it's just so we can actually see the progress 00:18:15.400 |
because it can take a little bit of time particularly if you have a very large 00:18:18.920 |
amount of data we're setting to no grads so we're not calculating the gradients 00:18:25.000 |
of our model because it takes more time we're not updating or optimizing the 00:18:29.640 |
model we're just using it for inference here so yeah and then what we need to do 00:18:36.480 |
is get our passage so for P in TQDM passages we are doing that so if I show 00:18:48.360 |
you here so look at passive is zero just one example run this okay and this is 00:19:07.440 |
just our passage long chunk of text so we're going through those if there's any 00:19:13.640 |
tab characters which I generally don't think there are but maybe 00:19:18.160 |
there's the odd one I'm going to replace those because later we're going to be 00:19:22.400 |
using a tab separated values for a TSV file for storing our pairs and obviously 00:19:30.000 |
if we have more than one tab in those pairs separating the query and the 00:19:34.520 |
passage it's going to mess up the formatting of the file so I'm just 00:19:39.760 |
removing those places and spaces just in case then creating the input IDs using 00:19:45.920 |
tokenizer encode as we did before then we're generating three queries per 00:19:51.440 |
passage we decode the queries to human readable language so so we can read it 00:20:00.760 |
and so the next model can actually can tokenize that text and understand it 00:20:07.840 |
otherwise we're using token IDs from another model which can't be read by our 00:20:13.480 |
next buying code a model one we're fine - so we decode them into human language 00:20:21.760 |
and then we are going to add the pairs so that's our query just in case any tab 00:20:28.640 |
characters end up in there we're generating them we just replace any type 00:20:33.520 |
of character space again and then we use a tab character to separate the query 00:20:37.920 |
from our passage okay and then what I'm doing here is I'm saying every 1024 00:20:47.000 |
pairs we create I want to save that I'll write that to file okay and then we 00:20:52.880 |
increase the file count which is because what every new file is going to save as 00:21:00.520 |
pairs for example first one be pairs 0 and it'll be pairs 1 and pairs 2 and so on 00:21:05.240 |
okay and then we just reset or refresh the pairs list finally if we because 00:21:15.040 |
here we're doing it in batches of 1024 maybe the final batch only has 600 pairs 00:21:21.960 |
in there I'm just saying if pairs is not non-e.g. the list is not empty I want you 00:21:27.400 |
to save that final batch okay and with that we've generated our data so that's 00:21:34.680 |
the generative or query generation step and we can then move on to the actual 00:21:42.800 |
fine-tuning of the file encoder model so to do that obviously we're in a new 00:21:49.520 |
notebook now so first thing I want to do is I get a list of all the files or the 00:21:55.200 |
pair files like before so so we're doing here we saved it into the data 00:22:05.720 |
then because we're going to use sentence transformers to train here we need to 00:22:11.320 |
import the input example object and this is just a specific object type used by 00:22:17.400 |
the sentence transformers library just a standardized format so that we can easily 00:22:23.400 |
read data that we're going to train with so all we're doing here is a looping 00:22:29.200 |
through all of those files we're opening all of them reading all the lines that's 00:22:34.120 |
all the pairs and here if there's no tab character in the line we're just kind of 00:22:42.480 |
avoiding it just in case for whatever reason the tab character isn't in there 00:22:49.160 |
and then here we're splitting getting the query and passage and we're creating 00:22:57.680 |
the input example and we're appending that to pairs so we have this big list 00:23:01.840 |
of input examples of our queries and passages okay and then we set up the 00:23:08.320 |
data loader for multiple negative ranking loss we've done this before but 00:23:14.360 |
new to you all so we're gonna go through it so I'm using a batch size of 24 with 00:23:19.840 |
MNR the greater the batch size the better generally because you're 00:23:25.280 |
basically making it harder for the model because if you have more pairs in that 00:23:30.240 |
batch the harder is going to be to rank that one pair that is a correct correct 00:23:36.680 |
pair at number position number one all right there's only for example two in 00:23:41.640 |
your batch your model has like a 50/50 chance if it's just guessing of ranking 00:23:46.820 |
the correct pair as position one if there's a hundred in your batch it has a 00:23:52.920 |
pretty small chance of actually guessing that correctly okay so that's why we try 00:24:01.080 |
and use a larger batch size but obviously it would depend on your 00:24:05.800 |
available compute that you have and then we are going to use this no duplicate 00:24:12.800 |
data loader from sentence transformers so this is a specifically built to avoid 00:24:18.280 |
having duplicate queries or context in your queries or patches sorry in your 00:24:27.240 |
batch because obviously if you're trying to assess or you're trying to rank your 00:24:33.280 |
query and passages correctly and you have duplicate passages in there your 00:24:38.840 |
models won't get confused and it can't actually rank them correctly so we avoid 00:24:43.520 |
having duplicates in there using this and then we are just going to define the 00:24:50.920 |
bi-encoder model that we're going to be fine-tuning so I'm using NPNET model 00:24:55.320 |
here NPNET is generally quite good as a sentence transformer bi-encoder model so 00:25:02.040 |
it's a good one to use we are using the mean token or mean pooling method here 00:25:10.080 |
so all the word embeddings that the model outputs are being compressed into 00:25:15.320 |
a single sentence embedding by taking the average across all of the dimensions 00:25:21.720 |
and I'm putting those both together the the actual transform model followed by 00:25:26.040 |
the pooling layer that's great our bi-encoder model and then after we 00:25:32.280 |
initialize our model we can initialize the MNR loss function super easy again 00:25:37.840 |
with sentence transformers all we do is losses multiple negative ranking loss 00:25:42.440 |
and then pass the model in there and then we train I put single epoch here 00:25:48.240 |
we're actually going to train you can train for a single epoch or I think 00:25:52.800 |
that's what they do in the GenQ paper but I've also seen good results from 00:26:01.120 |
using 3 epochs here which is I think the default in the sentence transformers 00:26:08.800 |
documentation so you can go single epoch maybe try it see if it works 00:26:14.640 |
otherwise go with three epochs in this case I use three okay here ignore that 00:26:22.280 |
so we spit we pass our loader loss like typical sentence transformers training 00:26:29.920 |
here and then I saved a model as mpnet GenQ squad okay we warm up for 10% of 00:26:37.160 |
those training sets and then yep the model is saved in that directory and we 00:26:44.640 |
can just load it as it would any other sentence transform model so I think 00:26:49.920 |
that's it for this walkthrough what I'm going to do is in a separate video we'll 00:26:56.840 |
have a look at how we can test this model and evaluate it so that'll be 00:27:03.440 |
interesting I'll release it pretty soon after this so I hope that's been useful 00:27:08.640 |
thank you very much for watching and I will see you again