Created
February 15, 2024 15:20
-
-
Save vwxyzjn/c4b2d0142deca9450fbc7836d99c8e9a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from datasets import load_dataset | |
import llm_blender | |
from transformers import HfArgumentParser | |
import multiprocessing | |
import random | |
import itertools | |
import warnings | |
from collections import defaultdict | |
warnings.filterwarnings("ignore") | |
@dataclass | |
class Args: | |
path: str = "HuggingFaceH4/openhermes_2.5_dpo_v0" | |
"""Path to the dataset""" | |
split: str = "train_prefs" | |
"""Dataset split to use""" | |
output_path: str = "openhermes_2.5_dpo_pairrm_v0" | |
"""Save to disk path""" | |
batch_size: int = 512 | |
"""Batch size for dataset mapping function""" | |
num_shards: int = 1 | |
"""Number of shards to split the data""" | |
shard_index: int = 0 | |
"""Index of the shard to use""" | |
max_samples: int = 1024 | |
"""The maximum umber of samples to generate (use -1 for all))""" | |
debug: bool = False | |
"""Debug mode""" | |
parser = HfArgumentParser([Args]) | |
args = parser.parse_args_into_dataclasses()[0] | |
blender = llm_blender.Blender() | |
blender.loadranker("llm-blender/PairRM") | |
def prepare_conversation(conversation): | |
transformed_conversation = [ | |
{ | |
"content": turn["content"], | |
"role": "USER" if turn["role"] == "user" else "ASSISTANT", | |
} | |
for turn in conversation | |
] | |
return transformed_conversation | |
def pairRM(row, batch_size=80): | |
inputs = [ | |
"\n".join([ | |
"USER: " + x[i]['content'] + | |
f"\nAssistant: <Response {i//2+1}>" for i in range(0, len(x), 2) | |
]) for x in [prepare_conversation(item) for item in row["candidate0"]] | |
] | |
cand0_texts = [ | |
"\n".join([ | |
f"<Response {i//2+1}>: " + x[i]['content'] for i in range(1, len(x), 2) | |
]) for x in [prepare_conversation(item) for item in row["candidate0"]] | |
] | |
cand1_texts = [ | |
"\n".join([ | |
f"<Response {i//2+1}>: " + x[i]['content'] for i in range(1, len(x), 2) | |
]) for x in [prepare_conversation(item) for item in row["candidate1"]] | |
] | |
cand2_texts = [ | |
"\n".join([ | |
f"<Response {i//2+1}>: " + x[i]['content'] for i in range(1, len(x), 2) | |
]) for x in [prepare_conversation(item) for item in row["candidate2"]] | |
] | |
results = blender.rank( | |
inputs, | |
list(zip(cand0_texts, cand1_texts, cand2_texts)), | |
) | |
print(results) | |
ranks = [" > ".join([row[f'candidate{p-1}_policy'][i] for p in item]) for i, item in enumerate(results)] | |
row["ranks"] = ranks | |
# results0 = blender.compare_conversations( | |
# [prepare_conversation(item) for item in row["candidate0"]], | |
# [prepare_conversation(item) for item in row["candidate1"]], | |
# batch_size=batch_size, | |
# ) | |
# results1 = blender.compare_conversations( | |
# [prepare_conversation(item) for item in row["candidate0"]], | |
# [prepare_conversation(item) for item in row["candidate2"]], | |
# batch_size=batch_size, | |
# ) | |
# results2 = blender.compare_conversations( | |
# [prepare_conversation(item) for item in row["candidate1"]], | |
# [prepare_conversation(item) for item in row["candidate2"]], | |
# batch_size=batch_size, | |
# ) | |
# # calculate rank order | |
# prefs = defaultdict(list) | |
# for i in range(len(results0)): | |
# if results0[i]: | |
# prefs[f"{row['candidate0_policy'][i]} 🆚 {row['candidate1_policy'][i]}"].append("✅win") | |
# prefs[f"{row['candidate1_policy'][i]} 🆚 {row['candidate0_policy'][i]}"].append("❌loss") | |
# else: | |
# prefs[f"{row['candidate0_policy'][i]} 🆚 {row['candidate1_policy'][i]}"].append("❌loss") | |
# prefs[f"{row['candidate1_policy'][i]} 🆚 {row['candidate0_policy'][i]}"].append("✅win") | |
# if results1[i]: | |
# prefs[f"{row['candidate0_policy'][i]} 🆚 {row['candidate2_policy'][i]}"].append("✅win") | |
# prefs[f"{row['candidate2_policy'][i]} 🆚 {row['candidate0_policy'][i]}"].append("❌loss") | |
# else: | |
# prefs[f"{row['candidate0_policy'][i]} 🆚 {row['candidate2_policy'][i]}"].append("❌loss") | |
# prefs[f"{row['candidate2_policy'][i]} 🆚 {row['candidate0_policy'][i]}"].append("✅win") | |
# if results2[i]: | |
# prefs[f"{row['candidate1_policy'][i]} 🆚 {row['candidate2_policy'][i]}"].append("✅win") | |
# prefs[f"{row['candidate2_policy'][i]} 🆚 {row['candidate1_policy'][i]}"].append("❌loss") | |
# else: | |
# prefs[f"{row['candidate1_policy'][i]} 🆚 {row['candidate2_policy'][i]}"].append("❌loss") | |
# prefs[f"{row['candidate2_policy'][i]} 🆚 {row['candidate1_policy'][i]}"].append("✅win") | |
# for k, v in prefs.items(): | |
# row[k] = v | |
return row | |
ds = load_dataset(args.path, split=args.split) | |
if args.max_samples > 0: | |
ds = ds.select(range(args.max_samples)) | |
def modify(row): | |
responses = [row["candidate0"], row["candidate1"], row["candidate2"]] | |
policies = [row["candidate0_policy"], row["candidate1_policy"], row["candidate2_policy"]] | |
indices = [0, 1, 2] | |
random.shuffle(indices) | |
row["candidate0"] = responses[indices[0]] | |
row["candidate1"] = responses[indices[1]] | |
row["candidate2"] = responses[indices[2]] | |
row["candidate0_policy"] = policies[indices[0]] | |
row["candidate1_policy"] = policies[indices[1]] | |
row["candidate2_policy"] = policies[indices[2]] | |
return row | |
ds = ds.map(modify, load_from_cache_file=False, num_proc=1 if args.debug else multiprocessing.cpu_count()) | |
# ds = ds.remove_columns( | |
# [ | |
# 'system_prompt', 'model', 'avatarUrl', 'conversations', 'title', | |
# 'skip_prompt_formatting', 'idx', 'hash', 'views', 'custom_instruction', | |
# 'language', 'id', 'model_name', 'chosen_policy', 'chosen', | |
# 'token_length', 'rejected', 'rejected_policy', | |
# ] | |
# ) | |
df = ds.to_pandas() | |
# print(df["candidate0_policy"][:10]) | |
# print(df["candidate0"][0]) | |
shard = ds.shard(num_shards=args.num_shards, index=args.shard_index) | |
pairrm_shard = shard.map(pairRM, batched=True, batch_size=args.batch_size, load_from_cache_file=False) | |
pairrm_shard.save_to_disk(f"{args.output_path}_{args.split}_{args.shard_index}") | |
# visualization | |
df = pairrm_shard.to_pandas() | |
# print(df["candidate0_policy"][:10]) | |
# print(df["candidate0"][0]) | |
print(args.path) | |
print(df["ranks"].value_counts()) | |
# print(df["chosen_policy"][:10]) | |
# row = ds[0] | |
# for comb in list(itertools.combinations(range(3), 2)): | |
# policy0, policy1 = row[f"candidate{comb[0]}_policy"], row[f"candidate{comb[1]}_policy"] | |
# print(df[f"{policy0} 🆚 {policy1}"].value_counts()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment