Back to Index

Train Sentence Transformers by Generating Queries (GenQ)


Chapters

0:0 Intro
0:32 Why GenQ?
2:23 GenQ Overview
4:28 Training Data
6:48 Asymmetric Semantic Search
7:54 T5 Query Generation
13:52 Finetuning Bi-encoders
16:2 GenQ Code Walkthrough
21:40 Finetuning Bi-encoder Walkthrough
26:48 Final Points

Transcript

Today we're going to have a look at how to fine-tune or build a sentence transformer or bi-encoder model using a very limited data set. So when I say limited data set all we actually need to fine-tune this bi-encoder model is unstructured text data from our target domain. Okay so the domain or the area that we want to apply the model to.

So the reason that we might want to do this is I think quite clear for anyone who has trained sentence transformers, bi-encoders before. We in most cases will find that we don't have that much data and the data that we do have is usually not in the correct format for actually training a model.

So we're going to look at one of the most recent methods for actually doing this and that is called Gen-Q. At a very high level Gen-Q relies or the main benefits of Gen-Q are these three points here or reasons for using Gen-Q. Our training data is simply unstructured text.

So when I say unstructured text it's if you have for example PDF documents or web pages that you're scraping data from you can use that as your training data. It's specifically for asymmetric semantic search e.g. where your queries that you're searching with are generally smaller than the passages or context that you're searching for.

We'll explain that a little bit more pretty soon. One of the very impressive things with Gen-Q is that the performance can approach the performance of models trained using supervised methods. Now that's not going to be the case for every single use case and I'd say you're probably going to be relatively lucky if you start getting towards supervised performance but it does happen so that's I think a pretty good indication that this is a good technique to use or at least try.

So at a high level let's go through what Gen-Q actually does. So we start with our unstructured text over here so our passages are text. We take a T5 model so a T5 is a model from Google and the general philosophy of T5 is that every problem in NLP is a text-to-text problem.

We'll explain that a little more in a moment but what we're doing with this T5 model is we take our long passages, unstructured text, and use T5 to generate queries which are these short chunks of text over here. Now these queries are basically questions that the passage should answer.

So what we end up getting are these query passage pairs. Now if you have trained a lot of buying coder models you probably recognize that a query passage pair like this is basically the format we need to actually train a model using multiple negatives ranking loss which is is actually what we do here.

So we're synthetically generating these queries and then we're actually using a supervised training method to train on those queries. So it's almost unsupervised because we're not needing to label data or anything here but in reality we are actually labeling data it's just automated. So it's a bit of a unsupervised/supervised learning technique and from that we get our buying coder or sentence transform model at the bottom here.

Now don't worry if none of that makes sense yet we are going to go through everything in a lot more detail. So let's start at the very start with our unlabeled data. We can describe our data as either being in domain or out of domain. So what I mean by that is for example we want to train a model that is going to allow us to query or search through German financial documents.

So in domain for that particular use case might be what we have over here. So we have German finance news articles probably in domain and then we also have German finance regulation documents probably in domain. Out of domain but not too far out of the domain is English financial documents over here.

So in some cases it can be okay to train your model on some things that are slightly out of domain like in this sort of area here if that's all you have but ideally you want things to be very specifically in your domain. And then other things like PyTorch documentation, PyTorch whether it's in German or English it's out of domain because it's not finance documents.

Doombugs are definitely getting out of domain here and even if we had for example something like German pirate metal even though it's in German it's still out of domain because there's probably very little crossover between that and financial documents. So ideally the data that we do need for training this it can be unstructured it can just be documents or web pages but it should really be within the topic of if we're doing a German financial model we would ideally want something like the finance news articles or the finance regulations in German.

Okay so that's in domain and out of domain and GenQ is specifically built for where we have a lot of in domain data but it's unstructured and like documents or web pages as I said before. So another thing that GenQ is very specifically built for is the what I called asymmetric semantic search.

Now asymmetric semantic search is basically where you have asymmetry between the sizes or the length of the text between what you are querying so your search query and what you're searching for e.g. trying to pull from your vector database. Okay so your queries might be very small let's say on average they are 32 words long that's probably actually very long your queries are probably even shorter.

Imagine when you type into Google so maybe it's more like five to ten words at most and then what you return from Google are pretty usually little chunks of text and web pages and so on so your passages are kind of like those chunks of text and they're bigger.

Right so there's asymmetry in the sizes between what you are searching with and searching for. That's where you get the asymmetric in asymmetric semantic search. Now the next point is that we're synthetically generating queries so we have that unstructured text it's just passages but to train the model we're actually going to use a supervised training method and to train with an unsupervised training method we need to have queries that match up to those passages.

So to generate those queries we use a t5 model that has been fine-tuned specifically for query generation. So as I mentioned before the philosophy of t5 is every problem is a text-to-text problem. So t5 is not just used for query generation we can use it we say okay translate English-German that's good and get the German translation for that.

You can use it for scoring two sentences on how similar they are and you get like 3.8 for these two because they're relatively similar and the range there is from 0 to 5. And then you might get a cola sentence so that is where you are saying is this sentence correct does it make sense and the course is jumping well it doesn't make sense so it's saying okay this is not acceptable.

All of these although okay for example in the semantic similarity example we do output 3.8 it's still actually text so all of this is just text-to-text and that's a core philosophy behind t5 just an encoder decoder model everything is handled every problem is handled text-to-text. The same applies to our problem we have a passage and we want to generate a query for that passage okay it's just text-to-text and with t5 it might look something like this so the passage up at the top here we have this really long you know this from Wikipedia and the Esper example was on GenQ just a paragraph saying the Python is interpreted high level etc etc programming language so it's just a little intro to Python.

T5 takes this and it will say okay let's generate three queries okay and they're randomized so it can take this one passage and apply a little bit of randomness and generate different queries so we get this define Python program what is Python program and what is Python useful like they're all queries that we would probably search in Google for example and return a passive that looks like this okay now for this we do need a t5 model that has been trained to produce queries and I'll show you later that we are going to use a specifically trained model for that and it is worth noting that your queries or the queries that are generated can be a little bit noisy and you might maybe we'll see an example later of this what I mean by noisy is that sometimes they can be nonsensical or and they're just weird queries and like not good quality queries and that is probably the main issue with GenQ that you need to be aware of is that because we're generating queries using a language generation model and that's pretty much the only step in in generating the the training data if there's any noise that will be obviously translated into the performance of our model later on that we are training and will obviously make it make the performance less than it could be if the generate queries were all perfect right but that's to be expected nonetheless it's in most cases I think works pretty well so let's have a look at a an example of this in Python so here we have the generation code for just one example so this paragraph here okay so I'm taking this example from the spy.net web page on GenQ so all we're doing here is we're downloading a model we're using a be a model over here that has been trained specifically for this although as far as I know it's been trained for GPL and not GenQ but nonetheless this works for this we initialize tokenizer and a model using the transformers library here and we're using t5 for conditional generation basically to generate text this is our import text and it's just the same paragraph we saw before and all we're going to do is we're going to create our input IDs by tokenizing text and then we're going to generate three sequences here and down here all I'm doing is decoding the sentences or sequences that we generated into text and I think this is the same example this is where I actually pulled the example that you saw before from so we generate these queries we get find Python program what is Python used for what is Python program okay that that's all there is to the query generation step okay so not really complicated we just need to do it for a lot of queries or a lot of passages okay so that's the query generation step and what we get after query generation obviously is a set of a query passage pairs then we can use those to fine-tune our model now fine-tuning our model we use multiple native ranking loss now there's sort of relatively abstract animation going through what MNR is actually doing so MNR at a high level works by placing all those pairs that we created the query passage pairs into batches okay and to each one of these batches you imagine your query at position 0 should have the high similarity score with your passage at position 0 because they're the pair right but then you have a batch so you're actually going to compare the similarity between the query of session 0 and the passage of position 1 and 2 and 3 and 4 and so on and for each batch the model weights are going to be optimized so that the pair in position 0 have the highest similarity and for example query 0 and query 3 or 4 or 5 and so on have a lowest similarity right and we're not going to go super in-depth on MNR here because we've covered it before and I'll make sure there's a link to that actual article and video in the description but at a high level it's basically just a ranking optimization we are optimizing so that the pairs that should be together those blocks that you see that are more obvious in the you know the diagonal line in that batch those should have the highest similarity scores okay so with that we will have a buying coder that has been fine-tuned to our specific use case now let's have a look at the actual code that we can use to do that so the very first step that we'll need to go through is actually getting our data so to get our data we're just going to use the squad data set the reason using this is because the test set of this has both the questions and context pairs so we can look or we can assess the quality of our model at the end of this I'm just checking here okay a lot of these contexts are duplicated because we have a sample of one context we might have 30 questions so all we're doing here is removing those duplicates so that we just have the passages so we can see that the the final length or the final number of passages that we have is just under 19,000 here which it's not that many for I think for a lot of use cases if you're using unstructured data you can probably get more but in this case it does work very well with just a small amount but it's a very generic data set so for more specific use cases you maybe you might want to have more data maybe not you'll just have to sort of play around with it and see what works so to generate the queries you already saw before what we were doing so all we're doing is loading this model again tokenize it and the model I was setting the model to evaluation so for example it's only normalization layers in there they are set to that evaluation mode ready for inference so I'm gonna shoot that and we move on to the inference set where we are going to generate the queries for our pairs for our passages so here I'm TQDM we're just using this as a progress bar nothing fancy going on there it's just so we can actually see the progress because it can take a little bit of time particularly if you have a very large amount of data we're setting to no grads so we're not calculating the gradients of our model because it takes more time we're not updating or optimizing the model we're just using it for inference here so yeah and then what we need to do is get our passage so for P in TQDM passages we are doing that so if I show you here so look at passive is zero just one example run this okay and this is just our passage long chunk of text so we're going through those if there's any tab characters which I generally don't think there are but maybe there's the odd one I'm going to replace those because later we're going to be using a tab separated values for a TSV file for storing our pairs and obviously if we have more than one tab in those pairs separating the query and the passage it's going to mess up the formatting of the file so I'm just removing those places and spaces just in case then creating the input IDs using tokenizer encode as we did before then we're generating three queries per passage we decode the queries to human readable language so so we can read it and so the next model can actually can tokenize that text and understand it otherwise we're using token IDs from another model which can't be read by our next buying code a model one we're fine - so we decode them into human language and then we are going to add the pairs so that's our query just in case any tab characters end up in there we're generating them we just replace any type of character space again and then we use a tab character to separate the query from our passage okay and then what I'm doing here is I'm saying every 1024 pairs we create I want to save that I'll write that to file okay and then we increase the file count which is because what every new file is going to save as pairs for example first one be pairs 0 and it'll be pairs 1 and pairs 2 and so on okay and then we just reset or refresh the pairs list finally if we because here we're doing it in batches of 1024 maybe the final batch only has 600 pairs in there I'm just saying if pairs is not non-e.g.

