Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created February 26, 2024 18:05
Show Gist options
  • Save vwxyzjn/e083cdde4c401bb5d4029aa096dd0213 to your computer and use it in GitHub Desktop.
Save vwxyzjn/e083cdde4c401bb5d4029aa096dd0213 to your computer and use it in GitHub Desktop.
from datasets import load_dataset
import numpy as np
import pandas as pd
from rich.console import Console
from rich.table import Table
console = Console()
ds = load_dataset("HuggingFaceH4/OpenHermesPreferences", split="train")
idxs = np.random.choice(len(ds), 1000, replace=False)
ds = ds.select(idxs)
def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table:
table = Table(show_lines=True)
for column in df.columns:
table.add_column(column)
for _, row in df.iterrows():
table.add_row(*row.astype(str).tolist())
console.rule(f"[bold red]{title}")
console.print(table)
def modify(x):
x["chosen_text"] = "\n".join([
"😁 User: " + x["chosen"][i]['content'] +
f"\n🤖 Assistant: " + x["chosen"][i+1]['content'] for i in range(0, len(x["chosen"])-1)
])
x["rejected_text"] = "\n".join([
"😁 User: " + x["rejected"][i]['content'] +
f"\n🤖 Assistant: " + x["rejected"][i+1]['content'] for i in range(0, len(x["rejected"])-1)
])
return x
ds = ds.map(modify, load_from_cache_file=False)
ds = ds.filter(lambda x: len(x["chosen_text"]) < 1000 and len(x["rejected_text"]) < 1000, load_from_cache_file=False)
interested = ["teknium/OpenHermes-2.5", "mistralai/Mixtral-8x7B-Instruct-v0.1"]
ds = ds.filter(lambda x: x["chosen_policy"] in interested and x["rejected_policy"] in interested, load_from_cache_file=False)
df = ds.to_pandas()
df = df[["chosen_policy", "rejected_policy", "chosen_text", "rejected_text"]]
for i in range(len(df)):
print_rich_table(f"Row {i}", df.iloc[i:i+1], console)
input("Press Enter to continue...")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment