Skip to content

Instantly share code, notes, and snippets.

@piiswrong
Created September 26, 2016 22:57
Show Gist options
  • Save piiswrong/d859ba805607775a8306d310e0e8345c to your computer and use it in GitHub Desktop.
Save piiswrong/d859ba805607775a8306d310e0e8345c to your computer and use it in GitHub Desktop.
def fit(args, network, data_loader, batch_end_callback=None):
# kvstore
kv = mx.kvstore.create(args.kv_store)
model_prefix = args.model_prefix
if model_prefix is not None:
model_prefix += "-%d" % (kv.rank)
save_model_prefix = args.save_model_prefix
if save_model_prefix is None:
save_model_prefix = model_prefix
# logging
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
if 'log_dir' in args and args.log_dir is not None:
logging.basicConfig(level=logging.DEBUG, format=head)
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
if args.log_file is None:
log_file = (save_model_prefix if save_model_prefix else '') + datetime.now().strftime('_%Y_%m_%d-%H_%M.log')
log_file = log_file.replace('/', '-')
else:
log_file = args.log_file
log_file_full_name = os.path.join(args.log_dir, log_file)
handler = logging.FileHandler(log_file_full_name, mode='w')
formatter = logging.Formatter(head)
handler.setFormatter(formatter)
logging.getLogger().addHandler(handler)
logging.info('start with arguments %s', args)
else:
logging.basicConfig(level=logging.DEBUG, format=head)
logging.info('start with arguments %s', args)
# load model
model_args = {}
if args.load_epoch is not None:
assert model_prefix is not None
tmp = mx.model.load_checkpoint(model_prefix, args.load_epoch)
model_args = {'arg_params' : tmp.arg_params,
'aux_params' : tmp.aux_params,
'begin_epoch' : args.load_epoch}
# save model
checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)
# data
(train, val) = data_loader(args, kv)
# train
devs = mx.cpu() if args.gpus is None else [
mx.gpu(int(i)) for i in args.gpus.split(',')]
epoch_size = args.num_examples / args.batch_size
if args.kv_store == 'dist_sync':
epoch_size /= kv.num_workers
model_args['epoch_size'] = epoch_size
# disable kvstore for single device
if 'local' in kv.type and (
args.gpus is None or len(args.gpus.split(',')) is 1):
kv = None
model = mx.mod.Module(symbol=network, context=devs)
optim = {'learning_rate': args.lr, 'wd': 1e-4, 'momentum': 0.9}
if 'lr_factor' in args and args.lr_factor < 1:
optim['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
step = max(int(epoch_size * args.lr_factor_epoch), 1),
factor = args.lr_factor)
if 'clip_gradient' in args and args.clip_gradient is not None:
optim['clip_gradient'] = args.clip_gradient
model_args['optimizer_params'] = optim
model_args['num_epoch'] = args.num_epochs
model_args['initializer'] = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
eval_metric = [args.metric]#, mx.metric.create('top_k_accuracy', top_k = 5)]
if batch_end_callback is not None:
if not isinstance(batch_end_callback, list):
batch_end_callback = [batch_end_callback]
else:
batch_end_callback = []
batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
model.fit(
train_data = train,
eval_data = val,
eval_metric = eval_metric,
kvstore = kv,
batch_end_callback = batch_end_callback,
epoch_end_callback = checkpoint,
optimizer = 'nag',
**model_args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment