-
-
Save jcklie/bba26c8807c616ffbc1f17f2d7d687dd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
import torch | |
from transformers import AutoModelForTokenClassification, AutoTokenizer | |
model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english") | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") | |
label_list = [ | |
"O", # Outside of a named entity | |
"B-MISC", # Beginning of a miscellaneous entity right after another miscellaneous entity | |
"I-MISC", # Miscellaneous entity | |
"B-PER", # Beginning of a person's name right after another person's name | |
"I-PER", # Person's name | |
"B-ORG", # Beginning of an organisation right after another organisation | |
"I-ORG", # Organisation | |
"B-LOC", # Beginning of a location right after another location | |
"I-LOC" # Location | |
] | |
text = "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very " \ | |
"close to the Manhattan Bridge." | |
my_tokens = text.split() | |
# Tokenize for transformers | |
grouped_inputs = [torch.LongTensor([tokenizer.cls_token_id])] | |
subtokens_per_token = [] | |
for token in my_tokens: | |
tokens = tokenizer.encode( | |
token, | |
return_tensors="pt", | |
add_special_tokens=False, | |
).squeeze(axis=0) | |
grouped_inputs.append(tokens) | |
subtokens_per_token.append(len(tokens)) | |
grouped_inputs.append(torch.LongTensor([tokenizer.sep_token_id])) | |
flattened_inputs = torch.cat(grouped_inputs) | |
flattened_inputs = torch.unsqueeze(flattened_inputs, 0) | |
# Predict | |
predictions_tensor = model(flattened_inputs)[0] | |
predictions_tensor = torch.argmax(predictions_tensor, dim=2)[0] | |
predictions = [label_list[prediction] for prediction in predictions_tensor] | |
# Align tokens | |
# Remove special tokens [CLS] and [SEP] | |
predictions = predictions[1:-1] | |
aligned_predictions = [] | |
assert len(predictions) == sum(subtokens_per_token) | |
ptr = 0 | |
for size in subtokens_per_token: | |
group = predictions[ptr:ptr + size] | |
assert len(group) == size | |
aligned_predictions.append(group) | |
ptr += size | |
assert len(my_tokens) == len(aligned_predictions) | |
for token, prediction_group in zip(my_tokens, aligned_predictions): | |
print("{0:>12}\t{1:>5}\t{2}".format(token, Counter(prediction_group).most_common(1)[0][0], prediction_group)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment