Skip to content

Instantly share code, notes, and snippets.

Reinforcement Learning for Language Models

Yoav Goldberg, April 2023.

Why RL?

With the release of the ChatGPT model and followup large language models (LLMs), there was a lot of discussion of the importance of "RLHF training", that is, "reinforcement learning from human feedback". I was puzzled for a while as to why RL (Reinforcement Learning) is better than learning from demonstrations (a.k.a supervised learning) for training language models. Shouldn't learning from demonstrations (or, in language model terminology "instruction fine tuning", learning to immitate human written answers) be sufficient? I came up with a theoretical argument that was somewhat convincing. But I came to realize there is an additional argumment which not only supports the case of RL training, but also requires it, in particular for models like ChatGPT. This additional argument is spelled out in (the first half of) a talk by John Schulman from OpenAI. This post pretty much

@fgolemo
fgolemo / 3d_rotate_reproject.py
Last active July 11, 2023 18:02
3D rotation and reprojection in pytorch, i.e. differentiable
import numpy as np
import math
import torch
from torchvision import datasets
import cv2 # OpenCV, this is only used for visualization, see bottom of file
def rotation_matrix(axis, theta):
"""
Generalized 3d rotation via Euler-Rodriguez formula, https://www.wikiwand.com/en/Euler%E2%80%93Rodrigues_formula
@thomwolf
thomwolf / gpt-2-wikitext-103.py
Last active April 16, 2024 19:27
A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103
# Copyright (c) 2019-present, Thomas Wolf.
# All rights reserved. This source code is licensed under the MIT-style license.
""" A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103 """
import os
from collections import namedtuple
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ignite.engine import Engine, Events
@thomwolf
thomwolf / top-k-top-p.py
Last active May 14, 2024 00:20
Sample the next token from a probability distribution using top-k and/or nucleus (top-p) sampling
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k >0: keep only top k tokens with highest probability (top-k filtering).
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
import bpy
import bmesh
import numpy as np
import time
import os
import colorsys
import random
import yaml
import sys
from graphviz import Digraph
import torch
from torch.autograd import Variable, Function
def iter_graph(root, callback):
queue = [root]
seen = set()
while queue:
fn = queue.pop()
if fn in seen:
@schickling
schickling / _README.md
Last active January 4, 2024 09:37
Script to import and export docker-machine configurations to sync between hosts/collaborators

docker-machine import/export

Script to import and export docker-machine configurations to sync between hosts/collaborators

Export (on host A)

$ docker-machine ls
NAME       ACTIVE   DRIVER         STATE     URL                            SWARM   DOCKER    ERRORS
dev        -        digitalocean   Running   tcp://example.com:2376                 v1.10.1
@alexproca
alexproca / docker-machine-rename
Last active November 23, 2020 17:15
Rename docker-machine
#!/usr/bin/env bash
#copy this in a folder from path ex: /usr/local/bin
#usage: docker-machine-rename default my-default
# Authors
#
# alexproca initial script
# eurythmia sed magic
@xflr6
xflr6 / bench_pg_array.py
Last active June 4, 2022 08:29
Benchmark PostgreSQL array vs. join performance
"""Benchmark PostgreSQL array vs. join performance.
Replicate http://shon.github.io/2015/12/21/postgres_array_performance.html
with proper join table indexes (uniqueness constraints) using SQLAlchemy.
$ python -i bench_pg_array.py
>>> setup()
$ python -m timeit -s "import bench_pg_array" "bench_pg_array.test_join()"
500 loops, best of 5: 445 usec per loop
@karpathy
karpathy / min-char-rnn.py
Last active May 16, 2024 19:54
Minimal character-level language model with a Vanilla Recurrent Neural Network, in Python/numpy
"""
Minimal character-level Vanilla RNN model. Written by Andrej Karpathy (@karpathy)
BSD License
"""
import numpy as np
# data I/O
data = open('input.txt', 'r').read() # should be simple plain text file
chars = list(set(data))
data_size, vocab_size = len(data), len(chars)