Created
December 6, 2018 12:15
-
-
Save cosmincatalin/6dcd8716c61356e89df545092f5d76e3 to your computer and use it in GitHub Desktop.
Training script used by SageMaker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import logging | |
import os | |
from pickle import load | |
import mxnet as mx | |
import numpy as np | |
from mxnet import autograd, nd, gluon | |
from mxnet.contrib import onnx as onnx_mxnet | |
from mxnet.gluon.loss import L2Loss | |
from mxnet.gluon.nn import Dense, Dropout, HybridSequential | |
from mxnet.gluon.trainer import Trainer | |
from mxnet.initializer import Xavier | |
logging.basicConfig(level=logging.INFO) | |
def train(data_dir, num_gpus): | |
mx.random.seed(42) | |
with open("{}/train/data.p".format(data_dir), "rb") as pickle: | |
train_nd = load(pickle) | |
with open("{}/test/data.p".format(data_dir), "rb") as pickle: | |
test_nd = load(pickle) | |
train_data = gluon.data.DataLoader(train_nd, 64, shuffle=True) | |
validation_data = gluon.data.DataLoader(test_nd, 64, shuffle=True) | |
net = HybridSequential() | |
with net.name_scope(): | |
net.add(Dense(9)) | |
net.add(Dropout(.25)) | |
net.add(Dense(16)) | |
net.add(Dropout(.25)) | |
net.add(Dense(1)) | |
net.hybridize() | |
ctx = mx.gpu() if num_gpus > 0 else mx.cpu() | |
# Also known as Glorot | |
net.collect_params().initialize(Xavier(magnitude=2.24), ctx=ctx) | |
loss = L2Loss() | |
trainer = Trainer(net.collect_params(), optimizer="adam") | |
smoothing_constant = .01 | |
for e in range(5): | |
moving_loss = 0 | |
for i, (data, label) in enumerate(train_data): | |
data = data.as_in_context(ctx) | |
label = label.as_in_context(ctx) | |
with autograd.record(): | |
output = net(data) | |
loss_result = loss(output, label) | |
loss_result.backward() | |
trainer.step(64) | |
curr_loss = nd.mean(loss_result).asscalar() | |
moving_loss = (curr_loss if ((i == 0) and (e == 0)) | |
else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss) | |
test_mae = measure_performance(net, ctx, validation_data) | |
train_mae = measure_performance(net, ctx, train_data) | |
print("Epoch %s. Loss: %s, Train_mae %s, Test_mae %s" % (e, moving_loss, train_mae, test_mae)) | |
return net | |
def measure_performance(model, ctx, data_iter): | |
mae = mx.metric.MAE() | |
for _, (data, labels) in enumerate(data_iter): | |
data = data.as_in_context(ctx) | |
labels = labels.as_in_context(ctx) | |
output = model(data) | |
predictions = output | |
mae.update(preds=predictions, labels=labels) | |
return mae.get()[1] | |
def save(net, model_dir): | |
net.export("model", epoch=4) | |
onnx_mxnet.export_model(sym="model-symbol.json", | |
params="model-0004.params", | |
input_shape=[(1, 4)], | |
input_type=np.float32, | |
onnx_file_path="{}/model.onnx".format(model_dir), | |
verbose=True) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model-dir', type=str, default=os.environ["SM_MODEL_DIR"]) | |
parser.add_argument("--data-dir", type=str, default=os.environ["SM_CHANNEL_TRAINING"]) | |
parser.add_argument("--gpus", type=int, default=os.environ["SM_NUM_GPUS"]) | |
args, _ = parser.parse_known_args() | |
net = train(args.data_dir, args.gpus) | |
save(net, args.model_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment