Last active
April 21, 2023 16:36
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''Convert a CONLLU file with coreference data to a JSONL file with clusters. | |
Usage: | |
python parse_dacoref.py <input_file> | |
Author: | |
Dan Saattrup Nielsen (dan.nielsen@alexandra.dk) | |
''' | |
import conllu | |
import click | |
from pathlib import Path | |
from collections import defaultdict | |
import pandas as pd | |
@click.command() | |
@click.argument('input_file', type=click.Path(exists=True)) | |
def main(input_file: Path | str) -> None: | |
# Specify list of CONNLU fields in the input file | |
fields = [ | |
"id", | |
"form", | |
"lemma", | |
"upos", | |
"xpos", | |
"feats", | |
"head", | |
"deprel", | |
"deps", | |
"misc", | |
"coref_id", | |
"coref_rel", | |
"doc_id", | |
"qid" | |
] | |
# Open the input file and parse it | |
with Path(input_file).open() as f: | |
parsed_sents = conllu.parse(f.read(), fields=fields) | |
records = list() | |
for sent in parsed_sents: | |
# Extract the document and the list of tokens | |
doc: str = sent.metadata['text'] | |
sent_id: str = sent.metadata['sent_id'] | |
doc_id: str = sent[0]['doc_id'] | |
tokens: list[str] = [token['form'] for token in sent] | |
# Parse the coreference clusters | |
cluster_dict: dict[int, list[int]] = defaultdict(list) | |
active_clusters: list[int] = list() | |
for token in sent: | |
# Get the coference resolution relation, which is a string of the form | |
# "(cluster_id" if the token is the beginning of a cluster, "cluster_id)" | |
# if the token is the end of a cluster, "(cluster_id)" if the token begins | |
# and ends a cluster, or "-" if the token is not part of a cluster. | |
# Multiple clusters can be specified, separated by "|" | |
coref_rel = token['coref_rel'] | |
# Separate the relation into potential multiple clusters | |
coref_rel_parts = coref_rel.split('|') | |
# Iterate over the clusters | |
for coref_rel_part in coref_rel_parts: | |
# If the token both begins and ends a coreference cluster then add the | |
# token to the cluster and do not change the cluster id | |
if coref_rel_part.startswith('(') and coref_rel_part.endswith(')'): | |
cluster_id = coref_rel_part.strip('()') | |
cluster_dict[cluster_id].append(token['id'] - 1) | |
# Otherwise, if the token begins a coreference cluster then add the | |
# token to the cluster and update the cluster id | |
elif coref_rel_part.startswith('('): | |
cluster_id = coref_rel_part.strip('()') | |
cluster_dict[cluster_id].append(token['id'] - 1) | |
active_clusters.append(cluster_id) | |
# Otherwise, if the token ends a coreference cluster then add the token to | |
# the current cluster and reset the current cluster ID | |
elif coref_rel_part.endswith(')'): | |
cluster_id = coref_rel_part.strip('()') | |
cluster_dict[cluster_id].append(token['id'] - 1) | |
active_clusters.remove(cluster_id) | |
# Otherwise, if we are in the middle of a coreference cluster then add | |
# the token to the current cluster | |
elif active_clusters: | |
for cluster_id in active_clusters: | |
cluster_dict[cluster_id].append(token['id'] - 1) | |
# Extract the coreference clusters from the dictionary | |
clusters = [ | |
tuple(cluster) | |
for cluster in cluster_dict.values() | |
if len(cluster) > 1 | |
] | |
# Collect the data in a record | |
record = dict( | |
sent_id=sent_id, | |
doc_id=doc_id, | |
text=doc, | |
tokens=tokens, | |
clusters=clusters, | |
) | |
records.append(record) | |
# Convert the records to a pandas DataFrame | |
df = pd.DataFrame.from_records(records) | |
# Store the DataFrame as a JSONL file | |
output_file = Path(input_file).with_suffix('.jsonl') | |
df.to_json(output_file, orient='records', lines=True) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment