Skip to content

Instantly share code, notes, and snippets.

@koustuvsinha
Created December 27, 2022 01:49
Show Gist options
  • Save koustuvsinha/555051d2112fd999ff1159436cadfd07 to your computer and use it in GitHub Desktop.
Save koustuvsinha/555051d2112fd999ff1159436cadfd07 to your computer and use it in GitHub Desktop.
Zero Shot relation extractor for CLUTRR
# Zero shot inference to verify AMT templates
# Using Flan T5
import argparse
import time
from transformers import T5Tokenizer, T5ForConditionalGeneration
from tqdm.auto import tqdm
import pandas as pd
import torch
import random
import os
import pickle as pkl
from pathlib import Path
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
from sacremoses import MosesTokenizer, MosesDetokenizer
moses_detokenizer = MosesDetokenizer()
moses_tokenizer = MosesTokenizer()
def get_common_names():
df = pd.read_csv(
"https://github.com/PhantomInsights/baby-names-analysis/raw/master/data/data.csv"
)
top_female_names = (
df[df.gender == "F"].sort_values(by="count", ascending=False).name.unique()[0:5]
)
top_male_names = (
df[df.gender == "M"].sort_values(by="count", ascending=False).name.unique()[0:5]
)
return top_male_names, top_female_names
def convert_to_common_names(sent, top_male_names, top_female_names):
entity_map = {}
names = {"male": list(top_male_names), "female": list(top_female_names)}
tokens = moses_tokenizer.tokenize(sent)
start = 0
new_sent = []
while start < len(tokens):
cur_token = tokens[start]
if cur_token == "ENT":
ent_id = tokens[start + 2]
ent_gender = "female" if tokens[start + 4].startswith("female") else "male"
key = f"ENT_{ent_id}_{ent_gender}"
if key in entity_map:
new_sent.append(entity_map[key])
else:
ent_name = names[ent_gender].pop(0)
entity_map[key] = ent_name
new_sent.append(ent_name)
start += 5
else:
new_sent.append(cur_token)
start += 1
return moses_detokenizer.detokenize(new_sent), entity_map
def isOpenAIAPI(model_key):
if model_key in ["text-davinci-003"]:
return True
else:
return False
def setup_model(model_key="google/flan-t5-xl"):
if isOpenAIAPI(model_key):
return None, None
tokenizer = T5Tokenizer.from_pretrained(model_key)
model = T5ForConditionalGeneration.from_pretrained(
model_key,
device_map="auto",
)
model = model.cuda()
return model, tokenizer
def apply_prompt(text_or_array_of_text, head_entity, tail_entity):
if type(text_or_array_of_text) == str:
text_or_array_of_text = [text_or_array_of_text]
head_entity = [head_entity]
tail_entity = [tail_entity]
return [
f"{text} \n Fill in the blank: {h} is the _ of {t}."
for text, h, t in zip(text_or_array_of_text, head_entity, tail_entity)
]
# openAI only allows 20 parallel prompts
def predict_relations(
model_key, df, model, tokenizer, convert_names=False, batch_size=20
):
if convert_names:
top_male_names, top_female_names = get_common_names()
else:
top_male_names, top_female_names = [], []
def infer_relations(text, head_entity, tail_entity):
"""
Infer relation using Flan T5
"""
if convert_names:
text, _ = convert_to_common_names(text, top_male_names, top_female_names)
# input_text = (
# f"{text} \n Fill in the blank: {head_entity} is the _ of {tail_entity}."
# )
input_text = apply_prompt(text, head_entity, tail_entity)
assert type(input_text) == list
if not isOpenAIAPI(model_key):
assert len(input_text) == 1
input_text = input_text[0]
with torch.no_grad():
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(
"cuda"
)
outputs = model.generate(input_ids)
return (
tokenizer.decode(outputs[0])
.replace("<pad>", "")
.replace("</s>", "")
.lstrip()
.rstrip()
)
else:
time.sleep(2)
response = openai.Completion.create(
model=model_key,
prompt=input_text,
temperature=0,
max_tokens=60,
top_p=1,
frequency_penalty=0.5,
presence_penalty=0,
)
output = [
row["text"].lstrip().rstrip().lower() for row in response["choices"]
]
return output
predicted = []
# Prepare
queries = []
for i, row in df.iterrows():
gender_comb = row["gender_comb"]
entities = [f"ENT_{i}_{gen}" for i, gen in enumerate(gender_comb.split("-"))]
relations = []
for ei in range(len(entities) - 1):
head = entities[ei + 1]
tail = entities[ei]
queries.append((row["template"], head, tail, i, ei))
# Process
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
# Update
results = []
tmp_results = "tmp_results.pkl"
pb = tqdm(total=len(queries))
if Path(tmp_results).exists():
print("Loading previously saved results...")
results = pkl.load(open(tmp_results, "rb"))
batch_of_queries = list(chunks(queries[len(results) :], batch_size))
pb.update(len(results))
else:
batch_of_queries = list(chunks(queries, batch_size))
for bq in batch_of_queries:
input = [row[0] for row in bq]
head = [row[1] for row in bq]
tail = [row[2] for row in bq]
res = infer_relations(input, head, tail)
if type(res) == list:
results.extend(res)
else:
results.append(res)
pkl.dump(results, open(tmp_results, "wb"))
pb.update(len(bq))
pb.close()
assert len(results) == len(queries)
os.remove(tmp_results)
buffer = []
last_row = -1
for qi, query in enumerate(queries):
if last_row != query[3] and len(buffer) > 0:
predicted.append("-".join(buffer))
buffer = []
rel = results[qi]
buffer.append(rel)
last_row = query[3]
# relations.append(infer_relations(row["template"], head, tail))
if len(buffer) > 0:
predicted.append("-".join(buffer))
df[model_key] = predicted
return df
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", default="google/flan-t5-xl")
parser.add_argument("-s", "--split", default="test")
parser.add_argument("-n", "--names", action="store_true", default=False)
parser.add_argument("-d", "--debug", action="store_true", default=False)
args = parser.parse_args()
print("Loading model ...")
model, tokenizer = setup_model(args.model)
# Running the extraction on the full dataset
df = pd.read_csv(
f"https://raw.githubusercontent.com/facebookresearch/clutrr/develop/clutrr/templates/amt/{args.split}.csv"
)
if args.debug:
print("In debug mode, testing 2 examples")
df = df.iloc[:2]
print("Extracting relations...")
print(f"In {args.split} set.")
df = predict_relations(args.model, df, model, tokenizer, args.names)
extra_markers = []
if not isOpenAIAPI(args.model):
extra_markers.append("fp32")
if args.names:
extra_markers.append("named")
if args.debug:
extra_markers.append("debug")
df.to_csv(
f"clutrr_amt_{args.split}_{'_'.join(extra_markers)}_{args.model.split('/')[-1]}.csv"
)
print("Done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment