Instantly share code, notes, and snippets.
Last active
May 17, 2018 16:08
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save bricksdont/3dc6213482c38526decacf2a54b7f8b6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /bin/env python | |
import os | |
import logging | |
import mxnet as mx | |
import numpy as np | |
logging.getLogger().setLevel(logging.INFO) | |
############################################ | |
# Parameter sharing between BucketingModules | |
############################################ | |
# Partly derived from: | |
# https://github.com/apache/incubator-mxnet/blob/master/example/rnn/bucketing/lstm_bucketing.py | |
def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0): | |
if not os.path.isfile(fname): | |
raise IOError("Please use get_ptb_data.sh to download required file (data/ptb.train.txt)") | |
lines = open(fname).readlines() | |
lines = [filter(None, i.split(' ')) for i in lines] | |
sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, | |
start_label=start_label) | |
return sentences, vocab | |
buckets = [10, 20, 30, 40, 50, 60] | |
start_label = 1 | |
invalid_label = 0 | |
num_hidden = 16 | |
batch_size = 32 | |
num_embed = 16 | |
contexts = mx.cpu(0) | |
# Emulate several sources of data and sequence iterators. | |
train_sent, vocab = tokenize_text("./data/ptb.train.txt", start_label=start_label, invalid_label=invalid_label) | |
train_iter1 = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets, invalid_label=invalid_label, data_name='net1_data', label_name='net1_softmax_label') | |
train_iter2 = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets, invalid_label=invalid_label, data_name='net2_data', label_name='net2_softmax_label') | |
# build first module | |
stack1 = mx.rnn.SequentialRNNCell() | |
for i in range(2): | |
stack1.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='net1_lstm_l%d_' % i)) | |
def sym_gen1(seq_len): | |
data = mx.sym.Variable('net1_data') | |
label = mx.sym.Variable('net1_softmax_label') | |
embed = mx.sym.Embedding(data=data, input_dim=len(vocab), | |
output_dim=num_embed, name='net1_embed') | |
stack1.reset() | |
outputs, states = stack1.unroll(seq_len, inputs=embed, merge_outputs=True) | |
pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden)) | |
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='net1_pred') | |
label = mx.sym.Reshape(label, shape=(-1,)) | |
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='net1_softmax') | |
return pred, ('net1_data',), ('net1_softmax_label',) | |
# build second module | |
stack2 = mx.rnn.SequentialRNNCell() | |
for i in range(1): | |
stack2.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='net2_lstm_l%d_' % i)) | |
def sym_gen2(seq_len): | |
data = mx.sym.Variable('net2_data') | |
label = mx.sym.Variable('net2_softmax_label') | |
embed = mx.sym.Embedding(data=data, input_dim=len(vocab), | |
output_dim=num_embed, name='net2_embed') | |
stack2.reset() | |
outputs, states = stack2.unroll(seq_len, inputs=embed, merge_outputs=True) | |
pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden)) | |
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='net2_pred') | |
label = mx.sym.Reshape(label, shape=(-1,)) | |
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='net2_softmax') | |
return pred, ('net2_data',), ('net2_softmax_label',) | |
mod1 = mx.mod.BucketingModule(sym_gen=sym_gen1, | |
default_bucket_key=train_iter1.default_bucket_key, | |
context=contexts) | |
mod2 = mx.mod.BucketingModule(sym_gen=sym_gen2, | |
default_bucket_key=train_iter2.default_bucket_key, | |
context=contexts) | |
mod1.bind(data_shapes=train_iter1.provide_data, | |
label_shapes=train_iter1.provide_label, | |
for_training=True, | |
force_rebind=True, | |
grad_req="write") | |
mod1.init_params(initializer=mx.init.Uniform(scale=.1), force_init=True) | |
mod2.bind(data_shapes=train_iter2.provide_data, | |
label_shapes=train_iter2.provide_label, | |
for_training=True, | |
force_rebind=True, | |
grad_req="write") | |
mod2.init_params(initializer=mx.init.Uniform(scale=.1)) | |
def print_arg_names(mod): | |
arg_params, aux_params = mod.get_params() | |
print "arg_param names: ", arg_params.keys() | |
print "aux_param names: ", aux_params.keys() | |
print_arg_names(mod1) | |
print_arg_names(mod2) | |
def sync_params(copy_from, copy_to, name_map, verbose=False): | |
""" | |
copy_from: mx.mod.Module | |
copy_to: mx.mod.Module | |
name_map: dict with entries "name in copy_from": "name in copy_to" | |
""" | |
from_arg_params, from_aux_params = copy_from.get_params() | |
to_arg_params, to_aux_params = {}, {} | |
for from_name, from_array in from_arg_params.items(): | |
try: | |
to_name = name_map[from_name] | |
except KeyError: | |
continue | |
if verbose: | |
print "[sync_params] Setting '%s' to the value of '%s'" % (to_name, from_name) | |
to_arg_params[to_name] = from_array | |
# same for aux params | |
for from_name, from_array in from_aux_params.items(): | |
try: | |
to_name = name_map[from_name] | |
except KeyError: | |
continue | |
to_aux_params[to_name] = from_array | |
copy_to.set_params(arg_params=to_arg_params, | |
aux_params=to_aux_params, | |
allow_missing=True, | |
allow_extra=False) | |
def compare_params(mod1, mod2, param1, param2, when=None): | |
""" | |
Compares the values of two parameters, potentially from | |
different modules. | |
mod1: mx.mod.Module | |
mod2: mx.mod.Module | |
param1: str | |
param1: str | |
""" | |
if when: | |
print when | |
array1 = mod1.get_params()[0][param1].asnumpy() | |
array2 = mod2.get_params()[0][param2].asnumpy() | |
if np.array_equal(array1, array2): | |
print "'%s' and '%s' are identical." % (param1, param2) | |
else: | |
print "'%s' and '%s' are different." % (param1, param2) | |
# optimizer | |
mod1.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), )) | |
mod2.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), )) | |
# sync params with mod2 | |
name_map = {'net1_embed_weight': 'net2_embed_weight', | |
'net1_lstm_l0_h2h_weight': 'net2_lstm_l0_h2h_weight'} | |
# compare before training | |
for name1, name2 in name_map.items(): | |
compare_params(mod1, mod2, name1, name2, when="BEFORE TRAINING:") | |
# train mod1 with one batch | |
batch1 = next(train_iter1) | |
mod1.forward_backward(batch1) | |
mod1.update() | |
# compare before syncing | |
for name1, name2 in name_map.items(): | |
compare_params(mod1, mod2, name1, name2, when="AFTER BATCH1, BEFORE SYNC:") | |
sync_params(mod1, mod2, name_map) | |
# compare after syncing | |
for name1, name2 in name_map.items(): | |
compare_params(mod1, mod2, name1, name2, when="AFTER BATCH1, AFTER SYNC:") | |
# train mod2 with one batch | |
batch2 = next(train_iter2) | |
mod2.forward_backward(batch2) | |
mod2.update() | |
# compare before syncing | |
for name1, name2 in name_map.items(): | |
compare_params(mod1, mod2, name1, name2, when="AFTER BATCH2, BEFORE SYNC:") | |
# reverse sync | |
name_map_reverse = {v: k for k, v in name_map.iteritems()} | |
sync_params(mod2, mod1, name_map_reverse) | |
# compare after syncing | |
for name1, name2 in name_map.items(): | |
compare_params(mod1, mod2, name1, name2, when="AFTER BATCH2, AFTER SYNC:") | |
# Output: | |
# arg_param names: ['net1_lstm_l0_i2h_bias', 'net1_pred_bias', 'net1_pred_weight', 'net1_lstm_l0_h2h_weight', 'net1_lstm_l1_h2h_bias', 'net1_lstm_l0_h2h_bias', 'net1_lstm_l1_h2h_weight', 'net1_embed_weight', 'net1_lstm_l1_i2h_bias', 'net1_lstm_l0_i2h_weight', 'net1_lstm_l1_i2h_weight'] | |
# aux_param names: [] | |
# arg_param names: ['net2_lstm_l0_h2h_bias', 'net2_pred_weight', 'net2_pred_bias', 'net2_lstm_l0_h2h_weight', 'net2_lstm_l0_i2h_bias', 'net2_lstm_l0_i2h_weight', 'net2_embed_weight'] | |
# aux_param names: [] | |
# BEFORE TRAINING: | |
# 'net1_embed_weight' and 'net2_embed_weight' are different. | |
# BEFORE TRAINING: | |
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are different. | |
# AFTER BATCH1, BEFORE SYNC: | |
# 'net1_embed_weight' and 'net2_embed_weight' are different. | |
# AFTER BATCH1, BEFORE SYNC: | |
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are different. | |
# AFTER BATCH1, AFTER SYNC: | |
# 'net1_embed_weight' and 'net2_embed_weight' are identical. | |
# AFTER BATCH1, AFTER SYNC: | |
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are identical. | |
# AFTER BATCH2, BEFORE SYNC: | |
# 'net1_embed_weight' and 'net2_embed_weight' are different. | |
# AFTER BATCH2, BEFORE SYNC: | |
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are different. | |
# AFTER BATCH2, AFTER SYNC: | |
# 'net1_embed_weight' and 'net2_embed_weight' are identical. | |
# AFTER BATCH2, AFTER SYNC: | |
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are identical. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment