Created
July 13, 2019 03:09
-
-
Save csJd/8dd1c74bacb9db2bc88adfb6e9cf2182 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from eval import gen_sentence_tensors | |
from dataset import ExhaustiveDataset | |
from utils.path_util import from_project_root | |
def gen_new_records(records): | |
""" | |
Calculate which layer the record is in | |
Args: | |
records: dict of records in format {(start, end): type,} | |
rec: the record to calculate layer | |
Returns: | |
new_records in format [[start, end, type, layer],] | |
""" | |
new_records = list() | |
for rec in records: | |
layer = 1 | |
for rect in records: | |
if rect != rec and rect[0] <= rec[0] <= rect[1]: | |
layer = 2 | |
break | |
new_records.append([rec[0], rec[1], records[rec], layer]) | |
new_records = list(sorted(new_records, key=lambda x: (x[0], -x[1]))) | |
return new_records | |
def predict(model, sentences, labels, data_url): | |
""" predict NER result for sentence list | |
Args: | |
model: trained model | |
sentences: sentences to be predicted | |
data_url: data url to locate vocab file | |
Returns: | |
predicted results | |
""" | |
max_region = model.max_region | |
device = next(model.parameters()).device | |
tensors = gen_sentence_tensors( | |
sentences, device, data_url) | |
pred_regions_list = torch.argmax(model.forward(*tensors), dim=1).cpu() | |
lengths = tensors[1] | |
pred_sentence_records = [] | |
for pred_regions, length in zip(pred_regions_list, lengths): | |
pred_records = {} | |
ind = 0 | |
for region_size in range(1, max_region + 1): | |
for start in range(0, lengths[0] - region_size + 1): | |
if 0 < pred_regions[ind] < len(labels): | |
pred_records[(start, start + region_size)] = \ | |
labels[pred_regions[ind]] | |
ind += 1 | |
pred_sentence_records.append(pred_records) | |
return pred_sentence_records | |
def records_to_tags(records, length, max_layer=2): | |
""" transform entity records into per token tags | |
Args: | |
records (dict): records in dict format, {(start, end): type} | |
length (int): sequence length | |
max_layer (int, optional): max nested layer, defaults to 2. | |
Returns: | |
list(str): tags for per token | |
""" | |
n_records = len(records) | |
records = gen_new_records(records) | |
tags = [''] * length | |
for k in range(1, max_layer+1): | |
rec_ind = 0 | |
in_tag = False | |
for i in range(length): | |
# get next record in k-th layer | |
while rec_ind < n_records and records[rec_ind][3] != k: | |
rec_ind += 1 | |
if rec_ind >= n_records or i < records[rec_ind][0]: | |
tag = 'O' | |
in_tag = False | |
else: | |
rec = records[rec_ind] | |
tag = ('I-' if in_tag else 'B-') + rec[2] | |
in_tag = True | |
if i >= rec[1] - 1: | |
rec_ind += 1 | |
in_tag = False | |
tags[i] += '\t' + tag | |
return tags | |
def predict_to_germ(model, iob_url): | |
""" predict on iob2 file and save the results | |
Args: | |
model: trained model | |
iob_url: url to iob file | |
""" | |
save_url = iob_url.replace('.iob2', '.germ.txt') | |
print("predicting on {} \n the result will be saved in {}".format( | |
iob_url, save_url)) | |
test_set = ExhaustiveDataset(iob_url, device=next( | |
model.parameters()).device) | |
model.eval() | |
with open(save_url, 'w', encoding='utf-8', newline='\n') as save_file: | |
for sentence, true_records in test_set: | |
length = len(sentence) | |
pred_records = predict(model, [sentence], test_set.label_list, iob_url)[0] | |
pred_tags = records_to_tags(pred_records, length) | |
true_tags = records_to_tags(true_records, length) | |
for i, word in enumerate(sentence): | |
save_file.write("{}\t{}{}{}\n".format(i+1, word, pred_tags[i], true_tags[i])) | |
save_file.write('\n') | |
def main(): | |
model_url = from_project_root("data/model/germeval_exhaustive_model_epoch8_0.665445.pt") | |
print("loading model from", model_url) | |
model = torch.load(model_url) | |
test_url = from_project_root("data/germeval/test.iob2") | |
predict_to_germ(model, test_url) | |
pass | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment