Skip to content

Instantly share code, notes, and snippets.

@lapp0
Created April 17, 2024 20:15
Show Gist options
  • Save lapp0/e7d17884ed76669194c36e7fb3f64040 to your computer and use it in GitHub Desktop.
Save lapp0/e7d17884ed76669194c36e7fb3f64040 to your computer and use it in GitHub Desktop.
Online AI Feedback T5
from datasets import Dataset, load_from_disk
from transformers import TrainingArguments
from transformers.trainer_utils import EvalLoopOutput
from unsloth import FastLanguageModel
import random
from huggingface_hub import create_repo
from scipy.spatial.distance import cosine
from sentence_transformers import SentenceTransformer
import statistics
from typing import Dict, Union, Any
import torch
from torch.utils.data import DataLoader
import trl
class DynamicDataLoader:
def __init__(self, base_dataloader, mutate_fn):
self.base_dataloader = base_dataloader
self.mutate_fn = mutate_fn
def __iter__(self):
for batch in self.base_dataloader.__iter__():
yield self.mutate_fn(batch)
def __len__(self):
return len(self.base_dataloader)
class eval_mode:
def __init__(self, model):
self.model = model
def __enter__(self):
self.was_training = self.model.training
if self.was_training:
FastLanguageModel.for_inference(self.model)
self.model.eval()
def __exit__(self, exc_type, exc_val, exc_tb):
if self.was_training:
FastLanguageModel.for_training(self.model)
self.model.train()
class OAIFTrainer(trl.DPOTrainer):
def __init__(
self,
*args,
oaif_annotator,
train_dataset,
eval_oaif_annotator=None,
eval_dataset=None,
**kwargs
):
# insert mock chosen / rejected
train_dataset = self._pre_patch_dataset(train_dataset)
if eval_dataset is not None:
eval_dataset = self._pre_patch_dataset(eval_dataset)
super().__init__(*args, train_dataset=train_dataset, eval_dataset=eval_dataset, **kwargs)
self.oaif_annotator = oaif_annotator
self.eval_oaif_annotator = eval_oaif_annotator if eval_oaif_annotator else oaif_annotator
@staticmethod
def _pre_patch_dataset(ds):
assert "chosen" not in ds
assert "rejected" not in ds
def add_columns(example):
example['chosen'] = ''
example['rejected'] = ''
return example
return ds.map(add_columns)
@staticmethod
def _post_patch_dataset(ds):
nullified_keys = [
"chosen", "chosen_input_ids", "chosen_attention_mask", "chosen_labels",
"rejected", "rejected_input_ids", "rejected_attention_mask", "rejected_labels",
]
for i in range(len(ds)):
for nullified_key in nullified_keys:
ds[nullified_key][i] = None
@staticmethod
def _batch_list(lst, batch_size):
return [
lst[i:i + batch_size]
for i in range(0, len(lst), batch_size)
]
@staticmethod
def rstrip_pad_token(t, pad_token_id):
non_pad_indices = (t != pad_token_id).nonzero(as_tuple=True)[0]
if len(non_pad_indices) == 0:
return t
last_non_pad_index = non_pad_indices[-1].item()
return t[:last_non_pad_index + 1]
def oaif_label_rewards(self, batch):
"""
Use model to annotate rewards, then use self.oaif_annotator to label rejects and chosen
"""
# parameters from paper
# https://www.semanticscholar.org/reader/04d64be16fb402f28348faffef484bd419c8bd8f
#temperature = 0.7
num_return_sequences = 4
top_p = 0.9
# deviates from paper
temperature = 0.9
batch_size = 4
response_groups = []
with eval_mode(model):
for queries in self._batch_list(batch["prompt"], batch_size):
masked_inputs = self.tokenizer(queries, padding=True, return_tensors="pt").to("cuda")
generation = self.model.generate(
**masked_inputs,
num_return_sequences=num_return_sequences,
pad_token_id=self.tokenizer.pad_token_id,
# TODO: implement use of self.oaif_generation_kwargs for below args
max_new_tokens=128,
do_sample=True,
top_p=top_p,
temperature=temperature,
)
response_group_chunk = generation[:, masked_inputs.input_ids.shape[1]:]
response_group_chunk = response_group_chunk.view(
-1,
num_return_sequences,
response_group_chunk.shape[1]
)
for responses in list(response_group_chunk):
response_groups.append(tuple([
self.rstrip_pad_token(resp, self.tokenizer.pad_token_id)
for resp in responses
]))
# generate annotations
base_annotations = self.oaif_annotator(
batch,
response_groups,
self.tokenizer,
)
annotated_ds = Dataset.from_dict(base_annotations).map(self.tokenize_row)
collated_ds = self.data_collator(annotated_ds)
# hack
collated_ds["attention_mask"] = collated_ds["prompt_attention_mask"]
collated_ds["input_ids"] = collated_ds["prompt_input_ids"]
return collated_ds
def get_train_dataloader(self):
dataloader = super().get_train_dataloader()
mutate_fn = lambda batch: {**batch, **self.oaif_label_rewards(batch)}
return DynamicDataLoader(dataloader, mutate_fn)
def evaluation_loop(self, dataloader, *args, metric_key_prefix="eval", **kwargs):
"""
Modified evaluate() which calculates the cosine similarity
hacky, this should be part of SharpenedCosineSimilarityAnnotator, not this class
"""
greedy_responses = []
true_responses = []
with eval_mode(model):
for batch in dataloader:
queries = batch["prompt"]
true_responses += batch["resolved_prompt"]
masked_inputs = self.tokenizer(queries, padding=True, return_tensors="pt").to("cuda")
generation = self.model.generate(
**masked_inputs,
pad_token_id=self.tokenizer.pad_token_id,
max_new_tokens=128,
do_sample=False,
)
greedy_responses += list(generation[:, masked_inputs.input_ids.shape[1]:])
observed_responses = [
self.tokenizer.decode(resp, skip_special_tokens=True).strip()
for resp in greedy_responses
]
assert len(observed_responses) == len(true_responses)
rewards = list(map(float,
self.oaif_annotator.get_reward(observed_responses, true_responses)
))
exact_prefix = f"{metric_key_prefix}_oaif_exact"
q1, q2, q3 = statistics.quantiles(rewards, n=4)
oaif_metrics = {
f"{exact_prefix}_mean": statistics.mean(rewards),
f"{exact_prefix}_std_dev": statistics.stdev(rewards),
f"{exact_prefix}_min": min(rewards),
f"{exact_prefix}_q1": q1,
f"{exact_prefix}_q2": q2,
f"{exact_prefix}_q3": q3,
f"{exact_prefix}_max": max(rewards),
}
print()
for i in range(4):
print(f"rewards[{i}]:", rewards[i])
print(f"\tobserved_prompt[{i}]:", observed_responses[i])
print(f"\ttrue_prompt[{i}]:", true_responses[i])
max_reward_idx = rewards.index(max(rewards))
print(f"rewards[max]:", rewards[max_reward_idx])
print(f"\tobserved_prompt[max]:", observed_responses[max_reward_idx])
print(f"\ttrue_prompt[max]:", true_responses[max_reward_idx])
print()
return EvalLoopOutput(
predictions=None,
label_ids=None,
metrics=oaif_metrics,
num_samples=len(true_responses),
)
class SharpenedCosineSimilarityAnnotator:
def __init__(self, reference_prompts, embed_model_id='sentence-transformers/sentence-t5-base'):
self.embed_model = SentenceTransformer(embed_model_id)
self.reference_prompts = reference_prompts
self.embedding_cache = {}
def get_embedding(self, s):
if s not in self.embedding_cache:
self.embedding_cache[s] = self.embed_model.encode([s])[0]
return self.embedding_cache[s]
def get_sharpened_cos_sim(self, observed_prompts, true_prompts):
observed_embeddings = [self.get_embedding(p) for p in observed_prompts]
true_embeddings = [self.get_embedding(p) for p in true_prompts]
similarities = [
(1 - cosine(o, r))**3
for o, r in zip(observed_embeddings, true_embeddings)
]
return torch.tensor(similarities)
def get_closest_n_similarity(self, observed_prompt, n=None):
"""
get the reward of the n'th most similar prompt
"""
if n is None:
n = int(len(self.reference_prompts) / 4)
observed_embedding = self.get_embedding(observed_prompt)
reference_embeddings = [self.get_embedding(tp) for tp in self.reference_prompts]
similarities = torch.tensor([
(1 - cosine(observed_embedding, ref_emb))**3
for ref_emb in reference_embeddings
])
top_n_sim, _ = torch.topk(similarities, n)
return torch.min(top_n_sim)
def get_reward(self, observed_prompts, true_prompts):
exact_rewards = self.get_sharpened_cos_sim(
list(map(str.strip, observed_prompts)),
true_prompts
)
# 3/4th reward if \n in response
return [
r * 3 / 4 if "\n" in op.strip() else r
for op, r in zip(observed_prompts, exact_rewards)
]
def __call__(self, batch, response_groups, tokenizer):
chosen = []
rejected = []
for resp_ids_group, resolved_prompt in zip(response_groups, batch["resolved_prompt"]):
best = None
best_reward = None
worst = None
worst_reward = None
for resp_ids in resp_ids_group:
response_string = tokenizer.decode(resp_ids, skip_special_tokens=True)
reward = self.get_reward([response_string], [resolved_prompt])[0]
if best_reward is None or reward > best_reward:
best_reward = reward
best = tokenizer.decode(resp_ids, skip_special_tokens=False)
if worst_reward is None or reward < worst_reward:
worst_reward = reward
worst = tokenizer.decode(resp_ids, skip_special_tokens=False)
chosen.append(best)
rejected.append(worst)
return {
"prompt": batch["prompt"],
"chosen": chosen,
"rejected": rejected,
}
def get_unsloth_model(base_model_name, max_seq_length=2048):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model_name,
max_seq_length=max_seq_length,
dtype=None,
load_in_4bit=True,
)
peft_model = FastLanguageModel.get_peft_model(
model,
target_modules=[
"q_proj", "v_proj", "k_proj", "o_proj", # attention (self_attn)
"gate_proj", "down_proj", "up_proj", # FFN (mlp)
],
r=8,
lora_alpha=32,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = True,
max_seq_length=max_seq_length,
)
tokenizer.padding_side = "left"
return peft_model, tokenizer
def get_trainer(
model,
tokenizer,
train_dataset,
eval_dataset,
output_dir,
hub_repo_id,
# parameters from paper
# https://www.semanticscholar.org/reader/04d64be16fb402f28348faffef484bd419c8bd8f
train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=1e-5,
beta=0.1,
eval_batch_size=8,
loss_type="dpo",
warmup_steps=0,
max_prompt_length=2048,
max_length=2048 + 128,
seed=42
):
training_args = TrainingArguments(
learning_rate=learning_rate,
lr_scheduler_type="constant",
warmup_steps=warmup_steps,
optim="paged_adamw_8bit",
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=eval_batch_size,
report_to="tensorboard",
logging_steps=1,
evaluation_strategy="steps",
eval_steps=12,
save_strategy="steps",
save_steps=12,
push_to_hub=True,
hub_private_repo=True,
hub_model_id=hub_repo_id,
hub_strategy="every_save",
#load_best_model_at_end=True,
gradient_checkpointing=True,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing_kwargs=dict(use_reentrant=True),
output_dir=output_dir,
bf16=torch.cuda.is_bf16_supported(),
fp16=not torch.cuda.is_bf16_supported(),
seed=seed,
# experients suggest learning past 1.5 epochs is mostly useless
num_train_epochs=3,
# hack
metric_for_best_model=None,
# stabilize
max_grad_norm=10.0,
# label
run_name="oaif_standard",
)
return OAIFTrainer(
model,
ref_model=None, # TODO: would be nice to eval against base model by stripping the adapters
args=training_args,
beta=beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_prompt_length=max_prompt_length,
max_length=max_length,
oaif_annotator=SharpenedCosineSimilarityAnnotator(train_dataset["resolved_prompt"]),
eval_oaif_annotator=SharpenedCosineSimilarityAnnotator(eval_dataset["resolved_prompt"]),
)
prompt_template = """<|im_start|>user
You are a reverse prompt generator. The user will provide two texts, "Original Text" and "Rewritten Text". You will determine the "Transformation Prompt" which was applied to "Original Text" with a language model to generate "Rewritten Text". Analyze the changes in style, theme, etc to determine whach "Transformation Prompt" was used with a language model to convert "Original Text" into "Rewritten Text"<|im_end|>
<|im_start|>user
Original Text:
'''
{}
'''
Rewritten Text:
'''
{}
'''
What is the Transformation Prompt which was applied to modify the input text into the transformed output text?<|im_end|>
<|im_start|>assistant
Transformation Prompt:
"""
def format_query(example):
return prompt_template.format(
example['original_text'].strip(),
example['rewritten_text'].strip(),
)
def get_datasets():
def coll(example):
return {
"resolved_prompt": example["prompt"],
"prompt": format_query(example),
"chosen": "",
"rejected": "",
}
# Loading the datasets
ds = load_from_disk("dataset_v43")
train_dataset = ds["oaif_train"]
eval_dataset = ds["eval"]
# Applying transformations
train_dataset = train_dataset.map(coll)
train_dataset = train_dataset.remove_columns([
col for col in train_dataset.column_names
if col not in ['prompt', 'resolved_prompt', "chosen", "rejected"]
])
train_dataset = train_dataset.shuffle(seed=42)
eval_dataset = eval_dataset.map(coll)
eval_dataset = eval_dataset.remove_columns([
col for col in eval_dataset.column_names
if col not in ['prompt', 'resolved_prompt', "chosen", "rejected"]
])
return train_dataset, eval_dataset
if __name__ == "__main__":
output_dir = "oaif_v43.1"
base_model_name = "sft_v43_merged"
hub_repo_id = f"lapp0/{output_dir}"
model, tokenizer = get_unsloth_model(base_model_name=base_model_name)
train_ds, eval_ds = get_datasets()
trainer = get_trainer(
model,
tokenizer,
train_ds,
eval_ds,
output_dir=output_dir,
hub_repo_id=hub_repo_id
)
trainer.evaluate() # step 0
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment