Skip to content

Instantly share code, notes, and snippets.

@saswatac
Created August 30, 2018 21:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saswatac/d362ae03f3e23342217faa88698ee1fe to your computer and use it in GitHub Desktop.
Save saswatac/d362ae03f3e23342217faa88698ee1fe to your computer and use it in GitHub Desktop.
import mxnet as mx
import numpy as np
from scipy.sparse import coo_matrix
from sklearn import datasets
def batch_row_ids(data_batch):
""" Generate row ids based on the current mini-batch """
return {'weight': data_batch.data[0].indices}
def all_row_ids(data_batch, num_features):
""" Generate row ids for all rows """
all_rows = mx.nd.arange(0, num_features, dtype='int64')
return {'weight': all_rows}
def linear_model(num_features):
# data with csr storage type to enable feeding data with CSRNDArray
x = mx.symbol.Variable("data", stype='csr')
norm_init = mx.initializer.Normal(sigma=0.01)
# weight with row_sparse storage type to enable sparse gradient updates
weight = mx.symbol.Variable("weight", shape=(num_features, 2),
init=norm_init, stype='row_sparse')
bias = mx.symbol.Variable("bias", shape=(2,))
dot = mx.symbol.sparse.dot(x, weight)
pred = mx.symbol.broadcast_add(dot, bias)
y = mx.symbol.Variable("softmax_label")
model = mx.sym.SoftmaxOutput(pred, label=y)
return model
if __name__ == '__main__':
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--num_features", type=int)
args = parser.parse_args()
batch_size = 1000
nnz = 100
num_features = args.num_features
optimizer = 'adam'
rows = 10000
# train data
row = np.array([[i for _ in range(nnz)] for i in range(rows)])
row = row.flatten()
col = np.array([i for i in range(nnz)] * rows)
data = np.random.random(size=rows*nnz)
train_data = coo_matrix((data, (row, col)), shape=(rows, num_features))
train_label = np.random.randint(2, size=rows)
datasets.dump_svmlight_file(train_data, train_label, "/tmp/libsvm.data")
train_data = "/tmp/libsvm.data"
# data iterator
train_data = mx.io.LibSVMIter(data_libsvm=train_data, data_shape=(num_features,),
batch_size=batch_size)
# model
model = linear_model(num_features)
# module
mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label'])
mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
mod.init_params()
mod.init_optimizer(optimizer=optimizer)
# start profiler
mx.profiler.set_config(profile_all=True, filename='profile_output.json', aggregate_stats=True)
mx.profiler.set_state('run')
logging.info('Training started ...')
for batch in train_data:
# for distributed training, we need to manually pull sparse weights from kvstore
mod.prepare(batch, sparse_row_id_fn=batch_row_ids)
mod.forward_backward(batch)
# update all parameters (including the weight parameter)
mod.update()
mod.prepare(None, all_row_ids)
mod.save_checkpoint("checkpoint", 0)
logging.info('Training completed.')
mx.profiler.set_state('stop')
print mx.profiler.dumps()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment