-
-
Save VictorSanh/6cfce8bad8a80d3ba1cd1c95aba2216d to your computer and use it in GitHub Desktop.
Clean format (spaces)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import glob | |
import xml.etree.ElementTree as ET | |
import sys, os, re | |
import copy | |
import spacy | |
from typing import Dict, List, Tuple | |
nlp = spacy.load("en_core_web_sm") | |
path_to_ace2005 = sys.argv[1] | |
saving_path = sys.argv[2] | |
class MentionContainer: | |
def __init__(self, | |
container_type, | |
storing, | |
starts, | |
ends, | |
duplicate_mapping = None): | |
assert container_type in ["whole", "head"] | |
self.container_type = container_type | |
self.storing = storing | |
self.starts = starts | |
self.ends = ends | |
#if we are filtering out the duplicates | |
#mapping from the id of the head duplicated we are trying to add to | |
#the original one | |
self.duplicate_mapping = duplicate_mapping | |
class CoreferenceContainer: | |
def __init__(self, | |
coref_dic): | |
self.coref_dic = coref_dic | |
def extract_relation(apf_root, | |
head_container): | |
heads = head_container.storing | |
head_map = head_container.duplicate_mapping | |
#Storing the relations | |
relations: Dict[str, List[str]] = {} | |
#Don't remember the duplicated relations | |
check_relations: List[List[str, str, str]] = [] | |
for relation in apf_root.iter('relation'): | |
relation_type: str = relation.attrib["TYPE"] | |
for mention in relation.iter('relation_mention'): | |
#There can be definitely more than 1 relation_mention in a relation | |
relation_id: str = mention.attrib["ID"] | |
relation_list: check_relations = [relation_id, relation_type, "", ""] | |
ignore:bool = False | |
for arg in mention.iter('relation_mention_argument'): | |
#There are only 2 arguments with ROLE = ARG1 and ARG2 | |
arg_id: str = arg.attrib["REFID"] | |
if arg.attrib["ROLE"] != "Arg-1" and arg.attrib["ROLE"] != "Arg-2": | |
continue | |
#Arg id should be the original head id and not the duplicated one | |
if arg_id in head_map: arg_id: str = head_map[arg_id] | |
relation_list[int(arg.attrib["ROLE"][-1])+1] = arg_id | |
if not arg_id in heads: ignore: bool = True # ignored duplicated entity | |
if ignore: | |
sys.stderr.write("ignored relation %s\n" % (relation_id)) | |
continue | |
if relation_list[1:] in check_relations: | |
sys.stderr.write("duplicated relation %s\n" % (relation_id)) | |
continue | |
check_relations.append(relation_list[1:]) | |
relations[relation_id] = relation_list | |
return relations | |
def extract_mentions_and_relations_and_coref(apf_xml_file_path: str): | |
""" | |
Extract Mentions (heads and whole mention), relations between heads and coreference clusters | |
""" | |
apf_tree = ET.parse(apf_xml_file_path) | |
apf_root = apf_tree.getroot() | |
#Store the heads and the mentions | |
heads: Dict[str, List] = {} | |
mentions: Dict[str, List] = {} | |
#Store the heads and mentions and the mapping to the ID to the original one | |
#to avoid duplicated heads or mentions with different IDs | |
check_heads: Dict[Tuple[str, int, int, str], str] = {} | |
check_mentions: Dict[Tuple[str, int, int, str], str] = {} | |
#If duplicate heads: mapping the duplicate id to the original head id | |
head_map: Dict[str, str] = {} | |
mention_map: Dict[str, str] = {} | |
#Index of the starts and ends (character index) of mentions and heads | |
mention_starts: Dict[int, List[str]] = {} | |
mention_ends: Dict[int, List[str]] = {} | |
head_starts: Dict[int, List[str]] = {} | |
head_ends: Dict[int, List[str]] = {} | |
#Storing the coreference clusters | |
coref_clusters: Dict[str, List[str]] = {} | |
for entity in apf_root.iter('entity'): | |
entity_type:str = entity.attrib["TYPE"] #Type of the entity | |
entity_id:str = entity.attrib["ID"] #Id of the entity (for coreference) | |
for mention in entity.iter('entity_mention'): | |
mention_id:str = mention.attrib["ID"] | |
for child in mention: | |
#An entity_mention always has 2 children: one head and one extent | |
if child.tag == "head": | |
for charseq in child: #always 1 child: charseq | |
head_start: int = int(charseq.attrib["START"]) | |
head_end: int = int(charseq.attrib["END"])+1 | |
head_text: str = re.sub(r"\n", r" ", charseq.text) | |
if child.tag == "extent": | |
for charseq in child: #always 1 child: charseq | |
mention_start: int = int(charseq.attrib["START"]) | |
mention_end: int = int(charseq.attrib["END"])+1 | |
mention_text: str = re.sub(r"\n", r" ", charseq.text) | |
#Storing the mention | |
double_mention:bool = False | |
mention_tuple: Tuple[str, int, int, str] = (entity_type, mention_start, mention_end, mention_text) | |
if mention_tuple in check_mentions: | |
sys.stderr.write("duplicated mention entity %s\n" % (mention_id)) | |
mention_map[mention_id] = check_mentions[mention_tuple] | |
double_mention: bool = True | |
else: | |
check_mentions[mention_tuple] = mention_id | |
mentions[mention_id] = list(mention_tuple) | |
if not double_mention: | |
if not mention_start in mention_starts: | |
mention_starts[mention_start] = [] | |
mention_starts[mention_start].append(mention_id) | |
if not mention_end in mention_ends: | |
mention_ends[mention_end] = [] | |
mention_ends[mention_end].append(mention_id) | |
#Store the coreference link if not double_metion | |
if entity_id not in coref_clusters: coref_clusters[entity_id] = [] | |
coref_clusters[entity_id].append(mention_id) | |
#Store the head | |
double_head = False | |
head_tuple: Tuple[str, int, int, str] = (entity_type, head_start, head_end, head_text) | |
if (head_start, head_end, head_text) in check_heads: | |
#some cases where same (head_start, head_end, head_text) has two different entity_type -> | |
#so not checking on the head_tuple but the (head_start, head_end, head_text) | |
#it is not possible when considering the whole mention.... | |
sys.stderr.write("duplicated head entity %s\n" % (mention_id)) | |
head_map[mention_id] = check_heads[(head_start, head_end, head_text)] | |
double_head: bool = True | |
else: | |
check_heads[(head_start, head_end, head_text)] = mention_id | |
heads[mention_id] = list(head_tuple) | |
if not double_head: | |
if not head_start in head_starts: | |
head_starts[head_start] = [] | |
head_starts[head_start].append(mention_id) | |
if not head_end in head_ends: | |
head_ends[head_end] = [] | |
head_ends[head_end].append(mention_id) | |
#everything related to mentions | |
mention_container = MentionContainer("whole", mentions, mention_starts, mention_ends, mention_map) | |
#everything related to heads | |
head_container = MentionContainer("head", heads, head_starts, head_ends, head_map) | |
#Extract relation | |
relations = extract_relation(apf_root, head_container) | |
return mention_container, head_container, relations, coref_clusters | |
def lstrip_doc(doc: str, | |
mention_container, | |
head_container): | |
""" | |
Left strip the document and update the heads and mentions accordingly. | |
""" | |
offset_lstrip: int = len(doc) - len(doc.lstrip()) | |
doc = doc.strip() | |
#Whole Mention | |
for _, mention in mention_container.storing.items(): | |
mention[1] -= offset_lstrip | |
mention[2] -= offset_lstrip | |
#Head | |
for _, head in head_container.storing.items(): | |
head[1] -= offset_lstrip | |
head[2] -= offset_lstrip | |
mention_starts_offset: Dict[int, List[str]] = {} | |
mention_ends_offset: Dict[int, List[str]] = {} | |
head_starts_offset: Dict[int, List[str]] = {} | |
head_ends_offset: Dict[int, List[str]] = {} | |
for key, value in mention_container.starts.items(): mention_starts_offset[key - offset_lstrip] = value | |
for key, value in mention_container.ends.items(): mention_ends_offset[key - offset_lstrip] = value | |
for key, value in head_container.starts.items(): head_starts_offset[key - offset_lstrip] = value | |
for key, value in head_container.ends.items(): head_ends_offset[key - offset_lstrip] = value | |
#Update the containers | |
mention_container.starts, mention_container.ends = mention_starts_offset, mention_ends_offset | |
head_container.starts, head_container.ends = head_starts_offset, head_ends_offset | |
return doc, mention_container, head_container | |
def offset_mentions(offset_to_apply, flag, mentions, init_mentions): | |
for key, mention in init_mentions.items(): | |
if mention[1] >= flag: mentions[key][1] += offset_to_apply | |
if mention[2] >= flag: mentions[key][2] += offset_to_apply | |
def regulate_offset_with_mentions(doc: str, | |
mention_container, | |
head_container): | |
""" | |
Regulate offset with mentions. Especially if head is part of a bigger token -> split | |
into two tokens so that we can only tag the head (and not the other part). | |
""" | |
offset: int = 0 | |
size: int = len(doc) | |
current:int = 0 | |
regions: List[str] = [] | |
mention_starts_regu: Dict[int, List[str]] = {} | |
mention_ends_regu: Dict[int, List[str]] = {} | |
head_starts_regu: Dict[int, List[str]] = {} | |
head_ends_regu: Dict[int, List[str]] = {} | |
heads = head_container.storing | |
mentions = mention_container.storing | |
init_mentions = copy.deepcopy(mentions) | |
mention_map = mention_container.duplicate_mapping | |
for i in range(size): | |
if i in head_container.starts or i in head_container.ends: | |
inc = 0 | |
if (doc[i-1] != " " and doc[i-1] != "\n") and (doc[i] != " " and doc[i] != "\n"): | |
regions.append(doc[current:i]) | |
inc = 1 | |
current = i | |
if i in head_container.starts: | |
for mention_id in head_container.starts[i]: | |
#initial = heads[mention_id][1] | |
heads[mention_id][1] += offset + inc | |
if not heads[mention_id][1] in head_starts_regu: | |
head_starts_regu[heads[mention_id][1]] = [] | |
head_starts_regu[heads[mention_id][1]].append(mention_id) | |
if i in head_container.ends: | |
for mention_id in head_container.ends[i]: | |
#initial = heads[mention_id][2] | |
heads[mention_id][2] += offset | |
if not heads[mention_id][2] in head_ends_regu: | |
head_ends_regu[heads[mention_id][2]] = [] | |
head_ends_regu[heads[mention_id][2]].append(mention_id) | |
#Offsets the mentions | |
if inc > 0: offset_mentions(1, i, mentions, init_mentions) | |
offset += inc | |
regions.append(doc[current:]) | |
doc = " ".join(regions) | |
#Update | |
head_container.starts, head_container.ends = head_starts_regu, head_ends_regu | |
for mention_id, mention in mentions.items(): | |
start, end = mention[1:3] | |
if start not in mention_starts_regu: mention_starts_regu[start] = [] | |
mention_starts_regu[start].append(mention_id) | |
if end not in mention_ends_regu: mention_ends_regu[end] = [] | |
mention_ends_regu[end].append(mention_id) | |
mention_container.starts, mention_container.ends = mention_starts_regu, mention_ends_regu | |
doc, mention_container, head_container = lstrip_doc(doc, mention_container, head_container) | |
return doc, mention_container, head_container | |
def read_doc_and_replace(sgm_file_name): | |
""" | |
Read the doc and perform basic subs. | |
""" | |
doc = open(sgm_file_name).read() | |
doc = re.sub(r"<[^>]+>", "", doc) | |
doc = re.sub(r"(\S+)\n(\S[^:])", r"\1 \2", doc) | |
return doc | |
def filter_overlapping_head_relations(relations: Dict[str, List[str]], | |
head_container): | |
""" | |
Filterer out the overlappings heads in relations, and more specifically the heads that are in relation | |
with themselves.... There are only a few cases (<10) like that to filter out. | |
""" | |
filtered_relations: Dict[str, List[str]] = {} | |
for relation_id, relation in relations.items(): | |
arg1 = head_container.storing[relation[2]] | |
arg2 = head_container.storing[relation[3]] | |
#end of arg1 is not overlapping with start of arg2 | |
verslavant = arg2[1] > arg1[2] | |
#end of arg2 is not overlapping with start of arg1 | |
verslarriere = arg1[1] > arg2[2] | |
#check we don't have arg1 == arg2 | |
exact_match = (arg1[1], arg1[2]) == (arg2[1], arg2[2]) | |
#Want to have non overlapping heads relation and non exact mathc heads | |
if (verslavant or verslarriere) and not exact_match: | |
filtered_relations[relation_id] = relation | |
return filtered_relations | |
def filter_out_uni_mention_cluster(coref_clusters): | |
""" | |
Filter out the clusters in which there are only ONE single mention of the entity. | |
The remaining clusters are thus real clusters with at least two mentions referring | |
to the same real life entity. | |
""" | |
filtered_coref_clusters = {} | |
for entity_id, mention_list in coref_clusters.items(): | |
if len(mention_list) > 1: | |
filtered_coref_clusters[entity_id] = mention_list | |
return filtered_coref_clusters | |
for sgm_file_name in glob.iglob(path_to_ace2005 + 'data/English/*/timex2norm/*.sgm', recursive=True): | |
apf_xml_file_name = sgm_file_name.replace(".sgm", ".apf.xml") | |
##### Extraction #### | |
mention_container, head_container, relations, coref_clusters = extract_mentions_and_relations_and_coref(apf_xml_file_name) | |
doc = read_doc_and_replace(sgm_file_name) | |
doc, mention_container, head_container = regulate_offset_with_mentions(doc, mention_container, head_container) | |
doc = doc.replace("\n", " ") | |
#Filter the relations | |
relations = filter_overlapping_head_relations(relations, head_container) | |
ids_arg1_of_relations, ids_arg2_of_relations = {}, {} #Dict[mention)id:str, List[relation_id:str]] | |
for relation_id, relation in relations.items(): | |
if relation[2] not in ids_arg1_of_relations: ids_arg1_of_relations[relation[2]] = [] | |
ids_arg1_of_relations[relation[2]].append(relation_id) | |
if relation[3] not in ids_arg2_of_relations: ids_arg2_of_relations[relation[3]] = [] | |
ids_arg2_of_relations[relation[3]].append(relation_id) | |
#Filter out the uni-mention clusters | |
coref_clusters = filter_out_uni_mention_cluster(coref_clusters) | |
#Map mention_ids that are coreferent (i.e. are not part of a uni-mention cluster) to entity_ids | |
#Also create a list of the mention_ids which are part of a real cluster. | |
mention_id_2_entity_id: Dict[str, str] = {} | |
coreferent_mention_ids = [] | |
for entity_id, mention_id_list in coref_clusters.items(): | |
for mention_id in mention_id_list: | |
mention_id_2_entity_id[mention_id] = entity_id | |
coreferent_mention_ids.append(mention_id) | |
doc = doc.replace("&", "ZZZZZ")#quick fix on spacy tokenizer customed | |
document = nlp(doc) | |
##################################################### | |
idx = 0 | |
# EMD stuffs | |
emd_label_stack: List[str] = [] | |
emd_dump: List[str] = [] | |
# Relation stuffs | |
max_col: int = 0 | |
# Label stack for arg1s and arg2s. | |
# Dict[mention_id:str, Dict[relation_id:str, List[str]]] | |
arg1_labels: Dict[str, Dict[str, List[str]]] = {} | |
arg2_labels: Dict[str, Dict[str, List[str]]] = {} | |
# On which columm to write the relations (one column per relation if there are multiple relations in one sentence). | |
# It is also a zay to track the current relations in the stack. | |
# Dict[relation_id:str, Dict[relation_id:str, List[str]]] | |
relation_id_2_column: Dict[str, Dict[str, List[str]]] = {} | |
relation_dump = [] | |
head_starts = head_container.starts | |
heads = head_container.storing | |
mention_starts = mention_container.starts | |
mentions = mention_container.storing | |
#### Writing #### | |
save_file_name_rel = saving_path + os.path.basename(sgm_file_name) + ".like_conll" | |
save_file_name_coref = saving_path + os.path.basename(sgm_file_name) + ".coref" | |
document_rel_dump = [] | |
document_coref_dump = [] | |
#Emd and Relation | |
for token in document: | |
# Collect labels for EMD and Relation since they are both on the head level | |
if token.idx in head_starts: | |
mention_id = head_starts[token.idx][0] | |
head = heads[mention_id] | |
emd_label, token_stack = head[0], head[3] | |
token_stack = [token.text for token in nlp(token_stack)] | |
# EMD stuffs################# | |
# Fill the EMD label stacks | |
if len(token_stack) == 1: | |
emd_label_stack = ["(%s)" % emd_label] | |
else: | |
emd_label_stack = ["(%s*" % emd_label] + (len(token_stack)-2)*["*"] + ["*)"] | |
############################# | |
#Relation Thing####################################################### | |
#Fill the label stacks (arg1_labels and arg2_labels) | |
if (mention_id in ids_arg1_of_relations) or (mention_id in ids_arg2_of_relations): | |
# WARNING a mention can be both arg1 and arg2 in two different relations | |
if (mention_id in ids_arg1_of_relations): | |
#print("Entered a related mention ARG1", mention_id) | |
arg1_labels[mention_id] = {} | |
for relation_id in ids_arg1_of_relations[mention_id]: | |
_, relation_type, _, _ = relations[relation_id] | |
if len(token_stack) == 1: | |
arg1_label_stack = ["(ARG1_%s*)" % relation_type] | |
else: | |
arg1_label_stack = ["(ARG1_%s*" % relation_type] + (len(token_stack)-2)*["*"] + ["*)"] | |
arg1_labels[mention_id][relation_id] = arg1_label_stack | |
# Assign a column to relation_id if it does not have one yet. | |
if relation_id not in relation_id_2_column: | |
relation_id_2_column[relation_id] = {"col": max_col, "done": set()} | |
max_col += 1 | |
#print(relation_id_2_column) | |
relation_id_2_column[relation_id]["current_arg"] = "arg1" | |
relation_id_2_column[relation_id]["current_id"] = mention_id | |
if (mention_id in ids_arg2_of_relations): | |
#print("Entered a related mention ARG2", mention_id) | |
arg2_labels[mention_id] = {} | |
for relation_id in ids_arg2_of_relations[mention_id]: | |
_, relation_type, _, _ = relations[relation_id] | |
if len(token_stack) == 1: | |
arg2_label_stack = ["(ARG2_%s*)" % relation_type] | |
else: | |
arg2_label_stack = ["(ARG2_%s*" % relation_type] + (len(token_stack)-2)*["*"] + ["*)"] | |
arg2_labels[mention_id][relation_id] = arg2_label_stack | |
if relation_id not in relation_id_2_column: | |
relation_id_2_column[relation_id] = {"col": max_col, "done": set()} | |
max_col += 1 | |
#print(relation_id_2_column) | |
relation_id_2_column[relation_id]["current_arg"] = "arg2" | |
relation_id_2_column[relation_id]["current_id"] = mention_id | |
################################################################# | |
# EMD stuff ######################### | |
# Unstack EMD label for current token | |
if emd_label_stack != []: | |
labelled_token = token_stack[0] | |
emd_label = emd_label_stack[0] | |
token_stack = token_stack[1:] | |
emd_label_stack = emd_label_stack[1:] | |
else: | |
labelled_token = token.text | |
emd_label = "*" | |
##################################### | |
# Relation stuff ############################## | |
# Unstack relation label | |
relation_labels = ["*"]*max_col | |
for relation_id, rela_dump_info in relation_id_2_column.items(): | |
current_arg = rela_dump_info["current_arg"] | |
current_id = rela_dump_info["current_id"] | |
if current_arg is not None: | |
col = rela_dump_info["col"] | |
if current_arg == "arg1": | |
rel_label = arg1_labels[current_id][relation_id][0] | |
arg1_labels[mention_id][relation_id] = arg1_labels[mention_id][relation_id][1:] | |
if arg1_labels[mention_id][relation_id] == []: | |
rela_dump_info["current_arg"] = None | |
rela_dump_info["done"].add(current_arg) | |
#print(rela_dump_info["done"]) | |
relation_labels[col] = rel_label | |
if current_arg == "arg2": | |
rel_label = arg2_labels[current_id][relation_id][0] | |
arg2_labels[mention_id][relation_id] = arg2_labels[mention_id][relation_id][1:] | |
if arg2_labels[mention_id][relation_id] == []: | |
rela_dump_info["current_arg"] = None | |
rela_dump_info["done"].add(current_arg) | |
#print(rela_dump_info["done"]) | |
relation_labels[col] = rel_label | |
keys_to_delete = [] | |
for relation_id, rela_dump_info in relation_id_2_column.items(): | |
if len(rela_dump_info["done"])==2: keys_to_delete.append(relation_id) | |
for key in keys_to_delete: del relation_id_2_column[key] | |
################################################ | |
labelled_token.replace("ZZZZZ", "&") #quick fix on spacy tokenizer customed | |
if token.is_space: continue | |
emd_dump.append("{}\t{}\t{}".format(idx, labelled_token, emd_label)) | |
relation_dump.append(relation_labels) | |
idx += 1 | |
#Dump EMD and Relation | |
if (token.text in [".", "?", "!"]) and (not relation_id_2_column): | |
if max_col > 0: | |
padded_relation_dump = [a + ["*"] * (max_col - len(a)) for a in relation_dump] | |
for i in range(len(padded_relation_dump)): | |
emd_dump[i] = emd_dump[i] + "\t" + "\t".join(padded_relation_dump[i]) | |
for i in range(len(emd_dump)): | |
emd_dump[i] += "\t-" | |
emd_dump[-1] += '\n' | |
document_rel_dump.extend(emd_dump) | |
idx = 0 | |
emd_dump = [] | |
relation_dump, max_col = [], 0 | |
document_rel_dump.append("#end document") | |
with open(save_file_name_rel, "w") as file: | |
file.write("#begin document ({});\n".format(os.path.basename(sgm_file_name))) | |
file.write("\n".join(document_rel_dump)) | |
################################# | |
idx = 0 | |
# EMD stuffs | |
emd_label_stack: List[str] = [] | |
emd_dump: List[str] = [] | |
# Coref stuffs | |
coref_label_stack: List[str] = [] | |
coref_dump: List[str] = [] | |
entity_id_2_idx: Dict[str, int] = {} | |
entity_idx_count = 0 | |
coref_dump = [] | |
for token in document: | |
# Collect labels for EMD and Relation since they are both on the head level | |
if token.idx in head_starts: | |
mention_id = head_starts[token.idx][0] | |
head = heads[mention_id] | |
emd_label, token_stack = head[0], head[3] | |
token_stack = [token.text for token in nlp(token_stack)] | |
# EMD stuffs################# | |
# Fill the EMD label stacks | |
if len(token_stack) == 1: | |
emd_label_stack = ["(%s)" % emd_label] | |
else: | |
emd_label_stack = ["(%s*" % emd_label] + (len(token_stack)-2)*["*"] + ["*)"] | |
############################# | |
# Collect coref label since it is on the mention level (and not the head level) | |
if token.idx in mention_starts: | |
mentions_starting_here = mention_starts[token.idx] | |
for coref_mention_id in mentions_starting_here: | |
#First check that this mention_id is really part of non-uni-mention cluster | |
if coref_mention_id not in coreferent_mention_ids: | |
continue | |
start, end = mentions[coref_mention_id][1:3] | |
coref_mention_text = doc[start:end] | |
len_mention = len(nlp(coref_mention_text)) | |
entity_id = mention_id_2_entity_id[coref_mention_id] | |
if entity_id not in entity_id_2_idx: | |
entity_idx_count += 1 | |
entity_id_2_idx[entity_id] = entity_idx_count | |
entity_idx: int = entity_id_2_idx[entity_id] | |
if len_mention == 1: | |
coref_labels = ["(%d)" % entity_idx] | |
else: | |
coref_labels = ["(%d" % entity_idx] + (len_mention-2)*["-"] + ["%d)" % entity_idx] | |
#Fusion coref_labels with coref_label_stack | |
if len(coref_label_stack) < len(coref_labels): | |
coref_label_stack = coref_label_stack + ["-"] * (len_mention - len(coref_label_stack)) | |
for i, current_coref_label in enumerate(coref_label_stack): | |
if current_coref_label == "-": | |
coref_label_stack[i] = coref_labels[i] | |
elif "(" in current_coref_label and ")" in current_coref_label: | |
if "(" in coref_labels[i]: | |
coref_label_stack[i] = coref_labels[i] + "|" + coref_label_stack[i] | |
elif ")" in coref_labels[i]: | |
coref_label_stack[i] = coref_label_stack[i] + "|" + coref_labels[i] | |
elif "-" == coref_labels[i]: | |
pass | |
elif "(" in current_coref_label: | |
if coref_labels[i] != "-": | |
coref_label_stack[i] = coref_label_stack[i] + "|" + coref_labels[i] | |
elif ")" in current_coref_label: | |
if coref_labels[i] != "-": | |
coref_label_stack[i] = coref_labels[i] + "|" + coref_label_stack[i] | |
if i+1 >= len_mention: | |
break | |
# EMD stuff ######################### | |
# Unstack EMD label for current token | |
if emd_label_stack != []: | |
labelled_token = token_stack[0] | |
emd_label = emd_label_stack[0] | |
token_stack = token_stack[1:] | |
emd_label_stack = emd_label_stack[1:] | |
else: | |
labelled_token = token.text | |
emd_label = "*" | |
##################################### | |
# Coref stuff ############### | |
# unstack label for current token | |
if coref_label_stack != []: | |
coref_labels = coref_label_stack[0] | |
coref_label_stack = coref_label_stack[1:] | |
else: | |
coref_labels = "-" | |
############################ | |
labelled_token.replace("ZZZZZ", "&") #quick fix on spacy tokenizer customed | |
if token.is_space: continue | |
emd_dump.append("{}\t{}\t{}".format(idx, labelled_token, emd_label)) | |
coref_dump.append(coref_labels) | |
idx += 1 | |
#Dump EMD and Coref | |
if (token.text in [".", "?", "!"]) and (coref_label_stack == []): | |
for i in range(len(emd_dump)): | |
emd_dump[i] = emd_dump[i] + "\t" + coref_dump[i] | |
emd_dump[-1] += '\n' | |
document_coref_dump.extend(emd_dump) | |
idx = 0 | |
emd_dump = [] | |
coref_dump = [] | |
document_coref_dump.append("#end document") | |
with open(save_file_name_coref, "w") as file: | |
file.write("#begin document ({});\n".format(os.path.basename(sgm_file_name))) | |
file.write("\n".join(document_coref_dump)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment