Skip to content

Instantly share code, notes, and snippets.

@dorajam
Created August 23, 2021 15:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dorajam/15f2988681437cd9bc905be468f82200 to your computer and use it in GitHub Desktop.
Save dorajam/15f2988681437cd9bc905be468f82200 to your computer and use it in GitHub Desktop.
COGS reverse parser
import json
import os
import re
from collections import Counter
import numpy as np
import pandas as pd
import torch
import torchtext
from torch.nn.utils.rnn import pad_sequence as pad
class Preprocessor():
def __init__(self, data_dir='cogs/graph/', graph_layers=1):
self.data_dir = data_dir
self.node_vocab, self.edge_vocab = self.tokenize()
self.graph_layers = graph_layers
assert self.word_to_token('<pad>') == 1, '<pad> token needs to be mapped to 1'
assert self.word_to_token('null') == 0, 'Null token needs to be mapped to 0'
assert self.word_to_token('<pad>', 'edges') == 1, '<pad> token needs to be mapped to 1'
assert self.word_to_token('null', 'edges') == 0, 'Null token needs to be mapped to 0'
def tokenize(self):
with open(os.path.join(self.data_dir, 'nodes.txt'), 'r') as f:
nodes = [line.rstrip() for line in f]
# vocab will use null token:0, <pad>_token:1
node_vocab = torchtext.vocab.Vocab(Counter(nodes), specials=['null', '<pad>'])
with open(os.path.join(self.data_dir, 'edges.txt'), 'r') as f:
edges = [line.rstrip() for line in f]
# vocab will use null token:0, <pad>_token:1
edge_vocab = torchtext.vocab.Vocab(Counter(edges), specials=['null', '<pad>'])
return node_vocab, edge_vocab
def word_to_token(self, word, vocab='nodes'):
if vocab == 'nodes':
return self.node_vocab.stoi[word]
return self.edge_vocab.stoi[word]
def token_to_word(self, token, vocab='nodes'):
if vocab == 'nodes':
return self.node_vocab.itos[token]
return self.edge_vocab.itos[token]
def nodes_to_node_seq(self, nodes, in_seq):
"""
Takes nodes dict and turns it into a sequence of node tags for each input token.
nodes = {'x_1': node1, 'x_2': node2}
in_seq: ['some', input, sentence]
returns
- node token sequence (torch.LongTensor) ( corresponding to [null, node1, node2] )
"""
try:
assert isinstance(in_seq, list)
except:
in_seq = in_seq.split(' ')
node_sequence = torch.LongTensor(
[self.word_to_token('null') for _ in range(len(in_seq))])
for node, node_label in nodes.items():
node_sequence[int(node[2:])] = self.word_to_token(node_label)
return node_sequence
def edges_to_edge_seq(self, edges, edge_labels, in_seq):
"""
Takes edges and turns it into a sequence of set of edge tags for each input token.
nodes = {'x_0': word0, 'x_1': word1, 'x_2': Levi}
edges = [['x_1', 'x_0'], ['x_2', 'x_1']]
in_seq: [some, nice, input, sentence]
edge_labels = [theme, agent]
returns
torch tensor [[1,1,0,0], [1,1,0,0], [0,0,1,1], [0,0,1,0]]
"""
try:
assert isinstance(in_seq, list)
except:
in_seq = in_seq.split(' ')
N = len(in_seq)
# default edge type: null (no edge)
final_edges = torch.LongTensor([self.word_to_token('null', vocab='edges')] * N * N).view(N, N)
for edge, label in zip(edges, edge_labels):
e1, e2 = [int(e[2:]) for e in edge]
label_token = self.word_to_token(label, vocab='edges')
final_edges[e1, e2] = label_token
return final_edges
def pad_sequences(self, node_targets, edge_targets, device='cuda'):
padded_nodes = pad(node_targets, batch_first=True, padding_value=self.word_to_token('<pad>', vocab='nodes'))
max_len = padded_nodes.shape[-1]
res = []
for e in edge_targets:
tmp = torch.LongTensor([self.word_to_token('<pad>', vocab='edges')] * max_len * max_len).view(max_len,
max_len)
tmp[:e.shape[0], :e.shape[1]] = e
res.append(tmp)
padded_edges = torch.stack(res)
return padded_nodes.to(device), padded_edges.to(device)
def generate_target_sequences(self, data='train', file_name='_parsed.json'):
"""
Takes dataset in graph format, and parses it to a sequence of tokens
both for nodes and edges.
Returns:
- node token sequence (torch LongTensor)
- edge token sequence (torch LongTensor)
"""
file_path = os.path.join(self.data_dir, data + file_name)
graph = json.load(open(file_path))
graph_df = pd.DataFrame(graph)
graph_df['inp'] = graph_df['original_inp']
with open(f'./cogs/lambda/{data}.tsv') as f:
df = pd.read_csv(f, sep='\t', header=None)
df.columns = ['inp', 'out', 'type']
df = df.reset_index()
assert len(df) == len(graph_df), 'Graph output should exist for each input, but \
original is of len {len(df)}, while graph is of len {len(graph_df)}.'
joined_df = pd.merge(df, graph_df, on='inp', how='left')
# assert joined_df[
# joined_df.type == 'primitive'].size == 0, 'Primitives are not filtered. You need to do this first.'
joined_df['node_seq'] = joined_df.apply(
lambda x: self.nodes_to_node_seq(x['nodes'], x['inp']), axis=1)
joined_df['edge_seq'] = joined_df.apply(
lambda x: self.edges_to_edge_seq(x['edges'], x['edge_labels'], x['inp']), axis=1)
return joined_df.node_seq.values.tolist(), joined_df.edge_seq.values.tolist()
def split_and_sort_clauses(self, lambd):
return sorted(re.split(' AND | ; ', lambd))
def scores_to_graph(self, node_preds, edge_preds, node_targets, edge_targets):
""" returns the batch of nodes, edges and edge labels where all three include null nodes and null edges, but exclude pad predictions."""
assert node_preds.shape[0] == edge_preds.shape[0], \
f'Expected batch first, but got {node_preds.shape[0]}, {edge_preds.shape[0]}'
assert node_preds.shape == node_targets.shape
node_sets = []
for sequence, target_seq in zip(node_preds, node_targets): # BS, seq_len
nodes = {}
for idx, (token, target_token) in enumerate(zip(sequence, target_seq)): # seq_len
token = token.item()
node = self.token_to_word(token, vocab='nodes')
# if the target is a pad token, or if the prediction is pad, ignore it
if node == '<pad>' or target_token == self.word_to_token('<pad>', vocab='nodes'):
continue
else:
nodes[f'x_{idx}'] = node
node_sets.append(nodes)
edge_sets = []
edge_label_sets = []
for sequence, target_seq in zip(edge_preds, edge_targets): # BS, seq_len, seq_len
edges = []
edge_labels = []
for token_i in range(sequence.shape[0]):
for token_j in range(sequence.shape[0]):
predicted_token = sequence[token_i, token_j]
edge_label = self.token_to_word(predicted_token, vocab='edges')
true_token = target_seq[token_i, token_j]
true_label = self.token_to_word(true_token, vocab='edges')
if edge_label == '<pad>' or true_label == '<pad>':
continue
else:
edge = [f'x_{token_i}', f'x_{token_j}']
edges.append(edge)
edge_labels.append(edge_label)
edge_sets.append(edges)
edge_label_sets.append(edge_labels)
return node_sets, edge_sets, edge_label_sets
def graph_to_lambda(self, node_sets, edge_sets, edge_labels):
lambdas = []
for nodes, edges, labels in zip(node_sets, edge_sets, edge_labels):
# removes null nodes
nodes = {k: v for k, v in nodes.items() if nodes[k] != 'null'}
try:
# removes null edges
edges, labels = zip(*[(edge, label) for edge, label in zip(edges, labels) if label != 'null'])
lambdas.append(self._graph_to_lambda(nodes, edges, labels))
except:
# skip example if nodes or edges are both empty
lambdas.append('NA')
return lambdas
def _graph_to_lambda(self, nodes, edges, edge_labels):
"""
Turns a graph into a set of clauses serialized in the same way as COGS does.
"""
## hacky, but here we combine the * nodes into the next node's label
new = {}
definite = False
for pos, label in sorted(nodes.items(), key=lambda r: int(r[0][2:])):
if definite:
new[pos] = '* ' + label
definite = False
else:
if label == '*':
definite = True
else:
new[pos] = label
nodes = new
# containing * word clauses - always goes to the beginning
definite_clauses = {}
# containing rest of the clauses - ordered primarily by arg1 and by arg2, if available
event_clauses = {f'x_{idx}': {} for idx in range(40)}
# import ipdb;ipdb.set_trace()
for idx, items in enumerate(zip(edges, edge_labels)):
edge, label = items
# e.g. x _ 1, x _ 3
e1, e2 = edge
# retrieves node labels, except if they are null nodes
try:
node_label1 = nodes[e1]
node_label2 = nodes[e2]
except:
continue
# reformat to x_1, x_3
e1 = e1.replace(' ', '')
e2 = e2.replace(' ', '')
# if edge contains named entity, only produce one clause
if node_label1[0].isupper() and node_label2[0].isupper():
# print('Edge connects two named entities. Unexpected behavior.')
continue
if node_label1[0].isupper():
named_entity = node_label1
event_clauses[e2][e1] = node_label2 + ' . ' + label + ' (' + e2 + ', ' + named_entity + ') '
elif node_label2[0].isupper():
named_entity = node_label2
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + named_entity + ') '
# if not named entity
else:
# if * word appears, modify its name in the arg
if '*' in node_label1 and '*' in node_label2:
# print('Two definite articles connected. Unexpected behavior')
# import ipdb;ipdb.set_trace()
definite_clauses[e1] = node_label1 + ' (' + e1 + ')'
definite_clauses[e2] = node_label2 + ' (' + e2 + ')'
node_label1 = node_label1[2:]
# node_label2 = node_label2[2:]
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')'
elif '*' in node_label1:
definite_clauses[e1] = node_label1 + ' (' + e1 + ')'
node_label1 = node_label1[2:]
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')'
event_clauses[e2] = {'_': node_label2 + ' (' + e2 + ')'}
elif '*' in node_label2:
definite_clauses[e2] = node_label2 + ' (' + e2 + ')'
node_label2 = node_label2[2:]
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')'
else:
event_clauses[e1][e2] = node_label1 + ' . ' + label + ' (' + e1 + ', ' + e2 + ')'
event_clauses[e2] = {'_': node_label2 + ' (' + e2 + ')'}
# produces serialized output
final = ''
definite_clause_keys = sorted([int(k[2:]) for k in definite_clauses.keys()])
for key in definite_clause_keys:
final += definite_clauses['x_' + str(key)] + ' ; '
event_clause_keys = sorted([int(k[2:]) for k in event_clauses.keys()])
for key in event_clause_keys:
subkeys = sorted([int(k[2:]) if k != '_' else -1 for k in event_clauses['x_' + str(key)].keys()])
for clause_key in subkeys:
clause_key = 'x_' + str(clause_key) if clause_key != -1 else '_'
if clause_key == '_':
if len(event_clauses['x_' + str(key)]) > 1 and \
not np.any(['nmod' in val for val in event_clauses['x_' + str(key)].values()]):
continue
final += event_clauses['x_' + str(key)][clause_key] + ' AND '
# removes AND from the end
try:
final = final[:-5]
if final[-1] == ' ':
final = final[:-1]
except:
return 'NA'
return final.replace('x_', ' x _ ').replace(')', ' )').replace(' ', ' ').replace(',', ' ,')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment