Skip to content

Instantly share code, notes, and snippets.

@napsternxg
Created August 15, 2023 00:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save napsternxg/86f3e1238ea66e12a39919687c085995 to your computer and use it in GitHub Desktop.
Save napsternxg/86f3e1238ea66e12a39919687c085995 to your computer and use it in GitHub Desktop.
T5 CausalLM Constrained Generation Using Tries
import functools
import pandas as pd
import torch
import transformers
from accelerate import Accelerator
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from t5_training_utils import (
GenerationType,
build_prefix_allowed_tokens_fn,
convert_to_features,
get_gen_type_attributes,
get_model_full_name,
get_prediction_name,
)
torch_dtype = "auto"
model_ckpt = "t5-base"
gen_type = GenerationType.ALL_TOKENS
input_max_length = 512
label_max_length = 6
use_task_prefix = True
class_names = [
"Soccer",
"Cricket",
"Handball",
"Snow Cycling",
]
non_eligible_classes = {
"Snow Cycling"
}
non_eligible_idx = [
i for i, c in enumerate(class_names) if c in non_eligible_classes
]
num_classes = len(class_names)
# Model training
### Uncomment a config section for the model type
## For small test run
train_batch_size = 8
eval_batch_size = 8
epochs = 30
save_every_k_epochs = 5
seed = 3333
torch.manual_seed(seed)
logging_steps = 100 # len(squad["train"]) // batch_size
eval_step = 100
learning_rate = 2e-5
weight_decay = 0.01
data_version = "guidelines-fixed-occasion"
model_full_name = get_model_full_name(model_ckpt, gen_type, epochs, data_version)
def get_model(model_local_ckpt):
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_local_ckpt)
return model.eval()
def get_dataset(
data_path, tokenizer, class_text_map, task_prefix, accelerator, nrows=1024, offset=0
):
df_data = pd.read_csv(data_path, sep="\t", nrows=nrows + offset).rename(
columns={"query": "text"}
)
df_data = df_data.iloc[offset : offset + nrows]
print(df_data)
dataset = Dataset.from_pandas(df_data)
dataset.reset_format()
with accelerator.main_process_first():
dataset = dataset.map(
functools.partial(
convert_to_features,
class_text_map=class_text_map,
task_prefix=task_prefix,
query_key="text",
label_key=None,
tokenizer=tokenizer,
)
)
return dataset, df_data
def get_predictions_accelerate(data_path, model_local_ckpt, nrows=1024, offset=0):
accelerator = Accelerator()
device = accelerator.device
tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)
class_text_map, max_decoding_length, task_prefix = get_gen_type_attributes(
gen_type, tokenizer, class_names
)
task_prefix = task_prefix if use_task_prefix else ""
dataset, df_data = get_dataset(
data_path,
tokenizer,
class_text_map,
task_prefix,
accelerator,
nrows=nrows,
offset=offset,
)
model = get_model(model_local_ckpt)
model = model.to(device)
allowed_sequences = [[0] + tokenizer.encode(x) for x in class_text_map.values()]
dataset.set_format("pt")
custom_dataloader = DataLoader(
dataset, shuffle=True, batch_size=eval_batch_size, num_workers=4
)
model, custom_dataloader = accelerator.prepare(model, custom_dataloader)
preds = []
with torch.no_grad():
for batch in tqdm(
custom_dataloader, disable=not accelerator.is_local_main_process
):
batch_input_ids = batch["input_ids"].to(device)
batch_attention_mask = batch["attention_mask"].to(device)
# For DDP models use accelerator.unwrap_model(model).generate(inputs)
# Taken from: https://github.com/huggingface/transformers/issues/18974
batch_outs = accelerator.unwrap_model(model).generate(
input_ids=batch_input_ids,
attention_mask=batch_attention_mask,
max_length=max_decoding_length,
prefix_allowed_tokens_fn=build_prefix_allowed_tokens_fn(
allowed_sequences
),
)
batch_outs = accelerator.pad_across_processes(
batch_outs, dim=1, pad_index=tokenizer.pad_token_id
)
batch_outs = accelerator.gather_for_metrics(batch_outs).cpu().numpy()
preds.extend(tokenizer.batch_decode(batch_outs, skip_special_tokens=True))
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if len(preds) != len(dataset):
raise ValueError(
f"Predictions and labels have different lengths. preds: {len(preds)} "
f"labels: {len(dataset)}"
)
pred_col = get_prediction_name(model_full_name)
df_data[pred_col] = preds
class_text_map_reversed = {val: key for key, val in class_text_map.items()}
df_data[pred_col] = df_data[pred_col].apply(lambda x: class_text_map_reversed[x])
# eligible = ~df_test[pred_col].isin(non_eligible_classes)
eligible = ~df_data[pred_col].isin(
{v for v in non_eligible_classes if v != "Cricket"}
)
df_data["eligible_pred"] = eligible
output_path = data_path.replace(".tsv", f".predicted.{offset}.{nrows}.tsv")
print(df_data)
print(f"Writing df_data with predictions to {output_path}")
df_data.to_csv(output_path, sep="\t", index=False)
return df_data
def main():
data_path = "data.tsv"
offset = 400_000
nrows = 153 # 600_000
model_local_ckpt = "./model_path/checkpoint-2830"
print(data_path)
print(nrows)
print(model_local_ckpt)
get_predictions_accelerate(data_path, model_local_ckpt, nrows=nrows, offset=offset)
if __name__ == "__main__":
main()
import enum
import os
import string
from typing import List, Mapping
import marisa_trie
import torch
class GenerationType(enum.Enum):
ALL_TOKENS = "all"
THREE_TOKENS = "gen3"
TWO_TOKENS = "gen2"
ONE_TOKEN = "gen1"
OPTION_ID = "abcd"
def get_model_full_name(
model_ckpt, gen_type: GenerationType, epochs: int, data_version: str = "0"
):
model_base_name = model_ckpt.split("/")[-1]
gen_type = gen_type.value
model_full_name = f"{model_base_name}_{gen_type}_ep{epochs}_dt{data_version}"
return model_full_name
def get_outdir(model_name: str):
return os.path.join("./classifers/", model_name)
def get_prediction_name(model_name: str):
return f"{model_name}_predicted"
class MarisaTrie(object):
def __init__(
self,
sequences: List[List[int]] = [],
cache_fist_branch=True,
max_token_id=256001,
):
self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + (
[chr(i) for i in range(65000, max_token_id + 10000)]
if max_token_id >= 55000
else []
)
self.char2int = {self.int2char[i]: i for i in range(max_token_id)}
self.cache_fist_branch = cache_fist_branch
if self.cache_fist_branch:
self.zero_iter = list({sequence[0] for sequence in sequences})
assert len(self.zero_iter) == 1
self.first_iter = list({sequence[1] for sequence in sequences})
self.trie = marisa_trie.Trie(
"".join([self.int2char[i] for i in sequence]) for sequence in sequences
)
def get(self, prefix_sequence: List[int]):
if self.cache_fist_branch and len(prefix_sequence) == 0:
return self.zero_iter
elif (
self.cache_fist_branch
and len(prefix_sequence) == 1
and self.zero_iter == prefix_sequence
):
return self.first_iter
else:
key = "".join([self.int2char[i] for i in prefix_sequence])
return list(
{
self.char2int[e[len(key)]]
for e in self.trie.keys(key)
if len(e) > len(key)
}
)
def __iter__(self):
for sequence in self.trie.iterkeys():
yield [self.char2int[e] for e in sequence]
def __len__(self):
return len(self.trie)
def __getitem__(self, value):
return self.get(value)
def map_class_name(tokenizer, class_raw_name, num_tokens=None, delim=" "):
if num_tokens is None:
return class_raw_name
class_raw_words = class_raw_name.split(delim)
for i in range(1, len(class_raw_words) + 1):
class_name_candidate = " ".join(class_raw_words[:i])
tokens = tokenizer.tokenize(class_name_candidate)
if len(tokens) == num_tokens:
return class_name_candidate
raise ValueError(
f"Cannot find class name at the specificed num_tokens: {class_raw_name}, {num_tokens}"
)
def create_class_text_map(tokenizer, class_raw_names, num_tokens, delim="_"):
res = {}
for raw_name in class_raw_names:
res[raw_name] = map_class_name(tokenizer, raw_name, num_tokens, delim=delim)
return res
def get_task_prefix(gen_type, class_text_map):
task_prefix = "Classify query intent into one of the following categories: "
if (
gen_type == GenerationType.TWO_TOKENS
or gen_type == GenerationType.ONE_TOKEN
or gen_type == GenerationType.ALL_TOKENS
):
classes = [f"'{x}'" for x in class_text_map.values()]
task_prefix += ", ".join(classes)
task_prefix += ". query: "
elif gen_type == GenerationType.OPTION_ID:
classes = [
f"{val}: {' '.join(key.split('_')[:-1])}"
for key, val in class_text_map.items()
]
task_prefix += "\n" + "\n".join(classes)
task_prefix += "\nquery: "
return task_prefix
def get_gen_type_attributes(gen_type, tokenizer, class_names):
if gen_type == GenerationType.THREE_TOKENS:
class_text_map = create_class_text_map(tokenizer, class_names, 3)
max_decoding_length = 3
elif gen_type == GenerationType.TWO_TOKENS:
class_text_map = create_class_text_map(tokenizer, class_names, 2)
max_decoding_length = 2
elif gen_type == GenerationType.ONE_TOKEN:
class_text_map = create_class_text_map(tokenizer, class_names, 1)
max_decoding_length = 1
elif gen_type == GenerationType.ALL_TOKENS:
class_text_map = create_class_text_map(tokenizer, class_names, None)
max_decoding_length = max(
[len([0] + tokenizer.encode(x)) for x in class_text_map.values()]
)
elif gen_type == GenerationType.OPTION_ID:
class_text_map = {}
for i, raw_name in enumerate(class_names):
class_text_map[raw_name] = string.ascii_uppercase[i]
max_decoding_length = 2
else:
raise ValueError(f"Non-existent `gen_type`: {gen_type}")
task_prefix = get_task_prefix(gen_type, class_text_map)
return class_text_map, max_decoding_length, task_prefix
def convert_to_features(
example_batch,
class_text_map: Mapping[str, str],
task_prefix: str,
input_max_length=512,
label_max_length=16,
query_key="query",
label_key="expected_single",
class_names=None,
tokenizer=None,
):
q = example_batch[query_key]
example_batch["input_text"] = f"{task_prefix}{q}"
input_encodings = tokenizer(
example_batch["input_text"],
padding="max_length",
max_length=input_max_length,
truncation=True,
)
encodings = {
"inputs": example_batch["input_text"],
"input_ids": input_encodings["input_ids"],
"attention_mask": input_encodings["attention_mask"],
}
if label_key:
label = class_text_map[class_names[example_batch[label_key]]]
example_batch["target_text"] = f"{label}"
target_encodings = tokenizer(
example_batch["target_text"],
padding="max_length",
max_length=label_max_length,
truncation=True,
)
encodings["labels"] = target_encodings["input_ids"]
return encodings
def preprocess_logits_for_metrics(logits, labels):
"""
Original Trainer may have a memory leak.
This is a workaround to avoid storing too many tensors that are not needed.
"""
pred_ids = torch.argmax(logits[0], dim=-1)
return pred_ids, labels
def build_prefix_allowed_tokens_fn(allowed_sequences):
"""Returns a function that provides next allowed tokens based on the prefix `seq`."""
t = MarisaTrie(allowed_sequences)
def fn(unused_batch_id, seq):
return t.get(seq)
return fn
def process_golden_labels(example_batch, class_text_map, class_names):
def fn(expected_text):
return [class_text_map[y.strip()] for y in expected_text.split(",")]
# example_batch['golden_labels'] = fn(example_batch['expected'])
example_batch["golden_labels"] = fn(class_names[example_batch["Label"]])
return example_batch
def compute_metrics(preds):
return {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment