back to index

Train Sentence Transformers by Generating Queries (GenQ)


Chapters

0:0 Intro
0:32 Why GenQ?
2:23 GenQ Overview
4:28 Training Data
6:48 Asymmetric Semantic Search
7:54 T5 Query Generation
13:52 Finetuning Bi-encoders
16:2 GenQ Code Walkthrough
21:40 Finetuning Bi-encoder Walkthrough
26:48 Final Points

Whisper Transcript | Transcript Only Page

00:00:00.000 | Today we're going to have a look at how to fine-tune or build a sentence
00:00:06.640 | transformer or bi-encoder model using a very limited data set. So when I say
00:00:14.040 | limited data set all we actually need to fine-tune this bi-encoder model is
00:00:21.120 | unstructured text data from our target domain. Okay so the domain or the area
00:00:28.280 | that we want to apply the model to. So the reason that we might want to do this
00:00:34.720 | is I think quite clear for anyone who has trained sentence transformers,
00:00:42.200 | bi-encoders before. We in most cases will find that we don't have that much data
00:00:48.040 | and the data that we do have is usually not in the correct format for actually
00:00:53.800 | training a model. So we're going to look at one of the most recent methods for
00:00:59.920 | actually doing this and that is called Gen-Q. At a very high level Gen-Q relies
00:01:08.440 | or the main benefits of Gen-Q are these three points here or reasons for
00:01:14.400 | using Gen-Q. Our training data is simply unstructured text. So when I say
00:01:20.560 | unstructured text it's if you have for example PDF documents or web pages that
00:01:26.080 | you're scraping data from you can use that as your training data. It's
00:01:32.160 | specifically for asymmetric semantic search e.g. where your queries that you're
00:01:37.920 | searching with are generally smaller than the passages or context that you're
00:01:44.320 | searching for. We'll explain that a little bit more pretty soon. One of the
00:01:51.640 | very impressive things with Gen-Q is that the performance can approach the
00:01:57.680 | performance of models trained using supervised methods. Now that's not going
00:02:02.680 | to be the case for every single use case and I'd say you're probably going to be
00:02:07.020 | relatively lucky if you start getting towards supervised performance but it
00:02:12.720 | does happen so that's I think a pretty good indication that this is a good
00:02:19.120 | technique to use or at least try. So at a high level let's go through what Gen-Q
00:02:27.220 | actually does. So we start with our unstructured text over here so our
00:02:32.360 | passages are text. We take a T5 model so a T5 is a model from Google and the
00:02:42.840 | general philosophy of T5 is that every problem in NLP is a text-to-text problem.
00:02:49.800 | We'll explain that a little more in a moment but what we're doing with
00:02:56.240 | this T5 model is we take our long passages, unstructured text, and use T5 to
00:03:03.440 | generate queries which are these short chunks of text over here. Now these
00:03:10.560 | queries are basically questions that the passage should answer. So what we
00:03:20.720 | end up getting are these query passage pairs. Now if you have trained a lot of
00:03:27.200 | buying coder models you probably recognize that a query passage pair like
00:03:31.040 | this is basically the format we need to actually train a model using multiple
00:03:39.920 | negatives ranking loss which is is actually what we do here. So we're
00:03:43.760 | synthetically generating these queries and then we're actually using a
00:03:48.480 | supervised training method to train on those queries. So it's almost unsupervised
00:03:55.360 | because we're not needing to label data or anything here but in
00:04:03.000 | reality we are actually labeling data it's just automated. So it's a bit of a
00:04:09.640 | unsupervised/supervised learning technique and from that we get our
00:04:17.440 | buying coder or sentence transform model at the bottom here. Now don't worry if
00:04:23.240 | none of that makes sense yet we are going to go through everything in a lot
00:04:26.640 | more detail. So let's start at the very start with our unlabeled data. We can
00:04:33.680 | describe our data as either being in domain or out of domain. So what I mean
00:04:38.640 | by that is for example we want to train a model that is going to allow us to
00:04:45.480 | query or search through German financial documents. So in domain for that
00:04:54.080 | particular use case might be what we have over here. So we have German finance
00:04:59.080 | news articles probably in domain and then we also have German finance
00:05:04.400 | regulation documents probably in domain. Out of domain but not too far out of the
00:05:12.000 | domain is English financial documents over here. So in some cases it can be
00:05:18.800 | okay to train your model on some things that are slightly out of domain like in
00:05:24.140 | this sort of area here if that's all you have but ideally you want things to be
00:05:29.200 | very specifically in your domain. And then other things like PyTorch
00:05:34.520 | documentation, PyTorch whether it's in German or English it's out of domain
00:05:39.120 | because it's not finance documents. Doombugs are definitely getting out of
00:05:45.120 | domain here and even if we had for example something like German pirate
00:05:49.680 | metal even though it's in German it's still out of domain because there's
00:05:55.080 | probably very little crossover between that and financial documents. So ideally
00:06:04.160 | the data that we do need for training this it can be unstructured it can just
00:06:10.760 | be documents or web pages but it should really be within the topic of if we're
00:06:16.960 | doing a German financial model we would ideally want
00:06:23.400 | something like the finance news articles or the finance regulations in German.
00:06:28.280 | Okay so that's in domain and out of domain and GenQ is specifically built
00:06:36.800 | for where we have a lot of in domain data but it's unstructured and like
00:06:42.720 | documents or web pages as I said before. So another thing that GenQ is very
00:06:50.720 | specifically built for is the what I called asymmetric semantic search. Now
00:06:57.600 | asymmetric semantic search is basically where you have asymmetry between the
00:07:04.080 | sizes or the length of the text between what you are querying so your search
00:07:08.920 | query and what you're searching for e.g. trying to pull from your vector database.
00:07:14.600 | Okay so your queries might be very small let's say on average they are 32 words
00:07:21.320 | long that's probably actually very long your queries are probably even shorter.
00:07:24.840 | Imagine when you type into Google so maybe it's more like five to ten words
00:07:28.520 | at most and then what you return from Google are pretty usually little chunks of
00:07:34.040 | text and web pages and so on so your passages are kind of like those chunks
00:07:39.320 | of text and they're bigger. Right so there's asymmetry in the sizes between
00:07:43.960 | what you are searching with and searching for. That's where you get the
00:07:49.720 | asymmetric in asymmetric semantic search. Now the next point is that
00:07:57.520 | we're synthetically generating queries so we have that unstructured text it's
00:08:03.040 | just passages but to train the model we're actually going to use a supervised
00:08:09.560 | training method and to train with an unsupervised training method we need to
00:08:16.160 | have queries that match up to those passages. So to generate those queries we
00:08:24.680 | use a t5 model that has been fine-tuned specifically for query generation. So as
00:08:33.360 | I mentioned before the philosophy of t5 is every problem is a text-to-text
00:08:38.320 | problem. So t5 is not just used for query generation we can use it we say okay
00:08:45.040 | translate English-German that's good and get the German translation for that. You
00:08:50.520 | can use it for scoring two sentences on how similar they are and you get like
00:08:56.280 | 3.8 for these two because they're relatively similar and the range there is
00:09:00.360 | from 0 to 5. And then you might get a cola sentence so that is where you are
00:09:08.320 | saying is this sentence correct does it make sense and the course is jumping
00:09:14.040 | well it doesn't make sense so it's saying okay this is not acceptable. All
00:09:18.120 | of these although okay for example in the semantic similarity example we do
00:09:23.680 | output 3.8 it's still actually text so all of this is just text-to-text and
00:09:29.760 | that's a core philosophy behind t5 just an encoder decoder model everything is
00:09:36.640 | handled every problem is handled text-to-text. The same applies to our
00:09:41.440 | problem we have a passage and we want to generate a query for that passage okay
00:09:47.000 | it's just text-to-text and with t5 it might look something like this so the
00:09:52.840 | passage up at the top here we have this really long you know this from Wikipedia
00:09:58.360 | and the Esper example was on GenQ just a paragraph saying the Python is
00:10:05.360 | interpreted high level etc etc programming language so it's just a
00:10:10.040 | little intro to Python. T5 takes this and it will say okay let's generate
00:10:19.200 | three queries okay and they're randomized so it can take this one
00:10:23.100 | passage and apply a little bit of randomness and generate different
00:10:26.200 | queries so we get this define Python program what is Python program and what
00:10:33.960 | is Python useful like they're all queries that we would probably search in
00:10:39.160 | Google for example and return a passive that looks like this okay now for this
00:10:46.640 | we do need a t5 model that has been trained to produce queries and I'll show
00:10:53.760 | you later that we are going to use a specifically trained model for that and
00:10:58.600 | it is worth noting that your queries or the queries that are generated can be a
00:11:05.840 | little bit noisy and you might maybe we'll see an example later of this what
00:11:09.520 | I mean by noisy is that sometimes they can be nonsensical or and they're just
00:11:15.480 | weird queries and like not good quality queries and that is probably the main
00:11:22.240 | issue with GenQ that you need to be aware of is that because we're generating
00:11:28.560 | queries using a language generation model and that's pretty much the only
00:11:32.320 | step in in generating the the training data if there's any noise that will be
00:11:40.520 | obviously translated into the performance of our model later on that
00:11:44.240 | we are training and will obviously make it make the performance less than it
00:11:51.340 | could be if the generate queries were all perfect right but that's to be
00:11:56.000 | expected nonetheless it's in most cases I think works pretty well so let's have
00:12:02.200 | a look at a an example of this in Python so here we have the generation code for
00:12:09.480 | just one example so this paragraph here okay so I'm taking this example from the
00:12:14.920 | spy.net web page on GenQ so all we're doing here is we're downloading a model
00:12:22.880 | we're using a be a model over here that has been trained specifically for this
00:12:27.360 | although as far as I know it's been trained for GPL and not GenQ but
00:12:36.140 | nonetheless this works for this we initialize tokenizer and a model using
00:12:43.720 | the transformers library here and we're using t5 for conditional generation
00:12:49.920 | basically to generate text this is our import text and it's just the same
00:12:57.720 | paragraph we saw before and all we're going to do is we're going to create our input IDs by
00:13:03.120 | tokenizing text and then we're going to generate three sequences here and down
00:13:13.600 | here all I'm doing is decoding the sentences or sequences that we
00:13:19.320 | generated into text and I think this is the same example this is where I
00:13:27.560 | actually pulled the example that you saw before from so we generate these queries
00:13:31.320 | we get find Python program what is Python used for what is Python program okay
00:13:36.240 | that that's all there is to the query generation step okay so not really
00:13:44.640 | complicated we just need to do it for a lot of queries or a lot of passages okay
00:13:49.120 | so that's the query generation step and what we get after query generation
00:13:54.200 | obviously is a set of a query passage pairs then we can use those to fine-tune
00:14:01.200 | our model now fine-tuning our model we use multiple native ranking loss now
00:14:07.280 | there's sort of relatively abstract animation going through what MNR is
00:14:13.360 | actually doing so MNR at a high level works by placing all those pairs that we
00:14:19.660 | created the query passage pairs into batches okay and to each one of these
00:14:26.200 | batches you imagine your query at position 0 should have the high
00:14:31.720 | similarity score with your passage at position 0 because they're the pair
00:14:36.960 | right but then you have a batch so you're actually going to compare the
00:14:41.040 | similarity between the query of session 0 and the passage of position 1 and 2
00:14:48.880 | and 3 and 4 and so on and for each batch the model weights are going to be
00:14:55.520 | optimized so that the pair in position 0 have the highest similarity and for
00:15:03.920 | example query 0 and query 3 or 4 or 5 and so on have a lowest similarity
00:15:10.840 | right and we're not going to go super in-depth on MNR here because we've covered it before
00:15:19.000 | and I'll make sure there's a link to that actual article and video in the
00:15:23.400 | description but at a high level it's basically just a ranking optimization we
00:15:31.760 | are optimizing so that the pairs that should be together those blocks that you
00:15:37.760 | see that are more obvious in the you know the diagonal line in that batch
00:15:42.640 | those should have the highest similarity scores okay so with that we will have a
00:15:53.280 | buying coder that has been fine-tuned to our specific use case now let's have a
00:15:58.440 | look at the actual code that we can use to do that so the very first step that
00:16:04.960 | we'll need to go through is actually getting our data so to get our data
00:16:10.480 | we're just going to use the squad data set the reason using this is because the
00:16:15.360 | test set of this has both the questions and context pairs so we can look or we
00:16:22.960 | can assess the quality of our model at the end of this I'm just checking here
00:16:30.880 | okay a lot of these contexts are duplicated because we have a sample of
00:16:36.760 | one context we might have 30 questions so all we're doing here is removing
00:16:41.480 | those duplicates so that we just have the passages so we can see that the the
00:16:47.480 | final length or the final number of passages that we have is just under 19,000
00:16:53.160 | here which it's not that many for I think for a lot of use cases if you're
00:16:57.760 | using unstructured data you can probably get more but in this case it does work
00:17:03.040 | very well with just a small amount but it's a very generic data set so for more
00:17:11.560 | specific use cases you maybe you might want to have more data maybe not you'll
00:17:17.280 | just have to sort of play around with it and see what works so to generate the
00:17:26.080 | queries you already saw before what we were doing so all we're doing is loading
00:17:32.200 | this model again tokenize it and the model I was setting the model to
00:17:37.080 | evaluation so for example it's only normalization layers in there they are
00:17:43.560 | set to that evaluation mode ready for inference so I'm gonna shoot that and we
00:17:54.240 | move on to the inference set where we are going to generate the queries for
00:18:01.480 | our pairs for our passages so here I'm TQDM we're just using this as a progress
00:18:08.920 | bar nothing fancy going on there it's just so we can actually see the progress
00:18:15.400 | because it can take a little bit of time particularly if you have a very large
00:18:18.920 | amount of data we're setting to no grads so we're not calculating the gradients
00:18:25.000 | of our model because it takes more time we're not updating or optimizing the
00:18:29.640 | model we're just using it for inference here so yeah and then what we need to do
00:18:36.480 | is get our passage so for P in TQDM passages we are doing that so if I show
00:18:48.360 | you here so look at passive is zero just one example run this okay and this is
00:19:07.440 | just our passage long chunk of text so we're going through those if there's any
00:19:13.640 | tab characters which I generally don't think there are but maybe
00:19:18.160 | there's the odd one I'm going to replace those because later we're going to be
00:19:22.400 | using a tab separated values for a TSV file for storing our pairs and obviously
00:19:30.000 | if we have more than one tab in those pairs separating the query and the
00:19:34.520 | passage it's going to mess up the formatting of the file so I'm just
00:19:39.760 | removing those places and spaces just in case then creating the input IDs using
00:19:45.920 | tokenizer encode as we did before then we're generating three queries per
00:19:51.440 | passage we decode the queries to human readable language so so we can read it
00:20:00.760 | and so the next model can actually can tokenize that text and understand it
00:20:07.840 | otherwise we're using token IDs from another model which can't be read by our
00:20:13.480 | next buying code a model one we're fine - so we decode them into human language
00:20:21.760 | and then we are going to add the pairs so that's our query just in case any tab
00:20:28.640 | characters end up in there we're generating them we just replace any type
00:20:33.520 | of character space again and then we use a tab character to separate the query
00:20:37.920 | from our passage okay and then what I'm doing here is I'm saying every 1024
00:20:47.000 | pairs we create I want to save that I'll write that to file okay and then we
00:20:52.880 | increase the file count which is because what every new file is going to save as
00:21:00.520 | pairs for example first one be pairs 0 and it'll be pairs 1 and pairs 2 and so on
00:21:05.240 | okay and then we just reset or refresh the pairs list finally if we because
00:21:15.040 | here we're doing it in batches of 1024 maybe the final batch only has 600 pairs
00:21:21.960 | in there I'm just saying if pairs is not non-e.g. the list is not empty I want you
00:21:27.400 | to save that final batch okay and with that we've generated our data so that's
00:21:34.680 | the generative or query generation step and we can then move on to the actual
00:21:42.800 | fine-tuning of the file encoder model so to do that obviously we're in a new
00:21:49.520 | notebook now so first thing I want to do is I get a list of all the files or the
00:21:55.200 | pair files like before so so we're doing here we saved it into the data
00:22:00.960 | directory and they're all .tsv files
00:22:05.720 | then because we're going to use sentence transformers to train here we need to
00:22:11.320 | import the input example object and this is just a specific object type used by
00:22:17.400 | the sentence transformers library just a standardized format so that we can easily
00:22:23.400 | read data that we're going to train with so all we're doing here is a looping
00:22:29.200 | through all of those files we're opening all of them reading all the lines that's
00:22:34.120 | all the pairs and here if there's no tab character in the line we're just kind of
00:22:42.480 | avoiding it just in case for whatever reason the tab character isn't in there
00:22:49.160 | and then here we're splitting getting the query and passage and we're creating
00:22:57.680 | the input example and we're appending that to pairs so we have this big list
00:23:01.840 | of input examples of our queries and passages okay and then we set up the
00:23:08.320 | data loader for multiple negative ranking loss we've done this before but
00:23:14.360 | new to you all so we're gonna go through it so I'm using a batch size of 24 with
00:23:19.840 | MNR the greater the batch size the better generally because you're
00:23:25.280 | basically making it harder for the model because if you have more pairs in that
00:23:30.240 | batch the harder is going to be to rank that one pair that is a correct correct
00:23:36.680 | pair at number position number one all right there's only for example two in
00:23:41.640 | your batch your model has like a 50/50 chance if it's just guessing of ranking
00:23:46.820 | the correct pair as position one if there's a hundred in your batch it has a
00:23:52.920 | pretty small chance of actually guessing that correctly okay so that's why we try
00:24:01.080 | and use a larger batch size but obviously it would depend on your
00:24:05.800 | available compute that you have and then we are going to use this no duplicate
00:24:12.800 | data loader from sentence transformers so this is a specifically built to avoid
00:24:18.280 | having duplicate queries or context in your queries or patches sorry in your
00:24:27.240 | batch because obviously if you're trying to assess or you're trying to rank your
00:24:33.280 | query and passages correctly and you have duplicate passages in there your
00:24:38.840 | models won't get confused and it can't actually rank them correctly so we avoid
00:24:43.520 | having duplicates in there using this and then we are just going to define the
00:24:50.920 | bi-encoder model that we're going to be fine-tuning so I'm using NPNET model
00:24:55.320 | here NPNET is generally quite good as a sentence transformer bi-encoder model so
00:25:02.040 | it's a good one to use we are using the mean token or mean pooling method here
00:25:10.080 | so all the word embeddings that the model outputs are being compressed into
00:25:15.320 | a single sentence embedding by taking the average across all of the dimensions
00:25:21.720 | and I'm putting those both together the the actual transform model followed by
00:25:26.040 | the pooling layer that's great our bi-encoder model and then after we
00:25:32.280 | initialize our model we can initialize the MNR loss function super easy again
00:25:37.840 | with sentence transformers all we do is losses multiple negative ranking loss
00:25:42.440 | and then pass the model in there and then we train I put single epoch here
00:25:48.240 | we're actually going to train you can train for a single epoch or I think
00:25:52.800 | that's what they do in the GenQ paper but I've also seen good results from
00:26:01.120 | using 3 epochs here which is I think the default in the sentence transformers
00:26:08.800 | documentation so you can go single epoch maybe try it see if it works
00:26:14.640 | otherwise go with three epochs in this case I use three okay here ignore that
00:26:22.280 | so we spit we pass our loader loss like typical sentence transformers training
00:26:29.920 | here and then I saved a model as mpnet GenQ squad okay we warm up for 10% of
00:26:37.160 | those training sets and then yep the model is saved in that directory and we
00:26:44.640 | can just load it as it would any other sentence transform model so I think
00:26:49.920 | that's it for this walkthrough what I'm going to do is in a separate video we'll
00:26:56.840 | have a look at how we can test this model and evaluate it so that'll be
00:27:03.440 | interesting I'll release it pretty soon after this so I hope that's been useful
00:27:08.640 | thank you very much for watching and I will see you again