Skip to content

Instantly share code, notes, and snippets.

Last active September 27, 2020 19:47
Show Gist options
  • Save jcklie/bba26c8807c616ffbc1f17f2d7d687dd to your computer and use it in GitHub Desktop.
Save jcklie/bba26c8807c616ffbc1f17f2d7d687dd to your computer and use it in GitHub Desktop.
from collections import Counter
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
label_list = [
"O", # Outside of a named entity
"B-MISC", # Beginning of a miscellaneous entity right after another miscellaneous entity
"I-MISC", # Miscellaneous entity
"B-PER", # Beginning of a person's name right after another person's name
"I-PER", # Person's name
"B-ORG", # Beginning of an organisation right after another organisation
"I-ORG", # Organisation
"B-LOC", # Beginning of a location right after another location
"I-LOC" # Location
text = "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very " \
"close to the Manhattan Bridge."
my_tokens = text.split()
# Tokenize for transformers
grouped_inputs = [torch.LongTensor([tokenizer.cls_token_id])]
subtokens_per_token = []
for token in my_tokens:
tokens = tokenizer.encode(
flattened_inputs =
flattened_inputs = torch.unsqueeze(flattened_inputs, 0)
# Predict
predictions_tensor = model(flattened_inputs)[0]
predictions_tensor = torch.argmax(predictions_tensor, dim=2)[0]
predictions = [label_list[prediction] for prediction in predictions_tensor]
# Align tokens
# Remove special tokens [CLS] and [SEP]
predictions = predictions[1:-1]
aligned_predictions = []
assert len(predictions) == sum(subtokens_per_token)
ptr = 0
for size in subtokens_per_token:
group = predictions[ptr:ptr + size]
assert len(group) == size
ptr += size
assert len(my_tokens) == len(aligned_predictions)
for token, prediction_group in zip(my_tokens, aligned_predictions):
print("{0:>12}\t{1:>5}\t{2}".format(token, Counter(prediction_group).most_common(1)[0][0], prediction_group))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment