Skip to content

Instantly share code, notes, and snippets.

@bricksdont
Last active May 17, 2018 14:24
Show Gist options
  • Save bricksdont/20321aa85cfdbb8b20407e80273e8f19 to your computer and use it in GitHub Desktop.
Save bricksdont/20321aa85cfdbb8b20407e80273e8f19 to your computer and use it in GitHub Desktop.
Sharing params between mxnet modules
#! /bin/env python
import os
import logging
import mxnet as mx
import numpy as np
logging.getLogger().setLevel(logging.INFO)
#######################################
# Basic example with usual Modules
#######################################
mx.random.seed(1234)
fname = mx.test_utils.download('https://s3.us-east-2.amazonaws.com/mxnet-public/letter_recognition/letter-recognition.data')
data = np.genfromtxt(fname, delimiter=',')[:, 1:]
label = np.array([ord(l.split(',')[0]) - ord('A') for l in open(fname, 'r')])
net1 = mx.sym.Variable('net1_data')
net1 = mx.sym.FullyConnected(net1, name='net1_fc1', num_hidden=64)
net1 = mx.sym.Activation(net1, name='relu1', act_type="relu")
net1 = mx.sym.FullyConnected(net1, name='net1_fc2', num_hidden=26)
net1 = mx.sym.Activation(net1, name='relu1', act_type="relu")
net1 = mx.sym.FullyConnected(net1, name='net1_fc3', num_hidden=10)
net1 = mx.sym.SoftmaxOutput(net1, name='net1_softmax')
net2 = mx.sym.Variable('net2_data')
net2 = mx.sym.FullyConnected(net2, name='net1_fc1', num_hidden=64)
net2 = mx.sym.Activation(net2, name='relu1', act_type="relu")
net2 = mx.sym.FullyConnected(net2, name='net1_fc2', num_hidden=26)
net2 = mx.sym.SoftmaxOutput(net2, name='net2_softmax')
net3 = mx.sym.Variable('net3_data')
net3 = mx.sym.FullyConnected(net3, name='net1_fc1', num_hidden=64)
net3 = mx.sym.Activation(net3, name='relu1', act_type="relu")
net3 = mx.sym.FullyConnected(net3, name='net1_fc2', num_hidden=26)
net3 = mx.sym.SoftmaxOutput(net3, name='net3_softmax')
for net in (net1, net2, net3):
print net.list_arguments()
print
batch_size = 32
ntrain = int(data.shape[0] * 0.8)
iter1 = mx.io.NDArrayIter(data[:ntrain, :],
label[:ntrain],
batch_size,
shuffle=True,
data_name='net1_data',
label_name='net1_softmax_label')
iter2 = mx.io.NDArrayIter(data[:ntrain, :],
label[:ntrain],
batch_size,
shuffle=True,
data_name='net2_data',
label_name='net2_softmax_label')
iter3 = mx.io.NDArrayIter(data[:ntrain, :],
label[:ntrain],
batch_size,
shuffle=True,
data_name='net3_data',
label_name='net3_softmax_label')
mod1 = mx.mod.Module(symbol=net1,
context=mx.cpu(),
data_names=['net1_data'],
label_names=['net1_softmax_label'])
mod2 = mx.mod.Module(symbol=net2,
context=mx.cpu(),
data_names=['net2_data'],
label_names=['net2_softmax_label'])
mod3 = mx.mod.Module(symbol=net3,
context=mx.cpu(),
data_names=['net3_data'],
label_names=['net3_softmax_label'])
mod1.bind(data_shapes=iter1.provide_data,
label_shapes=iter1.provide_label,
for_training=True,
force_rebind=True,
grad_req="write")
# init params of master module before other modules are bound
mod1.init_params(initializer=mx.init.Uniform(scale=.1), force_init=True)
mod2.bind(data_shapes=iter2.provide_data,
label_shapes=iter2.provide_label,
for_training=True,
shared_module=mod1,
grad_req="write")
mod3.bind(data_shapes=iter3.provide_data,
label_shapes=iter3.provide_label,
for_training=True,
shared_module=mod1,
grad_req="write")
mod2.init_params(initializer=mx.init.Uniform(scale=.1))
mod3.init_params(initializer=mx.init.Uniform(scale=.1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment