Browse parent directory
unimportant/rl_for_llm.html
2025-04-14
RL for LLM
I'm writing this more for my own understanding than to teach anyone.
LLM solves following problem
- Given some initial tokens t1 t2 t3 t4, guess probabilities of which token t5 comes next.
- You can apply this multiple times. For instance given some initial tokens t1 t2 t3 t4 t5, guess probabilities of t6.
RL for LLM solves following problem
- Given some initial tokens t1 t2 and some final tokens t5 t6, guess probabilities of which t3 t4 likely follow t1 t2 and likely lead to t5 t6.
- As an intuition, you can look at a sequence of tokens as a state, and every next token as a state transition. Given any initial state you can in theory compute a tree of possible state tranisition sequences from there. RL for LLM is basically trying to efficiently search this tree to get from some initial state to some final state.
Major doubt I have:
- Can we train an LLM but with the entire dataset in reverse? Won't that help solve the problem? Typically when solving a maze a human searches forward from the entrances and backward from the exits.
- Reverse LLM will solve following problem: given final tokens t5 t6, guess probabilities of which t4 precedes it. And then apply this multiple times, so you can get say given t2 t3 t4 t5 t6, probabilities of t1.
In order to do RL for LLM, you have to train the following:
- LLM.
- Given some initial tokens t1 t2 t3 t4, LLM guesses probabilities of which token t5 comes next.
- Value network.
- Assume we have a trained LLM.
- Assume some final tokens t5 t6 as hard-coded constant.
- Given some initial tokens t1 t2 t3 t4, value network guesses probability that t1 t2 t3 t4 leads to t5 t6.
- We can get accurate answers if we just exhaustively enumerate every possible t1 t2 t3 t4 and use the query the LLM on every possible such set. However this requires too much compute so the value network makes its own guesses using some simpler algo.
- Typically value network doesn't have just one set of final tokens t5 t6 hard-coded as high value, but thousands or even millions of such final token sets. Then we guess probability that a given t1 t2 t3 t4 leads to any one of the millions of hard-coded final token t5 t6
- Policy network.
- Assume we have a trained LLM and a trained policy network (with some hard-coded final token sets t5 t6).
- Given t1 t2, guess probabilities of which t3 t4 lead to any one of the final token sets t5 t6. During training of policy network, we can query value network.
- If you were a human being trying to solve a very big maze you'd probably do something similar. First you'd make (possibly incorrect) guesses of which sections of the maze are closer to atleast one of the exits, and then you'd use these guesses as input and your current position in the maze as input to try to guess where to go next.
How to train these models:
- Not discussed here. There are some generic training algos that can produce value and policy networks. It's possible what works best here is a training algo designed to do RL on LLM rather generic RL training algos.
RL for safety
- Safety assumes that some token sequences are "unsafe" and other are "safe". You can basically hard-code some final token sets t5 t6 as unsafe and then solve problem: given t1 t2, guess probalilities of tokens t3 t4 that likely follow t1 t2 and likely do not prcede any t5 t6 in the unsafe set.