the list is not empty I want you to save that final batch okay and with that we've generated our data so that's the generative or query generation step and we can then move on to the actual fine-tuning of the file encoder model so to do that obviously we're in a new notebook now so first thing I want to do is I get a list of all the files or the pair files like before so so we're doing here we saved it into the data directory and they're all .tsv files then because we're going to use sentence transformers to train here we need to import the input example object and this is just a specific object type used by the sentence transformers library just a standardized format so that we can easily read data that we're going to train with so all we're doing here is a looping through all of those files we're opening all of them reading all the lines that's all the pairs and here if there's no tab character in the line we're just kind of avoiding it just in case for whatever reason the tab character isn't in there and then here we're splitting getting the query and passage and we're creating the input example and we're appending that to pairs so we have this big list of input examples of our queries and passages okay and then we set up the data loader for multiple negative ranking loss we've done this before but new to you all so we're gonna go through it so I'm using a batch size of 24 with MNR the greater the batch size the better generally because you're basically making it harder for the model because if you have more pairs in that batch the harder is going to be to rank that one pair that is a correct correct pair at number position number one all right there's only for example two in your batch your model has like a 50/50 chance if it's just guessing of ranking the correct pair as position one if there's a hundred in your batch it has a pretty small chance of actually guessing that correctly okay so that's why we try and use a larger batch size but obviously it would depend on your available compute that you have and then we are going to use this no duplicate data loader from sentence transformers so this is a specifically built to avoid having duplicate queries or context in your queries or patches sorry in your batch because obviously if you're trying to assess or you're trying to rank your query and passages correctly and you have duplicate passages in there your models won't get confused and it can't actually rank them correctly so we avoid having duplicates in there using this and then we are just going to define the bi-encoder model that we're going to be fine-tuning so I'm using NPNET model here NPNET is generally quite good as a sentence transformer bi-encoder model so it's a good one to use we are using the mean token or mean pooling method here so all the word embeddings that the model outputs are being compressed into a single sentence embedding by taking the average across all of the dimensions and I'm putting those both together the the actual transform model followed by the pooling layer that's great our bi-encoder model and then after we initialize our model we can initialize the MNR loss function super easy again with sentence transformers all we do is losses multiple negative ranking loss and then pass the model in there and then we train I put single epoch here we're actually going to train you can train for a single epoch or I think that's what they do in the GenQ paper but I've also seen good results from using 3 epochs here which is I think the default in the sentence transformers documentation so you can go single epoch maybe try it see if it works otherwise go with three epochs in this case I use three okay here ignore that so we spit we pass our loader loss like typical sentence transformers training here and then I saved a model as mpnet GenQ squad okay we warm up for 10% of those training sets and then yep the model is saved in that directory and we can just load it as it would any other sentence transform model so I think that's it for this walkthrough what I'm going to do is in a separate video we'll have a look at how we can test this model and evaluate it so that'll be interesting I'll release it pretty soon after this so I hope that's been useful thank you very much for watching and I will see you again