-
-
Save MInner/4faa684a1d0d7eac6fafb9d6b08adfc3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
NB_NAME = 'path_lstm_clean_multigpu' | |
import os | |
import nltk | |
from nltk.corpus import treebank | |
import numpy as np | |
from math import ceil | |
#import matplotlib.pyplot as plt | |
#%matplotlib inline | |
import itertools | |
from collections import namedtuple, defaultdict | |
flatten = itertools.chain.from_iterable | |
class record: | |
def __init__(self, name, d): | |
if any(type(x) != str for x in d.keys()): | |
raise ValueError("Keys must be strings") | |
self.container = namedtuple(name, d.keys())(**d) | |
def __getattr__(self, name): | |
return getattr(self.container, name) | |
@classmethod | |
def from_local(cls, name, local, mapping): | |
if type(mapping) == list: | |
export_names = mapping | |
return record(name, {name:local[name] for name in export_names}) | |
elif type(mapping) is dict: | |
new_dict = {sub_name:record.from_local('sub', local, sub_map) | |
for sub_name, sub_map in mapping.items()} | |
return record(name, new_dict) | |
else: | |
raise ValueError("record mapping must be either dict() or list()") | |
def get(self, name): | |
return getattr(self.container, name) | |
def keys(self): | |
return self.container._asdict().keys() | |
def test_record(): | |
a = record('rec', {'a': 1, 'b': 2}) | |
assert a.a == 1 | |
assert a.get('b') == 2 | |
local = {'a': 2, 'b': 3} | |
mapping = ['a', 'b'] | |
b = record.from_local('rec', local, mapping) | |
assert b.a == 2 | |
assert b.get('b') == 3 | |
mapping = {'input': ['a'], 'output': ['b']} | |
c = record.from_local('rec', local, mapping) | |
assert c.input.a == 2 | |
assert c.get('output').get('b') == 3 | |
assert set(c.keys()) == set(['input', 'output']) | |
test_record() | |
PRIMARY_DATA_DIR = '/scratch2/dumps' | |
print(NB_NAME) | |
print(PRIMARY_DATA_DIR) | |
import pickle | |
DEFAULT_SHARD_SIZE = 100 | |
class PickleShardIterator: | |
""" | |
takes a generator over large list and turns it into sharded interator | |
to consume less memory and work faster then generator | |
""" | |
DEFAULT_DIR = None | |
def __init__(self, name, n_shards, pickle_dir=None): | |
partial_pickle_dir = self.get_partial_dir(pickle_dir) | |
self.full_pickle_dir = os.path.join(partial_pickle_dir, name) | |
self.n_shards = n_shards | |
self.n_current_shard = 0 | |
self.current_shard_iterator = None | |
def __iter__(self): | |
self.n_current_shard = 0 | |
self.current_shard_iterator = None | |
return self | |
def __next__(self): | |
if self.n_current_shard >= self.n_shards: | |
raise StopIteration | |
if self.current_shard_iterator == None: | |
shard_file_name = os.path.join(self.full_pickle_dir, str(self.n_current_shard)+'.pickle') | |
if os.path.isfile(shard_file_name): | |
with open(shard_file_name, 'rb') as f: | |
self.current_shard_iterator = iter(pickle.load(f)) | |
try: | |
return next(self.current_shard_iterator) | |
except StopIteration: | |
self.n_current_shard += 1 | |
self.current_shard_iterator = None | |
return self.__next__() | |
@classmethod | |
def set_default_dir(cls, folder): | |
print('Setting default dir to', folder) | |
cls.DEFAULT_DIR = folder | |
@classmethod | |
def get_partial_dir(cls, pickle_dir): | |
partial_pickle_dir = cls.DEFAULT_DIR or pickle_dir | |
if not partial_pickle_dir: | |
raise ValueError("No pickle dir specified, use" | |
"`PickleShardIterator.set_default_folder()`" | |
"or `pickle_dir=..` argument") | |
return partial_pickle_dir | |
@classmethod | |
def from_iterator(cls, iterator, name, shard_size=DEFAULT_SHARD_SIZE, size_func=len, pickle_dir=None): | |
partial_pickle_dir = cls.get_partial_dir(pickle_dir) | |
full_pickle_dir = os.path.join(partial_pickle_dir, name) | |
if not os.path.exists(full_pickle_dir): | |
os.makedirs(full_pickle_dir) | |
iterator_exhausted = False | |
n_current_shard = 0 | |
while not iterator_exhausted: | |
current_shard_to_pack = [] | |
while size_func(current_shard_to_pack) < shard_size: | |
try: | |
current_shard_to_pack.append(next(iterator)) | |
except StopIteration: | |
iterator_exhausted = True | |
break | |
shard_file_name = os.path.join(full_pickle_dir, str(n_current_shard)+'.pickle') | |
with open(shard_file_name, 'wb') as f: | |
pickle.dump(current_shard_to_pack, f) | |
n_current_shard += 1 | |
return PickleShardIterator(name, n_current_shard, partial_pickle_dir) | |
PickleShardIterator.set_default_dir(os.path.join(PRIMARY_DATA_DIR, 'tf_dumps/pickle_shard', NB_NAME)) | |
def pickle_load_cached_function(func, prefix=''): | |
dump_path = PRIMARY_DATA_DIR+'/tf_dumps/func_cache/' + NB_NAME + '/' | |
if not os.path.exists(dump_path): | |
os.makedirs(dump_path) | |
func_name = func.__name__ | |
file_name = dump_path+func_name+prefix+'.pickle' | |
if os.path.isfile(file_name): | |
print('loading', func_name ,'from', file_name) | |
try: | |
with open(file_name, 'rb') as f: | |
data = pickle.load(f) | |
return data | |
except: | |
print('something failed while fetching', func_name, 'running actual function') | |
pass # run actual function otherwise | |
else: | |
print('no dump found, running', func_name) | |
data = func() | |
with open(file_name, 'wb') as f: | |
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) | |
print('writing', func_name ,'to', file_name) | |
return data | |
## TESTING | |
def test_picke_sharder(): | |
range_iterator = range(1000) | |
sharded_iter = PickleShardIterator.from_iterator(iter(range_iterator), 'range_iter') | |
for i in range(5): | |
print(next(sharded_iter)) | |
for i, x in enumerate(sharded_iter): | |
if i % 300 == 0: | |
print(x) | |
## we want this guy to be pickleble too | |
with open('/tmp/test_range_sharded.picke', 'wb') as f: | |
pickle.dump(sharded_iter, f) | |
with open('/tmp/test_range_sharded.picke', 'rb') as f: | |
sharded_iter_loaded = pickle.load(f) | |
for i, x in enumerate(sharded_iter_loaded): | |
if i % 300 == 0: | |
print(x) | |
def test_pickle_load(): | |
pass | |
# test_picke_sharder() | |
import numpy as np | |
class MultiEdgeTreeNode: | |
""" Tree that has two types of edges: bottom - top-bottom, right - left-right""" | |
def __init__(self, content, tree_id, pos_id=None, print_mode=False): | |
self.top = None | |
self.bottom = None | |
self.right = None | |
self.left = None | |
self.parse_node = content | |
self.print_mode = print_mode # monkey patching :( | |
self.tree_id = tree_id | |
self.pos_id = pos_id | |
self.max_pos_over_subtree_cache = None | |
@property | |
def parent(self): | |
if self.is_first_in_row: | |
return self.top | |
else: | |
return self.left.parent | |
@property | |
def is_terminal(self): | |
return self.bottom == None | |
@property | |
def is_root(self): | |
return (self.right == None | |
and self.left == None | |
and self.top == None) | |
@property | |
def is_first_in_row(self): | |
return (self.left == None) | |
@property | |
def is_last_in_row(self): | |
return self.right == None | |
@staticmethod | |
def top_bottom_bound(a, b): | |
a.bottom = b | |
b.top = a | |
@staticmethod | |
def left_right_bound(a, b): | |
a.right = b | |
b.left = a | |
def pprint(self): | |
tmp_self_top = self.top | |
self._pprint(shift=0, top_str='ROOT') | |
@property | |
def content_label(self): | |
cont_label = self.parse_node.label() | |
if self.is_terminal: | |
cont_label += (' (%s)' % self.parse_node[0]) | |
return cont_label | |
@property | |
def full_content_label(self): | |
return (self.parse_node.label() | |
+ ' (%s)' % (' '.join(self.parse_node.leaves()))) | |
def _pprint(self, shift, top_str=None): | |
if self.is_first_in_row: | |
print('|'*shift, top_str or self.top.content_label, | |
' -down>- ', self.content_label) | |
if not self.is_terminal: | |
first_child = self.bottom | |
first_child._pprint(shift+1) | |
if not self.is_last_in_row: | |
print('|'*shift, self.content_label, | |
' -right>- ', self.right.content_label) | |
self.right._pprint(shift) | |
def __repr__(self): | |
if not self.print_mode: | |
return "<MENode %s>" % self.content_label | |
else: | |
return "<MENode %s>" % repr(self.parse_node) | |
## for word feature column | |
@property | |
def sentence_pos_id(self): | |
if self.pos_id: | |
return self.pos_id | |
if self.is_first_in_row: | |
if self.is_root: | |
return 0 | |
else: | |
return self.top.sentence_pos_id | |
else: | |
return self.left.max_pos_over_subtree + 1 | |
@property | |
def max_pos_over_subtree(self): | |
if self.max_pos_over_subtree_cache: | |
return self.max_pos_over_subtree_cache | |
if self.is_terminal: | |
## weird lost nodes (rare, like, "NP with removed *-2 under it") | |
cur_id = self.pos_id if self.pos_id else self.parent.sentence_pos_id | |
else: | |
last_child = self.bottom | |
while not last_child.is_last_in_row: | |
last_child = last_child.right | |
cur_id = last_child.max_pos_over_subtree | |
self.max_pos_over_subtree_cache = cur_id | |
return self.max_pos_over_subtree_cache | |
def find_leaf_pos(tree, leaf): | |
for pos in tree.treepositions(): | |
if id(tree[pos]) == id(leaf): | |
tp = pos | |
break | |
for i in range(len(tree.leaves())): | |
if tree.leaf_treeposition(i) == (*tp, 0): | |
return i | |
def find_aligned_leaf_pos(tree, leaf, word_seq): | |
## caching in required, a lot of repeated work otherwise | |
original_pos = find_leaf_pos(tree, leaf) | |
alignment = {} | |
j = 0 | |
for i, word in enumerate(tree.leaves()): | |
if word_seq[j] == word: | |
alignment[i] = j | |
j += 1 | |
if j == len(word_seq): | |
break | |
return alignment[original_pos] | |
## PARSE -> ME | |
def ptb_label_strip(labl): | |
if '=' in labl: | |
return labl.split('=')[0] | |
if '-' in labl: | |
split = labl.split('-') | |
if split[-1].isdigit(): | |
return '-'.join(split[:-1]) | |
return labl | |
def words_from_tree(tree): | |
return [word for word in tree.leaves() if word[0] != '*'] | |
def parse_tree_into_me_tree(node, tree_id, top_tree=None): | |
if not top_tree: | |
top_tree = node | |
node.set_label(ptb_label_strip(node.label())) | |
node_me = MultiEdgeTreeNode(node, tree_id=tree_id) | |
if type(node[0]) != str and node[0].label() != '-NONE-': # not terminal | |
first_child = node[0] | |
first_node_me = parse_tree_into_me_tree(first_child, tree_id, top_tree) | |
MultiEdgeTreeNode.top_bottom_bound(node_me, first_node_me) | |
prev_child_me_node = first_node_me | |
for child in node[1:]: | |
if child.label() == '-NONE-': | |
continue | |
new_node_me = parse_tree_into_me_tree(child, tree_id, top_tree) | |
MultiEdgeTreeNode.left_right_bound(prev_child_me_node, new_node_me) | |
prev_child_me_node = new_node_me | |
if type(node[0]) == str: | |
node_me.pos_id = find_aligned_leaf_pos(top_tree, node, words_from_tree(top_tree)) | |
return node_me | |
## BRANCHERS | |
def me_tree_into_topdown_branches(node_me): | |
sub_branches = [] | |
if not node_me.is_terminal: | |
for under_sub_branch in me_tree_into_topdown_branches(node_me.bottom): | |
sub_branches.append( [(node_me, 'bottom')]+under_sub_branch ) | |
if not node_me.is_last_in_row: | |
for after_sub_branch in me_tree_into_topdown_branches(node_me.right): | |
sub_branches.append( [(node_me, 'right')]+after_sub_branch ) | |
if node_me.is_last_in_row and node_me.is_terminal: | |
## must be STOP here, but bottom for simplicity | |
sub_branches = [[(node_me, 'stop')]] | |
return sub_branches | |
def me_tree_into_positive_branches(root, n_allowed_forward_span=1, n_allowed_duplicates=0): | |
## [ of braches [.. (node_me, command) ..] ] | |
root_branches = me_tree_into_topdown_branches(root) | |
me_leave_s = [root_branch[-1][0] for root_branch in root_branches] | |
all_terminal_s = [root] + me_leave_s | |
all_leaves_branches = [root_branches] | |
for me_leave in me_leave_s: | |
climb_up_node = me_leave | |
leave_branches = [] | |
prefix = [] | |
while not climb_up_node.is_root: | |
if climb_up_node.is_first_in_row: | |
to_branch_from = climb_up_node.top | |
prefix += [(climb_up_node, 'top')] | |
else: | |
to_branch_from = climb_up_node.left | |
prefix += [(climb_up_node, 'left')] | |
for from_branch in me_tree_into_topdown_branches(to_branch_from): | |
# don't want to just go back, like | |
# node1 -> parent -> .. -> node1 | |
if from_branch[-1][0] != me_leave: | |
leave_branches.append( prefix + from_branch ) | |
climb_up_node = to_branch_from | |
all_leaves_branches.append(leave_branches) | |
all_positive_time_branches_dict = defaultdict(list) | |
for leave_branches in all_leaves_branches: | |
for branch in leave_branches: | |
fr, to = branch[0][0], branch[-1][0] | |
## span < n_allowed_span or it starts in S (root) | |
if ( all_terminal_s.index(fr) - all_terminal_s.index(to) <= n_allowed_forward_span | |
or fr == root): | |
nodes_list = [x[0] for x in branch] | |
if len(nodes_list) - len(set(nodes_list)) > n_allowed_duplicates: | |
continue | |
all_positive_time_branches_dict[to].append(branch) | |
return all_leaves_branches, all_positive_time_branches_dict | |
def reverse_branch(branch): | |
nodes, cmds = zip(*branch) | |
nodes_new = nodes[::-1] | |
pairs = [('top', 'bottom'), ('right', 'left')] | |
cmd_inv_dict = dict(pairs + [pair[::-1] for pair in pairs]) | |
cmds_new = [cmd_inv_dict[cmd] for cmd in cmds[:-1][::-1]] + ['stop'] | |
return zip(nodes_new, cmds_new) | |
def me_tree_into_positive_negative_branches(me_tree, n_allowed_forward_span, | |
n_allowed_duplicates): | |
kw = {'n_allowed_forward_span':n_allowed_forward_span, | |
'n_allowed_duplicates':n_allowed_duplicates} | |
_, all_positive_branches_dict = me_tree_into_positive_branches(me_tree, **kw) | |
tree_branch_list_p1 = list(flatten(all_positive_branches_dict.values())) | |
tree_branch_iter = [*[reverse_branch(br) for br in tree_branch_list_p1], | |
*tree_branch_list_p1] | |
return tree_branch_iter | |
def me_tree_to_bfs(me_node): | |
current_sub_dfs = [] | |
if not me_node.is_terminal: | |
current_sub_dfs.append((me_node, 'bottom')) | |
current_sub_dfs.extend( me_tree_to_bfs(me_node.bottom) ) | |
if not me_node.is_last_in_row: | |
current_sub_dfs.append((me_node, 'right')) | |
child_sub_dfs = me_tree_to_bfs(me_node.right) | |
child_sub_dfs_fixed = child_sub_dfs[:-1] + [(child_sub_dfs[-1][0], 'left')] | |
current_sub_dfs.extend( child_sub_dfs_fixed ) | |
current_sub_dfs.append((me_node, 'top')) | |
return current_sub_dfs | |
def main_tree_s_into_branches(trees, mode, **args): | |
""" | |
modes: | |
't' topdown | |
'tp' topdown + positive [n_allowed_forward_span=?, n_allowed_duplicates=?] | |
'tpr' topdown + positive + reverse positive | |
'tpb' topdown + positive + bfs | |
'tprb' .. (all) | |
""" | |
if mode not in ['t', 'tp', 'tpr', 'tpb', 'tprb']: | |
raise ValueError("Mode not supported", mode) | |
tree_s_map = lambda f, trees: [f(parse_tree_into_me_tree(x, i)) for i, x in enumerate(trees)] | |
if 'p' in mode: | |
argv = [args['n_allowed_forward_span'], args['n_allowed_duplicates']] | |
if mode == 't': | |
brs = tree_s_map(tree_into_branches, trees) | |
elif mode.startswith('tpr'): | |
map_func = lambda x: me_tree_into_positive_negative_branches(x, *argv) | |
brs = tree_s_map(map_func, trees) | |
elif mode.startswith('tp'): | |
map_func = lambda x: me_tree_into_positive_branches(x, *argv) | |
bunch = tree_s_map(map_func, trees) | |
brs = [list(flatten(x[1].values())) for x in bunch] | |
if 'b' in mode: | |
brs = [tree_brs + [me_tree_to_bfs(parse_tree_into_me_tree(tree, i))] | |
for i, (tree_brs, tree) in enumerate(zip(brs, trees))] | |
return brs | |
## /BRANCHERS | |
def branches_into_feature_lists(tree_s_branches): | |
feature_map_funcs = [ | |
lambda x: x.is_first_in_row, lambda x: x.is_last_in_row, | |
lambda x: x.is_terminal, lambda x: x.parse_node.label(), | |
lambda x: (x.parse_node[0] # get token | |
if type(x.parse_node[0]) is str | |
else '<%s>'%x.parse_node.label()), | |
lambda x: x.sentence_pos_id, | |
lambda x: x.tree_id, | |
# + command few lines above | |
] | |
features_lists = [] | |
for tree_branches in tree_s_branches: | |
branch_lists = [] | |
for branch in tree_branches: | |
branch_m = [[f(node) for f in feature_map_funcs] + [command,] | |
for node, command in branch] | |
branch_lists.append(branch_m) | |
features_lists.append( branch_lists ) | |
return features_lists | |
def columns(ll, idxs): | |
return list(zip(*ll))[idxs] | |
def build_feature_map(data_list, start_index = 1): | |
ls = list(set(data_list)) | |
feature_map = dict(zip(ls, range(start_index, len(ls)+start_index))) # boy: 1, toy: 2 | |
backward_map = dict(zip(range(start_index, len(ls)+start_index), ls)) # 1: boy, 2: toy | |
return {'forward': feature_map, 'backward': backward_map} | |
def onehot_list(n, size): | |
if n >= size: | |
raise ValueError("Can't embed %d into vec of len %d" % (n, size)) | |
return [n == i for i in range(size)] | |
def features_lists_to_data_matrixes(features_lists, maps, use_word_flag): | |
sentence_matrices = [] | |
for tree_features in features_lists: | |
bool_features = columns(tree_features, slice(0, 3)) | |
tree_pos_features = columns(tree_features, slice(5, 7)) | |
label_ids = [maps['label']['forward'][x] for x in columns(tree_features, 3)] | |
token_ids = [maps['token']['forward'][x] for x in columns(tree_features, 4)] | |
cmd_ids = [maps['cmd']['forward'].get(x, -1) for x in columns(tree_features, 7)] | |
cmd_n = len(maps['cmd']['forward']) # ! 4 not 2 | |
cmd_features = zip(*[onehot_list(cmd_id-1, size=cmd_n) for cmd_id in cmd_ids]) | |
to_merge_pre = [*bool_features, *cmd_features, label_ids, token_ids] | |
to_merge = [*to_merge_pre, *tree_pos_features] if use_word_flag else to_merge_pre | |
feature_matrix = np.array(to_merge, dtype='int32') | |
sentence_matrices.append(feature_matrix.T) | |
return sentence_matrices | |
## PERTURBATIONS AND GENERATOR | |
from math import ceil | |
def extract_ij_data_point(i, j, sentence_matrixes, word_matrixes, | |
max_row_len, use_word_flag, stub_word_id): | |
## i - number of path within ALL paths, j - number of position | |
x_feature_slice = slice(0, -2) if use_word_flag else slice(None) | |
low_bound = lambda j: max(0, j-max_row_len) | |
X_single = sentence_matrixes[i][low_bound(j):j, x_feature_slice] | |
y_single = sentence_matrixes[i][j, [0, 1, 2, 7, 8]] | |
if not use_word_flag: | |
return X_single, y_single | |
## which tree i-th path corresponds to, which tenrinal position | |
tree_id = sentence_matrixes[i][j, -1] | |
target_pos_id = sentence_matrixes[i][j, -2] | |
relevant_word_ids = word_matrixes[tree_id][:target_pos_id] # depends on mode | |
relevant_word_ids_trimmed = relevant_word_ids[-max_row_len:] | |
word_seq_len = len(relevant_word_ids_trimmed) | |
if word_seq_len == 0: | |
relevant_word_ids_trimmed = [stub_word_id] | |
word_seq_len = 1 | |
max_hights = max(X_single.shape[0], word_seq_len) | |
X_final = np.zeros((max_hights, sentence_matrixes[i].shape[1])) | |
X_final[:X_single.shape[0], :X_single.shape[1]] = X_single | |
X_final[:word_seq_len, -2] = np.ones((word_seq_len, )) | |
X_final[:word_seq_len, -1] = np.array(relevant_word_ids_trimmed) | |
return X_final, y_single | |
def pad_stack_with_mask(to_stack, mask=True, max_row_len=None): | |
max_n_rows = max_row_len or max([x.shape[0] for x in to_stack]) | |
new_to_stack = [] | |
for m in to_stack: | |
pad_shape = list(m.shape) | |
pad_shape[0] = max_n_rows-pad_shape[0] | |
padded_m = np.concatenate([m, np.zeros(pad_shape)], axis=0) | |
if mask: | |
mask_column = np.zeros((max_n_rows, 1)) | |
mask_column[:m.shape[0]] = 1 | |
padded_m = np.concatenate([mask_column, padded_m], axis=1) | |
new_to_stack.append( padded_m ) | |
return np.stack(new_to_stack) | |
def feature_mxs_to_dataset_generator(sentence_matrices, word_matrixes, batch_size, | |
max_row_len, data_permutation, use_word_flag, | |
stub_word_id): | |
n_batches = int(len(data_permutation)/batch_size) | |
argv = { | |
'sentence_matrixes': sentence_matrices, | |
'word_matrixes': word_matrixes, | |
'max_row_len': max_row_len, | |
'use_word_flag': use_word_flag, | |
'stub_word_id': stub_word_id, | |
} | |
for ij_sub in np.array_split(data_permutation, n_batches): | |
X_batch, y_batch = zip(*[extract_ij_data_point(i, j, **argv) for i, j in ij_sub]) | |
X_s_tensor = pad_stack_with_mask(X_batch, max_row_len=max_row_len).astype('int32') | |
y_s_matrix = np.array(y_batch, dtype='int32') | |
yield {'X': X_s_tensor, 'y': y_s_matrix} | |
# from tqdm import tqdm_notebook as tq | |
from tqdm import tqdm as tq | |
def read_ptb_trees(dummy=False): | |
from nltk.corpus import ptb | |
data_sections_ranges = [ | |
('train', range(2, 23) if not dummy else [2]), | |
('valid', [24]), | |
('test', [23]) | |
] | |
data_sections_dict = {} | |
for data_name, data_range in tq(data_sections_ranges, desc='ptb_read'): | |
file_names = filter(lambda x: int(x[4:6]) in data_range, ptb.fileids()) | |
trees = [tree for fn in file_names for tree in ptb.parsed_sents(fn)] | |
data_sections_dict[data_name] = trees | |
return data_sections_dict | |
def build_ptb_datatset(batch_size=1000, | |
n_allowed_forward_span=1, n_allowed_duplicates=0, | |
use_word_input=False, dummy_dataset=False): | |
np.random.seed(1) | |
tree_dicts_n = 3 | |
if type(batch_size) is int: | |
batch_size = [batch_size]*tree_dicts_n | |
def compute_tree_flat_branches_dict_and_maps(): | |
def read_trees(): | |
return read_ptb_trees(dummy_dataset) | |
tree_dicts = pickle_load_cached_function(read_trees) | |
tree_flat_branches_dict = {} | |
token_set, label_set, cmd_set = set(), set(), set() | |
for data_name, trees in tq(tree_dicts.items(), desc='branch'): | |
argv_branching = {'n_allowed_forward_span': n_allowed_forward_span, | |
'n_allowed_duplicates':n_allowed_duplicates} | |
branches = main_tree_s_into_branches(trees, mode='tprb', **argv_branching) | |
features_lists = branches_into_feature_lists(branches) | |
flat_feature_lists = list(itertools.chain.from_iterable(features_lists)) | |
tree_flat_branches_dict[data_name] = flat_feature_lists | |
label_set |= set([label for tree_features in flat_feature_lists | |
for label in columns(tree_features, 3)]) | |
token_set |= set([token for tree_features in flat_feature_lists | |
for token in columns(tree_features, 4)]) | |
cmd_set |= set([cmd for tree_features in flat_feature_lists | |
for cmd in columns(tree_features, 7)]) | |
cmd_set.remove('stop') # it should _not_ be in the map; never indexed | |
maps = {'token': build_feature_map(token_set), | |
'label': build_feature_map(label_set), | |
'cmd': build_feature_map(cmd_set)} | |
print('maps sizes: ', {name:len(d['forward']) for name, d in maps.items()}) | |
return tree_flat_branches_dict, maps | |
def compute_sentence_matrixes_and_word_matrix_dicts(): | |
tree_flat_branches_dict, maps = pickle_load_cached_function(compute_tree_flat_branches_dict_and_maps) | |
def build_sentence_matrixes(): | |
sentence_matrixes_dicts = {} | |
for data_name, flat_branches in tq(tree_flat_branches_dict.items(), desc='sent_br_matrix'): | |
argkw = {'features_lists':flat_branches, 'maps': maps, 'use_word_flag': use_word_input} | |
sentence_matrixes_dicts[data_name] = features_lists_to_data_matrixes(**argkw) | |
return sentence_matrixes_dicts | |
sentence_matrixes_dicts = pickle_load_cached_function(build_sentence_matrixes) | |
# tree-word matrixes generation below | |
def tree_word_matrixes(): | |
tree_dicts = pickle_load_cached_function(read_ptb_trees) | |
word_maxtrixes_dicts = {} | |
for data_name, trees in tq(tree_dicts.items(), desc='word_matrix'): | |
token_id_lists = [[maps['token']['forward'][word] | |
for word in words_from_tree(tree) | |
if word in maps['token']['forward']] | |
for tree in trees] | |
word_maxtrixes_dicts[data_name] = token_id_lists | |
return word_maxtrixes_dicts | |
word_maxtrixes_dicts = pickle_load_cached_function(tree_word_matrixes) | |
return sentence_matrixes_dicts, word_maxtrixes_dicts, maps | |
max_row_len = 100 | |
ret = pickle_load_cached_function(compute_sentence_matrixes_and_word_matrix_dicts) | |
sentence_matrixes_dicts, word_maxtrixes_dicts, maps = ret | |
def compute_lens_and_permutations(): | |
data_length = {} | |
data_permutation = {} | |
for i, (data_name, sml) in tq(enumerate(sentence_matrixes_dicts.items()), | |
total=3, desc='permutations'): | |
total_n = sum([len(sml[i]) for i in range(len(sml))]) | |
data_length[data_name] = int(total_n / batch_size[i]) | |
# pairs of learning point: (i_matrix, j_row) -> y_k | |
sent_id_line_id = [(i, j) for i in range(len(sml)) | |
for j in range(1, len(sml[i]))] | |
np.random.seed(1) | |
# np.random.shuffle(sent_id_line_id) | |
data_permutation[data_name] = sent_id_line_id | |
return data_length, data_permutation | |
data_length, data_permutation = pickle_load_cached_function(compute_lens_and_permutations) | |
def build_sharded_dataset_iterator(): | |
dataset = {} | |
for i, data_name in enumerate(sentence_matrixes_dicts.keys()): | |
feature_mxs_args = { | |
'sentence_matrices': sentence_matrixes_dicts[data_name], | |
'word_matrixes': word_maxtrixes_dicts[data_name], | |
'batch_size': batch_size[i], 'max_row_len': max_row_len, | |
'data_permutation': data_permutation[data_name], | |
'use_word_flag': use_word_input, | |
'stub_word_id': maps['token']['forward']['<S>'], | |
} | |
dataset_iterator = feature_mxs_to_dataset_generator(**feature_mxs_args) | |
shard_argkw = {'name': data_name+'_dataset', 'shard_size': 1000} | |
sharded_data_iterator = PickleShardIterator.from_iterator(dataset_iterator, **shard_argkw) | |
dataset[data_name] = sharded_data_iterator | |
return dataset | |
sharded_dataset_iterator = pickle_load_cached_function(build_sharded_dataset_iterator) | |
## dataset is {train: [{X:.., y:..}, {X:.., y:..}, ..], test:..} | |
specs = {'seq_len': max_row_len, | |
'token_voc_size': len(maps['token']['forward']), | |
'label_voc_size': len(maps['label']['forward']), | |
'data_length': data_length, # dict(n_batches) | |
} | |
return sharded_dataset_iterator, specs, maps | |
def build_main_data(): | |
dataset_args = { | |
'batch_size': 1000, | |
'n_allowed_forward_span': 2, | |
'n_allowed_duplicates': 1, | |
'use_word_input': True, | |
'dummy_dataset': False, | |
} | |
sharded_dataset_iterator, specs, maps = build_ptb_datatset(**dataset_args) | |
main_data = {'dataset_iterator': sharded_dataset_iterator, 'specs': specs, 'maps': maps} | |
return main_data | |
main_data = pickle_load_cached_function(build_main_data) | |
## MODEL DEFINITION | |
import tensorflow as tf | |
from math import log10, ceil | |
def last_relevant(output, length): | |
batch_size = tf.shape(output)[0] | |
max_length = int(output.get_shape()[1]) | |
output_size = int(output.get_shape()[2]) | |
index = tf.range(0, batch_size) * max_length + (length - 1) | |
flat = tf.reshape(output, [-1, output_size]) | |
relevant = tf.gather(flat, index) | |
return relevant | |
def build_model(config): | |
X = tf.placeholder(tf.int32, [None, config['seq_len'], 12], name='X') | |
y = tf.placeholder(tf.int32, [None, 5], name='y') | |
use_dropout = tf.placeholder_with_default(tf.constant(1.0), [], name='use_dropout') | |
dropout_rate_const = tf.constant(config['dropout_keep_rate']) | |
dropout_keep_rate = dropout_rate_const*use_dropout + (1-use_dropout) | |
# X_mask = X[:, :, 0] # [batch_n, seq_len] # broken | |
X_bools = X[:, :, 1:4] # [batch_n, seq_len, 3] | |
X_cmd = X[:, :, 4:8] # [batch_n, seq_len, 4] | |
X_label = X[:, :, 8] # [batch_n, seq_len] | |
X_token = X[:, :, 9] # [batch_n, seq_len] | |
X_word = X[:, :, 11] | |
y_bool = y[:, 0:3] # [batch_n, 3] | |
y_label = y[:, 3] # [batch_n, ] | |
y_token = y[:, 4] # [batch_n, ] | |
batch_size = tf.shape(X)[0] | |
terminal_mask = tf.cast(y[:, 2], tf.float32) | |
data_lengths = tf.reduce_sum(tf.cast(tf.not_equal(X_token, 0), tf.int32), -1) | |
word_data_len = tf.reduce_sum(X[:, :, 10], -1) | |
# [batch_n, seq_len, him_dim] | |
with tf.variable_scope("embeddings", initializer=tf.contrib.layers.xavier_initializer()): | |
E_labl = tf.get_variable('E_labl', (config['label_voc_size'], config['label_hid_dim'])) | |
X_embedded_labels = tf.nn.embedding_lookup(E_labl, X_label-1) | |
E_tokn = tf.get_variable('E_tokn', (config['token_voc_size'], config['token_hid_dim'])) | |
X_embedded_tokens = tf.nn.embedding_lookup(E_tokn, X_token-1) | |
X_embedded_words = tf.nn.embedding_lookup(E_tokn, X_word-1) | |
X_path_total = tf.concat(2, [tf.cast(X_bools, tf.float32), | |
tf.cast(X_cmd, tf.float32), | |
X_embedded_labels, X_embedded_tokens]) | |
X_word_total = X_embedded_words | |
with tf.variable_scope('path_rnn'): | |
output_path, _ = tf.nn.dynamic_rnn( | |
tf.nn.rnn_cell.DropoutWrapper( | |
tf.nn.rnn_cell.GRUCell(config['hid_dim']), | |
input_keep_prob=dropout_keep_rate, | |
output_keep_prob=dropout_keep_rate | |
), | |
inputs=X_path_total, | |
dtype=tf.float32, | |
sequence_length=data_lengths, | |
) | |
with tf.variable_scope('word_rnn'): | |
output_word, _ = tf.nn.dynamic_rnn( | |
tf.nn.rnn_cell.DropoutWrapper( | |
tf.nn.rnn_cell.GRUCell(config['word_hid_dim']), | |
input_keep_prob=dropout_keep_rate, | |
output_keep_prob=dropout_keep_rate | |
), | |
inputs=X_word_total, | |
dtype=tf.float32, | |
sequence_length=word_data_len, | |
) | |
path_h = tf.tanh(last_relevant(output_path, data_lengths)) # [batch_size, hid_dim] | |
word_h = tf.tanh(last_relevant(output_word, word_data_len)) # [batch_size, word_hid_dim] | |
total_h = tf.concat(1, [path_h, word_h]) # [batch_size, hid_dim + word_hid_dim] | |
## rnn output -> logits | |
fcc = tf.contrib.layers.fully_connected | |
argkw = {'inputs': total_h, 'activation_fn': tf.tanh} | |
repr_bool = fcc(num_outputs=3, scope='repr2bool', **argkw) | |
repr_labl = fcc(num_outputs=config['label_hid_dim'], scope='repr2labl', **argkw) | |
repr_tokn = fcc(num_outputs=config['token_hid_dim'], scope='repr2tokn', **argkw) | |
with tf.variable_scope("embeddings", reuse=True): | |
logits_labl = tf.matmul(repr_labl, tf.transpose(tf.get_variable('E_labl'))) | |
logits_tokn = tf.matmul(repr_tokn, tf.transpose(tf.get_variable('E_tokn'))) | |
## logits -> losses | |
_y = tf.cast(y_bool, tf.float32) | |
cr_ent_bool = tf.nn.sigmoid_cross_entropy_with_logits(repr_bool, _y) | |
loss_bool = tf.reduce_mean(cr_ent_bool) | |
cr_ent_labl = tf.nn.sparse_softmax_cross_entropy_with_logits(logits_labl, y_label-1) | |
loss_labl = tf.reduce_mean( cr_ent_labl ) | |
cr_ent_tokn = tf.nn.sparse_softmax_cross_entropy_with_logits(logits_tokn, y_token-1) | |
loss_tokn = tf.reduce_mean( cr_ent_tokn ) | |
term_mask_normalizer = tf.reduce_sum(terminal_mask) | |
loss_tokn_term = tf.reduce_sum( cr_ent_tokn * terminal_mask ) / term_mask_normalizer | |
## loss | |
loss = loss_bool + loss_labl + loss_tokn_term | |
## errs | |
pred_bool = tf.cast(tf.greater(repr_bool, 0.5), tf.int32) | |
err_bool_mean = tf.reduce_mean(tf.cast(tf.not_equal(pred_bool, y_bool), tf.float32)) | |
pred_labl = tf.cast(tf.argmax(logits_labl, 1), tf.int32) | |
err_labl_mean = tf.reduce_mean(tf.cast(tf.not_equal(pred_labl, y_label-1), tf.float32)) | |
pred_tokn = tf.cast(tf.argmax(logits_tokn, 1), tf.int32) | |
err_tokn = tf.cast(tf.not_equal(pred_tokn, y_token-1), tf.float32) | |
err_tokn_mean = tf.reduce_mean(err_tokn) | |
err_tokn_term = tf.reduce_sum(tf.cast(terminal_mask, tf.float32)*err_tokn) | |
err_term_mean = err_tokn_term / tf.reduce_sum(tf.cast(terminal_mask, tf.float32)) | |
prob_bool = tf.sigmoid(repr_bool) | |
prob_labl = tf.nn.softmax(logits_labl) | |
prob_tokn = tf.nn.softmax(logits_tokn) | |
export_map = { | |
'inputs': ['X', 'y', 'use_dropout'], | |
'outputs': [ | |
'pred_bool', 'pred_labl', 'pred_tokn', | |
'prob_bool', 'prob_labl', 'prob_tokn', | |
], | |
'stat': [ | |
'loss', 'loss_bool', 'loss_tokn_term', | |
'err_bool_mean', 'err_labl_mean', | |
'err_tokn_mean', 'err_term_mean', | |
], | |
'other': [ | |
'cr_ent_tokn', 'terminal_mask', | |
'logits_labl' | |
] | |
} | |
return record.from_local('Model', locals(), export_map) | |
### SAVE ROUTINE | |
from contextlib import contextmanager | |
def recursive_dict_update(d, u, depth=-1): | |
""" | |
Recursively merge or update dict-like objects. | |
>>> update({'k1': {'k2': 2}}, {'k1': {'k2': {'k3': 3}}, 'k4': 4}) | |
{'k1': {'k2': {'k3': 3}}, 'k4': 4} | |
""" | |
for k, v in u.items(): | |
if isinstance(v, type({})) and not depth == 0: | |
r = recursive_dict_update(d.get(k, {}), v, depth=max(depth - 1, -1)) | |
d[k] = r | |
elif isinstance(d, type({})): | |
d[k] = u[k] | |
else: | |
d = {k: u[k]} | |
return d | |
@contextmanager | |
def run_loaded_model(build_model_func, model_name, *, prefix='', full_path_override=None, config_override=None, verbose=True): | |
dir_name = full_path_override or (PRIMARY_DATA_DIR+"/tf_dumps/models/"+NB_NAME) | |
full_model_fn = "%s/%s_%s_model.ckpt" % (dir_name, model_name, prefix) | |
full_config_fn = "%s/%s_%s_config.pickle" % (dir_name, model_name, prefix) | |
if verbose: | |
print('loading model from ', full_model_fn) | |
if os.path.isfile(full_config_fn): | |
with open(full_config_fn, 'rb') as handle: | |
config = pickle.load(handle) | |
else: | |
raise IOError("Can't find config file at ", full_config_fn) | |
if config_override is not None: | |
config = recursive_dict_update(config, config_override) | |
graph = tf.Graph() | |
with graph.as_default(), tf.device('/cpu:0'): | |
gd = tf.train.AdamOptimizer(config['model']['learning_rate']) | |
with tf.device(config['exec']['device_id']): | |
seed = config['exec']['seed'] if 'seed' in config['exec'] else 1 | |
tf.set_random_seed(seed) | |
model = build_model_func({**config['model'], **config['data']['specs']}) | |
train_op = gd.minimize(model.stat.loss) | |
# tf.reset_default_graph() | |
# with tf.device(config['exec']['device_id']): | |
# seed = config['exec']['seed'] if 'seed' in config['exec'] else 1 | |
# tf.set_random_seed(seed) | |
# model = build_model_func({**config['model'], **config['data']['specs']}) | |
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7, allow_growth=True) | |
sess_cfg = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) | |
with graph.as_default(), tf.Session(config=sess_cfg) as s: | |
tf.train.Saver().restore(s, full_model_fn) | |
yield s, model, config | |
def recursive_dict_make_pickleable(d, replace_by=None): | |
## dummy here! only replaces specific 'bad' field | |
d_copy = d.copy() | |
d_copy['eval']['score_postprocs'] = replace_by | |
return d_copy | |
def save_model_full(session, config, prefix=''): | |
model_name = config['exec']['model_name'] | |
dir_name = PRIMARY_DATA_DIR+"/tf_dumps/models/" + NB_NAME | |
full_model_fn = "%s/%s_%s_model.ckpt" % (dir_name, model_name, prefix) | |
if not os.path.exists(dir_name): | |
os.makedirs(dir_name) | |
tf.train.Saver().save(session, full_model_fn) | |
config_picklable = recursive_dict_make_pickleable(config) | |
with open("%s/%s_%s_config.pickle" % (dir_name, model_name, prefix), 'wb') as f: | |
pickle.dump(config_picklable, f) | |
## PRINT AND SCORE ROUTINE | |
class TablePrinter: | |
def __init__(self, print_every_iter, print_every_epoch, column_width=13, n_epoch=100, print_during=None): | |
self.print_every_iter = print_every_iter | |
self.print_every_epoch = print_every_epoch | |
self.column_width = column_width | |
self.fmt = "%{0}.{1}g".format(self.column_width, self.column_width-6) | |
epoch_digits = ceil(log10(n_epoch)) | |
self.epoch_fmt = '%'+str(epoch_digits)+'s %5s' | |
self.print_during = print_during or ['train'] | |
def float_list_fmt(self, float_list): | |
pattern = self.fmt*len(float_list) | |
return pattern % tuple(float(x) for x in float_list) | |
def headers(self, field_names): | |
if self.print_every_epoch == 0: | |
return '' | |
short_field_names = [fname[:self.column_width-2] for fname in field_names] | |
field_name_fmt = ('%{}s'.format(self.column_width))*len(field_names) | |
pstr = (self.epoch_fmt % ('', 'data')) + field_name_fmt%tuple(short_field_names) | |
return pstr | |
def iter_condition(self, epoch_i, iter_i, data_name): | |
return (self.print_every_iter > 0 and iter_i % self.print_every_iter == 0 | |
and data_name in self.print_during) | |
def iter_print(self, epoch_i, mean_scores, data_name): | |
print(self.epoch_fmt % (str(epoch_i), data_name), end='') | |
print(self.float_list_fmt(mean_scores)) | |
def epoch_print(self, epoch_i, mean_scores, data_name, new_best_flag): | |
if self.print_every_epoch > 0 and epoch_i % self.print_every_epoch == 0: | |
print(self.epoch_fmt % (str(epoch_i), data_name), end='') | |
print(self.float_list_fmt(mean_scores), end='') | |
if data_name == 'valid' and new_best_flag: | |
print(' # <-!', end='') | |
print() | |
class ScoreProcessingHandler: | |
def __init__(self, model, train_op, config): | |
self.config = config | |
step_funcs = [getattr(model.stat, fn) for fn in config['eval']['step_funcs']] | |
score_funcs = [getattr(model.stat, fn) for fn in config['eval']['score_funcs']] | |
self.data_split_funcs = [ | |
('train', [train_op] + step_funcs + score_funcs), | |
('valid', score_funcs), | |
('test', score_funcs), | |
] | |
postproc_names = list(zip(*config['eval']['score_postprocs']))[0] | |
self.score_postprocessors = list(zip(*config['eval']['score_postprocs']))[1] | |
self.validation_id = self.get_validation_function_id(postproc_names) | |
self.output_field_names = config['eval']['score_funcs'] + list(postproc_names) | |
def get_validation_function_id(self, postproc_names): | |
validation_field_name = self.config['eval']['validation_field_name'] | |
if validation_field_name in self.config['eval']['score_funcs']: | |
validation_id = self.config['eval']['score_funcs'].index(validation_field_name) | |
elif validation_field_name in postproc_names: | |
validation_id = postproc_names.index(validation_field_name) | |
validation_id += len(self.config['eval']['score_funcs']) | |
else: | |
raise KeyError("Could not find %s in fields" % validation_field_name) | |
return validation_id | |
def postprocess_scores(self, score_arr, ): | |
validation_score = mean_scores[validation_id] | |
return mean_scores, validation_score | |
def flush_scores(self): | |
self.iter_score_arr = [] | |
def update_scores(self, run_result): | |
n_numbers_useful = len(self.config['eval']['score_funcs']) | |
scores = run_result[-n_numbers_useful:] | |
self.iter_score_arr.append(scores) | |
def get_mean_scores(self, use_window=False): | |
prinet_every_iter = self.config['exec']['print_every_iter'] | |
score_mean_window = self.config['exec'].get('score_mean_window', prinet_every_iter) | |
score_arr_windowed = (self.iter_score_arr[-score_mean_window:] | |
if use_window else self.iter_score_arr) | |
mean_scores = np.mean(score_arr_windowed, axis=0).tolist() | |
for i, postproc in enumerate(self.score_postprocessors): | |
post_num = postproc(score_arr_windowed) | |
if not np.isscalar(post_num): | |
print("Postprocessors must return scalars") | |
mean_scores.append(post_num) | |
return mean_scores | |
def get_validation_value(self, mean_scores): | |
return mean_scores[self.validation_id] | |
## GRAPH | |
def multigpu_batch_composer(tower_input_records): | |
def tmp(batch_iterator, other_vars=None): | |
while True: | |
input_dict = {} | |
for input_placeholder_record in tower_input_records: | |
data_batch = next(batch_iterator) | |
if other_vars is not None: | |
data_batch.update(other_vars) | |
for data_part_name, data_part_val in data_batch.items(): | |
var = input_placeholder_record.get(data_part_name) | |
input_dict[var] = data_part_val | |
yield input_dict | |
return tmp | |
def average_gradients(tower_grads): | |
average_grads = [] | |
for grad_and_vars in zip(*tower_grads): | |
grads = [] | |
for g, _ in grad_and_vars: | |
expanded_g = tf.expand_dims(g, 0, name='grad_avg_expand_dims') | |
grads.append(expanded_g) | |
grad = tf.concat(0, grads) | |
grad = tf.reduce_mean(grad, 0, name='average_gradients') | |
v = grad_and_vars[0][1] | |
grad_and_var = (grad, v) | |
average_grads.append(grad_and_var) | |
return average_grads | |
from contextlib import ExitStack | |
@contextmanager | |
def with_all(context_managers): | |
with ExitStack() as stack: | |
yield [stack.enter_context(context) for context in context_managers] | |
def build_model_graph(build_model_func, config): | |
arg_scope = tf.contrib.framework.arg_scope | |
create_var_op = tf.contrib.framework.python.ops.variables.variable | |
build_config = {**config['model'], **config['data']['specs']} | |
graph = tf.Graph() | |
with graph.as_default(), tf.device('/cpu:0'): | |
gd = tf.train.AdamOptimizer(config['model']['learning_rate']) | |
if type(config['exec']['device_id']) is str: | |
with tf.device(config['exec']['device_id']): | |
seed = config['exec']['seed'] if 'seed' in config['exec'] else 1 | |
tf.set_random_seed(seed) | |
model = build_model_func(build_config) | |
train_op = gd.minimize(model.stat.loss) | |
batch_composer_init = multigpu_batch_composer([model.inputs]) | |
return model, train_op, batch_composer_init, graph | |
elif type(config['exec']['device_id']) is list: | |
tower_models = [] | |
tower_grads = [] | |
for tower_id, device_name in enumerate(config['exec']['device_id']): | |
context_managers = [ | |
tf.device(device_name), | |
tf.name_scope('tower_%d' % tower_id), | |
arg_scope([create_var_op], device='/cpu:0') | |
] | |
with with_all(context_managers): | |
tower_models.append(build_model_func(build_config)) | |
tower_grads.append(gd.compute_gradients(tower_models[-1].stat.loss)) # on GPU! | |
tf.get_variable_scope().reuse_variables() | |
tower_input_records = [model.inputs for model in tower_models] | |
batch_composer_init = multigpu_batch_composer(tower_input_records) | |
train_op = gd.apply_gradients(average_gradients(tower_grads)) | |
stat_keys = tower_models[0].stat.keys() | |
stat_vals = [[model.stat.get(stat_key) for stat_key in model.stat.keys()] | |
for model in tower_models] | |
stat_dict = {stat_name:tf.reduce_mean(stat_list) | |
for stat_name, stat_list in zip(stat_keys, zip(*stat_vals))} | |
avg_stat_record = record('stat', stat_dict) | |
new_model = record('MultiGPUModel', { | |
'inputs': None, | |
'outputs': None, | |
'stat': avg_stat_record, | |
}) | |
return new_model, train_op, batch_composer_init, graph | |
else: | |
raise ValueError("device_id must be either str() or list() of str()") | |
## EXECUTION | |
def run_model_learning(build_model_func, main_data, config): | |
model, train_op, batch_composer_init, graph = build_model_graph(build_model_func, config) | |
proc_handler = ScoreProcessingHandler(model, train_op, config) | |
request_early_stop = False | |
best_val_score = np.inf | |
best_val_test_score = np.inf | |
best_epoch_i = None | |
trainable_n = None | |
fmtr = TablePrinter(config['exec']['print_every_iter'], config['exec']['print_every_epoch']) | |
print(fmtr.headers(proc_handler.output_field_names)) | |
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7, allow_growth=True) | |
sess_cfg = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) | |
with graph.as_default(), tf.Session(config=sess_cfg) as s: | |
s.run(tf.initialize_all_variables()) | |
trainable_n = sum([np.prod(s.run(tf.shape(x))) for x in tf.trainable_variables()]) | |
epoch_rq_args = {'desc': 'epoch', 'leave': False} | |
epoch_iterator = tq(range(config['exec']['n_epoch']), **epoch_rq_args) | |
for epoch_i in epoch_iterator: | |
for data_name, funcs_to_execute in proc_handler.data_split_funcs: | |
proc_handler.flush_scores() | |
other_input_vars = {'use_dropout': int(data_name == 'train')} | |
raw_data_iterator = iter(main_data['dataset_iterator'][data_name]) | |
batch_composer_iter = batch_composer_init(raw_data_iterator, other_input_vars) | |
n_batches = config['data']['specs']['data_length'][data_name] | |
tq_args = {'leave': False, 'total': n_batches, 'desc': '%s-%d'%(data_name, epoch_i)} | |
for iter_i, composed_data_batch in tq(enumerate(batch_composer_iter), **tq_args): | |
ret = s.run(funcs_to_execute, composed_data_batch) | |
proc_handler.update_scores(ret) | |
if fmtr.iter_condition(epoch_i, iter_i, data_name): | |
mean_scores = proc_handler.get_mean_scores(use_window=True) | |
fmtr.iter_print(epoch_i, mean_scores, data_name) | |
## end of any ("train\test") epoch, scoring | |
mean_scores = proc_handler.get_mean_scores() | |
new_validation_score = proc_handler.get_validation_value(mean_scores) | |
## update best validation score | |
new_best_flag = False | |
if data_name == 'valid': | |
if new_validation_score < best_val_score: | |
best_val_score = new_validation_score | |
best_epoch_i = epoch_i | |
new_best_flag = True | |
if config['exec']['save_model']: | |
save_model_full(s, config) | |
elif epoch_i - best_epoch_i > config['exec']['early_stop_tolerance']: | |
request_early_stop = True | |
elif data_name == 'test' and best_epoch_i == epoch_i: | |
## test + last validation was successfull | |
best_val_test_score = new_validation_score | |
epoch_iterator.set_description('%9g '%best_val_test_score) | |
fmtr.epoch_print(epoch_i, mean_scores, data_name, new_best_flag) | |
## "main" epoch for loop | |
if request_early_stop: | |
print("Early stopping") | |
break | |
print('best validation \ test scores were', best_val_score, best_val_test_score) | |
return locals() | |
def compute_ppx_postproc(index): | |
## post processor that adds perplexity (from index pos) into scores | |
def tmp(score_arr_windowed): | |
losses = np.array(score_arr_windowed)[:, index] | |
ppx = np.exp(np.mean(losses)) | |
return ppx | |
return tmp | |
def first_n(it, n): | |
for i in range(n): | |
yield next(it) | |
def test_dummpy_data(): | |
batches_used_number = 10 | |
dummy_data_iterator = {data_name:first_n(iter(main_data['dataset_iterator'][data_name]), batches_used_number) | |
for data_name in main_data['dataset_iterator'].keys()} | |
dummy_main_data = main_data.copy() | |
dummy_main_data['dataset_iterator'] = dummy_data_iterator | |
config = { | |
'model': { | |
'token_hid_dim': 300, | |
'label_hid_dim': 150, | |
'hid_dim': 75, | |
'word_hid_dim': 75, | |
'learning_rate': 0.001, | |
'dropout_keep_rate': 0.8 | |
}, | |
'eval': { | |
'step_funcs': [], | |
'score_funcs': ['loss', 'loss_tokn_term', | |
'err_bool_mean', 'err_labl_mean', | |
'err_tokn_mean', 'err_term_mean'], | |
'score_postprocs': [ | |
('ppx_tokn_term', compute_ppx_postproc(index=1)), | |
], | |
'validation_field_name': 'ppx_tokn_term' | |
}, | |
'exec': { | |
'model_name': 'dummy', | |
'save_model': True, | |
'n_epoch': 1, | |
'early_stop_tolerance': 5, | |
'print_every_iter': 1, | |
'print_every_epoch': 1, | |
# 'device_id': ['/gpu:0', '/gpu:1'], | |
'device_id': '/gpu:0', | |
# 'device_id': ['/gpu:0'], | |
}, | |
'data': { | |
'specs': main_data['specs'], | |
'maps': main_data['maps'] | |
} | |
} | |
ret = run_model_learning(build_model, dummy_main_data, config) | |
# test_dummpy_data() | |
def test_model_load(): | |
print('tadam') | |
with run_loaded_model(build_model, 'dummy') as (s, model, config): | |
batch = next(iter(main_data['dataset_iterator']['test'])) | |
tf_run_args = { | |
model.inputs.X: batch['X'], | |
model.inputs.use_dropout: 0, | |
} | |
print('before running') | |
print('run 1: logits') | |
logits = s.run(model.other.logits_labl, tf_run_args) | |
print(logits.shape) | |
print(np.max(logits), np.min(logits)) | |
print('run 2: probs in np') | |
probs = s.run(tf.nn.softmax(logits)) | |
print(probs.shape) | |
print(np.sum(probs, axis=1)) | |
print('run 3: probs on cpu') | |
with tf.device('/cpu:0'): | |
t = tf.nn.softmax(model.other.logits_labl) | |
print(s.run(t, tf_run_args)) | |
print('run 4: probs on gpu') | |
with tf.device('/gpu:0'): | |
t2 = tf.nn.softmax(model.other.logits_labl) | |
print(s.run(t2, tf_run_args)) | |
test_model_load() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment