-
-
Save koustuvsinha/555051d2112fd999ff1159436cadfd07 to your computer and use it in GitHub Desktop.
Zero Shot relation extractor for CLUTRR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Zero shot inference to verify AMT templates | |
# Using Flan T5 | |
import argparse | |
import time | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
from tqdm.auto import tqdm | |
import pandas as pd | |
import torch | |
import random | |
import os | |
import pickle as pkl | |
from pathlib import Path | |
import openai | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
from sacremoses import MosesTokenizer, MosesDetokenizer | |
moses_detokenizer = MosesDetokenizer() | |
moses_tokenizer = MosesTokenizer() | |
def get_common_names(): | |
df = pd.read_csv( | |
"https://github.com/PhantomInsights/baby-names-analysis/raw/master/data/data.csv" | |
) | |
top_female_names = ( | |
df[df.gender == "F"].sort_values(by="count", ascending=False).name.unique()[0:5] | |
) | |
top_male_names = ( | |
df[df.gender == "M"].sort_values(by="count", ascending=False).name.unique()[0:5] | |
) | |
return top_male_names, top_female_names | |
def convert_to_common_names(sent, top_male_names, top_female_names): | |
entity_map = {} | |
names = {"male": list(top_male_names), "female": list(top_female_names)} | |
tokens = moses_tokenizer.tokenize(sent) | |
start = 0 | |
new_sent = [] | |
while start < len(tokens): | |
cur_token = tokens[start] | |
if cur_token == "ENT": | |
ent_id = tokens[start + 2] | |
ent_gender = "female" if tokens[start + 4].startswith("female") else "male" | |
key = f"ENT_{ent_id}_{ent_gender}" | |
if key in entity_map: | |
new_sent.append(entity_map[key]) | |
else: | |
ent_name = names[ent_gender].pop(0) | |
entity_map[key] = ent_name | |
new_sent.append(ent_name) | |
start += 5 | |
else: | |
new_sent.append(cur_token) | |
start += 1 | |
return moses_detokenizer.detokenize(new_sent), entity_map | |
def isOpenAIAPI(model_key): | |
if model_key in ["text-davinci-003"]: | |
return True | |
else: | |
return False | |
def setup_model(model_key="google/flan-t5-xl"): | |
if isOpenAIAPI(model_key): | |
return None, None | |
tokenizer = T5Tokenizer.from_pretrained(model_key) | |
model = T5ForConditionalGeneration.from_pretrained( | |
model_key, | |
device_map="auto", | |
) | |
model = model.cuda() | |
return model, tokenizer | |
def apply_prompt(text_or_array_of_text, head_entity, tail_entity): | |
if type(text_or_array_of_text) == str: | |
text_or_array_of_text = [text_or_array_of_text] | |
head_entity = [head_entity] | |
tail_entity = [tail_entity] | |
return [ | |
f"{text} \n Fill in the blank: {h} is the _ of {t}." | |
for text, h, t in zip(text_or_array_of_text, head_entity, tail_entity) | |
] | |
# openAI only allows 20 parallel prompts | |
def predict_relations( | |
model_key, df, model, tokenizer, convert_names=False, batch_size=20 | |
): | |
if convert_names: | |
top_male_names, top_female_names = get_common_names() | |
else: | |
top_male_names, top_female_names = [], [] | |
def infer_relations(text, head_entity, tail_entity): | |
""" | |
Infer relation using Flan T5 | |
""" | |
if convert_names: | |
text, _ = convert_to_common_names(text, top_male_names, top_female_names) | |
# input_text = ( | |
# f"{text} \n Fill in the blank: {head_entity} is the _ of {tail_entity}." | |
# ) | |
input_text = apply_prompt(text, head_entity, tail_entity) | |
assert type(input_text) == list | |
if not isOpenAIAPI(model_key): | |
assert len(input_text) == 1 | |
input_text = input_text[0] | |
with torch.no_grad(): | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to( | |
"cuda" | |
) | |
outputs = model.generate(input_ids) | |
return ( | |
tokenizer.decode(outputs[0]) | |
.replace("<pad>", "") | |
.replace("</s>", "") | |
.lstrip() | |
.rstrip() | |
) | |
else: | |
time.sleep(2) | |
response = openai.Completion.create( | |
model=model_key, | |
prompt=input_text, | |
temperature=0, | |
max_tokens=60, | |
top_p=1, | |
frequency_penalty=0.5, | |
presence_penalty=0, | |
) | |
output = [ | |
row["text"].lstrip().rstrip().lower() for row in response["choices"] | |
] | |
return output | |
predicted = [] | |
# Prepare | |
queries = [] | |
for i, row in df.iterrows(): | |
gender_comb = row["gender_comb"] | |
entities = [f"ENT_{i}_{gen}" for i, gen in enumerate(gender_comb.split("-"))] | |
relations = [] | |
for ei in range(len(entities) - 1): | |
head = entities[ei + 1] | |
tail = entities[ei] | |
queries.append((row["template"], head, tail, i, ei)) | |
# Process | |
def chunks(lst, n): | |
"""Yield successive n-sized chunks from lst.""" | |
for i in range(0, len(lst), n): | |
yield lst[i : i + n] | |
# Update | |
results = [] | |
tmp_results = "tmp_results.pkl" | |
pb = tqdm(total=len(queries)) | |
if Path(tmp_results).exists(): | |
print("Loading previously saved results...") | |
results = pkl.load(open(tmp_results, "rb")) | |
batch_of_queries = list(chunks(queries[len(results) :], batch_size)) | |
pb.update(len(results)) | |
else: | |
batch_of_queries = list(chunks(queries, batch_size)) | |
for bq in batch_of_queries: | |
input = [row[0] for row in bq] | |
head = [row[1] for row in bq] | |
tail = [row[2] for row in bq] | |
res = infer_relations(input, head, tail) | |
if type(res) == list: | |
results.extend(res) | |
else: | |
results.append(res) | |
pkl.dump(results, open(tmp_results, "wb")) | |
pb.update(len(bq)) | |
pb.close() | |
assert len(results) == len(queries) | |
os.remove(tmp_results) | |
buffer = [] | |
last_row = -1 | |
for qi, query in enumerate(queries): | |
if last_row != query[3] and len(buffer) > 0: | |
predicted.append("-".join(buffer)) | |
buffer = [] | |
rel = results[qi] | |
buffer.append(rel) | |
last_row = query[3] | |
# relations.append(infer_relations(row["template"], head, tail)) | |
if len(buffer) > 0: | |
predicted.append("-".join(buffer)) | |
df[model_key] = predicted | |
return df | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-m", "--model", default="google/flan-t5-xl") | |
parser.add_argument("-s", "--split", default="test") | |
parser.add_argument("-n", "--names", action="store_true", default=False) | |
parser.add_argument("-d", "--debug", action="store_true", default=False) | |
args = parser.parse_args() | |
print("Loading model ...") | |
model, tokenizer = setup_model(args.model) | |
# Running the extraction on the full dataset | |
df = pd.read_csv( | |
f"https://raw.githubusercontent.com/facebookresearch/clutrr/develop/clutrr/templates/amt/{args.split}.csv" | |
) | |
if args.debug: | |
print("In debug mode, testing 2 examples") | |
df = df.iloc[:2] | |
print("Extracting relations...") | |
print(f"In {args.split} set.") | |
df = predict_relations(args.model, df, model, tokenizer, args.names) | |
extra_markers = [] | |
if not isOpenAIAPI(args.model): | |
extra_markers.append("fp32") | |
if args.names: | |
extra_markers.append("named") | |
if args.debug: | |
extra_markers.append("debug") | |
df.to_csv( | |
f"clutrr_amt_{args.split}_{'_'.join(extra_markers)}_{args.model.split('/')[-1]}.csv" | |
) | |
print("Done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment