Skip to content

Instantly share code, notes, and snippets.

@bricksdont
Last active May 17, 2018 16:08
Show Gist options
  • Save bricksdont/3dc6213482c38526decacf2a54b7f8b6 to your computer and use it in GitHub Desktop.
Save bricksdont/3dc6213482c38526decacf2a54b7f8b6 to your computer and use it in GitHub Desktop.
#! /bin/env python
import os
import logging
import mxnet as mx
import numpy as np
logging.getLogger().setLevel(logging.INFO)
############################################
# Parameter sharing between BucketingModules
############################################
# Partly derived from:
# https://github.com/apache/incubator-mxnet/blob/master/example/rnn/bucketing/lstm_bucketing.py
def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
if not os.path.isfile(fname):
raise IOError("Please use get_ptb_data.sh to download required file (data/ptb.train.txt)")
lines = open(fname).readlines()
lines = [filter(None, i.split(' ')) for i in lines]
sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label,
start_label=start_label)
return sentences, vocab
buckets = [10, 20, 30, 40, 50, 60]
start_label = 1
invalid_label = 0
num_hidden = 16
batch_size = 32
num_embed = 16
contexts = mx.cpu(0)
# Emulate several sources of data and sequence iterators.
train_sent, vocab = tokenize_text("./data/ptb.train.txt", start_label=start_label, invalid_label=invalid_label)
train_iter1 = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets, invalid_label=invalid_label, data_name='net1_data', label_name='net1_softmax_label')
train_iter2 = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets, invalid_label=invalid_label, data_name='net2_data', label_name='net2_softmax_label')
# build first module
stack1 = mx.rnn.SequentialRNNCell()
for i in range(2):
stack1.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='net1_lstm_l%d_' % i))
def sym_gen1(seq_len):
data = mx.sym.Variable('net1_data')
label = mx.sym.Variable('net1_softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
output_dim=num_embed, name='net1_embed')
stack1.reset()
outputs, states = stack1.unroll(seq_len, inputs=embed, merge_outputs=True)
pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='net1_pred')
label = mx.sym.Reshape(label, shape=(-1,))
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='net1_softmax')
return pred, ('net1_data',), ('net1_softmax_label',)
# build second module
stack2 = mx.rnn.SequentialRNNCell()
for i in range(1):
stack2.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='net2_lstm_l%d_' % i))
def sym_gen2(seq_len):
data = mx.sym.Variable('net2_data')
label = mx.sym.Variable('net2_softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
output_dim=num_embed, name='net2_embed')
stack2.reset()
outputs, states = stack2.unroll(seq_len, inputs=embed, merge_outputs=True)
pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='net2_pred')
label = mx.sym.Reshape(label, shape=(-1,))
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='net2_softmax')
return pred, ('net2_data',), ('net2_softmax_label',)
mod1 = mx.mod.BucketingModule(sym_gen=sym_gen1,
default_bucket_key=train_iter1.default_bucket_key,
context=contexts)
mod2 = mx.mod.BucketingModule(sym_gen=sym_gen2,
default_bucket_key=train_iter2.default_bucket_key,
context=contexts)
mod1.bind(data_shapes=train_iter1.provide_data,
label_shapes=train_iter1.provide_label,
for_training=True,
force_rebind=True,
grad_req="write")
mod1.init_params(initializer=mx.init.Uniform(scale=.1), force_init=True)
mod2.bind(data_shapes=train_iter2.provide_data,
label_shapes=train_iter2.provide_label,
for_training=True,
force_rebind=True,
grad_req="write")
mod2.init_params(initializer=mx.init.Uniform(scale=.1))
def print_arg_names(mod):
arg_params, aux_params = mod.get_params()
print "arg_param names: ", arg_params.keys()
print "aux_param names: ", aux_params.keys()
print_arg_names(mod1)
print_arg_names(mod2)
def sync_params(copy_from, copy_to, name_map, verbose=False):
"""
copy_from: mx.mod.Module
copy_to: mx.mod.Module
name_map: dict with entries "name in copy_from": "name in copy_to"
"""
from_arg_params, from_aux_params = copy_from.get_params()
to_arg_params, to_aux_params = {}, {}
for from_name, from_array in from_arg_params.items():
try:
to_name = name_map[from_name]
except KeyError:
continue
if verbose:
print "[sync_params] Setting '%s' to the value of '%s'" % (to_name, from_name)
to_arg_params[to_name] = from_array
# same for aux params
for from_name, from_array in from_aux_params.items():
try:
to_name = name_map[from_name]
except KeyError:
continue
to_aux_params[to_name] = from_array
copy_to.set_params(arg_params=to_arg_params,
aux_params=to_aux_params,
allow_missing=True,
allow_extra=False)
def compare_params(mod1, mod2, param1, param2, when=None):
"""
Compares the values of two parameters, potentially from
different modules.
mod1: mx.mod.Module
mod2: mx.mod.Module
param1: str
param1: str
"""
if when:
print when
array1 = mod1.get_params()[0][param1].asnumpy()
array2 = mod2.get_params()[0][param2].asnumpy()
if np.array_equal(array1, array2):
print "'%s' and '%s' are identical." % (param1, param2)
else:
print "'%s' and '%s' are different." % (param1, param2)
# optimizer
mod1.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
mod2.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
# sync params with mod2
name_map = {'net1_embed_weight': 'net2_embed_weight',
'net1_lstm_l0_h2h_weight': 'net2_lstm_l0_h2h_weight'}
# compare before training
for name1, name2 in name_map.items():
compare_params(mod1, mod2, name1, name2, when="BEFORE TRAINING:")
# train mod1 with one batch
batch1 = next(train_iter1)
mod1.forward_backward(batch1)
mod1.update()
# compare before syncing
for name1, name2 in name_map.items():
compare_params(mod1, mod2, name1, name2, when="AFTER BATCH1, BEFORE SYNC:")
sync_params(mod1, mod2, name_map)
# compare after syncing
for name1, name2 in name_map.items():
compare_params(mod1, mod2, name1, name2, when="AFTER BATCH1, AFTER SYNC:")
# train mod2 with one batch
batch2 = next(train_iter2)
mod2.forward_backward(batch2)
mod2.update()
# compare before syncing
for name1, name2 in name_map.items():
compare_params(mod1, mod2, name1, name2, when="AFTER BATCH2, BEFORE SYNC:")
# reverse sync
name_map_reverse = {v: k for k, v in name_map.iteritems()}
sync_params(mod2, mod1, name_map_reverse)
# compare after syncing
for name1, name2 in name_map.items():
compare_params(mod1, mod2, name1, name2, when="AFTER BATCH2, AFTER SYNC:")
# Output:
# arg_param names: ['net1_lstm_l0_i2h_bias', 'net1_pred_bias', 'net1_pred_weight', 'net1_lstm_l0_h2h_weight', 'net1_lstm_l1_h2h_bias', 'net1_lstm_l0_h2h_bias', 'net1_lstm_l1_h2h_weight', 'net1_embed_weight', 'net1_lstm_l1_i2h_bias', 'net1_lstm_l0_i2h_weight', 'net1_lstm_l1_i2h_weight']
# aux_param names: []
# arg_param names: ['net2_lstm_l0_h2h_bias', 'net2_pred_weight', 'net2_pred_bias', 'net2_lstm_l0_h2h_weight', 'net2_lstm_l0_i2h_bias', 'net2_lstm_l0_i2h_weight', 'net2_embed_weight']
# aux_param names: []
# BEFORE TRAINING:
# 'net1_embed_weight' and 'net2_embed_weight' are different.
# BEFORE TRAINING:
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are different.
# AFTER BATCH1, BEFORE SYNC:
# 'net1_embed_weight' and 'net2_embed_weight' are different.
# AFTER BATCH1, BEFORE SYNC:
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are different.
# AFTER BATCH1, AFTER SYNC:
# 'net1_embed_weight' and 'net2_embed_weight' are identical.
# AFTER BATCH1, AFTER SYNC:
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are identical.
# AFTER BATCH2, BEFORE SYNC:
# 'net1_embed_weight' and 'net2_embed_weight' are different.
# AFTER BATCH2, BEFORE SYNC:
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are different.
# AFTER BATCH2, AFTER SYNC:
# 'net1_embed_weight' and 'net2_embed_weight' are identical.
# AFTER BATCH2, AFTER SYNC:
# 'net1_lstm_l0_h2h_weight' and 'net2_lstm_l0_h2h_weight' are identical.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment