Created
February 6, 2024 16:23
-
-
Save vwxyzjn/277edb651b3ae581129474ab4a4b47b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
import time | |
from collections import defaultdict | |
from dataclasses import asdict, dataclass, field | |
from types import SimpleNamespace | |
from typing import List, Literal, Optional | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import tyro | |
from accelerate import Accelerator | |
from accelerate.utils import gather_object, broadcast | |
from datasets import load_dataset, Dataset | |
from rich.console import Console | |
from rich.pretty import pprint | |
from rich.table import Table | |
from torch import optim | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
from transformers import ( | |
AutoConfig, | |
AutoModel, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
GenerationConfig, | |
PretrainedConfig, | |
PreTrainedModel, | |
get_scheduler, | |
) | |
from huggingface_hub import HfApi | |
api = HfApi() | |
@dataclass | |
class Args: | |
# common args | |
exp_name: str = os.path.basename(__file__)[: -len(".py")] | |
"""the name of this experiment""" | |
seed: int = 1 | |
"""seed of the experiment""" | |
cuda: bool = True | |
"""Whether to use cuda if available.""" | |
run_name: Optional[str] = None | |
"""a unique name of this run""" | |
load_from_cache_file: bool = False | |
"""Whether to load data from the local cache file in `dataset.map`""" | |
deepspeed: bool = False | |
"""Whether to use deepspeed to train the model""" | |
print_sample_output_freq: int = 220 | |
"""How often to print sample output""" | |
run_eval: bool = False | |
"""Whether to run evaluation""" | |
# optimizer args | |
eps: float = 1e-5 | |
"""the epsilon value for the optimizer""" | |
lr: float = 3e-6 | |
"""the learning rate""" | |
optimizer: Literal["adam", "adamw"] = "adamw" | |
"""Which optimizer to use""" | |
scheduler: str = "cosine" | |
"""Which scheduler to use""" | |
warm_up_steps: int = 0 | |
"""Number of warm up steps for the scheduler""" | |
# various batch sizes | |
world_size: Optional[int] = None | |
"""The number of processes (GPUs) to use""" | |
num_train_epochs: int = 1 | |
"""Number of epochs to train""" | |
num_updates: Optional[int] = None | |
"""The number of updates to train""" | |
gradient_accumulation_steps: int = 64 | |
"""The number of gradient accumulation steps""" | |
local_micro_batch_size: Optional[int] = 1 | |
"""The micro batch size per GPU (HF's `per_device_train_batch_size`)""" | |
total_episodes: Optional[int] = 1000000 | |
"""The total number of episodes in the dataset""" | |
micro_batch_size: Optional[int] = None | |
"""The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" | |
local_batch_size: Optional[int] = None | |
"""The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" | |
batch_size: Optional[int] = None | |
"""The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" | |
local_eval_batch_size: int = 1 | |
"""per rank eval batch size""" | |
# other args | |
base_model: str = "EleutherAI/pythia-160m" | |
"""the name of the pretrained model to use""" | |
query_dataset: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144" | |
"""the query dataset""" | |
response_length: int = 53 | |
"""the length of the response""" | |
truncate_token: Literal["eos"] = "eos" | |
"""the truncate token""" | |
truncate_token_id: Optional[int] = None | |
"""the truncation token id""" | |
temperature: float = 0.7 | |
"""the sampling temperature""" | |
ipo: bool = False | |
"""Whether to use IPO loss https://arxiv.org/abs/2310.12036""" | |
label_smoothing: float = 0.0 | |
"""Label smoothing for DPO (Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf))""" | |
beta: float = 0.05 | |
"""The beta value for DPO""" | |
penalty_reward_value: int = -1 | |
"""the reward value for responses that do not contain `truncate_token_id`""" | |
reward_model_path: str = "" | |
"""the path to the reward model""" | |
sft_model_path: str = "EleutherAI/pythia-160m" | |
"""the path to the sft model""" | |
local_rollout_forward_batch_size: int = 4 | |
"""per rank no grad forward pass in the rollout phase""" | |
# wandb and HF tracking configs | |
track: bool = False | |
"""if toggled, this experiment will be tracked with Weights and Biases""" | |
wandb_project_name: str = "tldr_summarize" | |
"""the wandb's project name""" | |
wandb_entity: Optional[str] = None | |
"""the entity (team) of wandb's project""" | |
push_to_hub: bool = False | |
"""whether to upload the saved model to huggingface""" | |
hf_entity: Optional[str] = None | |
"""the user or org name of the model repository from the Hugging Face Hub""" | |
hf_repo_id: Optional[str] = None | |
"""the id of the saved model in the Hugging Face Hub (can be autoset if not given)""" | |
hf_repo_revision: Optional[str] = None | |
"""the revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" | |
hf_repo_url: Optional[str] = None | |
"""the url of the saved model in the Hugging Face Hub (will be autoset)""" | |
hf_dataset_repo_url: Optional[str] = None | |
"""the url of the dataset in the Hugging Face Hub""" | |
output_dir: str = "models/dpo_onpolicy_model" | |
"""Where to save the model""" | |
def parse_args() -> tuple[Args, Accelerator]: | |
args = tyro.cli(Args) | |
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) | |
args.world_size = accelerator.num_processes | |
args.local_batch_size = args.local_micro_batch_size * args.gradient_accumulation_steps | |
args.micro_batch_size = int(args.local_micro_batch_size * args.world_size) | |
args.batch_size = int(args.local_batch_size * args.world_size) | |
time_tensor = torch.tensor(int(time.time()), device=accelerator.device) | |
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes | |
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" | |
if args.push_to_hub: | |
if args.hf_repo_id is None: # auto-generate one | |
args.hf_repo_id = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" | |
if args.hf_entity is None: # find the current user | |
args.hf_entity = api.whoami()["name"] | |
if "/" not in args.hf_repo_id: # prepend the current user | |
args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" | |
if args.hf_repo_revision is None: # auto-generate one | |
args.hf_repo_revision = args.run_name | |
args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" | |
args.hf_dataset_repo_url = f"https://huggingface.co/datasets/{args.hf_repo_id}/viewer/default/{args.hf_repo_revision}" | |
return args, accelerator | |
# taken from https://github.com/vwxyzjn/direct-preference-optimization/blob/f8b8c0f49dc92a430bae41585f9d467d3618fe2f/utils.py#L99 | |
def disable_dropout(model: torch.nn.Module): | |
"""Disable dropout in a model.""" | |
for module in model.modules(): | |
if isinstance(module, torch.nn.Dropout): | |
module.p = 0 | |
def layer_init(layer, std=np.sqrt(2), bias_const=0.0): | |
torch.nn.init.normal_(layer.weight, std=std) | |
torch.nn.init.constant_(layer.bias, val=bias_const) | |
return layer | |
class ScalarModelConfig(PretrainedConfig): | |
def __init__( | |
self, | |
base_model: str = "EleutherAI/pythia-160m", | |
base_config: PretrainedConfig = AutoConfig.from_pretrained("EleutherAI/pythia-160m"), | |
hidden_size: int = 768, | |
bias: float = 0.0, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.base_model = base_model | |
self.base_config = base_config | |
self.hidden_size = hidden_size | |
self.bias = bias | |
class ScalarModel(PreTrainedModel): | |
config_class = ScalarModelConfig | |
def __init__(self, config: ScalarModelConfig): | |
super().__init__(config) | |
self.config = config | |
self.lm_backbone = AutoModel.from_pretrained( | |
config.base_model, | |
config=self.config.base_config, | |
trust_remote_code=True, | |
) | |
self.scalar_head = layer_init( | |
nn.Linear(self.config.hidden_size, 1), | |
std=1 / np.sqrt(self.config.hidden_size + 1), | |
) | |
def forward(self, **kwargs): | |
output = self.lm_backbone(**kwargs) | |
reward = self.scalar_head(output.hidden_states[-1]) - self.config.bias | |
return reward | |
def get_reward(model, query_responses, tokenizer, context_length): | |
attention_mask = query_responses != tokenizer.pad_token_id | |
# position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum | |
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) | |
reward_logits = model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
# position_ids=position_ids, | |
return_dict=True, | |
output_hidden_states=True, | |
) | |
sequence_lengths = first_true_indices(query_responses[:, context_length:] == tokenizer.pad_token_id) - 1 + context_length | |
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 | |
return ( | |
reward_logits, | |
reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), | |
sequence_lengths, | |
) | |
def forward(model, query_responses, labels, tokenizer): | |
attention_mask = query_responses != tokenizer.pad_token_id | |
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) | |
output = model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
return_dict=True, | |
) | |
labels = labels[:, 1:].clone() | |
logits = output.logits[:, :-1, :] | |
loss_mask = (labels != tokenizer.pad_token_id) | |
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) | |
all_logps = (per_token_logps * loss_mask).sum(-1) | |
chosen_logps = all_logps[:query_responses.shape[0] // 2] | |
rejected_logps = all_logps[query_responses.shape[0] // 2:] | |
return chosen_logps, rejected_logps | |
def generate(lm_backbone, queries, tokenizer, generation_config): | |
"""generate in a way that does not affect padding tokens""" | |
context_length = queries.shape[1] | |
attention_mask = queries != tokenizer.pad_token_id | |
input_ids = torch.masked_fill(queries, ~attention_mask, 0) | |
output = lm_backbone.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
# position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
) | |
return torch.cat((queries, output.sequences[:, context_length:]), dim=1) | |
def first_true_indices(bools, dtype=torch.long): | |
""" | |
Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving | |
the position of the first True in each "row". | |
Returns the length of the rows (bools.size(-1)) if no element is True in a given row. | |
""" | |
row_len = bools.size(-1) | |
zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) | |
return torch.min(zero_or_index, dim=-1).values | |
def truncate_response(args, tokenizer, responses): | |
trunc_idxs = first_true_indices(responses == args.truncate_token_id).unsqueeze(-1) | |
new_size = [1] * (len(responses.size()) - 1) + [args.response_length] | |
idxs = torch.arange(args.response_length, device=responses.device).view(*new_size) | |
postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, tokenizer.pad_token_id) | |
return postprocessed_responses | |
def evaluate_rm(args: Args, accelerator, tokenizer, model, ref_model, dataloader): | |
model.eval() | |
with torch.no_grad(): | |
items = defaultdict(list) | |
for data in tqdm(dataloader): | |
query_responses = torch.cat((data["query_chosen_token"], data["query_rejected_token"]), dim=0) | |
labels = torch.cat((data["query_chosen_token_response_label"], data["query_rejected_token_response_label"]), dim=0) | |
ref_chosen_logps, ref_rejected_logps = forward(ref_model, query_responses, labels, tokenizer) | |
chosen_logps, rejected_logps = forward(model, query_responses, labels, tokenizer) | |
reward_preferred = args.beta * (chosen_logps - ref_chosen_logps) | |
reward_rejected = args.beta * (rejected_logps - ref_rejected_logps) | |
accuracy = reward_preferred > reward_rejected | |
print(accuracy.float().mean()) | |
for k in data: | |
data[k] = gather_object(data[k]) | |
for i in range(len(accuracy)): | |
items["query"].append(tokenizer.decode(data["query_token"][i], skip_special_tokens=True)) | |
items["response0"].append(tokenizer.decode(data["response0_token"][i])) | |
items["response1"].append(tokenizer.decode(data["response1_token"][i])) | |
items["batch"].append(data["batch"][i]) | |
items["split"].append(data["split"][i]) | |
items["confidence"].append(data["extra.confidence"][i].item()) | |
items["choice"].append(data["choice"][i].item()) | |
items["policies"].append(data["policies"][i]) | |
items["response0_policy"].append(data["response0_policy"][i]) | |
items["response1_policy"].append(data["response1_policy"][i]) | |
items["accuracy"].append(accuracy[i].item()) | |
model.train() | |
return pd.DataFrame(items) | |
@dataclass | |
class EvalStorage: | |
query_token: List[str] = field(default_factory=list) | |
postprocessed_response_token: List[str] = field(default_factory=list) | |
reference_response_token: List[str] = field(default_factory=list) | |
score: List[float] = field(default_factory=list) | |
reference_score: List[float] = field(default_factory=list) | |
query: List[str] = field(default_factory=list) | |
postprocessed_response: List[str] = field(default_factory=list) | |
reference_response: List[str] = field(default_factory=list) | |
def evaluate_policy(args: Args, model, tokenizer, dataloader, generation_config, sampling=True): | |
eval_storage = EvalStorage() | |
with torch.no_grad(): | |
for data in tqdm(dataloader): | |
queries = data["query_token"] | |
reference_response_token = data["reference_response_token"] | |
context_length = queries.shape[1] | |
query_responses = generate( | |
model, | |
queries, | |
tokenizer, | |
generation_config, | |
) | |
responses = query_responses[:, context_length:] | |
postprocessed_responses = truncate_response(args, tokenizer, responses) | |
eval_storage.query_token.extend(queries) | |
eval_storage.reference_response_token.extend(reference_response_token) | |
eval_storage.postprocessed_response_token.extend(postprocessed_responses) | |
if sampling: | |
break | |
eval_storage.query = tokenizer.batch_decode(eval_storage.query_token, skip_special_tokens=True) | |
eval_storage.reference_response = tokenizer.batch_decode(eval_storage.reference_response_token) | |
eval_storage.postprocessed_response = tokenizer.batch_decode( | |
eval_storage.postprocessed_response_token, skip_special_tokens=True | |
) | |
# eval_score = torch.cat(eval_storage.score).float().cpu().numpy().tolist() | |
# eval_reference_score = torch.cat(eval_storage.reference_score).float().cpu().numpy().tolist() | |
eval_df = pd.DataFrame( | |
{ | |
"query": gather_object(eval_storage.query), | |
"postprocessed_response": gather_object(eval_storage.postprocessed_response), | |
"reference_responses": gather_object(eval_storage.reference_response), | |
# "scores": gather_object(eval_score), | |
# "reference_scores": gather_object(eval_reference_score), | |
} | |
) | |
return eval_storage, eval_df | |
if __name__ == "__main__": | |
args, accelerator = parse_args() | |
local_seed = args.seed + accelerator.process_index * 100003 # Prime | |
# load dataset | |
dataset = load_dataset(args.query_dataset, split="train") | |
dataset = dataset.with_format("torch", columns=["query_token", "reference_response_token"]) | |
dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size, shuffle=True) | |
validation_dataset = load_dataset(args.query_dataset, split="validation") | |
eval_dataloaders = {} | |
for split in ["validation", "test"]: | |
eval_dataset = load_dataset(args.query_dataset, split=split) | |
eval_dataset = eval_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) | |
eval_dataloaders[split] = DataLoader(eval_dataset, batch_size=args.local_eval_batch_size) | |
# dataset = load_dataset(args.label_dataset, split="train") | |
# dataset = dataset.shuffle(seed=local_seed) | |
# dataset = dataset.select(range(args.total_episodes)) | |
# dataset = dataset.with_format( | |
# "torch", | |
# columns=[ | |
# "query_token", | |
# "chosen_token", | |
# "query_chosen_token", | |
# "query_chosen_token_response_label", | |
# "rejected_token", | |
# "query_rejected_token", | |
# "query_rejected_token_response_label", | |
# "batch", | |
# "split", | |
# ], | |
# ) | |
# dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) | |
# eval_datasets = [] | |
# eval_dataloaders = {} | |
# for split in ["validation", "validation_cnndm"]: | |
# validation_dataset = load_dataset(args.label_dataset, split=split).flatten() | |
# validation_dataset = validation_dataset.with_format( | |
# "torch", | |
# columns=[ | |
# "query_token", | |
# "choice", | |
# "chosen_token", | |
# "query_chosen_token", | |
# "rejected_token", | |
# "query_rejected_token", | |
# "batch", | |
# "split", | |
# "extra.confidence", | |
# "chosen_policy", | |
# "rejected_policy", | |
# "policies", | |
# ], | |
# ) | |
# eval_datasets.append(validation_dataset) | |
# eval_dataloaders[split] = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) | |
# sft_validation_dataset = load_dataset(args.query_dataset, split="validation") | |
# sft_validation_dataset = sft_validation_dataset.with_format("torch", columns=["query_token", "reference_response_token", "query_reference_response_token_response_label"]) | |
# sft_validation_dataloader = DataLoader(sft_validation_dataset, batch_size=args.local_eval_batch_size) | |
args.num_updates = args.total_episodes // args.batch_size | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.base_model, | |
padding_side="right", | |
trust_remote_code=True, | |
) | |
# we use the padding token manually but do not resize the token embedding of the model | |
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
if args.truncate_token == "eos": | |
args.truncate_token_id = tokenizer.eos_token_id | |
console = Console(force_terminal=True) | |
writer = SimpleNamespace() # dummy writer | |
writer.add_scalar = lambda x, y, z: None | |
if accelerator.is_main_process: | |
if args.track: | |
import wandb | |
wandb.init( | |
project=args.wandb_project_name, | |
entity=args.wandb_entity, | |
sync_tensorboard=True, | |
config=asdict(args), | |
name=args.run_name, | |
save_code=True, | |
) | |
file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"] | |
wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions])) | |
writer = SummaryWriter(f"runs/{args.run_name}") | |
writer.add_text( | |
"hyperparameters", | |
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | |
) | |
pprint(args) | |
device = accelerator.device | |
random.seed(local_seed) | |
np.random.seed(local_seed) | |
torch.manual_seed(local_seed) | |
torch.backends.cudnn.deterministic = True | |
model_config = AutoConfig.from_pretrained(args.sft_model_path) | |
scalar_model_config = ScalarModelConfig( | |
base_model=args.base_model, | |
base_config=model_config, | |
hidden_size=model_config.hidden_size, | |
) | |
if not args.reward_model_path: | |
reward_model: PreTrainedModel = ScalarModel(scalar_model_config) | |
else: | |
reward_model: PreTrainedModel = ScalarModel.from_pretrained( | |
args.reward_model_path, | |
trust_remote_code=True, | |
) | |
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( | |
args.sft_model_path, | |
config=model_config, | |
trust_remote_code=True, | |
) | |
ref_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( | |
args.sft_model_path, | |
config=model_config, | |
trust_remote_code=True, | |
) | |
disable_dropout(model) | |
model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to | |
model.generation_config.pad_token_id = None # generate tokens without truncation / padding | |
if args.optimizer == "adam": | |
optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=args.eps) | |
elif args.optimizer == "adamw": | |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, eps=args.eps) | |
scheduler = get_scheduler( | |
args.scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=args.warm_up_steps, | |
num_training_steps=args.num_updates * args.num_train_epochs, | |
) | |
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare` | |
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c | |
torch.manual_seed(args.seed) | |
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) | |
eval_dataloaders = {split: accelerator.prepare(eval_dataloader) for split, eval_dataloader in eval_dataloaders.items()} | |
torch.manual_seed(local_seed) # reset the local seed again | |
ref_model = ref_model.to(device) | |
reward_model = reward_model.to(device) | |
generation_config = GenerationConfig( | |
max_new_tokens=args.response_length, | |
min_new_tokens=args.response_length, | |
temperature=(args.temperature + 1e-7), | |
top_k=0.0, | |
top_p=1.0, | |
do_sample=True, | |
) | |
# use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 | |
validation_generation_config = GenerationConfig( | |
max_new_tokens=args.response_length, | |
min_new_tokens=args.response_length, | |
temperature=(0.01 + 1e-7), | |
top_k=0.0, | |
top_p=1.0, | |
do_sample=True, | |
) | |
accelerator.print("===training model===") | |
losses = torch.zeros((args.gradient_accumulation_steps,), device=device) | |
accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) | |
reward_preferreds = torch.zeros((args.gradient_accumulation_steps,), device=device) | |
reward_rejecteds = torch.zeros((args.gradient_accumulation_steps,), device=device) | |
reward_margins = torch.zeros((args.gradient_accumulation_steps,), device=device) | |
model.train() | |
gradient_accumulation_idx = 0 | |
global_step = 0 | |
update = 0 | |
for epoch in range(args.num_train_epochs): | |
accelerator.print(f"epoch: {epoch}") | |
# generate samples for the SFT dataset and label them using the reward model as the preference | |
annotated_data = defaultdict(list) | |
N_GENERATED_CANDIDATES = 2 | |
with torch.no_grad(): | |
queries = [] | |
for data in tqdm(dataloader): | |
query = data["query_token"].to(device) | |
context_length = query.shape[1] | |
queries.append(query) | |
if len(queries) % args.local_rollout_forward_batch_size == 0: | |
query = torch.cat(queries, dim=0) | |
queries = [] | |
query_tiled = query.unsqueeze(1).repeat(1, N_GENERATED_CANDIDATES, 1).flatten(0, 1) | |
query_response = generate( | |
accelerator.unwrap_model(model), | |
query_tiled, | |
tokenizer, | |
generation_config, | |
) | |
response = query_response[:, context_length:] | |
postprocessed_response = truncate_response(args, tokenizer, response) | |
postprocessed_query_response = torch.cat((query_tiled, postprocessed_response), 1) | |
# label the query-response samples using the reward model | |
_, score, _ = get_reward(reward_model, postprocessed_query_response, tokenizer, context_length) | |
contain_pad_token = torch.any(postprocessed_response == tokenizer.pad_token_id, dim=-1) | |
score = torch.where(contain_pad_token, score, torch.full_like(score, args.penalty_reward_value)) | |
postprocessed_query_response_label = torch.cat((torch.zeros_like(query_tiled, device=device) + tokenizer.pad_token_id, postprocessed_response), 1) | |
score = score.view(-1, N_GENERATED_CANDIDATES) | |
postprocessed_query_response = postprocessed_query_response.view(-1, N_GENERATED_CANDIDATES, postprocessed_query_response.shape[1]) | |
postprocessed_query_response_label = postprocessed_query_response_label.view(-1, N_GENERATED_CANDIDATES, postprocessed_query_response_label.shape[1]) | |
postprocessed_response = postprocessed_response.view(-1, N_GENERATED_CANDIDATES, postprocessed_response.shape[1]) | |
# get chosen and rejected token | |
chosen = gather_object(tokenizer.batch_decode(postprocessed_query_response[torch.arange(args.local_rollout_forward_batch_size), score.argmax(-1)])) | |
rejected = gather_object(tokenizer.batch_decode(postprocessed_query_response[torch.arange(args.local_rollout_forward_batch_size), score.argmin(-1)])) | |
chosen_token = accelerator.gather(postprocessed_query_response[torch.arange(args.local_rollout_forward_batch_size), score.argmax(-1)]) | |
rejected_token = accelerator.gather(postprocessed_query_response[torch.arange(args.local_rollout_forward_batch_size), score.argmin(-1)]) | |
chosen_token_label = accelerator.gather(postprocessed_query_response_label[torch.arange(args.local_rollout_forward_batch_size), score.argmax(-1)]) | |
rejected_token_label = accelerator.gather(postprocessed_query_response_label[torch.arange(args.local_rollout_forward_batch_size), score.argmin(-1)]) | |
annotated_data["chosen"].extend(chosen) | |
annotated_data["rejected"].extend(rejected) | |
annotated_data["chosen_token"].extend(chosen_token.cpu().tolist()) | |
annotated_data["rejected_token"].extend(rejected_token.cpu().tolist()) | |
annotated_data["chosen_token_label"].extend(chosen_token_label.cpu().tolist()) | |
annotated_data["rejected_token_label"].extend(rejected_token_label.cpu().tolist()) | |
ds = Dataset.from_dict(annotated_data) | |
if accelerator.is_main_process and args.push_to_hub: | |
ds.push_to_hub(args.hf_repo_id, split=args.hf_repo_revision) | |
accelerator.print(f"🔥 pushed to {args.hf_dataset_repo_url}") | |
ds = ds.with_format("torch", columns=["chosen_token", "rejected_token", "chosen_token_label", "rejected_token_label"]) | |
on_policy_dataloader = DataLoader(ds, batch_size=args.local_micro_batch_size, shuffle=True) | |
on_policy_dataloader = accelerator.prepare(on_policy_dataloader) | |
for data in on_policy_dataloader: | |
update += 1 | |
global_step += args.micro_batch_size | |
query_responses = torch.cat((data["chosen_token"], data["rejected_token"]), dim=0) | |
labels = torch.cat((data["chosen_token_label"], data["rejected_token_label"]), dim=0) | |
with torch.no_grad(): | |
ref_chosen_logps, ref_rejected_logps = forward(ref_model, query_responses, labels, tokenizer) | |
with accelerator.accumulate(model): | |
chosen_logps, rejected_logps = forward(model, query_responses, labels, tokenizer) | |
pi_logratios = chosen_logps - rejected_logps | |
ref_logratios = ref_chosen_logps - ref_rejected_logps | |
logits = pi_logratios - ref_logratios # also known as h_{\pi_\theta}^{y_w,y_l} | |
if args.ipo: | |
loss = (logits - 1/(2 * args.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf | |
else: | |
loss = -F.logsigmoid(args.beta * logits) * (1 - args.label_smoothing) - F.logsigmoid(-args.beta * logits) * args.label_smoothing | |
accelerator.backward(loss) | |
optimizer.step() | |
optimizer.zero_grad() | |
with torch.no_grad(): | |
reward_preferred = args.beta * (chosen_logps - ref_chosen_logps) | |
reward_rejected = args.beta * (rejected_logps - ref_rejected_logps) | |
losses[gradient_accumulation_idx] = loss | |
accuracies[gradient_accumulation_idx] = (reward_preferred > reward_rejected).float().mean() | |
reward_preferreds[gradient_accumulation_idx] = reward_preferred.mean() | |
reward_rejecteds[gradient_accumulation_idx] = reward_rejected.mean() | |
reward_margins[gradient_accumulation_idx] = (reward_preferred - reward_rejected).mean() | |
gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps | |
if update > 1 and (update - 1) % args.gradient_accumulation_steps == 0: | |
scheduler.step() | |
train_accuracy = accelerator.gather(accuracies).mean().item() | |
writer.add_scalar("train/rm/loss", accelerator.gather(losses).mean().item(), global_step) | |
writer.add_scalar("train/rm/accuracy", train_accuracy, global_step) | |
writer.add_scalar( | |
"train/rm/reward_preferred", accelerator.gather(reward_preferreds).mean().item(), global_step | |
) | |
writer.add_scalar("train/rm/reward_rejected", accelerator.gather(reward_rejecteds).mean().item(), global_step) | |
writer.add_scalar("train/rm/lr", scheduler.get_last_lr()[0], global_step) | |
accelerator.print( | |
f"{train_accuracy=}, {scheduler.get_last_lr()=}, {optimizer.param_groups[0]['lr']=}, {update=}" | |
) | |
if args.run_eval: | |
for eval_split in eval_dataloaders: | |
_, evaluate_df = evaluate_policy( | |
args, | |
accelerator.unwrap_model(model), | |
tokenizer, | |
eval_dataloaders[eval_split], | |
validation_generation_config, | |
sampling=False, | |
) | |
if accelerator.is_main_process: | |
evaluate_df.to_csv(f"runs/{args.run_name}/{eval_split}_table.csv") | |
if args.track: | |
wandb.log({f"eval/{eval_split}_query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) | |
# if accelerator.is_main_process: | |
# evaluate_df.to_csv(f"runs/{args.run_name}/table.csv") | |
# if args.track: | |
# wandb.log({"eval/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) | |
# for eval_split in eval_dataloaders: | |
# evaluate_df = evaluate_rm(args, accelerator, tokenizer, model, ref_model, eval_dataloaders[eval_split]) | |
# for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): | |
# writer.add_scalar(f"eval/rm/{eval_split}/accuracy/split/{split}", row["accuracy"], global_step) | |
# accelerator.print(f"eval/rm/{eval_split}/accuracy/split/{split}: {row['accuracy']}") | |
# for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): | |
# writer.add_scalar(f"eval/rm/{eval_split}/accuracy/batch/{batch}", row["accuracy"], global_step) | |
# accelerator.print(f"eval/rm/{eval_split}/accuracy/batch/{batch}: {row['accuracy']}") | |
# for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): | |
# writer.add_scalar(f"eval/rm/{eval_split}/accuracy/confidence/{confi}", row["accuracy"], global_step) | |
# accelerator.print(f"eval/rm/{eval_split}/accuracy/confidence/{confi}: {row['accuracy']}") | |
# writer.add_scalar(f"eval/rm/{eval_split}/accuracy", evaluate_df["accuracy"].mean(), global_step) | |
# accelerator.print(f"eval/rm/{eval_split}/accuracy: {evaluate_df['accuracy'].mean()}") | |
# if accelerator.is_main_process: | |
# os.makedirs(f"eval_tables/{args.run_name}", exist_ok=True) | |
# evaluate_df.to_csv(f"eval_tables/{args.run_name}/eval_{eval_split}_{update}.csv") | |
# if args.track: | |
# wandb.log({f"samples/{eval_split}/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) | |
# del evaluate_df | |
# torch.cuda.empty_cache() | |
# save model | |
if args.output_dir and args.num_train_epochs > 0: | |
os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) | |
time_tensor = torch.tensor([int(time.time())], device=device) | |
time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes | |
repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" | |
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name | |
if accelerator.is_main_process: | |
tokenizer.save_pretrained(args.output_dir) | |
if args.push_to_hub: | |
tokenizer.push_to_hub(repo_id=args.hf_repo_id, revision=args.hf_repo_revision) | |
unwrapped: PreTrainedModel = accelerator.unwrap_model(model) | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
unwrapped.save_pretrained( | |
args.output_dir, | |
is_main_process=accelerator.is_main_process, | |
save_function=accelerator.save, | |
state_dict=accelerator.get_state_dict(model), | |
safe_serialization=False, | |
) | |
if args.push_to_hub: | |
unwrapped.push_to_hub(repo_id=args.hf_repo_id, revision=args.hf_repo_revision, safe_serialization=False) | |
accelerator.print(f"🔥 pushed to {args.hf_repo_url}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment