Skip to content

Instantly share code, notes, and snippets.

@lewtun
Last active February 28, 2024 11:28
Show Gist options
  • Save lewtun/1160cef2ae45be20a95bf4a1dc5de40f to your computer and use it in GitHub Desktop.
Save lewtun/1160cef2ae45be20a95bf4a1dc5de40f to your computer and use it in GitHub Desktop.
View LightEval predictions
"""
First install: pip install datasets pandas rich transformers
Usage:
# Loglikelihood evals
python view_details.py --filepath path/to/parquet/details
# Generative evals
python view_details.py --filepath path/to/parquet/details --is_generative
"""
from dataclasses import dataclass
import pandas as pd
from datasets import load_dataset
from rich.console import Console
from rich.table import Table
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
filepath: str
is_generative: bool = False
def main():
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
console = Console()
ds = load_dataset("parquet", data_files=[args.filepath], split="train")
def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table:
table = Table(show_lines=True)
for column in df.columns:
table.add_column(column)
for _, row in df.iterrows():
table.add_row(*row.astype(str).tolist())
console.rule(f"[bold red]{title}")
console.print(table)
df = ds.to_pandas()
if args.is_generative:
df = df[["full_prompt", "predictions"]]
else:
df = df[["full_prompt", "choices", "pred_logits", "gold_index"]]
for i in range(len(df)):
print_rich_table(f"Row {i}", df.iloc[i : i + 1], console)
input("Press Enter to continue...")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment