Skip to content

Instantly share code, notes, and snippets.

@yoavg
Last active April 17, 2024 14:57
Show Gist options
  • Save yoavg/6bff0fecd65950898eba1bb321cfbd81 to your computer and use it in GitHub Desktop.
Save yoavg/6bff0fecd65950898eba1bb321cfbd81 to your computer and use it in GitHub Desktop.

Reinforcement Learning for Language Models

Yoav Goldberg, April 2023.

Why RL?

With the release of the ChatGPT model and followup large language models (LLMs), there was a lot of discussion of the importance of "RLHF training", that is, "reinforcement learning from human feedback". I was puzzled for a while as to why RL (Reinforcement Learning) is better than learning from demonstrations (a.k.a supervised learning) for training language models. Shouldn't learning from demonstrations (or, in language model terminology "instruction fine tuning", learning to immitate human written answers) be sufficient? I came up with a theoretical argument that was somewhat convincing. But I came to realize there is an additional argumment which not only supports the case of RL training, but also requires it, in particular for models like ChatGPT. This additional argument is spelled out in (the first half of) a talk by John Schulman from OpenAI. This post pretty much repeats his argument in more words, and also adds some things that John did not say explicitly (but which I am sure he thought about).

I included quite a bit of background to be self contained. You can skip directly to "The core argument" to get directly to the main dish.

Background: Supervised learing vs RL

Let's briefly explain the two learning scenarios so we are on the same page. Feel free to skip this part if you are "in the know".

Pre-training In both settings, we assume a language model is first pre-trained over large body of text, with the objective of predicing the next token. So we have a model that, for every sequence of words, can assign a probability over the options of the potential next word. By doing so, it also acquires some sort of an internal representation of the language. After this process, the model is very capable of generating texts, and providing natural continuations to given text prefixes, but it is not good at "communicating". For example, when prompted with a question, it might either answer it OR it might generate a series of additional questions, OR it might say this is an important question that was raised in the context of ..., etc. All of these are valid continuation that follow questions in natural language texts. We can make the model perform desired language actions by crafting input texts in a way that their continuation will solve our problem (this is called "prompt engineering"), but this is not a very convenient interaction mode for non-expert users, who just want to ask a question or provide an instruction, and let the model follow it. If we want a model that is capable of consistently answering queries and not only completing them, we need to guide it towards this behavior. This guidance is called "fine tuning": continue to train the pre-trained model so that it behaves as we would like. (Some people call this "aligning" the model with a desired behavior)

Suprvised Training In the supervised learning case (also called learning from demonstrations, or "instruction tuning") we collect a bunch of human authored texts that have a form of a question or an instruction, followed by the desired output. For example, these texts can be a question followed by its answer, or a task such as summarize the following text {text} followed by its human authored summary. By continuing to train the model on the same "predict the next token given the prefix" objective, but this time on this collection of instruction-output pairs, the model learns to respond to instructions by performing them. That is, the model receives demonstrations of what a correct output for a given query is, and learns to replicate this output. We hope that it would generalize this behavior to other queries, not seen in training.

Reinforcement Learning (RL) In the reinforcement learning setup, we provide the model with the instructions, but not with their human authored answers. Instead, the model should generate its own answer. A scoring mechanism (for example, a human) reads the generated answers, and tells the model if its good or not. The model's objective is to learn how to answer in a way that recieves high scores. An alternative mechanism is that the model generates several answers, and the scoring mechanism tells the model which one is the best. The model's objective is to learn to produce the higher scoring answers and not the lower scoring ones. In both cases, the model learns by creating its own answer, and receiving a feedback. (note: many researchers consider RL more narrowly, based on some techincal aspects of the credit assignment mechanism, and for them the question "do we need RL" may boil down to should we use this family of techniques or some alternative family. I share their curiosity, but for the purpose of this post I consider any method that uses an external scoring function as RL, regardless of its mechanics.)

RL is much harder than supervised training for several reasons. One such reason is "credit assignment". The language model generates a sequence of tokens, and only gets a score at the end of the sequence. The signal is weak: which parts of the answer are good and which are bad? A lot of technical works in RL attempts to solve this problem, and we put it aside for this post. It is an active research area, but reasonable solutions exist. The other issue is that we need a scoring mechanism to score the answers (either assign a score or compare two answers) and, in the context of language-based tasks, it is hard to come up with an automatic scorer (though that might be changing, as I briefly discuss below). Thus, we are left with "human feedback" for each learning step. This is very expensive and also slow, and the problem is even worse given that each human feedback only gives a rather sparse signal, as we just saw above. Given these difficulties, why should we use RL and not just supervised learning?

The diversity argument

Perhaps the most intuitive argument against superivsed learning / instruction tuning in the context of language generation models is that we teach the learner to replicate the exact answer given by the demonstrator, while the reality of human language is that there are many different ways to convey the same message, and they all might be valid answers. We "punish" the model for even slight deviations from our prescribed text, which may confuse it. We may also insist on a phrasing which is hard for the model to learn, while the model already knows how to produce an alternative---and equally valid---answer. We would thus like the diversity afforded by RL training. This is a very intuitive argument, but not a very convincing one, given that supervised learning does seem to work very well in practice, and given the challenges in training RL models. For a long time, I was not convinced that this is a core enough issue, and I am still not convinced.

The theoretical argument

The first "convincing" justification I came up with for RL vs supervied learning in LLMs is that supervised learning allows only positive feedback (we show the model a series of questions and their correct answers) while RL allows also for negative feedback (the model is allowed to generate an answer an get a feedback saying "this is not correct"). From a formal learning theory perspective, there is a big difference between the two: negative feedback is much more powerful. The theoretical argument is, roughly, that when learning only from demonstrations, an adversarial (or neglient..) demonstrator can mislead the learner into learning the wrong hypothesis by witholding some important examples. The demonstrator controls the learning process entirely. However, if you as a learner are allowed to form your own hypotheses and ask the teacher if they are correct (as in the RL setting), even an adversarial teacher can no longer trick you into latching on to a wrong hypothesis. It must disclose that its wrong if you ask about it. The learner is now much more powerful. (Of course, this assumes the adversarial or neglient teacher still plays by the rules and always provides truthful answers. But this is a reasonable assumption to make in a theoretical framework, and it does not hurt the overall argument of why learning from demonstrations is weaker than learning by interaction or by asking questions.)

This is all nice and well, and I do believe this is part of the reason RL is needed. But there is also an additional argument which might be even more important in the context of training large language models to communicate by answering questions.

The core argument

This leads me to the core reason that requires RL-like training. The previous two arguments rely on hypotheses such as "it might be harder for the model to learn" or "a neglient demonstrator may confuse the model", which may or may not hold in practice. In contrast, the current argument provably holds.

There are (at least) three modes of interaction with a language model: (a) text-grounded: we provide the model with a text and an instruction ("summarize this text", "based on this text, what is the population of Israel", "what are the chemical names mentioned in this text", "translate this text to spanish", etc), and expect the answer to be fully grounded in the provided text. (b) knowledge-seeking: we provide the model with a question or instruction, and expect a (truthful) answer based on the model's internal knowledge ("What are common causes of flu"). (c) creative: we provide the model with a question or instruction, and expect some creative output. ("Write a story about...")

The argument for RL is based on interaction type (b): knowledge-seeking queries in which we expect a truthful (or confident) answer, and the ability of the model to say "I don't know" or refuse to answer in situations in which it is uncertain.

For this type of interaction, we must use RL training, as supervised training teaches the model to lie. The core issue is that we want to encourage the model to answer based on its internal knowledge, but we don't know what this internal knowledge contains. In supervised training, we present the model with a question and its correct answer, and train the model to replicate the provided answer. There are two cases: (1) the model "knows" the answer. In this case, the supervised training correctly pushes it to associate the answer with the question, hopefully pushing it to perform similar steps to answer similar questions in the future. This is the desired behavior. (2) the model does not know the answer. In this case, the supervised training pushes the model to associate the answer with the question anyhow. Now, there are two options. It may push the model to memorize this particular question-answer pair. This is not harmful, but also not very effective, as our aim is for the model to generalize and learn to answer any question, not only the ones in the instructions training data. We want the model to generalize. But if we are succeed in training the model to generalize in these cases, then we essentially teaches the model to make stuff up! it actively encourages the model to "lie". This is bad.

Because we don't know what the model knows or not, we cannot avoid case (2), which is a real and serious issue for supervised training. We cannot use pure supervised learning to push the model for producing truthful answers, and we thus must use RL for this. In contrast to the supervised setting, the RL setting does not actively encourage the model to lie: even if the model does initally guess some answers correctly and learns a "making stuff up" behavior by mistake, in the long run it will get bad scores for made up answers (which are likely to be incorrect) and learn to adopt a policy that relies on its internal knowledge, or abstain.

Smaller remark: teaching to abstain

In case the model doesn't know the answer, we would like it to abstain and respond with "I don't know" or a similar answer. This is non trivial to do. This is hard to do in the supervised setting, because we do not know what the model knows or not. We can push it towards not answering questions of a certain kind ("never answer questions about people") and responding instead with "I don't know". But this is not the intended behavior of abstaining when the answer is unknown, only a very weak proxy for it. However, this is challenging also for the RL setting: the model may never produce an "I don't know" answer to begin with, and so we would have no way of encouraging it to generate such answers. One way around this is to start with some supervised training learning to produce "I don't know" answers in some cases, and then continuing the process with RL. In both the supervised and the RL cases there is the worry that the model will learn to over-generate "I don't know". This is an open research question. One possible family of solutions is to tailoring a reward that will assign very high scores to correct answers, medium-low scores to abstaining, and strong negative scores to incorrect answers. But this is not easy to get right.

Implications on model stealing / distillation

OpenAI, the company behind the GPT models, has invested a lot of effort in RL-type tuning of its language models. Parts of their motivation was to ensure factuality / truthfulness, by encouraging the model from abstaining from providing answers when it does not know the answer.

There is a recent trend of taking other, publicly available, base language models, and training them on GPT examples of GPT outputs, in order to replicate the GPT model's impressive behaviors.

Note that this is akin to supersied training / instruct tuning: the models are trained to produce the GPT model answers exactly. This should work well for teaching the model to perform instructions. However, it will not work well for case (b), teaching the model to answer knowledge-seeking queries. The publicly available base model likely knows a different set of facts from the OpenAI model, and training to replicate GPT's answers will suffer from the same issue supervised training suffers from: the model will be encouraged to make up facts to these types of queries, and additionally may learn to abstain in cases where it does know the answer but the GPT model did not.

The solution is, then, to train these models with RL. But isn't it too expensive?

Towards RL without human feedback

For a long time, training generative language tasks with RL has been impractical for most players: lacking a reliable automatic scoring metric, RL training requires a human feedback for every training sample. This is both expensive and extremely slow, especially for models that neeed to see thousands to tens or even hundreds of thousands of examples to learn.

However, RL training now becomes practical: first, it seems that large pre-trained language models manage to somehow learn from fewer examples. But, more importantly, they pave the way towards removing humans from the RL loop.

This relies on the observation that for text-grounded tasks the supervised training paradigm is very effective, and that the large models can learn to perform some tasks very well. One such task is considering two texts and asking "do these two texts mean the same thing", another is "are there facts in text A that do not appear in text B". (We can also decompose and task the model with "Generate all question-answer pairs that are answerable from this text" and then for each question ask "Is there an answer for this question in this other text, and what is it").

Empirically, large language models (or even medium ones) seem to be able to learn to perform such tasks rather reliably using supervised learning. This provides us with an effective automatic scoring metric that we can use in an RL setting: train on the human provided instruction-responses pairs, but rather than trying to replicate the human responses directly, let the model generate its own response, and compare the model generated response to the human provided one using a dedicated text comparison model that was trained in a supervised fashion.

@mdda
Copy link

mdda commented Apr 22, 2023

Very interesting analysis - thanks for writing this up.

Maybe fix the following typos for the next revision?

  • "Supervised learing vs RL" (heading) -> "Supervised learning vs RL"
  • " with for RL vs supervied learning in LLMs" -> " with for RL vs supervised learning in LLMs"
  • "generate an answer an get a feedback" -> "generate an answer and get feedback"
  • "an adversarial (or neglient..) demonstrator" -> "an adversarial (or negligent..) demonstrator"
  • "disclose that its wrong" -> "disclose that it's wrong"
  • "or neglient teacher still plays" -> "or negligent teacher still plays"
  • "a neglient demonstrator may confuse the model" -> "a negligent demonstrator may confuse the model"
  • " is akin to supersied training" -> " is akin to supervised training"

@KastanDay
Copy link

Excellent synthesis of his talk, much appreciated. I really enjoyed your ideas on "negative feedback" and why "learning from demonstrations is weaker than learning by interaction or by asking questions."

To understand John's talk, a key question centers on "small scale fine-tuning," as John calls it, vs full-network (or LoRA) fine-tuning. Does his argument hold in the latter case?

John's talk only considers the "small scale fine-tuning" case, which he implies only tunes the last (few?) layer(s) of an LLM. In this case, his argument about "learning to lie" makes intuitive sense because the internal 'knowledge graph' is not being updated.

However, if we fine-tune all the weights in the network, I have no reason to think his argument still holds. In this case, we're updating the 'knowledge graph' internal to the model along with the outputs. It's essentially pre-training, where we agree the model does not exhibit any "learning to lie" behavior.

Of course, the downside of full-network fine-tuning is compute cost and (possibly) data inefficiency. I'd be curious if fine-tuning with LoRA (instead of full-network fine-tuning) does any better on factualness, since it seems to out-perform full-network fine-tuning on other benchmarks (see LoRA and AdaLoRA papers).

@eminorhan
Copy link

The InstructGPT paper (of which Schulman is a co-author) has a figure (Figure 4) that clearly shows that supervised finetuning (SFT) models hallucinate less than the base language models and even less than the RLHFed models (PPO and PPO-ptx). So, I'm puzzled by this new narrative going around that SFT somehow causes (or exacerbates) hallucinations in language models and RLHF somehow reduces them. The arguments I've seen to this effect, including this one, are very handwavy and unconvincing. I will remain skeptical until somebody shows actual convincing empirical evidence to the contrary and explains the discrepancy with the InstructGPT results. https://arxiv.org/abs/2203.02155

@carlthome
Copy link

Wonderful high-level summary for getting some intuition!

@ajitvr
Copy link

ajitvr commented Apr 23, 2023

Thank you for post - extremely helpful.

https://www.youtube.com/watch?v=C_78DM8fG6E&t=621s

In this recent talk by Greg brockman he mentions how Khan Academy fine tuned the model to do correct arithmetic. Given that approach is claimed to have worked, the requirement for knowledge to be present in pertaining perhaps only means factual information (like KG triples) needs to be present as knowledge? Examples to add two numbers seem to fall into the knowledge-seeking category you mention. But those examples already exist in pretraining data and even though the model has not generalized during pretraining (Greg mentions it can add two 40 digit numbers but makes errors when adding a 40 and a 25 digit number) it appears to have been learned with SFT.

So SFT appears to play a role beyond just use for the two other buckets - text-grounded and creative tasks? Perhaps it is a fourth bucket of algorithmic tasks where SFT could help with generalization for that specific task? In this case it helps learn/improve methods model potentially didn't get right during pretraining. Iam not sure if they did SFT/RL or just SFT, but this question still remains, given pertaining appears to be the phase of knowledge acquisition from Ily'as recent interview too. Ilya mentions the key distinction between GPT4 and prior versions it has much higher accuracy in next word prediction (https://youtu.be/XjSUJUL9ADw)

@KastanDay
Copy link

The InstructGPT paper (of which Schulman is a co-author) has a figure (Figure 4) that clearly shows that supervised finetuning (SFT) models hallucinate less than the base language models and even less than the RLHFed models (PPO and PPO-ptx).

Excellent point. I have a feeling we're subtly misunderstanding John's real take away. His point about knowledge graph + last-layer fine tuning makes sense, doesn't apply to the common case of fine-tuning. Typically we don't fine-tune just the last layer, we either do full-model fine-tune or use a PEFT method like LoRA.

@tomatopuree
Copy link

Figure 4 is misleading. "these results are collapsed across model sizes". If you check Appendix E3 which has graphs across model sizes, you can see that the prevalence of attempting the correct instruction, following explicit constraints (both of whose failure are types of hallucinations) and appropriate for a customer assistant is higher for PPO-ptx than for SFT. The fact that SFT is better than PPO-ptx on hallucinations is true, but only for the 175B model, and only by a margin smaller than all the performance differences previously mentioned. I don't think the case against RLHF is as sound as some may think.

image

@Wang-Jinxu
Copy link

Wang-Jinxu commented May 2, 2023

Prof. Yoav, your blog are very enlightening. I am wondering if we can translate it into Chinese and post it in our WeChat official account platform. I believe this move will beneficial more people. We will highlight your name and keep the original link on the top of the translation version. Thank you!

@rama100
Copy link

rama100 commented May 14, 2023

@KastanDay

if we fine-tune all the weights in the network, I have no reason to think his argument still holds. In this case, we're updating the 'knowledge graph' internal to the model along with the outputs. It's essentially pre-training, where we agree the model does not exhibit any "learning to lie" behavior.

Agreed. In John's 'Solo' example, fine-tuning will increase the prob of 'Solo' the next time around (which is what we want). This thing was bothering me enough that I cold-emailed John a couple of days ago (no response! :-)).

John's talk only considers the "small scale fine-tuning" case, which he implies only tunes the last (few?) layer(s) of an LLM. In this case, his argument about "learning to lie" makes intuitive sense because the internal 'knowledge graph' is not being updated.

Why would fine-tuning just the last few layers (as opposed to all the layers) leave the internal knowledge graph untouched?

@saetlan
Copy link

saetlan commented May 29, 2023

  • "Supervised training teaches the model to lie". I don't understand why the model would start to learn a lie. When doing the supervision we would now the true answer no ? So we would make the model go towards the expected answer.
  • "There are many different ways to convey the same message": is there no way to use a supervised loss on preference without the RL part ? Like it isn't the RL that solves the fact that many answers may suit. It's the fact to do it on the "preference" task.
    Thanks for clarifying those concepts !

@disperaller
Copy link

  • "Supervised training teaches the model to lie". I don't understand why the model would start to learn a lie. When doing the supervision we would now the true answer no ? So we would make the model go towards the expected answer.
  • "There are many different ways to convey the same message": is there no way to use a supervised loss on preference without the RL part ? Like it isn't the RL that solves the fact that many answers may suit. It's the fact to do it on the "preference" task.
    Thanks for clarifying those concepts !

When reading this argument, what i have in mind is the example as follows: if we feed the model some examples about the Population Number in some countries, and we can even go further to fine-tune the whole model with these examples. It still does not have the true answer of the Population Number for countries that it didn't see during the finetuning. Thus, is this type of supervised learning, the SFT process will tweak the model to make its answer more hallucinating, which would partially explain that the SL will teach the model to lie.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment