back to index

LLM Asia Paper Club Survey Round


Chapters

0:0 Let's Think Dot by Dot: Hidden Computation in Transformer Language Models
12:30 Uncertainty Estimation and Quantification for LLMs: A Simple Supervised Approach
32:23 Monosemanticity
44:54 Medusa - Simple Speculative Decoding using multiple heads

Whisper Transcript | Transcript Only Page

00:00:00.000 | Okay, great, great.
00:00:01.000 | Okay, I will now start on presenting this paper.
00:00:04.000 | Basically, I found out about this paper recently,
00:00:08.000 | which is a paper called "Let's think dot by dot".
00:00:11.000 | And the motivation behind choosing this paper is that
00:00:14.000 | basically I always have the burning question as to like,
00:00:17.000 | how do LLMs actually think?
00:00:19.000 | Now, we know that chain of thought reasoning process
00:00:22.000 | or chain of thought prompting actually shows that
00:00:24.000 | by allowing LLMs to think out loud before answering,
00:00:28.000 | their performance actually improves considerably
00:00:30.000 | when compared to direct answering techniques.
00:00:33.000 | This actually provides us some intuition as to
00:00:36.000 | how LLMs reason through their tasks.
00:00:39.000 | However, recent work as found in this paper
00:00:42.000 | suggests that for chain of thought reasoning,
00:00:44.000 | LLM answers could also be unfaithful
00:00:47.000 | to the intermediate reasoning steps.
00:00:49.000 | Simply put, the answers do not really tally
00:00:52.000 | with their workings.
00:00:54.000 | Now, here are some guiding questions
00:00:56.000 | before moving on to the related words section.
00:01:00.000 | The first question is that,
00:01:01.000 | do LLMs actually need to think out loud,
00:01:04.000 | basically like write down their thoughts,
00:01:06.000 | or are they able to think internally like humans?
00:01:09.000 | The second question is that,
00:01:10.000 | are there any specific tasks to show that LLMs
00:01:13.000 | are not simply relying on the semantic information
00:01:17.000 | and possibly relying on other kinds of information
00:01:20.000 | found in the tokens in their inputs?
00:01:23.000 | So I'll just skim through the related words section
00:01:26.000 | just to provide some additional context on this work.
00:01:29.000 | The first section in the related work
00:01:31.000 | is talking about computational complexity.
00:01:33.000 | Basically, they are discussing that there are
00:01:35.000 | different levels of computational complexity
00:01:38.000 | and task complexity.
00:01:39.000 | Basically, the lowest level is something called TC0,
00:01:42.000 | and the higher levels of task complexity include
00:01:45.000 | graph connectivity, solving Sudoku, etc.
00:01:48.000 | And this requires some form of recursion
00:01:51.000 | and checking multiple constraints at the same time.
00:01:54.000 | Now, the second part of the related words section
00:01:56.000 | is talking about how transformers use tokens for reasoning.
00:02:00.000 | In the chain of thought prompting example,
00:02:02.000 | we can see that transformers can indeed
00:02:04.000 | solve higher complexity problems
00:02:06.000 | when additional reasoning tokens are provided.
00:02:09.000 | These tokens actually help to decompose the problem
00:02:12.000 | into smaller problems
00:02:13.000 | and improve the model reasoning's capabilities.
00:02:17.000 | So the question is that,
00:02:18.000 | are the LLMs making full use of these extra tokens
00:02:21.000 | as part of their reasoning process?
00:02:24.000 | And the third part in the related words section
00:02:27.000 | is talking about potential model performance improvements
00:02:30.000 | on more complex tasks.
00:02:32.000 | Some work also suggested that using filler tokens
00:02:35.000 | could improve performance on nested quantifier tasks.
00:02:38.000 | Basically, tasks that have some form of constraints,
00:02:41.000 | nested constraints that depend on each other.
00:02:45.000 | Now, what the authors want is that
00:02:47.000 | they actually wanted to create new data set
00:02:49.000 | that specifically target these kind of problems
00:02:51.000 | to test their filler tokens hypothesis.
00:02:55.000 | Now, let me move on to the methodology section
00:02:57.000 | and talk about the tasks created by the authors.
00:03:00.000 | I have a short section here talking about
00:03:02.000 | the nested quantifier,
00:03:04.000 | and then I will send the link over later,
00:03:06.000 | but this will just be a short summary
00:03:08.000 | about what exactly are nested quantifiers.
00:03:12.000 | So due to time constraints,
00:03:13.000 | I'll just move on to the first task created by the authors.
00:03:17.000 | So the first task is called 3SUM.
00:03:19.000 | Basically, it's that you have a following list of numbers
00:03:22.000 | and you want to check if there's a group of three numbers
00:03:25.000 | inside that list that adds up to 0 modulo 10.
00:03:29.000 | So the example question is that
00:03:31.000 | you have a list of 2, 3, 5, 7, 1, 8,
00:03:34.000 | and the answer could be like 2, 3, 5 or 2, 7, 1
00:03:37.000 | because the result will be 10.
00:03:40.000 | And 10 divided by 10, the remainder is 0.
00:03:44.000 | And for the second task, it's called 2SUM transform,
00:03:48.000 | which is also regarding a sequence of numbers.
00:03:52.000 | And basically, they have one list
00:03:54.000 | and then they have a hidden transform list.
00:03:56.000 | Basically, using the original list,
00:03:59.000 | and then they are adding some form of permutation
00:04:01.000 | to each element in that list,
00:04:03.000 | and that permutation is only revealed at the end.
00:04:05.000 | So for example, you have an original list
00:04:07.000 | and then you are +3 to each number in the transform list.
00:04:12.000 | And then what this task actually requires
00:04:16.000 | is that it requires model to remember transformation information
00:04:20.000 | and how to integrate this in its computational thinking process.
00:04:25.000 | So let me move on to the experiment section of the paper.
00:04:30.000 | So basically, the authors are using some form of scaled-down,
00:04:33.000 | randomly-initialized LAMA model in their training process
00:04:37.000 | using three different model configuration for the setup.
00:04:41.000 | And the first configuration is basically using filler tokens,
00:04:45.000 | basically a series of dots as intermediate tokens
00:04:48.000 | between the input tokens and the output tokens.
00:04:51.000 | Now, the second configuration is that instead of using dots,
00:04:54.000 | they are using the chain of thought reasoning tokens
00:04:57.000 | as the intermediate tokens.
00:04:59.000 | Now, the third configuration is the sort of like the control setup
00:05:02.000 | where there are no filler tokens in between the input and output.
00:05:08.000 | And then the results show that for the threesome task,
00:05:11.000 | the first task created by the authors,
00:05:14.000 | as the length of the sequence of numbers increases,
00:05:18.000 | the difference in performance
00:05:20.000 | between the configuration with the filler tokens
00:05:23.000 | and the one with no filler tokens
00:05:25.000 | actually widens the gap quite drastically.
00:05:29.000 | And then this basically shows that as the task complexity increases,
00:05:33.000 | the filler tokens actually do contain
00:05:36.000 | some form of task-related information.
00:05:40.000 | And the next figure is talking about a different experiment
00:05:44.000 | which is basically the authors trying to freeze
00:05:46.000 | all the previous models of the transformer
00:05:49.000 | and only preserving the final attention layer.
00:05:52.000 | And then basically, they are trying to prove that
00:05:54.000 | whatever changes they make in the final layer,
00:05:56.000 | this will cause the results to change.
00:05:59.000 | Basically, the performance improvements or performance changes
00:06:02.000 | are only resulting from the changes in the final layer.
00:06:06.000 | So in the final layer, what they tried is that
00:06:08.000 | they tried to increase the percentage of filler tokens
00:06:11.000 | in the final layer.
00:06:12.000 | And they've seen that in their tasks,
00:06:14.000 | actually it increases the accuracy up to like 100%.
00:06:19.000 | And then it increases sharply
00:06:22.000 | from 60% of the filler tokens kept
00:06:25.000 | until it tapers off and then it has some diminishing returns
00:06:28.000 | past the 60% mark.
00:06:31.000 | Now, for the two-sum transform task,
00:06:33.000 | which is the second task,
00:06:34.000 | the author lays down a table with comparing the chain of thought,
00:06:37.000 | the filler tokens, and the no-tokens configuration.
00:06:41.000 | And then this shows that, yeah, as expected,
00:06:44.000 | the chain of thought is still triumphing at the very top.
00:06:48.000 | And the filler token is sort of getting quite close
00:06:51.000 | to the chain of thought reasoning tokens
00:06:53.000 | while the no-tokens is showing like a 78% accuracy.
00:06:58.000 | Now, both of the aforementioned configurations
00:07:04.000 | are noticeably exceeding the performance
00:07:06.000 | of no intermediate tokens.
00:07:09.000 | And some additional insight that the author
00:07:12.000 | actually found out in their experiment is that
00:07:14.000 | they, before using the filler tokens
00:07:17.000 | to replace the intermediate tokens,
00:07:19.000 | they actually needed the chain of thought configuration
00:07:22.000 | to be set up to see how the chain of thought tokens
00:07:25.000 | are placed as the intermediate tokens
00:07:27.000 | before directly replacing one-to-one
00:07:30.000 | with the filler tokens.
00:07:32.000 | So this basically means that this is actually quite challenging.
00:07:35.000 | This is a quite challenging learning problem
00:07:37.000 | whereby prior knowledge of how to put
00:07:39.000 | the intermediate tokens are needed
00:07:41.000 | before you can use the filler tokens.
00:07:44.000 | This also implies that the filler tokens must be placed
00:07:46.000 | in some particular arrangement or configuration
00:07:49.000 | in order to showcase good performance.
00:07:52.000 | So, at the end of the paper,
00:07:55.000 | some food for thought is basically saying that
00:07:59.000 | the experiment results imply that
00:08:01.000 | although filler tokens do not contain
00:08:03.000 | any semantic information or instructions
00:08:05.000 | for the language models to follow,
00:08:07.000 | they still contain some form of structural information
00:08:10.000 | or synthetic information.
00:08:12.000 | So, a possible question is that
00:08:15.000 | why use dots as the filler tokens?
00:08:17.000 | We can also use other kinds of tokens, right?
00:08:20.000 | As long as the tokens do not really contain
00:08:22.000 | any kind of semantic meaning,
00:08:24.000 | it could all be valid choices for the model.
00:08:27.000 | So, I would love to see that maybe
00:08:28.000 | we can try different types of tokens
00:08:30.000 | to see if there are any different kinds of results.
00:08:33.000 | Although dot tokens are technically filler tokens,
00:08:35.000 | but they also do serve as some form of
00:08:37.000 | ellipses in the English language,
00:08:39.000 | acting as some form of continuation.
00:08:43.000 | So, in summary, this paper basically shows that
00:08:47.000 | how large language models can use filler tokens
00:08:50.000 | to achieve higher performance in certain tasks
00:08:53.000 | when compared to not using any filler tokens at all
00:08:57.000 | as the intermediate tokens.
00:08:59.000 | The results also show that filler tokens
00:09:01.000 | have some form of computational benefit,
00:09:04.000 | although not semantically,
00:09:05.000 | but maybe some form of structural information.
00:09:09.000 | So, that's the end of my survey.
00:09:12.000 | I think it's an interesting paper,
00:09:14.000 | and then I'm open to any questions.
00:09:18.000 | I think Casper had a question in the chat itself,
00:09:21.000 | whether the filler tokens were special tokens
00:09:24.000 | introduced after the model was trained,
00:09:27.000 | or was it just the same token used for pure filler
00:09:29.000 | that the tokenizer had learnt perhaps
00:09:31.000 | during the pre-training phase?
00:09:33.000 | They actually include the filler tokens,
00:09:36.000 | as I said, in the pre-training phase or so.
00:09:40.000 | Yeah.
00:09:41.000 | So, they were not like a special token added in afterwards.
00:09:45.000 | There's another question from Warren,
00:09:50.000 | which is why not just use unknown token or something
00:09:52.000 | instead of dots in this case?
00:09:54.000 | Yeah, that's an interesting question also, yeah.
00:09:57.000 | I think it will probably, yeah,
00:10:00.000 | you'll probably have different results, yeah.
00:10:02.000 | Seems some form of, any rare form of tokens,
00:10:04.000 | yeah, should be fine, yeah.
00:10:06.000 | It seems to, I don't know if you've seen
00:10:08.000 | like the memes on Twitter, where they, like,
00:10:10.000 | if you prompt the model and you put a space in between,
00:10:13.000 | or like one or two spaces,
00:10:15.000 | and then you just let it, like,
00:10:17.000 | and then you let it do the normal inference call,
00:10:19.000 | the performance actually increases sometimes
00:10:21.000 | because the space itself is another token, right?
00:10:23.000 | Oh, okay.
00:10:24.000 | And then that modify,
00:10:25.000 | that sort of like modifies the computation itself.
00:10:28.000 | Oh, yeah, that's quite interesting, yeah.
00:10:31.000 | It's all these hidden complexities that,
00:10:37.000 | that we don't really know.
00:10:38.000 | Yeah, it's quite interesting.
00:10:40.000 | I just comment,
00:10:43.000 | could it also be because of positional biases?
00:10:46.000 | Do you want to maybe elaborate on that?
00:10:48.000 | I'm not too sure about what positional biases are.
00:10:50.000 | I mean, you can drop it in the chat or you can unmute yourself.
00:10:53.000 | Yeah, sure.
00:10:54.000 | So what I meant was because, like,
00:10:55.000 | there's a positional encoding in the transformers, right?
00:10:57.000 | So because the models will train on a lot of sequences,
00:11:00.000 | then you, it's like,
00:11:02.000 | I think maybe if it trains too much, it gets saturated,
00:11:04.000 | then it kind of like learns the positional encoding as well.
00:11:07.000 | So like, if you put a space or two, right,
00:11:09.000 | then just nice, most of the training sequences,
00:11:11.000 | that particular one also is aligned with, like,
00:11:14.000 | the rest of the positions with, like,
00:11:16.000 | how it was observed in the training data,
00:11:18.000 | then it might improve in performance
00:11:20.000 | or something like that.
00:11:21.000 | I just know, like, this kind of bias exists,
00:11:24.000 | but I don't really know how to actually explain it.
00:11:26.000 | Let me see where I can find an article on it.
00:11:28.000 | For sure, for sure.
00:11:30.000 | I don't think the positional biases of the,
00:11:32.000 | the prompt and the original input would change,
00:11:34.000 | but I think perhaps what you're thinking about
00:11:36.000 | is the additional space itself plus positional bias
00:11:39.000 | might better reflect the training data
00:11:41.000 | or where the answer might be.
00:11:43.000 | Is that what you're trying to say?
00:11:45.000 | Something like this.
00:11:51.000 | Okay.
00:11:52.000 | Let me see if I can pull this link up.
00:11:55.000 | Large language models are not very valid in this.
00:12:01.000 | I'm looking to it.
00:12:02.000 | Yeah.
00:12:03.000 | I guess if there's no more questions,
00:12:11.000 | then maybe we can move on to the next paper.
00:12:13.000 | Do you want to go next?
00:12:15.000 | Yeah.
00:12:16.000 | Okay, I'll stop sharing my screen.
00:12:18.000 | Oh, you want me to go next?
00:12:23.000 | Okay.
00:12:24.000 | Yeah.
00:12:25.000 | Give me a moment.
00:12:28.000 | Can you see the screen?
00:12:31.000 | Just give me a while.
00:12:33.000 | Okay.
00:12:39.000 | Yeah.
00:12:40.000 | Hi, everyone.
00:12:41.000 | Yeah.
00:12:42.000 | This paper is a bit different.
00:12:43.000 | So it's about how we use a second model actually
00:12:48.000 | to estimate the uncertainty of LLM's response.
00:12:51.000 | So it's quite an obscure paper, actually.
00:12:53.000 | I spent a few days trying to think of a paper to sort of talk about.
00:12:57.000 | But I did enjoy reading this paper.
00:12:59.000 | For one, it approaches the problem quite differently.
00:13:02.000 | It's from a bunch of people in operations research,
00:13:05.000 | which is actually where I did grad school.
00:13:07.000 | And the sort of approach of doing it is actually more of just getting data,
00:13:11.000 | fitting a model in, and then trying to make use of that model
00:13:14.000 | for a real-world outcome.
00:13:16.000 | So let's just dive right in.
00:13:17.000 | So the TLDR is actually just three points.
00:13:19.000 | Number one, here they train a regression model.
00:13:22.000 | And, in fact, actually, it's a classical one,
00:13:24.000 | random forest, to estimate the uncertainty of an LLM's response.
00:13:28.000 | And the input for that is just the LLM's hidden layers
00:13:32.000 | of the last token, activations of the last token,
00:13:35.000 | or in the case of a model that, let's say,
00:13:38.000 | open AIs, APIs, or some other LLM provider
00:13:41.000 | where you don't actually get the actual hidden layers,
00:13:44.000 | you can actually -- or a gray box model,
00:13:46.000 | you can actually use some of the probability-related output
00:13:50.000 | as an input to this regression model.
00:13:53.000 | And the output, it's actually not the logits of the language model,
00:13:59.000 | but instead it's actually a task-specific score,
00:14:02.000 | and typically between 0 and 1, about the certainty of the answer.
00:14:05.000 | So I'll just elaborate a bit more about that later.
00:14:09.000 | The paper covers a bit about the existing methods,
00:14:12.000 | and I think most of the methods right now for quantifying uncertainty,
00:14:17.000 | they tend to be based directly on the output of the language model.
00:14:20.000 | What do I mean by that?
00:14:22.000 | Say, given a fixed prompt, you might get different outputs,
00:14:25.000 | and you try to do sampling there to see what the variation in the output is,
00:14:29.000 | or you might add some perturbations to your prompt
00:14:31.000 | and see how that results in variation in your output,
00:14:35.000 | and you measure it thereafter.
00:14:37.000 | The paper keeps specifying, keeps mentioning
00:14:39.000 | that these are unsupervised methods,
00:14:41.000 | and in contrast to what they're doing with a supervised method
00:14:44.000 | where they actually have ground truth to some degree,
00:14:46.000 | and I'll share again a bit more about that
00:14:48.000 | when I explain the problem mathematically.
00:14:50.000 | And this work has been somewhat applied
00:14:53.000 | where they take a second model to quantify the base language model,
00:14:57.000 | and it's been done for transformers,
00:14:59.000 | but not for recent large language models like LamaTree or Gemma.
00:15:03.000 | So before we go into the paper itself,
00:15:05.000 | actually, why does this matter?
00:15:07.000 | I think the paper sort of points out point two and point three
00:15:11.000 | in my list here of four points.
00:15:13.000 | They show that by doing this,
00:15:15.000 | you actually can get improved performance on certain tasks,
00:15:18.000 | and it's a potential use case in detecting hallucinations
00:15:21.000 | because by getting a sort of confidence score for an answer,
00:15:24.000 | you can then perhaps hedge, do a different sampling approach.
00:15:29.000 | Point one and point four are actually my own ideas upon reading the paper
00:15:33.000 | where actually if, say, you have a language model
00:15:35.000 | interface app, and actually with providing
00:15:40.000 | a sort of certainty score about the response,
00:15:43.000 | imagine you're a typical chatbot, people talk about rag chatbots,
00:15:46.000 | if you're able to provide like the certainty score,
00:15:49.000 | the UI/UX of it could incorporate that
00:15:51.000 | into explaining the certainty of the response.
00:15:54.000 | And number four is interesting because,
00:15:56.000 | as you'll see later on in this quick sharing,
00:15:59.000 | what we're doing is basically predicting
00:16:01.000 | the performance of an LLM answer,
00:16:04.000 | and that actually opens up possible use cases of auto evals
00:16:07.000 | where you already built in a sort of,
00:16:09.000 | you've already done your evals for like maybe a good training set,
00:16:13.000 | and then now you can scale this up to auto evals,
00:16:15.000 | and actually when you deploy an LLM system live,
00:16:18.000 | this auto eval could actually help sort of highlight cases
00:16:22.000 | where your system might potentially be giving low confidence answers.
00:16:27.000 | So let's just try to formalize this a little bit,
00:16:30.000 | so that's where we just go to this part
00:16:32.000 | where we express the problem mathematically.
00:16:34.000 | So the first thing is an LLM is just abstracted into,
00:16:37.000 | I give it an input and it generates a response,
00:16:40.000 | and the prompt here denoted by X is a series of tokens,
00:16:44.000 | and it's X1 to XK, and they all belong to this set chi,
00:16:48.000 | which is the vocabulary size of the language model.
00:16:51.000 | Thereafter, it would generate a response Y,
00:16:54.000 | which is over in the second line here, Y, it's a vector Y,
00:16:58.000 | where it consists of multiple tokens, like say M tokens,
00:17:01.000 | and again, they belong to the set Y,
00:17:03.000 | which also could be the vocabulary size,
00:17:06.000 | but it's actually also the probability distribution,
00:17:08.000 | which is the third line here where each token, YJ,
00:17:12.000 | is a probability of the conditional probability
00:17:16.000 | of the input prompts, which is vector X,
00:17:20.000 | and all the previous earlier outputs, Y1 all the way to YJ-1.
00:17:25.000 | So that's how we set up the problem,
00:17:27.000 | where it's just X and Ys.
00:17:29.000 | Typically then, if you use your language model,
00:17:32.000 | we're not just using it to do completions,
00:17:35.000 | we actually want to then use it for downstream tasks
00:17:38.000 | like Q&A, MCQ, translations,
00:17:40.000 | and here we actually have a scoring function,
00:17:43.000 | so this could be blue, and so here,
00:17:47.000 | it's actually this function S,
00:17:50.000 | where it sort of takes in the true Y,
00:17:53.000 | like let's say the true answer,
00:17:55.000 | and then models generate the answer Y,
00:17:58.000 | and then it maps it to a 0 to 1,
00:18:00.000 | so it's just any generic scoring function
00:18:03.000 | that you can think of.
00:18:04.000 | So then the task of uncertainty estimation
00:18:07.000 | is effectively learning a function G,
00:18:10.000 | so in this third step over here,
00:18:12.000 | it actually then sort of predicts the score
00:18:16.000 | given the input prompt and the output response.
00:18:20.000 | So given the prompt that I feed into the language model,
00:18:24.000 | and the response that the language model gives,
00:18:26.000 | can I then predict how good that answer is?
00:18:30.000 | So then the paper then explains
00:18:32.000 | how they can apply this approach
00:18:34.000 | to all sorts of language models,
00:18:36.000 | white box language models where you have the weights,
00:18:38.000 | grey box language models where you don't have the weights,
00:18:41.000 | and maybe you have more details about output,
00:18:43.000 | like say the probabilities of the various tokens,
00:18:47.000 | the log props, or completely black box models,
00:18:51.000 | like maybe an API provider that doesn't give
00:18:54.000 | any sort of indication of what goes on behind the scenes,
00:18:57.000 | and it just only gives a final check completion.
00:19:00.000 | So let's just share a bit more
00:19:02.000 | about the white box language model,
00:19:03.000 | because these sort of methods thereafter extend from here.
00:19:08.000 | So the sort of, for the white box language models,
00:19:12.000 | we sort of want to first build a data set
00:19:14.000 | that is called d_raw,
00:19:16.000 | where it consists of four things.
00:19:19.000 | Number one, the input prompt, which is x_i,
00:19:22.000 | the output, the answer that was generated
00:19:24.000 | by the language model y_i,
00:19:26.000 | we want the true answer, which is y_i_true,
00:19:30.000 | and lastly, the fourth item, which is the score,
00:19:33.000 | the evaluation of that answer,
00:19:35.000 | vis-a-vis the answer that the model generated.
00:19:39.000 | And notice how if I give the same x_i,
00:19:42.000 | I might get different y_i's
00:19:43.000 | because of the probabilistic nature of a language model,
00:19:45.000 | so that actually gives me more training data as well.
00:19:49.000 | So then from each sort of data,
00:19:52.000 | each row entry in this raw data set,
00:19:55.000 | I want to extract out features
00:19:57.000 | to construct what they call the uncertainty data set.
00:19:59.000 | So the uncertainty data set over here
00:20:02.000 | is denoted by d_un,
00:20:05.000 | is a tuple of your v_i and your scoring,
00:20:09.000 | and v_i here is a vector of selected features.
00:20:13.000 | So there are billions of parameters in a language model,
00:20:16.000 | and what they suggest is
00:20:19.000 | to use the hidden layers of the activation,
00:20:22.000 | and then for this experiment,
00:20:24.000 | they actually used the activations
00:20:25.000 | from the middle layer and the last layer.
00:20:27.000 | They even suggest other ways of getting features,
00:20:31.000 | such as directly asking the model to,
00:20:35.000 | like in the input prompt,
00:20:37.000 | appending the term, appending the phrase,
00:20:40.000 | "How certain are you about the response?"
00:20:42.000 | and then trying to get the activations
00:20:44.000 | with respect to that prompt.
00:20:46.000 | They suggest that as a method
00:20:48.000 | that other people have done it,
00:20:49.000 | but I don't think they implemented it
00:20:51.000 | in their approach when I read through the paper and the code.
00:20:54.000 | So once you have the uncertainty data set,
00:20:57.000 | which to recap is just
00:20:59.000 | a supervised learning example,
00:21:01.000 | where you're just a series of feature vectors,
00:21:03.000 | and then a score,
00:21:04.000 | you then train any other good old-fashioned
00:21:06.000 | supervised learning model to predict that score.
00:21:09.000 | And then once you have that trained model,
00:21:11.000 | which is now in 0.4 of what we see here,
00:21:13.000 | I can use that inference time
00:21:15.000 | where I have an input text,
00:21:17.000 | I feed it through the LLM,
00:21:18.000 | I extract the features,
00:21:19.000 | and I use my trained machine learning model
00:21:21.000 | to then predict the uncertainty score.
00:21:24.000 | So they do spend a bit of time in the paper,
00:21:26.000 | actually only in the appendix,
00:21:28.000 | so this is the kind of paper where you read the appendix
00:21:30.000 | to actually get the algorithm,
00:21:31.000 | how they actually implemented it.
00:21:33.000 | So they use 320 features.
00:21:36.000 | So these 320 features consist of 20 features
00:21:40.000 | that they use from the gray box LLM,
00:21:42.000 | which I'll explain later,
00:21:43.000 | and another 300 features of which
00:21:45.000 | they get 100 by running a LASSO regression,
00:21:48.000 | 100 by calculating the mutual information,
00:21:51.000 | and 100 by calculating the correlation coefficient,
00:21:54.000 | and then they train a random forest regressor
00:21:56.000 | with these 320 features.
00:21:58.000 | You know, actually, I don't think,
00:22:00.000 | and they probably didn't fine tune,
00:22:03.000 | I mean, optimize the hyperparameters
00:22:07.000 | for this regression model,
00:22:08.000 | and it wasn't the point of that paper,
00:22:10.000 | it was just to demonstrate that it's possible,
00:22:11.000 | so it's probably an area of work
00:22:14.000 | if people are interested to tinker.
00:22:16.000 | So that's what they do.
00:22:17.000 | To recap, prompt, feed language model,
00:22:21.000 | get the activations,
00:22:22.000 | those activations, select a few,
00:22:24.000 | train another regression model.
00:22:26.000 | The second thing that they explain
00:22:28.000 | is how you can then incorporate this
00:22:30.000 | for gray box language models.
00:22:32.000 | So here, they actually then just
00:22:34.000 | come up with 20 features
00:22:36.000 | related to the output probabilities
00:22:39.000 | from both the response,
00:22:41.000 | but also the question or the prompt,
00:22:43.000 | and it's sort of covered there in the paper.
00:22:45.000 | It's not too interesting,
00:22:46.000 | but what I thought was somewhat interesting
00:22:49.000 | was how they could get it from 20 features.
00:22:52.000 | Then the last category is for the
00:22:54.000 | black box language models.
00:22:55.000 | So it's stuff like, let's say,
00:22:56.000 | OpenAI or Entropic.
00:22:58.000 | How do you actually get some
00:23:00.000 | uncertainty estimates from there?
00:23:02.000 | So that's what I thought was a bit clever,
00:23:03.000 | where they take an input prompt
00:23:06.000 | and they feed it to the proprietary,
00:23:07.000 | let's say, the black box model,
00:23:08.000 | and they get the output response.
00:23:10.000 | At the same time,
00:23:11.000 | they take that original input prompt
00:23:12.000 | and they feed it through a white box model,
00:23:15.000 | say, LamaTree,
00:23:16.000 | and get the activations of LamaTree,
00:23:19.000 | and they're trying to map the activations
00:23:22.000 | in LamaTree to the outputs
00:23:24.000 | of a proprietary model,
00:23:26.000 | and somehow that seems to work.
00:23:28.000 | In the paper, they do the experiments
00:23:30.000 | with Lama7B and Gemma...
00:23:32.000 | I think Lama27B and Gemma7B
00:23:35.000 | as black boxes,
00:23:36.000 | and by using the other open source model
00:23:38.000 | as the white box
00:23:39.000 | for the uncertainty estimation.
00:23:42.000 | So just going into the results,
00:23:44.000 | they cover three tasks,
00:23:45.000 | Q&A, MCQ, and translation,
00:23:47.000 | and they give the results here,
00:23:49.000 | and they compare it
00:23:50.000 | against other methods
00:23:52.000 | for getting uncertainty estimation
00:23:55.000 | because what you want
00:23:56.000 | is not just an answer.
00:23:58.000 | You want something like a probability
00:24:01.000 | between 0 and 1,
00:24:02.000 | and here what the figures you see
00:24:03.000 | are actually the AUC scores.
00:24:05.000 | So then they demonstrate
00:24:07.000 | that their method is better
00:24:09.000 | in getting a higher AUC
00:24:10.000 | for these various tasks
00:24:11.000 | for Q&A and translation.
00:24:13.000 | So the AUC of the score
00:24:16.000 | vis-à-vis the answer being correct,
00:24:18.000 | binary, yes or no correct.
00:24:22.000 | So just to give an example
00:24:23.000 | of what this looks like,
00:24:24.000 | so let's say over here
00:24:25.000 | in this screenshot,
00:24:26.000 | the question is,
00:24:27.000 | "What musical featured the songs
00:24:29.000 | 'A Secretary is Not a Toy'
00:24:30.000 | in the company way?"
00:24:31.000 | I have no idea what that is,
00:24:32.000 | but apparently the answer is,
00:24:33.000 | "How to succeed in business
00:24:35.000 | without really trying?"
00:24:36.000 | And if you look at the table over here,
00:24:38.000 | if you were to take
00:24:39.000 | some sort of greedy approach
00:24:40.000 | and take the max probability,
00:24:44.000 | I believe, yeah,
00:24:46.000 | you would sort of get a--
00:24:49.000 | maybe it's the wrong answer,
00:24:51.000 | but yeah, you would get
00:24:52.000 | a probability of--
00:24:53.000 | you get a confidence score of 0.9,
00:24:55.000 | but if you look at, say,
00:24:56.000 | like the white box method,
00:24:57.000 | you would see that
00:24:58.000 | the correct answer
00:24:59.000 | over in this column,
00:25:00.000 | the WB dash S,
00:25:02.000 | the greedy answer gets
00:25:03.000 | a confidence score of 0.14,
00:25:05.000 | but answer one,
00:25:06.000 | which is the correct answer,
00:25:07.000 | gets a confidence score of 0.22.
00:25:10.000 | So in this case,
00:25:11.000 | they found that
00:25:12.000 | this confidence score approach
00:25:13.000 | could actually yield
00:25:15.000 | more accurate answers.
00:25:17.000 | So this is just an example.
00:25:19.000 | Just wrapping up quite soon.
00:25:21.000 | So remember how I said
00:25:22.000 | that what they're doing
00:25:23.000 | is training a regression model
00:25:24.000 | on the hidden activations?
00:25:27.000 | And I said that they were using
00:25:28.000 | the middle layer and the last layer.
00:25:30.000 | They sort of found that
00:25:31.000 | actually you get better performance
00:25:32.000 | when you're using the middle layer,
00:25:34.000 | and they apparently cite this thing
00:25:36.000 | from the literature,
00:25:37.000 | suggesting that actually
00:25:38.000 | the middle layer
00:25:39.000 | of these language models
00:25:40.000 | are better at summarization.
00:25:42.000 | So just to quote,
00:25:43.000 | this may come from the fact
00:25:44.000 | that the last layer
00:25:45.000 | focuses more on generation
00:25:47.000 | of the next token
00:25:48.000 | instead of summarizing information
00:25:49.000 | of the whole sentence
00:25:50.000 | as discussed by other authors.
00:25:53.000 | So that was new information for me.
00:25:55.000 | I thought it was quite interesting.
00:25:57.000 | So last thing is that
00:25:59.000 | actually there's right now
00:26:00.000 | a Kaggle competition
00:26:01.000 | where they're trying to predict
00:26:02.000 | how people rank
00:26:04.000 | the chatbot responses
00:26:06.000 | on the arena.
00:26:07.000 | So I thought that was
00:26:08.000 | an interesting sort of synergies here,
00:26:09.000 | and that paper,
00:26:10.000 | and that was pretty cool.
00:26:12.000 | And if sort of we can sort of
00:26:14.000 | automate some of the eval work
00:26:15.000 | and more from my perspective,
00:26:17.000 | where I build a large language model
00:26:20.000 | powered application
00:26:21.000 | and I deploy it,
00:26:22.000 | I also wanted to know
00:26:23.000 | is my large language model
00:26:24.000 | actually doing well in real time
00:26:26.000 | and maybe something like this
00:26:27.000 | could come in useful.
00:26:28.000 | So yeah, thanks for listening.
00:26:33.000 | Awesome, dude.
00:26:34.000 | That was a great presentation.
00:26:35.000 | Honestly,
00:26:36.000 | a super interesting paper.
00:26:39.000 | I think there were two questions
00:26:41.000 | inside the chat itself.
00:26:43.000 | First one was from Oman,
00:26:44.000 | which was,
00:26:45.000 | how is measuring of the uncertainty
00:26:46.000 | related to the concept
00:26:48.000 | of a well-calibrated model?
00:26:50.000 | Al does, yeah.
00:26:51.000 | Yeah.
00:26:52.000 | So the paper does cover
00:26:53.000 | the concept of calibration.
00:26:54.000 | So maybe to just share,
00:26:57.000 | to get everyone up to speed
00:26:58.000 | about what calibration is.
00:26:59.000 | So a model is considered
00:27:01.000 | well-calibrated.
00:27:02.000 | Let me give an example,
00:27:03.000 | a well-calibrated model.
00:27:05.000 | So let's say the model says
00:27:06.000 | it's going to rain 40% of the time.
00:27:09.000 | The probability of it raining
00:27:11.000 | is 0.4,
00:27:12.000 | then 40% of the time it would rain.
00:27:14.000 | So it's calibration,
00:27:17.000 | it's sort of means
00:27:19.000 | that the numerical score
00:27:20.000 | that the model produces
00:27:22.000 | gives a probability.
00:27:26.000 | So in this case,
00:27:27.000 | they do sort of mention
00:27:29.000 | calibration in the paper
00:27:30.000 | and they effectively are saying
00:27:32.000 | that number one,
00:27:34.000 | a model that has
00:27:35.000 | a good confidence score
00:27:36.000 | is more likely to be
00:27:37.000 | a good well-calibrated score.
00:27:39.000 | And number two,
00:27:40.000 | that usual methods
00:27:41.000 | of calibration,
00:27:42.000 | like I think isotonic regression
00:27:44.000 | or I believe some other
00:27:45.000 | binning method can be applied.
00:27:46.000 | And they do have
00:27:47.000 | some numerical results
00:27:48.000 | on calibration as well
00:27:49.000 | in the paper.
00:27:53.000 | Well, I guess,
00:27:54.000 | would it be right to say
00:27:55.000 | that then a well-calibrated model
00:27:57.000 | is a model that performs
00:27:58.000 | well on the task
00:27:59.000 | that you've assigned it
00:28:00.000 | or is it slightly more complex than that?
00:28:03.000 | Yes and no.
00:28:04.000 | Yes, it effectively means that.
00:28:06.000 | Well, you can just think of it as
00:28:08.000 | a model that's calibrated,
00:28:09.000 | it's a model that
00:28:11.000 | gives a probability
00:28:12.000 | just because a model,
00:28:13.000 | so just because a model
00:28:14.000 | spits you a number
00:28:15.000 | between zero and one
00:28:16.000 | doesn't mean it's a probability.
00:28:18.000 | So for the folks
00:28:19.000 | that are not too familiar with this,
00:28:21.000 | you can go to the Scikit-learn
00:28:22.000 | documentation and you'll give
00:28:24.000 | an example of calibration,
00:28:26.000 | of how you can calibrate
00:28:27.000 | sort of scores to become
00:28:28.000 | more like probabilities.
00:28:33.000 | I see.
00:28:34.000 | I think there are a few more questions.
00:28:36.000 | If it's okay with you.
00:28:38.000 | No, I'm sorry, I was reading it.
00:28:40.000 | Yeah, sorry.
00:28:41.000 | How does falsetto
00:28:42.000 | and sort of white box
00:28:43.000 | model to black box create?
00:28:44.000 | I am a bit doubtful as well.
00:28:46.000 | Ah yes, so Warren, yeah.
00:28:48.000 | So it doesn't mean, yeah,
00:28:49.000 | initially it told me
00:28:50.000 | I want to understand that
00:28:51.000 | because, yeah,
00:28:52.000 | how can you use one model
00:28:53.000 | for another model
00:28:54.000 | given that the architectures
00:28:55.000 | are different?
00:28:56.000 | So I think what they're,
00:28:57.000 | all they're doing is,
00:28:59.000 | okay, for let's say OpenAI,
00:29:02.000 | I run my prompt through OpenAI
00:29:04.000 | and I get a result
00:29:05.000 | and I keep that result.
00:29:07.000 | At the same time,
00:29:08.000 | I take my original prompt
00:29:09.000 | and I run it through LLAMA370B
00:29:11.000 | or whatever,
00:29:12.000 | and then I get the activations
00:29:14.000 | of LLAMA370B
00:29:16.000 | and then I just then use
00:29:17.000 | those activations as the input
00:29:19.000 | for my secondary model
00:29:21.000 | to predict the uncertainty.
00:29:24.000 | So they're trying to like
00:29:25.000 | sort of shortcut their way
00:29:26.000 | into predicting,
00:29:27.000 | to getting a sort of,
00:29:29.000 | they're effectively using,
00:29:30.000 | if I understand,
00:29:31.000 | if I sort of see it intuitively,
00:29:32.000 | is that they're using
00:29:33.000 | the open source
00:29:34.000 | or the white box model
00:29:35.000 | as a way to create
00:29:36.000 | some sort of representation
00:29:37.000 | of the prompt
00:29:38.000 | and then in any case,
00:29:39.000 | they are using
00:29:40.000 | a secondary downstream model
00:29:42.000 | to learn that mapping
00:29:43.000 | of how LLAMA370B
00:29:46.000 | would map the prompt
00:29:48.000 | into the outputs
00:29:49.000 | of OpenAI's models, yeah.
00:29:54.000 | So for Caspar's question
00:29:55.000 | about how this would compare
00:29:57.000 | to methods just using log props,
00:29:59.000 | that's another thing
00:30:00.000 | that took me a while
00:30:01.000 | to eventually realize
00:30:02.000 | is that if let's say
00:30:03.000 | I do log props,
00:30:04.000 | it's sort of predicting
00:30:05.000 | the probability
00:30:06.000 | of the next token,
00:30:07.000 | but the probability
00:30:08.000 | of the next token
00:30:09.000 | isn't necessarily
00:30:10.000 | the probability
00:30:11.000 | that my answer is correct
00:30:12.000 | because it's task specific.
00:30:15.000 | So let's say I want
00:30:16.000 | to do a question and answer
00:30:17.000 | and then I get a probability
00:30:18.000 | of like this is a MacBook.
00:30:20.000 | You can get some probability
00:30:21.000 | for that statement,
00:30:22.000 | but I don't get a probability
00:30:23.000 | for whether the statement
00:30:24.000 | this is a MacBook
00:30:25.000 | is the correct statement
00:30:26.000 | for my given task, yeah.
00:30:30.000 | And next one,
00:30:32.000 | for the mathematical models
00:30:33.000 | that measure uncertainty,
00:30:34.000 | is it dependent
00:30:35.000 | on the loss function?
00:30:36.000 | For example,
00:30:37.000 | when comparing
00:30:38.000 | a simple model train
00:30:39.000 | on L2 loss versus one
00:30:40.000 | that goes through
00:30:41.000 | additional complexity
00:30:42.000 | like RLHF model-based scoring,
00:30:44.000 | does the mathematical theory
00:30:45.000 | still hold?
00:30:46.000 | Okay, I'm not too sure,
00:30:47.000 | but there was a part
00:30:48.000 | of this paper
00:30:49.000 | that did have some theory
00:30:50.000 | about how given
00:30:52.000 | certain conditions,
00:30:54.000 | I believe like this problem,
00:30:58.000 | the optimal solution
00:31:00.000 | is your optimal base classifier,
00:31:02.000 | but there's certain
00:31:03.000 | conditional independence properties
00:31:05.000 | that aren't satisfied
00:31:06.000 | and also because if I recall,
00:31:08.000 | actually I have the paper open here.
00:31:10.000 | Yeah, so if you,
00:31:11.000 | because you don't,
00:31:12.000 | when we train a language model,
00:31:14.000 | the loss function
00:31:15.000 | is actually a different
00:31:16.000 | loss function.
00:31:18.000 | Over here.
00:31:21.000 | Sorry.
00:31:22.000 | Yeah, over here.
00:31:23.000 | So they basically say that,
00:31:28.000 | yeah, so because
00:31:30.000 | if the large language models
00:31:32.000 | aren't trained
00:31:33.000 | on the cross-entropy loss,
00:31:34.000 | the sort of theorem
00:31:35.000 | doesn't hold,
00:31:36.000 | but they do see like,
00:31:37.000 | oh, because large language models
00:31:39.000 | are trained on the larger data
00:31:40.000 | and so on and so forth,
00:31:41.000 | they do,
00:31:42.000 | there's apparently some,
00:31:44.000 | what do you call that,
00:31:46.000 | approximations that can be done.
00:31:47.000 | Yeah, so I didn't fully understand
00:31:49.000 | section 3.3,
00:31:50.000 | but maybe you could take a read
00:31:52.000 | and see if it helps.
00:31:54.000 | So yeah, it's trying to justify
00:31:56.000 | why they use hidden layers as features.
00:31:58.000 | Yeah.
00:31:59.000 | All right.
00:32:02.000 | All right.
00:32:03.000 | That's about it.
00:32:04.000 | Thanks, everyone.
00:32:05.000 | Thanks for presenting, dude.
00:32:06.000 | Yeah, I think Nick says Casper
00:32:08.000 | with mono-semanticity,
00:32:10.000 | I think.
00:32:14.000 | All right.
00:32:15.000 | Do you mind?
00:32:17.000 | Okay.
00:32:21.000 | I promise it's not just a
00:32:25.000 | but let me just get through
00:32:27.000 | some basic stuff first
00:32:28.000 | before we go through
00:32:29.000 | some interesting dashboards
00:32:31.000 | and visualizations.
00:32:34.000 | So I'm covering the paper
00:32:35.000 | towards mono-semanticity,
00:32:38.000 | which Anthropic released,
00:32:41.000 | I believe it was late last year.
00:32:45.000 | And basically the primary
00:32:47.000 | contribution of this paper
00:32:49.000 | was demonstrating
00:32:51.000 | that it's possible to use
00:32:54.000 | sparse autoencoders to identify
00:32:56.000 | features in a language model.
00:32:59.000 | So that's really fuzzy,
00:33:01.000 | but maybe just to
00:33:03.000 | level set and make sure everyone
00:33:05.000 | has an idea of what I'm talking about here.
00:33:08.000 | You know, this is
00:33:09.000 | probably one of the most important works
00:33:11.000 | in the field of mechanistic interpretability
00:33:13.000 | or mechinterp
00:33:15.000 | over the past sort of few years.
00:33:19.000 | And so what is mechanistic interpretability?
00:33:22.000 | It's basically making the inner workings
00:33:24.000 | of a neural network human interpretable.
00:33:27.000 | And this is in contrast
00:33:29.000 | to other interpretability
00:33:31.000 | or AI explainability approaches,
00:33:34.000 | which take more of a behavioral approach.
00:33:37.000 | If you think of like
00:33:39.000 | behaviorists in psychology,
00:33:42.000 | that's often the approach that's taken
00:33:44.000 | by some researchers.
00:33:46.000 | Mechinterp, on the other hand,
00:33:48.000 | is very much concerned with
00:33:50.000 | the activations within a model
00:33:53.000 | and figuring out at a very sort of
00:33:56.000 | granular level
00:33:58.000 | how it is a model is working.
00:34:02.000 | So really what we're trying to do here
00:34:05.000 | or what mechinterp is trying to do is
00:34:07.000 | trying to find features, right?
00:34:10.000 | And, you know, it's a bit of a fuzzy definition,
00:34:13.000 | but think of a feature as sort of
00:34:15.000 | a property of a token or a group of tokens.
00:34:19.000 | You know, for example,
00:34:21.000 | there might be a feature that fires
00:34:24.000 | on pronouns where
00:34:28.000 | and this might be an attention feature
00:34:31.000 | and it,
00:34:36.000 | you know, that feature firing,
00:34:38.000 | that feature activation might signify that,
00:34:41.000 | you know, the pronoun is attending
00:34:43.000 | to another proper noun
00:34:46.000 | somewhere in the sequence.
00:34:48.000 | Or, you know, it could be a
00:34:50.000 | sort of more fuzzy sort of feature where
00:34:53.000 | if this feature is firing,
00:34:55.000 | then the sort of sequence
00:34:58.000 | sounds angry, right?
00:35:00.000 | And there's all sorts of features.
00:35:02.000 | Basically, you know, if you think about
00:35:03.000 | what a model is capable of,
00:35:04.000 | what it can represent,
00:35:06.000 | you should have a feature underlying
00:35:08.000 | all of those capabilities, right?
00:35:11.000 | Now, the problem here is that
00:35:14.000 | it's actually really hard to identify features
00:35:19.000 | because models are so big.
00:35:21.000 | And in addition to that,
00:35:24.000 | you don't have nice clean features where
00:35:27.000 | one feature corresponds to a single neuron.
00:35:31.000 | And the sort of hypothesis is that
00:35:34.000 | it's because a model represents
00:35:36.000 | far more features than there are neurons
00:35:39.000 | available in the model.
00:35:41.000 | If you think about, you know,
00:35:42.000 | every single feature that you'd want to represent
00:35:45.000 | to represent the world,
00:35:48.000 | then it's sort of like,
00:35:52.000 | it's sort of obvious that a model can't be
00:35:54.000 | large enough such that you have one neuron
00:35:57.000 | that corresponds to each feature, right?
00:36:03.000 | But empirically, you know,
00:36:05.000 | it's been shown that it is possible
00:36:06.000 | to identify features.
00:36:08.000 | And this was initially done through
00:36:10.000 | really hard manual work
00:36:13.000 | and just, you know, people eyeballing
00:36:16.000 | feature activations.
00:36:20.000 | And broadly, this approach of figuring out
00:36:23.000 | what features a model represents is
00:36:26.000 | sometimes also called dictionary learning.
00:36:29.000 | So with that sort of background,
00:36:31.000 | you know, maybe I'll jump into the paper.
00:36:33.000 | So towards monosemitism was
00:36:37.000 | one of the first papers to
00:36:40.000 | really introduce and demonstrate that
00:36:42.000 | sparse autoencoders
00:36:45.000 | appear to be a pretty effective
00:36:47.000 | dictionary learning method.
00:36:50.000 | And what was done in this paper
00:36:52.000 | is that the authors at Entropiq
00:36:56.000 | trained sparse autoencoders to reconstruct
00:36:59.000 | the outputs of an MLP layer.
00:37:02.000 | But in this reconstruction,
00:37:04.000 | there's two things added.
00:37:05.000 | One is a sparsity penalty to remove noise
00:37:08.000 | and find, you know, more interpretable
00:37:11.000 | feature activations.
00:37:13.000 | And two, an expansion factor,
00:37:18.000 | which makes it so that you can --
00:37:25.000 | it's easier to sort of represent more features
00:37:28.000 | or identify more features using the SAE.
00:37:31.000 | And the way they did that in the paper
00:37:34.000 | was using a toy model with a single layer.
00:37:39.000 | So a single layer MLP.
00:37:43.000 | And they took the sort of hidden representations
00:37:45.000 | from the MLP and trained SAEs
00:37:48.000 | with a range of expansion factors
00:37:50.000 | from 1x to 256x.
00:37:54.000 | So, you know, I think --
00:38:00.000 | now that I've gone through that,
00:38:01.000 | let me actually just skip to --
00:38:03.000 | let me actually share my other screen
00:38:09.000 | where I've actually pulled up the dashboards
00:38:13.000 | of the SAEs, right?
00:38:15.000 | And then you can actually get a sense
00:38:16.000 | of what's actually being identified
00:38:19.000 | by these SAEs, right?
00:38:21.000 | So, you know, one example feature
00:38:24.000 | which they talk about in the paper
00:38:25.000 | is a feature that happens to fire
00:38:29.000 | when the model thinks that something
00:38:34.000 | is a DNA sequence in lowercase, right?
00:38:38.000 | And that looks pretty accurate.
00:38:41.000 | You know, what's interesting about these features
00:38:43.000 | is that if you look at what they do
00:38:49.000 | in terms of affecting downstream output
00:38:53.000 | is that if this feature fires on any single token,
00:38:57.000 | it has a very strong impact
00:39:00.000 | in terms of up-weighting subsequent tokens
00:39:05.000 | which are also, you know, DNA-like, right?
00:39:10.000 | You know, you have other sort of --
00:39:13.000 | there's all sorts of features.
00:39:14.000 | Some of them aren't so interpretable.
00:39:19.000 | Some of them are.
00:39:21.000 | You know, you'll have some that fire
00:39:22.000 | on some languages.
00:39:23.000 | You'll have some that fire
00:39:24.000 | on Caesar-shifted encoded words.
00:39:27.000 | And I'd encourage you to actually look
00:39:29.000 | at the visualization.
00:39:31.000 | It's linked in the Towards Monosemanticity paper
00:39:35.000 | and super interesting.
00:39:39.000 | Maybe just to give you more --
00:39:41.000 | you know, another sort of --
00:39:44.000 | give you a sense of what other work is going on.
00:39:47.000 | So that work was done on a toy model,
00:39:49.000 | a single layer.
00:39:50.000 | It's not a real language model.
00:39:54.000 | But on the open-source side,
00:39:55.000 | there's people trading SAEs now
00:39:58.000 | using the same techniques
00:40:02.000 | on all sorts of models.
00:40:04.000 | GPT-2-small is a favorite
00:40:06.000 | just because it's very well understood
00:40:08.000 | by interpretability researchers.
00:40:10.000 | And you can see all sorts of features
00:40:12.000 | identified in GPT-2-small.
00:40:14.000 | And it's actually super easy to train an SAE now.
00:40:17.000 | You can just use an off-the-shelf library
00:40:20.000 | and play around with it.
00:40:24.000 | But maybe just going back to, like,
00:40:27.000 | you know, why --
00:40:32.000 | going back to, you know,
00:40:36.000 | why this matters and what the limitations are
00:40:38.000 | and where we're headed in the future.
00:40:40.000 | Let me just switch back to my other screen.
00:40:43.000 | And here we go.
00:40:50.000 | You know, so --
00:40:52.000 | maybe let me just talk about limitations first, right?
00:40:54.000 | This is a pretty new --
00:40:56.000 | it's a pretty new technique.
00:40:57.000 | Less than a year old now, really.
00:41:00.000 | At least in the context of Mekinterp.
00:41:03.000 | You know, it's pretty clear
00:41:05.000 | that the feature set identified by SAEs isn't complete.
00:41:08.000 | There's rare features.
00:41:10.000 | There's features that might get ignored
00:41:12.000 | for some reason or another
00:41:13.000 | due to training methods not being that sort of refined.
00:41:21.000 | You know, you also have this limitation
00:41:23.000 | where as you increase the expansion factor in the SAE,
00:41:28.000 | you get this phenomenon of feature splitting
00:41:30.000 | where, you know, what was previously a single feature
00:41:34.000 | now becomes a family of, like, 20 features, for example,
00:41:39.000 | all of which are, like, sub-features.
00:41:41.000 | And it gets a bit complicated.
00:41:43.000 | So you don't get, like,
00:41:44.000 | these very nice, clean representations necessarily.
00:41:47.000 | Everything's sort of fuzzy.
00:41:50.000 | But there is work going towards, like, improving them, right?
00:41:53.000 | And there's quite a lot of work on improving SAEs.
00:41:57.000 | This is now seen as the most promising line
00:42:00.000 | of Mekinterp work within both Anthropic and DeepMind.
00:42:06.000 | And there's also folks at OpenAI working on SAEs.
00:42:09.000 | And OpenAI actually open-sourced their SAEs for GPT-2 small,
00:42:13.000 | which is pretty interesting.
00:42:16.000 | But, yeah, you know, I don't think this is --
00:42:18.000 | this is, like, very much like an exploratory paper, right?
00:42:20.000 | It's sort of like --
00:42:22.000 | I think the sort of point is that, hey,
00:42:25.000 | there's this new technique.
00:42:27.000 | It looks pretty interesting.
00:42:29.000 | But the jury is still out on whether this ends up being,
00:42:32.000 | you know, the technique that solves Mekinterp.
00:42:37.000 | But, yeah, that's it.
00:42:42.000 | >> Awesome. Thanks for the presentation.
00:42:44.000 | I think Warren had two questions.
00:42:46.000 | What is the difference between SAE versus a probing classifier?
00:42:50.000 | I think the other one he was saying is,
00:42:52.000 | what makes SAEs such a good fit for this specific task?
00:42:56.000 | Why not use other models to extract monosemanticity?
00:43:00.000 | >> Yeah, sure.
00:43:02.000 | So it is similar in some regards to linear probes.
00:43:08.000 | But I think the nice thing about SAEs is that --
00:43:11.000 | and this sort of kind of covers the second question to some extent --
00:43:15.000 | is that you get a few nice features.
00:43:17.000 | One is that it's unsupervised,
00:43:20.000 | so you can learn a lot of features at once, which is nice.
00:43:23.000 | It doesn't need any sort of humans eyeballing things.
00:43:29.000 | And two, so why SAEs in particular?
00:43:36.000 | You know, I think, you know, the features that are helpful --
00:43:42.000 | or rather than features, the elements on the SAE,
00:43:45.000 | which are helpful in this regard, one is like the sparsity,
00:43:48.000 | which is important.
00:43:49.000 | And also the ability to just reconstruct the inputs, right?
00:43:51.000 | Because you want that.
00:43:52.000 | And that helps you actually --
00:43:56.000 | if you're reconstructing hidden activations,
00:44:00.000 | you get a very nice natural way to assess the goodness of an SAE,
00:44:04.000 | because you can measure the reconstruction error.
00:44:07.000 | So, you know, how you typically evaluate these models.
00:44:10.000 | One metric that you measure them on is actually reconstruction error, right?
00:44:14.000 | So you ablate the actual sort of activations in the model,
00:44:19.000 | and you replace them with actually the reconstructed activations from the
00:44:36.000 | All righty.
00:44:37.000 | Any other questions?
00:44:43.000 | >> Seems like I think that's about it.
00:44:46.000 | I guess if that's the case, then let me just share briefly on Medusa.
00:44:49.000 | Okay.
00:44:52.000 | Let me just see if I can find the screen.
00:44:55.000 | Awesome.
00:44:56.000 | Thanks for the reminder, Doug Abel.
00:44:57.000 | Appreciate it.
00:44:58.000 | All right.
00:44:59.000 | Sorry about that.
00:45:00.000 | So today I'll just be presenting quickly on Medusa.
00:45:04.000 | So the concept behind Medusa is just basically it's a better way to do
00:45:08.000 | speculative decoding.
00:45:10.000 | And so I think before we move into what speculative decoding is,
00:45:13.000 | I think it's important to see what are the main problems with model
00:45:16.000 | inference.
00:45:17.000 | Just so we're all on the same page,
00:45:18.000 | I'm just going to paste the link that I'm looking at in the chat itself.
00:45:23.000 | So if you're familiar with any large language model,
00:45:26.000 | a large chunk of the time that we spend, it's basically running,
00:45:29.000 | transferring the weights from the, well,
00:45:31.000 | I guess the high bandwidth environment, like what you see over here,
00:45:33.000 | over to the cache,
00:45:35.000 | which is where there's a very limited amount of space and where most of the
00:45:38.000 | calculations happen.
00:45:40.000 | So what that means is that if we run inference one time,
00:45:43.000 | we load in all parameters, we load in all inputs,
00:45:46.000 | we get out the final output of one,
00:45:49.000 | we only get one token and we have to repeat the whole step.
00:45:52.000 | There are a lot of optimizations that have been done around it,
00:45:55.000 | but basically the main problem is still not fixed that you,
00:46:00.000 | you do all this work and you only get one token out.
00:46:03.000 | So what people have done is this thing called speculative decoding.
00:46:07.000 | So what does that mean?
00:46:09.000 | So I think we have a huge model, like a LLAMA70B,
00:46:12.000 | and we have a smaller model, which we call a draft model,
00:46:15.000 | called a LLAMA7B.
00:46:18.000 | A LLAMA7B can take an initial prompt and quickly generate a whole bunch of
00:46:22.000 | tokens itself.
00:46:23.000 | So you can think, let's say you've seen some prompt,
00:46:26.000 | it generates like N tokens that are supposed to be,
00:46:29.000 | that it thinks are going to follow this specific prompt itself.
00:46:34.000 | So now we have a bunch of candidates, right?
00:46:36.000 | For a proposed sequence inside this draft model.
00:46:39.000 | So how do we know what percentage of these tokens are correct?
00:46:43.000 | I guess along the same lines, how many of these tokens we should reject?
00:46:47.000 | The way to do this is to basically just batch all these tokens and feed
00:46:50.000 | it into the model itself.
00:46:52.000 | So what this could look like is let's say our prompt is the capital of
00:46:56.000 | France is, and our smaller model says, hey,
00:46:58.000 | the capital of France is Paris and it's a beautiful city.
00:47:01.000 | For inference, we would pass in the capital of France is,
00:47:04.000 | the capital of France is Paris.
00:47:07.000 | And each step, what we're basically trying to see is,
00:47:09.000 | does the completion from the smaller model match the big model?
00:47:14.000 | And so what we're doing here is we're able to batch all of these inside a
00:47:16.000 | single like forward pass.
00:47:18.000 | And we don't incur a huge back and forth transfer.
00:47:21.000 | And so we see this huge speed up in terms of the decoding speed itself
00:47:25.000 | when we're generating tokens.
00:47:27.000 | So there's some problems with this though.
00:47:31.000 | The first one is, of course, optimization,
00:47:33.000 | because now we need a smaller model.
00:47:35.000 | We need to feed it through this whole chunk of data.
00:47:38.000 | And we need to somehow do the reconsideration between the original
00:47:41.000 | completions and the new completions.
00:47:43.000 | The second, which I think is the bigger problem,
00:47:45.000 | is that the draft model might not accurately reflect the capabilities or
00:47:50.000 | the world knowledge of the larger model itself.
00:47:54.000 | If you're going to play, let's say, Gemma 2B,
00:47:56.000 | it might not really be the same as a Llama 7B,
00:47:58.000 | even if it's able to decode like six times its path.
00:48:02.000 | So in comes Medusa.
00:48:04.000 | So traditionally we have some input passes through an embedding.
00:48:09.000 | It goes through your transformer layers and we get our hidden state.
00:48:12.000 | This is going to be a single vector of some dimensions.
00:48:15.000 | That's going to be the same as your embedding dimension itself.
00:48:18.000 | And what we would always do is we'd say, okay,
00:48:21.000 | you might then have a linear layer.
00:48:26.000 | And you get up basically a vector with a whole bunch of probabilities that
00:48:30.000 | correspond to the probability that each individual token for that position is
00:48:34.000 | the next token that should be selected.
00:48:37.000 | So that's the original transformer flow itself.
00:48:41.000 | What Medusa does is that it slaps on a whole bunch of new MLPs that operate
00:48:45.000 | on the same hidden state and try to make the same prediction itself.
00:48:49.000 | So you can see this LM hit predicts it IS.
00:48:52.000 | Medusa predicts IS the second hit, the first hit,
00:48:55.000 | which predicts the second token in the completion.
00:48:58.000 | It goes for IS, comma, the IS across the speaker and so on.
00:49:03.000 | These hits aren't anything special.
00:49:05.000 | They're really just MLP networks that generate a distribution over the vocabulary.
00:49:08.000 | So you can see over here that all it does is it's just, well,
00:49:12.000 | the final hidden state.
00:49:13.000 | This is probably going to be a one times D vector, right?
00:49:16.000 | It's multiplied by a single weight matrix, a P1K, which is D by D.
00:49:20.000 | So that gives you a one times V, right?
00:49:23.000 | And then we add the residue itself.
00:49:26.000 | And so then once we add these two residues together, we do a softmax.
00:49:29.000 | I feel like I might be messing up the dimensions, but basically it's a,
00:49:33.000 | you're going to get out a probability distribution at the end that's equal to
00:49:36.000 | the number of tokens that you have.
00:49:38.000 | Each hit is essentially going to produce a probability distribution over all
00:49:42.000 | these different choices.
00:49:43.000 | And so you're going to get, well,
00:49:45.000 | basically SK different options for each token.
00:49:47.000 | I think the best way to see it is sort of over here,
00:49:50.000 | we can see the completions, right?
00:49:52.000 | So this is what the original language hit sort of predicts.
00:49:56.000 | These are going to be the next tokens that are predicted by these first hit,
00:50:01.000 | first Medusa hit, the second Medusa hit, the third Medusa hit, and so on.
00:50:06.000 | And so what they do is that they always choose the first token that's
00:50:10.000 | generated by the original language modeling hit to guarantee that your
00:50:13.000 | enemies get some completions.
00:50:15.000 | But then when it comes to the other tokens being chosen and that by itself,
00:50:20.000 | there's some new way that you have.
00:50:22.000 | So I think they also mentioned that they do some sort of greedy algorithm
00:50:25.000 | doing training on a training data set,
00:50:27.000 | whereby they try to determine the optimal size for the street.
00:50:30.000 | So each individual node level.
00:50:33.000 | So there are two ways that they do this training.
00:50:34.000 | One is that you freeze the base out and you only do the hits.
00:50:37.000 | And what this does is that you basically are just doing the same cross
00:50:41.000 | entropy loss,
00:50:42.000 | but you apply this sort of biased constant term here that is a constant
00:50:48.000 | taken to power of K.
00:50:49.000 | So what this means is that for the overall loss that you're calculating,
00:50:53.000 | they call it L_Medusa_1.
00:50:55.000 | This is the first way that you train the Medusa level, Medusa hit.
00:50:58.000 | You're essentially weighting the token,
00:51:00.000 | the hits that are predicting tokens that are further and further out into
00:51:03.000 | input sequence, less and less.
00:51:05.000 | And this is actually super fast because you can get around five hours.
00:51:08.000 | You just need about five hours with 60K samples and your 7P model is good
00:51:12.000 | to go.
00:51:14.000 | The harder way that I use a lot of better results is basically for you to
00:51:17.000 | train the LM and the individual hits.
00:51:20.000 | That results in new loss equation of this is your original language
00:51:23.000 | modeling equation. And this is your, well, what we had over here.
00:51:27.000 | So this time they have the smaller term called L_0,
00:51:29.000 | which is basically a small, small, small, small term.
00:51:31.000 | So that the head prediction doesn't mess up the overall loss because the
00:51:35.000 | Medusa hit is going to be super wrong at start.
00:51:38.000 | Since it's, it's not trained on the dataset, hasn't seen anything.
00:51:42.000 | It's just an MLP.
00:51:43.000 | So they do some sort of linear warmups whereby the learning rate is slowly
00:51:46.000 | increased over time and then maybe decrease and there's some scheduling
00:51:49.000 | that's going on there.
00:51:51.000 | The last part is just this dataset, which I think was,
00:51:54.000 | it's pretty interesting.
00:51:57.000 | If you train your model on the new, on, on this dataset itself,
00:52:00.000 | on the Medusa hits, it's not really a problem.
00:52:03.000 | You can just take a public seat, public dataset and you can train it.
00:52:08.000 | And in this case, we just,
00:52:09.000 | if you're worried that the dataset that you're training a model on doesn't
00:52:11.000 | actually reflect what the model has learned,
00:52:13.000 | you can just basically take a public dataset with a whole bunch of prompts
00:52:16.000 | and just get your model to generate a completion itself.
00:52:20.000 | And for certain models that have basically like basically the ability to
00:52:23.000 | train on the system, the user, the system,
00:52:28.000 | you can have multi-turn conversations, which is great.
00:52:31.000 | So that generally works pretty well from,
00:52:33.000 | from what they say about the Medusa hits.
00:52:35.000 | If you just freeze the base LM and you just train the Medusa hits,
00:52:39.000 | but they do say that if you are training the whole model itself,
00:52:43.000 | plus the Medusa hits, which is this step over here,
00:52:46.000 | you probably want to also include like a little KL divergence term so that
00:52:50.000 | when you run the loss itself,
00:52:51.000 | the model parameters don't change so much and you want to sort of minimize
00:52:55.000 | the difference of your model from your original model.
00:52:58.000 | So that essentially you're still outputting like high quality completions.
00:53:04.000 | So yeah, that's basically the Medusa paper summarized pretty fast.
00:53:07.000 | There's a whole bunch of stuff that I've skipped over,
00:53:10.000 | but this is basically the main high level idea behind the Medusa paper itself.
00:53:15.000 | So yeah, happy to take any questions.
00:53:18.000 | Let me just try to pull up the chat if there's any questions.
00:53:22.000 | But yeah.
00:53:23.000 | Okay.
00:53:32.000 | What does it cost more than increasing the beam search parameter?
00:53:38.000 | So if I use five Medusa hits,
00:53:42.000 | is it like five times six by 30?
00:53:45.000 | So I don't think they actually use beam search inside this itself.
00:53:49.000 | So the way that I've seen,
00:53:52.000 | so I looked at the code before this to try and see,
00:53:54.000 | and they provide a few different ways.
00:53:55.000 | The first one is some greedy new click and sampling.
00:53:58.000 | Basically the idea is that all you're doing with Medusa is that you're
00:54:03.000 | changing the use of it.
00:54:05.000 | You're changing the way that you generate these speculated tokens itself.
00:54:08.000 | So you originally use the draft model,
00:54:11.000 | but with Medusa you use these hits.
00:54:13.000 | You're still running it through a separate search.
00:54:15.000 | So depending on how you use your beam,
00:54:18.000 | how your beam search is implemented with the Medusa hit itself,
00:54:21.000 | I guess it will really determine the completion.
00:54:26.000 | But I think it's not super clear in the paper how exactly they do the
00:54:33.000 | final computation.
00:54:34.000 | They just say that they try to find the longest prefix length that's
00:54:37.000 | common across all the different potential completions that are generated.
00:54:42.000 | I hope that answers the question, Bennett.
00:54:45.000 | I'm starting to want to pull up some stuff.
00:54:50.000 | Let's see.
00:54:52.000 | Multi-token prediction.
00:54:53.000 | I've seen this paper.
00:54:54.000 | I haven't read it yet.
00:54:55.000 | So I think perhaps I will look at it after this and figure it out.
00:55:02.000 | Yeah.
00:55:03.000 | But I guess if not, then I'm just going to end the recording over here.
00:55:06.000 | If anyone has any questions, happy to answer it.
00:55:08.000 | I'm just going to end the recording over here.
00:55:12.000 | All right.
00:55:19.000 | Yongxin, I think I need you to end the recording.
00:55:21.000 | Okay.
00:55:22.000 | Okay, let's share.