Created
August 23, 2021 15:15
-
-
Save dorajam/15f2988681437cd9bc905be468f82200 to your computer and use it in GitHub Desktop.
COGS reverse parser
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import re | |
from collections import Counter | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torchtext | |
from torch.nn.utils.rnn import pad_sequence as pad | |
class Preprocessor(): | |
def __init__(self, data_dir='cogs/graph/', graph_layers=1): | |
self.data_dir = data_dir | |
self.node_vocab, self.edge_vocab = self.tokenize() | |
self.graph_layers = graph_layers | |
assert self.word_to_token('<pad>') == 1, '<pad> token needs to be mapped to 1' | |
assert self.word_to_token('null') == 0, 'Null token needs to be mapped to 0' | |
assert self.word_to_token('<pad>', 'edges') == 1, '<pad> token needs to be mapped to 1' | |
assert self.word_to_token('null', 'edges') == 0, 'Null token needs to be mapped to 0' | |
def tokenize(self): | |
with open(os.path.join(self.data_dir, 'nodes.txt'), 'r') as f: | |
nodes = [line.rstrip() for line in f] | |
# vocab will use null token:0, <pad>_token:1 | |
node_vocab = torchtext.vocab.Vocab(Counter(nodes), specials=['null', '<pad>']) | |
with open(os.path.join(self.data_dir, 'edges.txt'), 'r') as f: | |
edges = [line.rstrip() for line in f] | |
# vocab will use null token:0, <pad>_token:1 | |
edge_vocab = torchtext.vocab.Vocab(Counter(edges), specials=['null', '<pad>']) | |
return node_vocab, edge_vocab | |
def word_to_token(self, word, vocab='nodes'): | |
if vocab == 'nodes': | |
return self.node_vocab.stoi[word] | |
return self.edge_vocab.stoi[word] | |
def token_to_word(self, token, vocab='nodes'): | |
if vocab == 'nodes': | |
return self.node_vocab.itos[token] | |
return self.edge_vocab.itos[token] | |
def nodes_to_node_seq(self, nodes, in_seq): | |
""" | |
Takes nodes dict and turns it into a sequence of node tags for each input token. | |
nodes = {'x_1': node1, 'x_2': node2} | |
in_seq: ['some', input, sentence] | |
returns | |
- node token sequence (torch.LongTensor) ( corresponding to [null, node1, node2] ) | |
""" | |
try: | |
assert isinstance(in_seq, list) | |
except: | |
in_seq = in_seq.split(' ') | |
node_sequence = torch.LongTensor( | |
[self.word_to_token('null') for _ in range(len(in_seq))]) | |
for node, node_label in nodes.items(): | |
node_sequence[int(node[2:])] = self.word_to_token(node_label) | |
return node_sequence | |
def edges_to_edge_seq(self, edges, edge_labels, in_seq): | |
""" | |
Takes edges and turns it into a sequence of set of edge tags for each input token. | |
nodes = {'x_0': word0, 'x_1': word1, 'x_2': Levi} | |
edges = [['x_1', 'x_0'], ['x_2', 'x_1']] | |
in_seq: [some, nice, input, sentence] | |
edge_labels = [theme, agent] | |
returns | |
torch tensor [[1,1,0,0], [1,1,0,0], [0,0,1,1], [0,0,1,0]] | |
""" | |
try: | |
assert isinstance(in_seq, list) | |
except: | |
in_seq = in_seq.split(' ') | |
N = len(in_seq) | |
# default edge type: null (no edge) | |
final_edges = torch.LongTensor([self.word_to_token('null', vocab='edges')] * N * N).view(N, N) | |
for edge, label in zip(edges, edge_labels): | |
e1, e2 = [int(e[2:]) for e in edge] | |
label_token = self.word_to_token(label, vocab='edges') | |
final_edges[e1, e2] = label_token | |
return final_edges | |
def pad_sequences(self, node_targets, edge_targets, device='cuda'): | |
padded_nodes = pad(node_targets, batch_first=True, padding_value=self.word_to_token('<pad>', vocab='nodes')) | |
max_len = padded_nodes.shape[-1] | |
res = [] | |
for e in edge_targets: | |
tmp = torch.LongTensor([self.word_to_token('<pad>', vocab='edges')] * max_len * max_len).view(max_len, | |
max_len) | |
tmp[:e.shape[0], :e.shape[1]] = e | |
res.append(tmp) | |
padded_edges = torch.stack(res) | |
return padded_nodes.to(device), padded_edges.to(device) | |
def generate_target_sequences(self, data='train', file_name='_parsed.json'): | |
""" | |
Takes dataset in graph format, and parses it to a sequence of tokens | |
both for nodes and edges. | |
Returns: | |
- node token sequence (torch LongTensor) | |
- edge token sequence (torch LongTensor) | |
""" | |
file_path = os.path.join(self.data_dir, data + file_name) | |
graph = json.load(open(file_path)) | |
graph_df = pd.DataFrame(graph) | |
graph_df['inp'] = graph_df['original_inp'] | |
with open(f'./cogs/lambda/{data}.tsv') as f: | |
df = pd.read_csv(f, sep='\t', header=None) | |
df.columns = ['inp', 'out', 'type'] | |
df = df.reset_index() | |
assert len(df) == len(graph_df), 'Graph output should exist for each input, but \ | |
original is of len {len(df)}, while graph is of len {len(graph_df)}.' | |
joined_df = pd.merge(df, graph_df, on='inp', how='left') | |
# assert joined_df[ | |
# joined_df.type == 'primitive'].size == 0, 'Primitives are not filtered. You need to do this first.' | |
joined_df['node_seq'] = joined_df.apply( | |
lambda x: self.nodes_to_node_seq(x['nodes'], x['inp']), axis=1) | |
joined_df['edge_seq'] = joined_df.apply( | |
lambda x: self.edges_to_edge_seq(x['edges'], x['edge_labels'], x['inp']), axis=1) | |
return joined_df.node_seq.values.tolist(), joined_df.edge_seq.values.tolist() | |
def split_and_sort_clauses(self, lambd): | |
return sorted(re.split(' AND | ; ', lambd)) | |
def scores_to_graph(self, node_preds, edge_preds, node_targets, edge_targets): | |
""" returns the batch of nodes, edges and edge labels where all three include null nodes and null edges, but exclude pad predictions.""" | |
assert node_preds.shape[0] == edge_preds.shape[0], \ | |
f'Expected batch first, but got {node_preds.shape[0]}, {edge_preds.shape[0]}' | |
assert node_preds.shape == node_targets.shape | |
node_sets = [] | |
for sequence, target_seq in zip(node_preds, node_targets): # BS, seq_len | |
nodes = {} | |
for idx, (token, target_token) in enumerate(zip(sequence, target_seq)): # seq_len | |
token = token.item() | |
node = self.token_to_word(token, vocab='nodes') | |
# if the target is a pad token, or if the prediction is pad, ignore it | |
if node == '<pad>' or target_token == self.word_to_token('<pad>', vocab='nodes'): | |
continue | |
else: | |
nodes[f'x_{idx}'] = node | |
node_sets.append(nodes) | |
edge_sets = [] | |
edge_label_sets = [] | |
for sequence, target_seq in zip(edge_preds, edge_targets): # BS, seq_len, seq_len | |
edges = [] | |
edge_labels = [] | |
for token_i in range(sequence.shape[0]): | |
for token_j in range(sequence.shape[0]): | |
predicted_token = sequence[token_i, token_j] | |
edge_label = self.token_to_word(predicted_token, vocab='edges') | |
true_token = target_seq[token_i, token_j] | |
true_label = self.token_to_word(true_token, vocab='edges') | |
if edge_label == '<pad>' or true_label == '<pad>': | |
continue | |
else: | |
edge = [f'x_{token_i}', f'x_{token_j}'] | |
edges.append(edge) | |
edge_labels.append(edge_label) | |
edge_sets.append(edges) | |
edge_label_sets.append(edge_labels) | |
return node_sets, edge_sets, edge_label_sets | |
def graph_to_lambda(self, node_sets, edge_sets, edge_labels): | |
lambdas = [] | |
for nodes, edges, labels in zip(node_sets, edge_sets, edge_labels): | |
# removes null nodes | |
nodes = {k: v for k, v in nodes.items() if nodes[k] != 'null'} | |
try: | |
# removes null edges | |
edges, labels = zip(*[(edge, label) for edge, label in zip(edges, labels) if label != 'null']) | |
lambdas.append(self._graph_to_lambda(nodes, edges, labels)) | |
except: | |
# skip example if nodes or edges are both empty | |
lambdas.append('NA') | |
return lambdas | |
def _graph_to_lambda(self, nodes, edges, edge_labels): | |
""" | |
Turns a graph into a set of clauses serialized in the same way as COGS does. | |
""" | |
## hacky, but here we combine the * nodes into the next node's label | |
new = {} | |
definite = False | |
for pos, label in sorted(nodes.items(), key=lambda r: int(r[0][2:])): | |
if definite: | |
new[pos] = '* ' + label | |
definite = False | |
else: | |
if label == '*': | |
definite = True | |
else: | |
new[pos] = label | |
nodes = new | |
# containing * word clauses - always goes to the beginning | |
definite_clauses = {} | |
# containing rest of the clauses - ordered primarily by arg1 and by arg2, if available | |
event_clauses = {f'x_{idx}': {} for idx in range(40)} | |
# import ipdb;ipdb.set_trace() | |
for idx, items in enumerate(zip(edges, edge_labels)): | |
edge, label = items | |
# e.g. x _ 1, x _ 3 | |
e1, e2 = edge | |
# retrieves node labels, except if they are null nodes | |
try: | |
node_label1 = nodes[e1] | |
node_label2 = nodes[e2] | |
except: | |
continue | |
# reformat to x_1, x_3 | |
e1 = e1.replace(' ', '') | |
e2 = e2.replace(' ', '') | |
# if edge contains named entity, only produce one clause | |
if node_label1[0].isupper() and node_label2[0].isupper(): | |
# print('Edge connects two named entities. Unexpected behavior.') | |
continue | |
if node_label1[0].isupper(): | |
named_entity = node_label1 | |
event_clauses[e2][e1] = node_label2 + ' . ' + label + ' (' + e2 + ', ' + named_entity + ') ' | |
elif node_label2[0].isupper(): | |
named_entity = node_label2 | |
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + named_entity + ') ' | |
# if not named entity | |
else: | |
# if * word appears, modify its name in the arg | |
if '*' in node_label1 and '*' in node_label2: | |
# print('Two definite articles connected. Unexpected behavior') | |
# import ipdb;ipdb.set_trace() | |
definite_clauses[e1] = node_label1 + ' (' + e1 + ')' | |
definite_clauses[e2] = node_label2 + ' (' + e2 + ')' | |
node_label1 = node_label1[2:] | |
# node_label2 = node_label2[2:] | |
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')' | |
elif '*' in node_label1: | |
definite_clauses[e1] = node_label1 + ' (' + e1 + ')' | |
node_label1 = node_label1[2:] | |
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')' | |
event_clauses[e2] = {'_': node_label2 + ' (' + e2 + ')'} | |
elif '*' in node_label2: | |
definite_clauses[e2] = node_label2 + ' (' + e2 + ')' | |
node_label2 = node_label2[2:] | |
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')' | |
else: | |
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')' | |
event_clauses[e2] = {'_': node_label2 + ' (' + e2 + ')'} | |
# produces serialized output | |
final = '' | |
definite_clause_keys = sorted([int(k[2:]) for k in definite_clauses.keys()]) | |
for key in definite_clause_keys: | |
final += definite_clauses['x_' + str(key)] + ' ; ' | |
event_clause_keys = sorted([int(k[2:]) for k in event_clauses.keys()]) | |
for key in event_clause_keys: | |
subkeys = sorted([int(k[2:]) if k != '_' else -1 for k in event_clauses['x_' + str(key)].keys()]) | |
for clause_key in subkeys: | |
clause_key = 'x_' + str(clause_key) if clause_key != -1 else '_' | |
if clause_key == '_': | |
if len(event_clauses['x_' + str(key)]) > 1 and \ | |
not np.any(['nmod' in val for val in event_clauses['x_' + str(key)].values()]): | |
continue | |
final += event_clauses['x_' + str(key)][clause_key] + ' AND ' | |
# removes AND from the end | |
try: | |
final = final[:-5] | |
if final[-1] == ' ': | |
final = final[:-1] | |
except: | |
return 'NA' | |
return final.replace('x_', ' x _ ').replace(')', ' )').replace(' ', ' ').replace(',', ' ,') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment