Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created February 29, 2024 23:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save vwxyzjn/a150876d94445b6df2c4b80ff1290425 to your computer and use it in GitHub Desktop.
Save vwxyzjn/a150876d94445b6df2c4b80ff1290425 to your computer and use it in GitHub Desktop.
from collections import defaultdict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from datasets import load_dataset
from rich.console import Console
from rich.table import Table
from transformers import (
AutoTokenizer,
PreTrainedModel,
AutoModelForCausalLM,
GenerationConfig,
)
######
# Utility functions
######
def generate(lm_backbone, queries, tokenizer, generation_config):
"""generate in a way that does not affect padding tokens"""
context_length = queries.shape[1]
attention_mask = queries != tokenizer.pad_token_id
input_ids = torch.masked_fill(queries, ~attention_mask, 0)
output = lm_backbone.generate(
input_ids=input_ids,
attention_mask=attention_mask,
# position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this?
generation_config=generation_config,
return_dict_in_generate=True,
)
return torch.cat((queries, output.sequences[:, context_length:]), dim=1)
def forward(model, query_responses, tokenizer):
attention_mask = query_responses != tokenizer.pad_token_id
position_ids = attention_mask.cumsum(1) - attention_mask.long()
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
return model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
)
def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table:
table = Table(show_lines=True)
for column in df.columns:
table.add_column(column)
for _, row in df.iterrows():
table.add_row(*row.astype(str).tolist())
console.rule(f"[bold red]{title}")
console.print(table)
######
# Start
######
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b-deduped")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
response_length = 128
validation_generation_config = GenerationConfig(
max_new_tokens=response_length,
temperature=(0.01 + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
sft_dataset = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144")
base_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-1b-deduped").to(device)
# https://wandb.ai/costa-huang/tldr_summarize/runs/a0rutstb
# https://huggingface.co/vwxyzjn/EleutherAI_pythia-1b-deduped__sft__tldr/tree/sft__55513__1706646024
sft_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
"vwxyzjn/EleutherAI_pythia-1b-deduped__sft__tldr",
revision="sft__55513__1706646024",
trust_remote_code=True,
).to(device)
# https://wandb.ai/costa-huang/tldr_summarize/runs/ulekmmac
# https://huggingface.co/vwxyzjn/EleutherAI_pythia-1b-deduped__ppo_left_padding__tldr/tree/ppo_left_padding__55513__1706746254
ppo_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
"vwxyzjn/EleutherAI_pythia-1b-deduped__ppo_left_padding__tldr",
revision="ppo_left_padding__55513__1706746254",
trust_remote_code=True,
).to(device)
compared_models = {
"base_model": base_model,
"aligned_model": sft_model,
}
console = Console()
for i in range(len(sft_dataset["validation"])):
table = defaultdict(list)
query = torch.Tensor(sft_dataset["validation"][i]["query_token"]).to(device).long()
query_reference_response = torch.Tensor(sft_dataset["validation"][i]["query_reference_response_token"]).to(device).long()
with torch.no_grad():
# sft_query_response = generate(sft_model, query.unsqueeze(0), tokenizer, validation_generation_config)
# sft_response = sft_query_response[:, query.shape[0]:]
# ppo_query_response = generate(ppo_model, query.unsqueeze(0), tokenizer, validation_generation_config)
# ppo_response = ppo_query_response[:, query.shape[0]:]
aligned_model = compared_models["aligned_model"]
aligned_model_response = generate(aligned_model, query.unsqueeze(0), tokenizer, validation_generation_config)
raise
# # print results
# table["type"].append("Query")
# table["content"].append(tokenizer.decode(query, skip_special_tokens=True))
# table["type"].append("SFT response")
# table["content"].append(tokenizer.decode(sft_response[0]))
# table["type"].append("PPO response")
# table["content"].append(tokenizer.decode(ppo_response[0]))
# df = pd.DataFrame(table)
# print_rich_table("Results", df, console)
# if input("Continue? (press `n` to stop) ") == "n":
# break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment