Created August 23, 2021 15:15
COGS reverse parser
import json
import os
import re
from collections import Counter
import numpy as np
import pandas as pd
import torch
import torchtext
from torch.nn.utils.rnn import pad_sequence as pad
class Preprocessor():
def __init__(self, data_dir='cogs/graph/', graph_layers=1):
self.data_dir = data_dir
self.node_vocab, self.edge_vocab = self.tokenize()
self.graph_layers = graph_layers
assert self.word_to_token('<pad>') == 1, '<pad> token needs to be mapped to 1'
assert self.word_to_token('null') == 0, 'Null token needs to be mapped to 0'
assert self.word_to_token('<pad>', 'edges') == 1, '<pad> token needs to be mapped to 1'
assert self.word_to_token('null', 'edges') == 0, 'Null token needs to be mapped to 0'
def tokenize(self):
with open(os.path.join(self.data_dir, 'nodes.txt'), 'r') as f:
nodes = [line.rstrip() for line in f]
# vocab will use null token:0, <pad>_token:1
node_vocab = torchtext.vocab.Vocab(Counter(nodes), specials=['null', '<pad>'])
with open(os.path.join(self.data_dir, 'edges.txt'), 'r') as f:
edges = [line.rstrip() for line in f]
# vocab will use null token:0, <pad>_token:1
edge_vocab = torchtext.vocab.Vocab(Counter(edges), specials=['null', '<pad>'])
return node_vocab, edge_vocab
def word_to_token(self, word, vocab='nodes'):
if vocab == 'nodes':
return self.node_vocab.stoi[word]
return self.edge_vocab.stoi[word]
def token_to_word(self, token, vocab='nodes'):
if vocab == 'nodes':
return self.node_vocab.itos[token]
return self.edge_vocab.itos[token]
def nodes_to_node_seq(self, nodes, in_seq):
Takes nodes dict and turns it into a sequence of node tags for each input token.
nodes = {'x_1': node1, 'x_2': node2}
in_seq: ['some', input, sentence]
- node token sequence (torch.LongTensor) ( corresponding to [null, node1, node2] )
assert isinstance(in_seq, list)
in_seq = in_seq.split(' ')
node_sequence = torch.LongTensor(
[self.word_to_token('null') for _ in range(len(in_seq))])
for node, node_label in nodes.items():
node_sequence[int(node[2:])] = self.word_to_token(node_label)
return node_sequence
def edges_to_edge_seq(self, edges, edge_labels, in_seq):
Takes edges and turns it into a sequence of set of edge tags for each input token.
nodes = {'x_0': word0, 'x_1': word1, 'x_2': Levi}
edges = [['x_1', 'x_0'], ['x_2', 'x_1']]
in_seq: [some, nice, input, sentence]
edge_labels = [theme, agent]
torch tensor [[1,1,0,0], [1,1,0,0], [0,0,1,1], [0,0,1,0]]
assert isinstance(in_seq, list)
in_seq = in_seq.split(' ')
N = len(in_seq)
# default edge type: null (no edge)
final_edges = torch.LongTensor([self.word_to_token('null', vocab='edges')] * N * N).view(N, N)
for edge, label in zip(edges, edge_labels):
e1, e2 = [int(e[2:]) for e in edge]
label_token = self.word_to_token(label, vocab='edges')
final_edges[e1, e2] = label_token
return final_edges
def pad_sequences(self, node_targets, edge_targets, device='cuda'):
padded_nodes = pad(node_targets, batch_first=True, padding_value=self.word_to_token('<pad>', vocab='nodes'))
max_len = padded_nodes.shape[-1]
res = []
for e in edge_targets:
tmp = torch.LongTensor([self.word_to_token('<pad>', vocab='edges')] * max_len * max_len).view(max_len,
tmp[:e.shape[0], :e.shape[1]] = e
padded_edges = torch.stack(res)
def generate_target_sequences(self, data='train', file_name='_parsed.json'):
Takes dataset in graph format, and parses it to a sequence of tokens
both for nodes and edges.
- node token sequence (torch LongTensor)
- edge token sequence (torch LongTensor)
file_path = os.path.join(self.data_dir, data + file_name)
graph = json.load(open(file_path))
graph_df = pd.DataFrame(graph)
graph_df['inp'] = graph_df['original_inp']
with open(f'./cogs/lambda/{data}.tsv') as f:
df = pd.read_csv(f, sep='\t', header=None)
df.columns = ['inp', 'out', 'type']
df = df.reset_index()
assert len(df) == len(graph_df), 'Graph output should exist for each input, but \
original is of len {len(df)}, while graph is of len {len(graph_df)}.'
joined_df = pd.merge(df, graph_df, on='inp', how='left')
# assert joined_df[
# joined_df.type == 'primitive'].size == 0, 'Primitives are not filtered. You need to do this first.'
joined_df['node_seq'] = joined_df.apply(
lambda x: self.nodes_to_node_seq(x['nodes'], x['inp']), axis=1)
joined_df['edge_seq'] = joined_df.apply(
lambda x: self.edges_to_edge_seq(x['edges'], x['edge_labels'], x['inp']), axis=1)
return joined_df.node_seq.values.tolist(), joined_df.edge_seq.values.tolist()
def split_and_sort_clauses(self, lambd):
return sorted(re.split(' AND | ; ', lambd))
def scores_to_graph(self, node_preds, edge_preds, node_targets, edge_targets):
""" returns the batch of nodes, edges and edge labels where all three include null nodes and null edges, but exclude pad predictions."""
assert node_preds.shape[0] == edge_preds.shape[0], \
f'Expected batch first, but got {node_preds.shape[0]}, {edge_preds.shape[0]}'
assert node_preds.shape == node_targets.shape
node_sets = []
for sequence, target_seq in zip(node_preds, node_targets): # BS, seq_len
nodes = {}
for idx, (token, target_token) in enumerate(zip(sequence, target_seq)): # seq_len
token = token.item()
node = self.token_to_word(token, vocab='nodes')
# if the target is a pad token, or if the prediction is pad, ignore it
if node == '<pad>' or target_token == self.word_to_token('<pad>', vocab='nodes'):
nodes[f'x_{idx}'] = node
edge_sets = []
edge_label_sets = []
for sequence, target_seq in zip(edge_preds, edge_targets): # BS, seq_len, seq_len
edges = []
edge_labels = []
for token_i in range(sequence.shape[0]):
for token_j in range(sequence.shape[0]):
predicted_token = sequence[token_i, token_j]
edge_label = self.token_to_word(predicted_token, vocab='edges')
true_token = target_seq[token_i, token_j]
true_label = self.token_to_word(true_token, vocab='edges')
if edge_label == '<pad>' or true_label == '<pad>':
edge = [f'x_{token_i}', f'x_{token_j}']
return node_sets, edge_sets, edge_label_sets
def graph_to_lambda(self, node_sets, edge_sets, edge_labels):
lambdas = []
for nodes, edges, labels in zip(node_sets, edge_sets, edge_labels):
# removes null nodes
nodes = {k: v for k, v in nodes.items() if nodes[k] != 'null'}
# removes null edges
edges, labels = zip(*[(edge, label) for edge, label in zip(edges, labels) if label != 'null'])
lambdas.append(self._graph_to_lambda(nodes, edges, labels))
# skip example if nodes or edges are both empty
return lambdas
def _graph_to_lambda(self, nodes, edges, edge_labels):
Turns a graph into a set of clauses serialized in the same way as COGS does.
## hacky, but here we combine the * nodes into the next node's label
new = {}
definite = False
for pos, label in sorted(nodes.items(), key=lambda r: int(r[0][2:])):
if definite:
new[pos] = '* ' + label
definite = False
if label == '*':
definite = True
new[pos] = label
nodes = new
# containing * word clauses - always goes to the beginning
definite_clauses = {}
# containing rest of the clauses - ordered primarily by arg1 and by arg2, if available
event_clauses = {f'x_{idx}': {} for idx in range(40)}
# import ipdb;ipdb.set_trace()
for idx, items in enumerate(zip(edges, edge_labels)):
edge, label = items
# e.g. x _ 1, x _ 3
e1, e2 = edge
# retrieves node labels, except if they are null nodes
node_label1 = nodes[e1]
node_label2 = nodes[e2]
# reformat to x_1, x_3
e1 = e1.replace(' ', '')
e2 = e2.replace(' ', '')
# if edge contains named entity, only produce one clause
if node_label1[0].isupper() and node_label2[0].isupper():
# print('Edge connects two named entities. Unexpected behavior.')
if node_label1[0].isupper():
named_entity = node_label1
event_clauses[e2][e1] = node_label2 + ' . ' + label + ' (' + e2 + ', ' + named_entity + ') '
elif node_label2[0].isupper():
named_entity = node_label2
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + named_entity + ') '
# if not named entity
# if * word appears, modify its name in the arg
if '*' in node_label1 and '*' in node_label2:
# print('Two definite articles connected. Unexpected behavior')
# import ipdb;ipdb.set_trace()
definite_clauses[e1] = node_label1 + ' (' + e1 + ')'
definite_clauses[e2] = node_label2 + ' (' + e2 + ')'
node_label1 = node_label1[2:]
# node_label2 = node_label2[2:]
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')'
elif '*' in node_label1:
definite_clauses[e1] = node_label1 + ' (' + e1 + ')'
node_label1 = node_label1[2:]
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')'
event_clauses[e2] = {'_': node_label2 + ' (' + e2 + ')'}
elif '*' in node_label2:
definite_clauses[e2] = node_label2 + ' (' + e2 + ')'
node_label2 = node_label2[2:]
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')'
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')'
event_clauses[e2] = {'_': node_label2 + ' (' + e2 + ')'}
# produces serialized output
final = ''
definite_clause_keys = sorted([int(k[2:]) for k in definite_clauses.keys()])
for key in definite_clause_keys:
final += definite_clauses['x_' + str(key)] + ' ; '
event_clause_keys = sorted([int(k[2:]) for k in event_clauses.keys()])
for key in event_clause_keys:
subkeys = sorted([int(k[2:]) if k != '_' else -1 for k in event_clauses['x_' + str(key)].keys()])
for clause_key in subkeys:
clause_key = 'x_' + str(clause_key) if clause_key != -1 else '_'
if clause_key == '_':
if len(event_clauses['x_' + str(key)]) > 1 and \
not np.any(['nmod' in val for val in event_clauses['x_' + str(key)].values()]):
final += event_clauses['x_' + str(key)][clause_key] + ' AND '
# removes AND from the end
final = final[:-5]
if final[-1] == ' ':
final = final[:-1]
return 'NA'
return final.replace('x_', ' x _ ').replace(')', ' )').replace(' ', ' ').replace(',', ' ,')
