Skip to content

Instantly share code, notes, and snippets.

Last active March 1, 2024 17:08
Show Gist options
  • Save buttercutter/6a5b90a987840831e71f98dcc723633a to your computer and use it in GitHub Desktop.
Save buttercutter/6a5b90a987840831e71f98dcc723633a to your computer and use it in GitHub Desktop.
A simple code for [Nash Learning from Human Feedback](
# [Nash Learning from Human Feedback](
import os
import torch
import torch.nn as nn
import torch.optim as optim
from import DataLoader, Dataset
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 16
debugging_is_on = 0
def print_tensor_info(tensor_name, tensor):
# Check if tensor is floating point, and convert if necessary
tensor_float = tensor.float() if not tensor.is_floating_point() else tensor
# Gather the information
info = {
"shape": tuple(tensor.shape),
"min/max": (tensor.min().item(), tensor.max().item()),
"mean": tensor_float.mean().item(),
"std": tensor_float.std().item()
# Print the default representation and the extra information
print(f"{tensor_name} = {tensor}")
for key, value in info.items():
print(f"{key}: {value}")
# Adjusting eta affects the balance between exploiting the current policy
# and exploring new policies suggested by the reference model or feedback.
eta = 0.5
In reinforcement learning, a current policy and a reference policy represent
different strategies for decision-making.
The current policy (π) is the strategy the agent is currently using to make
decisions. It is typically represented as a probability distribution over
actions given states.
The reference policy (μ), sometimes used as a baseline or target, can be a
previous version of the current policy, a fixed heuristic, or an expert policy.
both policies are instances of the same neural network class, but they could be
separate models or even different types of models depending on the NLHF
application. The current policy is what we actively train, and the reference
policy provides a stable comparison point which could be static or
periodically updated.
class PolicyNetwork(nn.Module):
def __init__(self, input_size, output_size):
super(PolicyNetwork, self).__init__()
# Define network layers
self.layers = nn.Sequential(
nn.Linear(input_size, 256),
nn.Linear(256, output_size),
def forward(self, state):
y = self.layers(state)
print(f"y has shape of {y.shape}")
return y
elif USE_NLP:
from transformers import AutoModel
# Create TinyBert model instance
bert_model = AutoModel.from_pretrained("prajjwal1/bert-tiny").to(device)
print(f"bert_model.config.hidden_size = {bert_model.config.hidden_size}")
class PolicyNetwork(nn.Module):
def __init__(self, state_dim):
super(PolicyNetwork, self).__init__()
self.bert = bert_model
self.final_layer = nn.Linear(self.bert.config.hidden_size, 1)
self.relu = nn.ReLU()
def forward(self, input_ids, attention_mask=None):
# Process text
text_embeddings = self.bert(
).last_hidden_state[:, 0, :]
# Combine and score
score = self.final_layer(text_embeddings)
# for numerical stability
score = self.relu(score) + 1e-6
return score
Implementing a preference model function in Python for a reinforcement
learning context involves comparing the expected rewards or values of
different actions and selecting the one that aligns best with human
Preference_model uses the probabilities assigned by the current policy
model to each action as a measure of the model's confidence. It then
combines this with a human preference score, which could be obtained
from pre-recorded data or real-time feedback, to produce a final
preference score for the action. This is a simplified version, and
in practice, the human preference component might involve more complex
methods like comparing against a database of preferred actions, or
using a learned model to predict human preferences.
Please note that the function action_to_index(action) would need to be
defined according to how actions are represented in Atari environment,
and human_preferences would be a data structure we'd need to define
based on how we're collecting and storing human feedback.
See Section 7.2 inside the NLHF paper for an overview
See also expression (1.1), section 4 and Theorem 1 of
[Transforming and Combining Rewards for Aligning Large Language Models]
def preference_model(state, state_action_a, state_action_b, model,
reference_model, human_preferences):
A model that scores actions based on a combination of model predictions
and human preferences.
:param state: The current state from the environment.
:param action: The action taken by the policy.
:param model: The current policy model.
:param reference_model: The reference policy model.
:param human_preferences: A dictionary mapping state-action pairs to
human preference scores.
:return: A preference score for the action.
# Use float32 since the models internal compute ops are floating-point
state = state.float()
# Get the current policy's probability distribution for actions
current_policy_probs_a = model(state_action_a)
current_policy_probs_b = model(state_action_b)
# Get the reference policy's probability distribution for actions
reference_policy_probs_a = reference_model(state_action_a)
reference_policy_probs_b = reference_model(state_action_b)
# Calculate model confidences using both current and reference models
# log(sigmoid(delta_reward)) is better, since delta_reward is clamped
# to [0.0-1.0] with sigmoid
# and log(x) makes x more stable if we want to train over it
model_confidence_a = torch.log(torch.sigmoid(current_policy_probs_a -
model_confidence_b = torch.log(torch.sigmoid(current_policy_probs_b -
# model_confidence_a = current_policy_probs_a - reference_policy_probs_a
# model_confidence_b = current_policy_probs_b - reference_policy_probs_b
# Calculate the preference score by combining model confidence and
# human preference
preference_score_a = model_confidence_a * human_preferences
preference_score_b = model_confidence_b * human_preferences
# Compare and return the preferred action's score
if preference_score_a > preference_score_b:
preference_score = preference_score_a
preference_score = preference_score_b
# Subtract the baseline (average preference score) for variance reduction
# to reduce the variance of the policy gradient estimate, which can help
# stabilize training
baseline = preference_score.mean()
return preference_score - baseline
# Dataset selection
import agc.dataset as ds
import agc.util as util
# DATA_DIR is the directory, which contains the 'trajectories' and
# 'screens' folders
dataset = ds.AtariDataset(DATA_DIR)
# dataset.trajectories returns the dictionary with all the trajs from
# the Atari dataset
all_trajectories = dataset.trajectories
from IPython.display import display, HTML
from prettytable import PrettyTable
titles = [' '] + [util.TITLES[g] for g in util.GAMES]
table = PrettyTable(titles)
table.align[''] = "l"
table.align[''] = "l"
row = ['episodes']
for g in util.GAMES:
row = ['frames']
for g in util.GAMES:
row = ['hours of gameplay']
for g in util.GAMES:
hrs = float(dataset.stats[g]['total_frames']//60//60/60)
row.append('%.2f' % (hrs,))
row = ['worst score']
for g in util.GAMES:
row = ['best score']
for g in util.GAMES:
row = ['average score']
for g in util.GAMES:
row.append("%.0f" % dataset.stats[g]['avg_score'])
row = ['score SEM']
for g in util.GAMES:
row.append("%.0f" % dataset.stats[g]['sem'])
row = ['score stddev']
for g in util.GAMES:
row.append("%.0f" % dataset.stats[g]['stddev'])
# We are using the Atari Pong game environment
import gym
# print(f"list of all gym environments = {gym.envs.registry.keys()}")
env = gym.make('Pong-v4')
state = env.reset() # Reset the environment to get the initial state
# Assuming the first element of the tuple is the screen image we want
screen = state[0] if isinstance(state, tuple) else state
# Add batch dimension
state_tensor = torch.tensor(screen, dtype=torch.float32).unsqueeze(0)
state_size = 33600
elif USE_NLP:
from datasets import load_dataset
from transformers import AutoTokenizer
# Load all the data
# dataset = load_dataset("stanfordnlp/shp")
# Load one of the subreddits
dataset = load_dataset("stanfordnlp/shp", data_dir="explainlikeimfive")
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
def preprocess_shp_entry(entry):
state = tokenizer(entry['history'],
state_action_a = tokenizer(entry['history'] + entry['human_ref_A'],
state_action_b = tokenizer(entry['history'] + entry['human_ref_B'],
# Indicates preference between action_a and action_b
preference = entry['labels']
return state, state_action_a, state_action_b, preference
encoded_inputs_file = ''
if os.path.exists(encoded_inputs_file):
print("Loading pre-tokenized data...")
encoded_inputs = torch.load(encoded_inputs_file)
# Process data
print("Tokenizing data now ...")
encoded_inputs = [preprocess_shp_entry(entry)
for entry in dataset['train']], encoded_inputs_file)
print("Finished tokenizing data !!!")
for item in encoded_inputs:
state, state_action_a, state_action_b, preference = item
state_tensor = state['input_ids']
state_size = state_tensor.size()[-1]
print(f"state has type {type(state)} and length of {len(state)}")
print(f"state_tensor has shape of {state_tensor.size()}")
action_size = 64
# Initialize current policy π to obtain an action from the current policy
current_policy = PolicyNetwork(input_size=state_size,
# Initialize reference policy μ (could be a previous checkpoint
# of the current policy)
reference_policy = PolicyNetwork(input_size=state_size,
# reference_policy.load_state_dict(torch.load('path_to_checkpoint'))
elif USE_NLP:
# Initialize current policy π to obtain an action from the current policy
current_policy = PolicyNetwork(state_dim=state_size)
# Initialize reference policy μ (could be a previous checkpoint
# of the current policy)
reference_policy = PolicyNetwork(state_dim=state_size)
# reference_policy.load_state_dict(torch.load('path_to_checkpoint'))
# Set reference policy to evaluation mode if it's not being trained
# Extracting token IDs for state, action_a and action_b
state_ids = state['input_ids']
state_action_a_ids = state_action_a['input_ids']
state_action_b_ids = state_action_b['input_ids']
print("state_ids shape:", state_ids.shape)
print("state_action_a_ids shape:", state_action_a_ids.shape)
print("state_action_b_ids shape:", state_action_b_ids.shape)
# Assuming we have a current policy model, a reference model, and
# a human preferences dictionary
preference_score = preference_model(
class SHPDataset(Dataset):
def __init__(self, data): = data
def __len__(self):
return len(['preference'])
def __getitem__(self, idx):
item = {key: val[idx] for key, val in}
return item
# Combine into a single dictionary
data = {
'state': state_ids,
'state_action_a': state_action_a_ids,
'state_action_b': state_action_b_ids,
'preference': torch.tensor(preference).unsqueeze(0)
# Split the data into train and validation sets
total_size = len(dataset['train']['labels'])
train_size = int(total_size * 0.8)
print(f"total_size = {total_size}")
train_data = {key: val[:train_size] for key, val in data.items()}
val_data = {key: val[train_size:] for key, val in data.items()}
train_dataset = SHPDataset(train_data)
val_dataset = SHPDataset(val_data)
# Create a DataLoader for batch processing
# Now we can use data_loader in the training loop
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Define the loss function and optimizer
# Credit :
# Copyright 2023 Google Research. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""PyTorch implementation of the Lion optimizer."""
# import torch
from torch.optim.optimizer import Optimizer
class Lion(Optimizer):
r"""Implements Lion algorithm."""
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
"""Initialize the hyperparameters.
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-4)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.99))
weight_decay (float, optional): weight decay coefficient (default: 0)
if not 0.0 <= lr:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
the loss.
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
# Perform stepweight decay - group['lr'] * group['weight_decay'])
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p)
exp_avg = state['exp_avg']
beta1, beta2 = group['betas']
# Weight update
update = exp_avg * beta1 + grad * (1 - beta1)
p.add_(update.sign_(), alpha=-group['lr'])
# Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
return loss
class AdamW_on_Lion_Optimizer(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, adam_betas=(0.9, 0.999),
lion_betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, adam_betas=adam_betas,
lion_betas=lion_betas, eps=eps,
super(AdamW_on_Lion_Optimizer, self).__init__(params, defaults)
# Diagnostic code to check parameters
if not list(params):
print("No parameters in params.")
for idx, param in enumerate(params):
print(f"Param {idx}: requires_grad={param.requires_grad}")
if param.requires_grad:
print("There are parameters that require gradients.")
print("No parameters require gradients.")
# Define the Adam and Lion optimizers
self.adamW = optim.AdamW(params=params, lr=lr, betas=adam_betas,
eps=eps, weight_decay=weight_decay)
self.lion = lion(params=params, lr=lr, betas=lion_betas,
def step(self, lr=1e-3, max_iter=25, closure=None):
"""Performs a single optimization step.
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
the loss.
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for i in range(max_iter):
# Apply the Lion and Adam optimizer
lion_step = self.lion.step()
adam_step = self.adamW.step()
# See [Learning Rate Grafting Transferability of Optimizer Tuning]
# (
# Grafting adamW#lion: update direction from lion, update magnitude from adamW
step = np.linalg.norm(adamW_step) * lion_step / np.linalg.norm(lion_step)
return loss
optimizer_current_policy = AdamW_on_Lion_Optimizer(
optimizer_reference_policy = AdamW_on_Lion_Optimizer(
optimizer_current_policy = optim.AdamW(current_policy.parameters(), lr=1e-3)
optimizer_reference_policy = optim.AdamW(reference_policy.parameters(),
# Training loop
num_epochs = 25 # Number of epochs to train for
for epoch in tqdm(range(num_epochs)): # loop over the dataset multiple times
train_loader is an iterable of states that we're training on.
action_space is the set of all possible actions.
human_preferences is a dictionary or function that provides
human preference scores.
learning_rate is the learning rate η.
# Assuming these are defined:
# current_policy: current policy network
# reference_policy: reference policy network
# model: the model being trained
# state: current state from the environment
total_loss = 0 # this variable is only valid for one epoch
for batch in train_loader:
state = batch['state'].clone().to(device)
state_action_a = batch['state_action_a'].clone().to(device)
state_action_b = batch['state_action_b'].clone().to(device)
human_preferences = batch['preference'].clone().to(device)
# Initialize a dictionary to store updated policies for each action
updated_policies = {}
state_action_space = [state_action_a, state_action_b]
# Calculate preference score and perform Nash-MD update
for state_action in state_action_space:
# Get the action probabilities from the current policy
current_policy_prob = current_policy(state_action).detach()
reference_policy_prob = reference_policy(state_action).detach()
# print(f"current_policy_prob = {current_policy_prob}")
# print(f"reference_policy_prob = {reference_policy_prob}")
# Calculate the preference score
preference_score = preference_model(
# Perform Nash-MD update
updated_policy_prob = \
current_policy_prob**(1 - eta) * \
reference_policy_prob**eta * \
torch.exp(eta * preference_score)
# Store the updated policy probability for the action
updated_policies[state_action] = updated_policy_prob
See section 7 of the paper
In equation (5), the normalization constant c is indeed important.
It ensures that after updating the policy using the Nash-MD algorithm,
the updated policy π_t+1 is still a valid probability distribution.
The constant c is determined after the update so that the
sum of probabilities across all possible actions y equals 1.
# Normalize the updated policies
normalization_constant = sum(updated_policies.values())
updated_policies_normalized = \
{action: prob / normalization_constant
for action, prob in updated_policies.items()}
Theorem 1 in the paper is related to the convergence properties of the
Nash-MD algorithm. It states that if we have a Nash equilibrium π*
for the regularized preference model,
the KL divergence between π* and the policy obtained at each iteration
of the Nash-MD algorithm (π_t+1) is non-increasing and converges at a
rate proportional to 1/sqrt(T), where T is the number of iterations.
The convergence rate is affected by the choice of the learning rate η,
which is suggested to be set as log(T)/T. This rate is significant
because it dictates how quickly the policies converge to the Nash
equilibrium in terms of KL divergence, a measure of the difference
between two probability distributions.
The theorem is crucial for understanding the Nash-MD loss function
because it provides the theoretical foundation that guarantees the
algorithm's policies will converge to a Nash equilibrium. The loss
function used in the Nash-MD algorithm is designed to both maximize
alignment with human preferences (as captured by the preference model)
and minimize the KL divergence from the previous policy, which ensures
that the updated policy does not diverge too rapidly from the previous
This careful balance allows sequence of policies to improve steadily
while maintaining a trajectory toward the Nash equilibrium.
In equation (4), the KL divergence serves as a regularization term
within the arg max operation, but once we solve for π_t+1 optimization
problem, its effects are embedded in the form of the solution
in equation (5) and don't need to be listed separately.
# Calculate the KL divergence part of the Nash-MD objective
KL_divergence = torch.distributions.kl_divergence(
# Calculate the loss for backpropagation
loss = -torch.sum(torch.stack(
for state_action in state_action_space]))
# Perform backpropagation
# Clip gradients: gradients are modified in place
max_grad_norm = 10.0
for model in [current_policy, reference_policy]:
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
for name, param in model.named_parameters():
if 'out_proj.bias' not in name:
# clip weights but not bias for out_proj
if debugging_is_on:
print("DEBUGGING IS ON !!!")
print_tensor_info("normalization_constant", normalization_constant)
for model in [current_policy, reference_policy]:
for name, parameter in model.named_parameters():
if parameter.grad is not None:
print(f"{name} gradient: \
print(f"{name} has no gradient")
total_loss += loss.item()
train_loss = total_loss / len(train_loader)
if not train_loss >= 0:
print("non-positive training loss !!!")
debugging_is_on = 1
print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment