Skip to content

Instantly share code, notes, and snippets.

@csJd
Created July 13, 2019 03:09
Show Gist options
  • Save csJd/8dd1c74bacb9db2bc88adfb6e9cf2182 to your computer and use it in GitHub Desktop.
Save csJd/8dd1c74bacb9db2bc88adfb6e9cf2182 to your computer and use it in GitHub Desktop.
import torch
from eval import gen_sentence_tensors
from dataset import ExhaustiveDataset
from utils.path_util import from_project_root
def gen_new_records(records):
"""
Calculate which layer the record is in
Args:
records: dict of records in format {(start, end): type,}
rec: the record to calculate layer
Returns:
new_records in format [[start, end, type, layer],]
"""
new_records = list()
for rec in records:
layer = 1
for rect in records:
if rect != rec and rect[0] <= rec[0] <= rect[1]:
layer = 2
break
new_records.append([rec[0], rec[1], records[rec], layer])
new_records = list(sorted(new_records, key=lambda x: (x[0], -x[1])))
return new_records
def predict(model, sentences, labels, data_url):
""" predict NER result for sentence list
Args:
model: trained model
sentences: sentences to be predicted
data_url: data url to locate vocab file
Returns:
predicted results
"""
max_region = model.max_region
device = next(model.parameters()).device
tensors = gen_sentence_tensors(
sentences, device, data_url)
pred_regions_list = torch.argmax(model.forward(*tensors), dim=1).cpu()
lengths = tensors[1]
pred_sentence_records = []
for pred_regions, length in zip(pred_regions_list, lengths):
pred_records = {}
ind = 0
for region_size in range(1, max_region + 1):
for start in range(0, lengths[0] - region_size + 1):
if 0 < pred_regions[ind] < len(labels):
pred_records[(start, start + region_size)] = \
labels[pred_regions[ind]]
ind += 1
pred_sentence_records.append(pred_records)
return pred_sentence_records
def records_to_tags(records, length, max_layer=2):
""" transform entity records into per token tags
Args:
records (dict): records in dict format, {(start, end): type}
length (int): sequence length
max_layer (int, optional): max nested layer, defaults to 2.
Returns:
list(str): tags for per token
"""
n_records = len(records)
records = gen_new_records(records)
tags = [''] * length
for k in range(1, max_layer+1):
rec_ind = 0
in_tag = False
for i in range(length):
# get next record in k-th layer
while rec_ind < n_records and records[rec_ind][3] != k:
rec_ind += 1
if rec_ind >= n_records or i < records[rec_ind][0]:
tag = 'O'
in_tag = False
else:
rec = records[rec_ind]
tag = ('I-' if in_tag else 'B-') + rec[2]
in_tag = True
if i >= rec[1] - 1:
rec_ind += 1
in_tag = False
tags[i] += '\t' + tag
return tags
def predict_to_germ(model, iob_url):
""" predict on iob2 file and save the results
Args:
model: trained model
iob_url: url to iob file
"""
save_url = iob_url.replace('.iob2', '.germ.txt')
print("predicting on {} \n the result will be saved in {}".format(
iob_url, save_url))
test_set = ExhaustiveDataset(iob_url, device=next(
model.parameters()).device)
model.eval()
with open(save_url, 'w', encoding='utf-8', newline='\n') as save_file:
for sentence, true_records in test_set:
length = len(sentence)
pred_records = predict(model, [sentence], test_set.label_list, iob_url)[0]
pred_tags = records_to_tags(pred_records, length)
true_tags = records_to_tags(true_records, length)
for i, word in enumerate(sentence):
save_file.write("{}\t{}{}{}\n".format(i+1, word, pred_tags[i], true_tags[i]))
save_file.write('\n')
def main():
model_url = from_project_root("data/model/germeval_exhaustive_model_epoch8_0.665445.pt")
print("loading model from", model_url)
model = torch.load(model_url)
test_url = from_project_root("data/germeval/test.iob2")
predict_to_germ(model, test_url)
pass
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment