Skip to content

Instantly share code, notes, and snippets.

@VictorSanh
Last active February 25, 2019 08:50
Show Gist options
  • Save VictorSanh/6cfce8bad8a80d3ba1cd1c95aba2216d to your computer and use it in GitHub Desktop.
Save VictorSanh/6cfce8bad8a80d3ba1cd1c95aba2216d to your computer and use it in GitHub Desktop.
Clean format (spaces)
# coding: utf-8
import glob
import xml.etree.ElementTree as ET
import sys, os, re
import copy
import spacy
from typing import Dict, List, Tuple
nlp = spacy.load("en_core_web_sm")
path_to_ace2005 = sys.argv[1]
saving_path = sys.argv[2]
class MentionContainer:
def __init__(self,
container_type,
storing,
starts,
ends,
duplicate_mapping = None):
assert container_type in ["whole", "head"]
self.container_type = container_type
self.storing = storing
self.starts = starts
self.ends = ends
#if we are filtering out the duplicates
#mapping from the id of the head duplicated we are trying to add to
#the original one
self.duplicate_mapping = duplicate_mapping
class CoreferenceContainer:
def __init__(self,
coref_dic):
self.coref_dic = coref_dic
def extract_relation(apf_root,
head_container):
heads = head_container.storing
head_map = head_container.duplicate_mapping
#Storing the relations
relations: Dict[str, List[str]] = {}
#Don't remember the duplicated relations
check_relations: List[List[str, str, str]] = []
for relation in apf_root.iter('relation'):
relation_type: str = relation.attrib["TYPE"]
for mention in relation.iter('relation_mention'):
#There can be definitely more than 1 relation_mention in a relation
relation_id: str = mention.attrib["ID"]
relation_list: check_relations = [relation_id, relation_type, "", ""]
ignore:bool = False
for arg in mention.iter('relation_mention_argument'):
#There are only 2 arguments with ROLE = ARG1 and ARG2
arg_id: str = arg.attrib["REFID"]
if arg.attrib["ROLE"] != "Arg-1" and arg.attrib["ROLE"] != "Arg-2":
continue
#Arg id should be the original head id and not the duplicated one
if arg_id in head_map: arg_id: str = head_map[arg_id]
relation_list[int(arg.attrib["ROLE"][-1])+1] = arg_id
if not arg_id in heads: ignore: bool = True # ignored duplicated entity
if ignore:
sys.stderr.write("ignored relation %s\n" % (relation_id))
continue
if relation_list[1:] in check_relations:
sys.stderr.write("duplicated relation %s\n" % (relation_id))
continue
check_relations.append(relation_list[1:])
relations[relation_id] = relation_list
return relations
def extract_mentions_and_relations_and_coref(apf_xml_file_path: str):
"""
Extract Mentions (heads and whole mention), relations between heads and coreference clusters
"""
apf_tree = ET.parse(apf_xml_file_path)
apf_root = apf_tree.getroot()
#Store the heads and the mentions
heads: Dict[str, List] = {}
mentions: Dict[str, List] = {}
#Store the heads and mentions and the mapping to the ID to the original one
#to avoid duplicated heads or mentions with different IDs
check_heads: Dict[Tuple[str, int, int, str], str] = {}
check_mentions: Dict[Tuple[str, int, int, str], str] = {}
#If duplicate heads: mapping the duplicate id to the original head id
head_map: Dict[str, str] = {}
mention_map: Dict[str, str] = {}
#Index of the starts and ends (character index) of mentions and heads
mention_starts: Dict[int, List[str]] = {}
mention_ends: Dict[int, List[str]] = {}
head_starts: Dict[int, List[str]] = {}
head_ends: Dict[int, List[str]] = {}
#Storing the coreference clusters
coref_clusters: Dict[str, List[str]] = {}
for entity in apf_root.iter('entity'):
entity_type:str = entity.attrib["TYPE"] #Type of the entity
entity_id:str = entity.attrib["ID"] #Id of the entity (for coreference)
for mention in entity.iter('entity_mention'):
mention_id:str = mention.attrib["ID"]
for child in mention:
#An entity_mention always has 2 children: one head and one extent
if child.tag == "head":
for charseq in child: #always 1 child: charseq
head_start: int = int(charseq.attrib["START"])
head_end: int = int(charseq.attrib["END"])+1
head_text: str = re.sub(r"\n", r" ", charseq.text)
if child.tag == "extent":
for charseq in child: #always 1 child: charseq
mention_start: int = int(charseq.attrib["START"])
mention_end: int = int(charseq.attrib["END"])+1
mention_text: str = re.sub(r"\n", r" ", charseq.text)
#Storing the mention
double_mention:bool = False
mention_tuple: Tuple[str, int, int, str] = (entity_type, mention_start, mention_end, mention_text)
if mention_tuple in check_mentions:
sys.stderr.write("duplicated mention entity %s\n" % (mention_id))
mention_map[mention_id] = check_mentions[mention_tuple]
double_mention: bool = True
else:
check_mentions[mention_tuple] = mention_id
mentions[mention_id] = list(mention_tuple)
if not double_mention:
if not mention_start in mention_starts:
mention_starts[mention_start] = []
mention_starts[mention_start].append(mention_id)
if not mention_end in mention_ends:
mention_ends[mention_end] = []
mention_ends[mention_end].append(mention_id)
#Store the coreference link if not double_metion
if entity_id not in coref_clusters: coref_clusters[entity_id] = []
coref_clusters[entity_id].append(mention_id)
#Store the head
double_head = False
head_tuple: Tuple[str, int, int, str] = (entity_type, head_start, head_end, head_text)
if (head_start, head_end, head_text) in check_heads:
#some cases where same (head_start, head_end, head_text) has two different entity_type ->
#so not checking on the head_tuple but the (head_start, head_end, head_text)
#it is not possible when considering the whole mention....
sys.stderr.write("duplicated head entity %s\n" % (mention_id))
head_map[mention_id] = check_heads[(head_start, head_end, head_text)]
double_head: bool = True
else:
check_heads[(head_start, head_end, head_text)] = mention_id
heads[mention_id] = list(head_tuple)
if not double_head:
if not head_start in head_starts:
head_starts[head_start] = []
head_starts[head_start].append(mention_id)
if not head_end in head_ends:
head_ends[head_end] = []
head_ends[head_end].append(mention_id)
#everything related to mentions
mention_container = MentionContainer("whole", mentions, mention_starts, mention_ends, mention_map)
#everything related to heads
head_container = MentionContainer("head", heads, head_starts, head_ends, head_map)
#Extract relation
relations = extract_relation(apf_root, head_container)
return mention_container, head_container, relations, coref_clusters
def lstrip_doc(doc: str,
mention_container,
head_container):
"""
Left strip the document and update the heads and mentions accordingly.
"""
offset_lstrip: int = len(doc) - len(doc.lstrip())
doc = doc.strip()
#Whole Mention
for _, mention in mention_container.storing.items():
mention[1] -= offset_lstrip
mention[2] -= offset_lstrip
#Head
for _, head in head_container.storing.items():
head[1] -= offset_lstrip
head[2] -= offset_lstrip
mention_starts_offset: Dict[int, List[str]] = {}
mention_ends_offset: Dict[int, List[str]] = {}
head_starts_offset: Dict[int, List[str]] = {}
head_ends_offset: Dict[int, List[str]] = {}
for key, value in mention_container.starts.items(): mention_starts_offset[key - offset_lstrip] = value
for key, value in mention_container.ends.items(): mention_ends_offset[key - offset_lstrip] = value
for key, value in head_container.starts.items(): head_starts_offset[key - offset_lstrip] = value
for key, value in head_container.ends.items(): head_ends_offset[key - offset_lstrip] = value
#Update the containers
mention_container.starts, mention_container.ends = mention_starts_offset, mention_ends_offset
head_container.starts, head_container.ends = head_starts_offset, head_ends_offset
return doc, mention_container, head_container
def offset_mentions(offset_to_apply, flag, mentions, init_mentions):
for key, mention in init_mentions.items():
if mention[1] >= flag: mentions[key][1] += offset_to_apply
if mention[2] >= flag: mentions[key][2] += offset_to_apply
def regulate_offset_with_mentions(doc: str,
mention_container,
head_container):
"""
Regulate offset with mentions. Especially if head is part of a bigger token -> split
into two tokens so that we can only tag the head (and not the other part).
"""
offset: int = 0
size: int = len(doc)
current:int = 0
regions: List[str] = []
mention_starts_regu: Dict[int, List[str]] = {}
mention_ends_regu: Dict[int, List[str]] = {}
head_starts_regu: Dict[int, List[str]] = {}
head_ends_regu: Dict[int, List[str]] = {}
heads = head_container.storing
mentions = mention_container.storing
init_mentions = copy.deepcopy(mentions)
mention_map = mention_container.duplicate_mapping
for i in range(size):
if i in head_container.starts or i in head_container.ends:
inc = 0
if (doc[i-1] != " " and doc[i-1] != "\n") and (doc[i] != " " and doc[i] != "\n"):
regions.append(doc[current:i])
inc = 1
current = i
if i in head_container.starts:
for mention_id in head_container.starts[i]:
#initial = heads[mention_id][1]
heads[mention_id][1] += offset + inc
if not heads[mention_id][1] in head_starts_regu:
head_starts_regu[heads[mention_id][1]] = []
head_starts_regu[heads[mention_id][1]].append(mention_id)
if i in head_container.ends:
for mention_id in head_container.ends[i]:
#initial = heads[mention_id][2]
heads[mention_id][2] += offset
if not heads[mention_id][2] in head_ends_regu:
head_ends_regu[heads[mention_id][2]] = []
head_ends_regu[heads[mention_id][2]].append(mention_id)
#Offsets the mentions
if inc > 0: offset_mentions(1, i, mentions, init_mentions)
offset += inc
regions.append(doc[current:])
doc = " ".join(regions)
#Update
head_container.starts, head_container.ends = head_starts_regu, head_ends_regu
for mention_id, mention in mentions.items():
start, end = mention[1:3]
if start not in mention_starts_regu: mention_starts_regu[start] = []
mention_starts_regu[start].append(mention_id)
if end not in mention_ends_regu: mention_ends_regu[end] = []
mention_ends_regu[end].append(mention_id)
mention_container.starts, mention_container.ends = mention_starts_regu, mention_ends_regu
doc, mention_container, head_container = lstrip_doc(doc, mention_container, head_container)
return doc, mention_container, head_container
def read_doc_and_replace(sgm_file_name):
"""
Read the doc and perform basic subs.
"""
doc = open(sgm_file_name).read()
doc = re.sub(r"<[^>]+>", "", doc)
doc = re.sub(r"(\S+)\n(\S[^:])", r"\1 \2", doc)
return doc
def filter_overlapping_head_relations(relations: Dict[str, List[str]],
head_container):
"""
Filterer out the overlappings heads in relations, and more specifically the heads that are in relation
with themselves.... There are only a few cases (<10) like that to filter out.
"""
filtered_relations: Dict[str, List[str]] = {}
for relation_id, relation in relations.items():
arg1 = head_container.storing[relation[2]]
arg2 = head_container.storing[relation[3]]
#end of arg1 is not overlapping with start of arg2
verslavant = arg2[1] > arg1[2]
#end of arg2 is not overlapping with start of arg1
verslarriere = arg1[1] > arg2[2]
#check we don't have arg1 == arg2
exact_match = (arg1[1], arg1[2]) == (arg2[1], arg2[2])
#Want to have non overlapping heads relation and non exact mathc heads
if (verslavant or verslarriere) and not exact_match:
filtered_relations[relation_id] = relation
return filtered_relations
def filter_out_uni_mention_cluster(coref_clusters):
"""
Filter out the clusters in which there are only ONE single mention of the entity.
The remaining clusters are thus real clusters with at least two mentions referring
to the same real life entity.
"""
filtered_coref_clusters = {}
for entity_id, mention_list in coref_clusters.items():
if len(mention_list) > 1:
filtered_coref_clusters[entity_id] = mention_list
return filtered_coref_clusters
for sgm_file_name in glob.iglob(path_to_ace2005 + 'data/English/*/timex2norm/*.sgm', recursive=True):
apf_xml_file_name = sgm_file_name.replace(".sgm", ".apf.xml")
##### Extraction ####
mention_container, head_container, relations, coref_clusters = extract_mentions_and_relations_and_coref(apf_xml_file_name)
doc = read_doc_and_replace(sgm_file_name)
doc, mention_container, head_container = regulate_offset_with_mentions(doc, mention_container, head_container)
doc = doc.replace("\n", " ")
#Filter the relations
relations = filter_overlapping_head_relations(relations, head_container)
ids_arg1_of_relations, ids_arg2_of_relations = {}, {} #Dict[mention)id:str, List[relation_id:str]]
for relation_id, relation in relations.items():
if relation[2] not in ids_arg1_of_relations: ids_arg1_of_relations[relation[2]] = []
ids_arg1_of_relations[relation[2]].append(relation_id)
if relation[3] not in ids_arg2_of_relations: ids_arg2_of_relations[relation[3]] = []
ids_arg2_of_relations[relation[3]].append(relation_id)
#Filter out the uni-mention clusters
coref_clusters = filter_out_uni_mention_cluster(coref_clusters)
#Map mention_ids that are coreferent (i.e. are not part of a uni-mention cluster) to entity_ids
#Also create a list of the mention_ids which are part of a real cluster.
mention_id_2_entity_id: Dict[str, str] = {}
coreferent_mention_ids = []
for entity_id, mention_id_list in coref_clusters.items():
for mention_id in mention_id_list:
mention_id_2_entity_id[mention_id] = entity_id
coreferent_mention_ids.append(mention_id)
doc = doc.replace("&amp;", "ZZZZZ")#quick fix on spacy tokenizer customed
document = nlp(doc)
#####################################################
idx = 0
# EMD stuffs
emd_label_stack: List[str] = []
emd_dump: List[str] = []
# Relation stuffs
max_col: int = 0
# Label stack for arg1s and arg2s.
# Dict[mention_id:str, Dict[relation_id:str, List[str]]]
arg1_labels: Dict[str, Dict[str, List[str]]] = {}
arg2_labels: Dict[str, Dict[str, List[str]]] = {}
# On which columm to write the relations (one column per relation if there are multiple relations in one sentence).
# It is also a zay to track the current relations in the stack.
# Dict[relation_id:str, Dict[relation_id:str, List[str]]]
relation_id_2_column: Dict[str, Dict[str, List[str]]] = {}
relation_dump = []
head_starts = head_container.starts
heads = head_container.storing
mention_starts = mention_container.starts
mentions = mention_container.storing
#### Writing ####
save_file_name_rel = saving_path + os.path.basename(sgm_file_name) + ".like_conll"
save_file_name_coref = saving_path + os.path.basename(sgm_file_name) + ".coref"
document_rel_dump = []
document_coref_dump = []
#Emd and Relation
for token in document:
# Collect labels for EMD and Relation since they are both on the head level
if token.idx in head_starts:
mention_id = head_starts[token.idx][0]
head = heads[mention_id]
emd_label, token_stack = head[0], head[3]
token_stack = [token.text for token in nlp(token_stack)]
# EMD stuffs#################
# Fill the EMD label stacks
if len(token_stack) == 1:
emd_label_stack = ["(%s)" % emd_label]
else:
emd_label_stack = ["(%s*" % emd_label] + (len(token_stack)-2)*["*"] + ["*)"]
#############################
#Relation Thing#######################################################
#Fill the label stacks (arg1_labels and arg2_labels)
if (mention_id in ids_arg1_of_relations) or (mention_id in ids_arg2_of_relations):
# WARNING a mention can be both arg1 and arg2 in two different relations
if (mention_id in ids_arg1_of_relations):
#print("Entered a related mention ARG1", mention_id)
arg1_labels[mention_id] = {}
for relation_id in ids_arg1_of_relations[mention_id]:
_, relation_type, _, _ = relations[relation_id]
if len(token_stack) == 1:
arg1_label_stack = ["(ARG1_%s*)" % relation_type]
else:
arg1_label_stack = ["(ARG1_%s*" % relation_type] + (len(token_stack)-2)*["*"] + ["*)"]
arg1_labels[mention_id][relation_id] = arg1_label_stack
# Assign a column to relation_id if it does not have one yet.
if relation_id not in relation_id_2_column:
relation_id_2_column[relation_id] = {"col": max_col, "done": set()}
max_col += 1
#print(relation_id_2_column)
relation_id_2_column[relation_id]["current_arg"] = "arg1"
relation_id_2_column[relation_id]["current_id"] = mention_id
if (mention_id in ids_arg2_of_relations):
#print("Entered a related mention ARG2", mention_id)
arg2_labels[mention_id] = {}
for relation_id in ids_arg2_of_relations[mention_id]:
_, relation_type, _, _ = relations[relation_id]
if len(token_stack) == 1:
arg2_label_stack = ["(ARG2_%s*)" % relation_type]
else:
arg2_label_stack = ["(ARG2_%s*" % relation_type] + (len(token_stack)-2)*["*"] + ["*)"]
arg2_labels[mention_id][relation_id] = arg2_label_stack
if relation_id not in relation_id_2_column:
relation_id_2_column[relation_id] = {"col": max_col, "done": set()}
max_col += 1
#print(relation_id_2_column)
relation_id_2_column[relation_id]["current_arg"] = "arg2"
relation_id_2_column[relation_id]["current_id"] = mention_id
#################################################################
# EMD stuff #########################
# Unstack EMD label for current token
if emd_label_stack != []:
labelled_token = token_stack[0]
emd_label = emd_label_stack[0]
token_stack = token_stack[1:]
emd_label_stack = emd_label_stack[1:]
else:
labelled_token = token.text
emd_label = "*"
#####################################
# Relation stuff ##############################
# Unstack relation label
relation_labels = ["*"]*max_col
for relation_id, rela_dump_info in relation_id_2_column.items():
current_arg = rela_dump_info["current_arg"]
current_id = rela_dump_info["current_id"]
if current_arg is not None:
col = rela_dump_info["col"]
if current_arg == "arg1":
rel_label = arg1_labels[current_id][relation_id][0]
arg1_labels[mention_id][relation_id] = arg1_labels[mention_id][relation_id][1:]
if arg1_labels[mention_id][relation_id] == []:
rela_dump_info["current_arg"] = None
rela_dump_info["done"].add(current_arg)
#print(rela_dump_info["done"])
relation_labels[col] = rel_label
if current_arg == "arg2":
rel_label = arg2_labels[current_id][relation_id][0]
arg2_labels[mention_id][relation_id] = arg2_labels[mention_id][relation_id][1:]
if arg2_labels[mention_id][relation_id] == []:
rela_dump_info["current_arg"] = None
rela_dump_info["done"].add(current_arg)
#print(rela_dump_info["done"])
relation_labels[col] = rel_label
keys_to_delete = []
for relation_id, rela_dump_info in relation_id_2_column.items():
if len(rela_dump_info["done"])==2: keys_to_delete.append(relation_id)
for key in keys_to_delete: del relation_id_2_column[key]
################################################
labelled_token.replace("ZZZZZ", "&") #quick fix on spacy tokenizer customed
if token.is_space: continue
emd_dump.append("{}\t{}\t{}".format(idx, labelled_token, emd_label))
relation_dump.append(relation_labels)
idx += 1
#Dump EMD and Relation
if (token.text in [".", "?", "!"]) and (not relation_id_2_column):
if max_col > 0:
padded_relation_dump = [a + ["*"] * (max_col - len(a)) for a in relation_dump]
for i in range(len(padded_relation_dump)):
emd_dump[i] = emd_dump[i] + "\t" + "\t".join(padded_relation_dump[i])
for i in range(len(emd_dump)):
emd_dump[i] += "\t-"
emd_dump[-1] += '\n'
document_rel_dump.extend(emd_dump)
idx = 0
emd_dump = []
relation_dump, max_col = [], 0
document_rel_dump.append("#end document")
with open(save_file_name_rel, "w") as file:
file.write("#begin document ({});\n".format(os.path.basename(sgm_file_name)))
file.write("\n".join(document_rel_dump))
#################################
idx = 0
# EMD stuffs
emd_label_stack: List[str] = []
emd_dump: List[str] = []
# Coref stuffs
coref_label_stack: List[str] = []
coref_dump: List[str] = []
entity_id_2_idx: Dict[str, int] = {}
entity_idx_count = 0
coref_dump = []
for token in document:
# Collect labels for EMD and Relation since they are both on the head level
if token.idx in head_starts:
mention_id = head_starts[token.idx][0]
head = heads[mention_id]
emd_label, token_stack = head[0], head[3]
token_stack = [token.text for token in nlp(token_stack)]
# EMD stuffs#################
# Fill the EMD label stacks
if len(token_stack) == 1:
emd_label_stack = ["(%s)" % emd_label]
else:
emd_label_stack = ["(%s*" % emd_label] + (len(token_stack)-2)*["*"] + ["*)"]
#############################
# Collect coref label since it is on the mention level (and not the head level)
if token.idx in mention_starts:
mentions_starting_here = mention_starts[token.idx]
for coref_mention_id in mentions_starting_here:
#First check that this mention_id is really part of non-uni-mention cluster
if coref_mention_id not in coreferent_mention_ids:
continue
start, end = mentions[coref_mention_id][1:3]
coref_mention_text = doc[start:end]
len_mention = len(nlp(coref_mention_text))
entity_id = mention_id_2_entity_id[coref_mention_id]
if entity_id not in entity_id_2_idx:
entity_idx_count += 1
entity_id_2_idx[entity_id] = entity_idx_count
entity_idx: int = entity_id_2_idx[entity_id]
if len_mention == 1:
coref_labels = ["(%d)" % entity_idx]
else:
coref_labels = ["(%d" % entity_idx] + (len_mention-2)*["-"] + ["%d)" % entity_idx]
#Fusion coref_labels with coref_label_stack
if len(coref_label_stack) < len(coref_labels):
coref_label_stack = coref_label_stack + ["-"] * (len_mention - len(coref_label_stack))
for i, current_coref_label in enumerate(coref_label_stack):
if current_coref_label == "-":
coref_label_stack[i] = coref_labels[i]
elif "(" in current_coref_label and ")" in current_coref_label:
if "(" in coref_labels[i]:
coref_label_stack[i] = coref_labels[i] + "|" + coref_label_stack[i]
elif ")" in coref_labels[i]:
coref_label_stack[i] = coref_label_stack[i] + "|" + coref_labels[i]
elif "-" == coref_labels[i]:
pass
elif "(" in current_coref_label:
if coref_labels[i] != "-":
coref_label_stack[i] = coref_label_stack[i] + "|" + coref_labels[i]
elif ")" in current_coref_label:
if coref_labels[i] != "-":
coref_label_stack[i] = coref_labels[i] + "|" + coref_label_stack[i]
if i+1 >= len_mention:
break
# EMD stuff #########################
# Unstack EMD label for current token
if emd_label_stack != []:
labelled_token = token_stack[0]
emd_label = emd_label_stack[0]
token_stack = token_stack[1:]
emd_label_stack = emd_label_stack[1:]
else:
labelled_token = token.text
emd_label = "*"
#####################################
# Coref stuff ###############
# unstack label for current token
if coref_label_stack != []:
coref_labels = coref_label_stack[0]
coref_label_stack = coref_label_stack[1:]
else:
coref_labels = "-"
############################
labelled_token.replace("ZZZZZ", "&") #quick fix on spacy tokenizer customed
if token.is_space: continue
emd_dump.append("{}\t{}\t{}".format(idx, labelled_token, emd_label))
coref_dump.append(coref_labels)
idx += 1
#Dump EMD and Coref
if (token.text in [".", "?", "!"]) and (coref_label_stack == []):
for i in range(len(emd_dump)):
emd_dump[i] = emd_dump[i] + "\t" + coref_dump[i]
emd_dump[-1] += '\n'
document_coref_dump.extend(emd_dump)
idx = 0
emd_dump = []
coref_dump = []
document_coref_dump.append("#end document")
with open(save_file_name_coref, "w") as file:
file.write("#begin document ({});\n".format(os.path.basename(sgm_file_name)))
file.write("\n".join(document_coref_dump))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment