Skip to content

Instantly share code, notes, and snippets.

@harshildarji
Created March 4, 2022 12:27
Show Gist options
  • Save harshildarji/7933ec58c266b58ca8522baac2c8b789 to your computer and use it in GitHub Desktop.
Save harshildarji/7933ec58c266b58ca8522baac2c8b789 to your computer and use it in GitHub Desktop.
Generate NER predictions
import os
import pickle
import warnings
from functools import reduce
from operator import add
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import BertForTokenClassification, logging
warnings.simplefilter(action="ignore", category=Warning)
logging.set_verbosity(logging.ERROR)
device = torch.device("cpu")
def analyze(text):
tokenized_sentence = tokenizer.encode(text)
input_ids = torch.tensor([tokenized_sentence])
with torch.no_grad():
output = model(input_ids)
label_indices = np.argmax(output[0].numpy(), axis=2)
tokens = tokenizer.convert_ids_to_tokens(input_ids.numpy()[0])
new_tokens, new_labels = [], []
for token, label_idx in zip(tokens, label_indices[0]):
if token.startswith("##"):
new_tokens[-1] = new_tokens[-1] + token[2:]
else:
new_labels.append(tag_values[label_idx])
new_tokens.append(token)
to_remove = []
for idx in range(len(new_tokens)):
if new_tokens[idx] == "." and new_labels[idx] != "O":
new_tokens[idx - 1] += "."
to_remove.append(idx)
new_tokens = [token for idx, token in enumerate(new_tokens) if idx not in to_remove]
new_labels = [label for idx, label in enumerate(new_labels) if idx not in to_remove]
return new_tokens, new_labels
def chunks(lst, n):
for i in range(0, len(lst), n):
yield lst[i : i + n]
if __name__ == "__main__":
print("[+] Reading data")
data = pd.read_csv("../metadata.csv")
tenor = data["tenor"].dropna().reset_index(drop=True).str.split("|").tolist()
tenor = reduce(add, tenor)
tenor = set(filter(None, tenor))
print("[+] Downloading model, tag_values, and tokenizer")
os.system(
"wget -q https://www.dropbox.com/s/vos8pqwmlbqe0wf/model.pt https://www.dropbox.com/s/u2oojgmmprt0a9d/tag_values.pkl https://www.dropbox.com/s/uj15pab78emefoq/tokenizer.pkl"
)
tokenizer = pickle.load(open("tokenizer.pkl", "rb"))
tag_values = pickle.load(open("tag_values.pkl", "rb"))
model = BertForTokenClassification.from_pretrained(
"bert-base-german-cased",
num_labels=len(tag_values),
output_attentions=False,
output_hidden_states=False,
)
model.load_state_dict(torch.load("model.pt", map_location=device))
print("[+] NER annotation")
conll = open(f"tenor.conll", "a+")
for datum in tqdm(tenor):
if len(datum) > 512:
tokens, labels = [], []
chunked = list(chunks(datum, 512))
for c in chunked:
ts, ls = analyze(c)
tokens.extend(ts)
labels.extend(ls)
else:
tokens, labels = analyze(datum)
for token, label in zip(tokens, labels):
if token == "[CLS]" or token == "[SEP]":
continue
line = f"{token} {label}\n"
conll.write(line)
conll.write("\n")
conll.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment