Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created February 15, 2024 15:20
Show Gist options
  • Save vwxyzjn/c4b2d0142deca9450fbc7836d99c8e9a to your computer and use it in GitHub Desktop.
Save vwxyzjn/c4b2d0142deca9450fbc7836d99c8e9a to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from datasets import load_dataset
import llm_blender
from transformers import HfArgumentParser
import multiprocessing
import random
import itertools
import warnings
from collections import defaultdict
warnings.filterwarnings("ignore")
@dataclass
class Args:
path: str = "HuggingFaceH4/openhermes_2.5_dpo_v0"
"""Path to the dataset"""
split: str = "train_prefs"
"""Dataset split to use"""
output_path: str = "openhermes_2.5_dpo_pairrm_v0"
"""Save to disk path"""
batch_size: int = 512
"""Batch size for dataset mapping function"""
num_shards: int = 1
"""Number of shards to split the data"""
shard_index: int = 0
"""Index of the shard to use"""
max_samples: int = 1024
"""The maximum umber of samples to generate (use -1 for all))"""
debug: bool = False
"""Debug mode"""
parser = HfArgumentParser([Args])
args = parser.parse_args_into_dataclasses()[0]
blender = llm_blender.Blender()
blender.loadranker("llm-blender/PairRM")
def prepare_conversation(conversation):
transformed_conversation = [
{
"content": turn["content"],
"role": "USER" if turn["role"] == "user" else "ASSISTANT",
}
for turn in conversation
]
return transformed_conversation
def pairRM(row, batch_size=80):
inputs = [
"\n".join([
"USER: " + x[i]['content'] +
f"\nAssistant: <Response {i//2+1}>" for i in range(0, len(x), 2)
]) for x in [prepare_conversation(item) for item in row["candidate0"]]
]
cand0_texts = [
"\n".join([
f"<Response {i//2+1}>: " + x[i]['content'] for i in range(1, len(x), 2)
]) for x in [prepare_conversation(item) for item in row["candidate0"]]
]
cand1_texts = [
"\n".join([
f"<Response {i//2+1}>: " + x[i]['content'] for i in range(1, len(x), 2)
]) for x in [prepare_conversation(item) for item in row["candidate1"]]
]
cand2_texts = [
"\n".join([
f"<Response {i//2+1}>: " + x[i]['content'] for i in range(1, len(x), 2)
]) for x in [prepare_conversation(item) for item in row["candidate2"]]
]
results = blender.rank(
inputs,
list(zip(cand0_texts, cand1_texts, cand2_texts)),
)
print(results)
ranks = [" > ".join([row[f'candidate{p-1}_policy'][i] for p in item]) for i, item in enumerate(results)]
row["ranks"] = ranks
# results0 = blender.compare_conversations(
# [prepare_conversation(item) for item in row["candidate0"]],
# [prepare_conversation(item) for item in row["candidate1"]],
# batch_size=batch_size,
# )
# results1 = blender.compare_conversations(
# [prepare_conversation(item) for item in row["candidate0"]],
# [prepare_conversation(item) for item in row["candidate2"]],
# batch_size=batch_size,
# )
# results2 = blender.compare_conversations(
# [prepare_conversation(item) for item in row["candidate1"]],
# [prepare_conversation(item) for item in row["candidate2"]],
# batch_size=batch_size,
# )
# # calculate rank order
# prefs = defaultdict(list)
# for i in range(len(results0)):
# if results0[i]:
# prefs[f"{row['candidate0_policy'][i]} 🆚 {row['candidate1_policy'][i]}"].append("✅win")
# prefs[f"{row['candidate1_policy'][i]} 🆚 {row['candidate0_policy'][i]}"].append("❌loss")
# else:
# prefs[f"{row['candidate0_policy'][i]} 🆚 {row['candidate1_policy'][i]}"].append("❌loss")
# prefs[f"{row['candidate1_policy'][i]} 🆚 {row['candidate0_policy'][i]}"].append("✅win")
# if results1[i]:
# prefs[f"{row['candidate0_policy'][i]} 🆚 {row['candidate2_policy'][i]}"].append("✅win")
# prefs[f"{row['candidate2_policy'][i]} 🆚 {row['candidate0_policy'][i]}"].append("❌loss")
# else:
# prefs[f"{row['candidate0_policy'][i]} 🆚 {row['candidate2_policy'][i]}"].append("❌loss")
# prefs[f"{row['candidate2_policy'][i]} 🆚 {row['candidate0_policy'][i]}"].append("✅win")
# if results2[i]:
# prefs[f"{row['candidate1_policy'][i]} 🆚 {row['candidate2_policy'][i]}"].append("✅win")
# prefs[f"{row['candidate2_policy'][i]} 🆚 {row['candidate1_policy'][i]}"].append("❌loss")
# else:
# prefs[f"{row['candidate1_policy'][i]} 🆚 {row['candidate2_policy'][i]}"].append("❌loss")
# prefs[f"{row['candidate2_policy'][i]} 🆚 {row['candidate1_policy'][i]}"].append("✅win")
# for k, v in prefs.items():
# row[k] = v
return row
ds = load_dataset(args.path, split=args.split)
if args.max_samples > 0:
ds = ds.select(range(args.max_samples))
def modify(row):
responses = [row["candidate0"], row["candidate1"], row["candidate2"]]
policies = [row["candidate0_policy"], row["candidate1_policy"], row["candidate2_policy"]]
indices = [0, 1, 2]
random.shuffle(indices)
row["candidate0"] = responses[indices[0]]
row["candidate1"] = responses[indices[1]]
row["candidate2"] = responses[indices[2]]
row["candidate0_policy"] = policies[indices[0]]
row["candidate1_policy"] = policies[indices[1]]
row["candidate2_policy"] = policies[indices[2]]
return row
ds = ds.map(modify, load_from_cache_file=False, num_proc=1 if args.debug else multiprocessing.cpu_count())
# ds = ds.remove_columns(
# [
# 'system_prompt', 'model', 'avatarUrl', 'conversations', 'title',
# 'skip_prompt_formatting', 'idx', 'hash', 'views', 'custom_instruction',
# 'language', 'id', 'model_name', 'chosen_policy', 'chosen',
# 'token_length', 'rejected', 'rejected_policy',
# ]
# )
df = ds.to_pandas()
# print(df["candidate0_policy"][:10])
# print(df["candidate0"][0])
shard = ds.shard(num_shards=args.num_shards, index=args.shard_index)
pairrm_shard = shard.map(pairRM, batched=True, batch_size=args.batch_size, load_from_cache_file=False)
pairrm_shard.save_to_disk(f"{args.output_path}_{args.split}_{args.shard_index}")
# visualization
df = pairrm_shard.to_pandas()
# print(df["candidate0_policy"][:10])
# print(df["candidate0"][0])
print(args.path)
print(df["ranks"].value_counts())
# print(df["chosen_policy"][:10])
# row = ds[0]
# for comb in list(itertools.combinations(range(3), 2)):
# policy0, policy1 = row[f"candidate{comb[0]}_policy"], row[f"candidate{comb[1]}_policy"]
# print(df[f"{policy0} 🆚 {policy1}"].value_counts())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment