Skip to content

Instantly share code, notes, and snippets.

@jtanios
Last active February 9, 2017 23:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jtanios/711ce7dc0f8410e8a11d6486ff612bd6 to your computer and use it in GitHub Desktop.
Save jtanios/711ce7dc0f8410e8a11d6486ff612bd6 to your computer and use it in GitHub Desktop.
mxnet How-To's

Resume training from a saved checkpoint

sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch)
model = mx.model.FeedForward(symbol=sym, arg_params=arg_params, aux_params=aux_params, ctx=devices, ...)
model.fit(...)

Fine-tuning a model (transfer learning)

Functions in fine_tune.py

# Get pre-trained model
model = mx.model.FeedForward.load(prefix, 0)

# Get cutoff layer
cutoff = get_feature_symbol(model)

# Add new layers
new_net = mx.sym.FullyConnected(data=cutoff, num_hidden=10, name="new_fc")
new_net = mx.sym.SoftmaxOutput(data=new_net, name="softmax")

# Train
new_model = finetune(symbol=new_net, model=model, X=X, y=y, num_epoch=0)
def get_feature_symbol(model, top_layer=None):
"""Get feature symbol from a model
.. note::
If top_layer is not present, it will return the second last layer symbol
Parameters
----------
model: mx.model.FeedForward
Model will be used to extract feature symbol
top_layer: str, option
Name of top_layer will be used
Returns
-------
internals[top_layer]: mx.symbol.Symbol
Feature symbol
"""
if type(model) is mx.symbol.Symbol:
internals = model.get_internals()
else:
internals = model.symbol.get_internals()
tmp = internals.list_outputs()[::-1]
outputs = [name for name in tmp if name.endswith("output")]
if top_layer != None and type(top_layer) != str:
error_msg = "top_layer must be a string in following candidates:\n %s" % "\n".join(outputs)
raise TypeError(error_msg)
if top_layer == None:
assert len(outputs) > 3
top_layer = outputs[2]
else:
if top_layer not in outputs:
error_msg = "%s not exists in symbol. Possible choice:\n%s" \
% (top_layer, "\n".join(outputs))
raise ValueError(error_msg)
return internals[top_layer]
def finetune(symbol, model, **kwargs):
"""Get a FeedForward model for fine-tune
.. note::
For layer doesn't exist in model, will be initialized as uniform random weight
Parameters
----------
symbol: mx.symbol.Symbol
Symbol of new network will be finetuned.
model: mx.model.FeedForward
Model which contains parameters which will be used for fine-tune.
kwargs: kwargs
mx.model.create function's parameters
Returns
-------
new_model: mx.model.FeedForward
Finetuned model
Examples
--------
Load a model
>>> sym, arg_params, aux_params = mx.model.load_checkpoint("model", 9)
Make new symbol for finetune
>>> feature = mx.model.get_feature_symbol(model)
>>> net = mx.sym.FullyConnected(data=feature, num_hidden=10, name="new_fc")
>>> net = mx.sym.SoftmaxOutput(data=net, name="softmax")
Finetune the model
>>> new_model = mx.model.finetune(symbol=net, model=model, num_epoch=2, learning_rate=1e-3,
X=train, eval_data=val,
batch_end_callback=mx.callback.Speedometer(100))
"""
initializer = mx.init.Load(param=model.arg_params, default_init=mx.init.Uniform(0.001))
new_model = mx.model.FeedForward.create(symbol=symbol, initializer=initializer, **kwargs)
return new_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment