Skip to content

Instantly share code, notes, and snippets.

@facundoq
Last active April 19, 2017 14:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save facundoq/93a9d90c52c94aa9b329c47a4150d288 to your computer and use it in GitHub Desktop.
Save facundoq/93a9d90c52c94aa9b329c47a4150d288 to your computer and use it in GitHub Desktop.
Training a lenet on mnist using MXNet.jl, without using the `fit` function
using MXNet
function accuracy(predicted_probability::mx.NDArray,label::mx.NDArray)
predicted_label=copy(mx.argmax(predicted_probability,axis=1))
julia_label=copy(label)
result=sum(predicted_label.==julia_label)
result
end
function initialize_model(ex,model)
#todo remove the model as parameter and get the arg_names from ex
initializer=mx.XavierInitializer(distribution = mx.xv_uniform, regularization = mx.xv_avg, magnitude = 3)
arg_names=mx.list_arguments(model)
args=Dict(zip(arg_names,ex.arg_arrays))
grads=Dict(zip(arg_names,ex.grad_arrays))
for (k,v) in args
if grads[k]!=nothing && !endswith(string(k),"label")
mx.init(initializer,k,v)
end
end
end
function update_weights(updater::Function,model,ex,iteration::Int)
arg_names=keys(ex.arg_dict)
arg_names=mx.list_arguments(model)
args = Dict(zip(arg_names,ex.arg_arrays ))
grads= Dict(zip(arg_names,ex.grad_arrays))
for (k,v) in args
if grads[k]!=nothing && !endswith(string(k),"label")
weight_index=findfirst(arg_names,k)
updater(weight_index,grads[k],args[k])
end
end
end
#--------------------------------------------------------------------------------
# define lenet
# input
data = mx.Variable(:data)
# first conv
conv1 = @mx.chain mx.Convolution(data, kernel=(5,5), num_filter=20) =>
mx.Activation(act_type=:tanh) =>
mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))
# second conv
conv2 = @mx.chain mx.Convolution(conv1, kernel=(5,5), num_filter=50) =>
mx.Activation(act_type=:tanh) =>
mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))
# first fully-connected
fc1 = @mx.chain mx.Flatten(conv2) =>
mx.FullyConnected(num_hidden=500) =>
mx.Activation(act_type=:tanh)
# second fully-connected
fc2 = mx.FullyConnected(fc1, num_hidden=10)
# softmax loss
lenet = mx.SoftmaxOutput(fc2, name=:softmax)
#--------------------------------------------------------------------------------
# load data
batch_size = 512
include("mnist-data.jl")
train_provider, eval_provider = get_mnist_providers(batch_size; flat=false)
#--------------------------------------------------------------------------------
data_shape=(28,28,1,batch_size)
ex=mx.simple_bind(lenet,mx.cpu(),data=data_shape)
initialize_model(ex,lenet)
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001)
op_state = mx.OptimizationState(batch_size)
optimizer.state = op_state
updater = mx.get_updater(optimizer)
n_epoch=10
for epoch in range(1,n_epoch)
train_acc=0
val_acc=0
nbatch=0
op_state.curr_epoch = epoch
op_state.curr_batch = 0
println("Epoch $epoch")
for batch in train_provider
data=mx.get(train_provider,batch,:data)
ex.arg_dict[:data][:]=data
label=mx.get(train_provider,batch,:softmax_label)
ex.arg_dict[:softmax_label][:]=label
mx.forward(ex,is_train=true)
predicted_probability=ex.outputs[1]
mx.backward(ex)
update_weights(updater,lenet,ex,nbatch)
train_acc+=accuracy(predicted_probability,label)
nbatch+=1
op_state.curr_batch += 1
end
println("Train accuracy: $train_acc/$(nbatch*batch_size)")
train_acc/=(nbatch*batch_size)
println("Train accuracy: $train_acc")
nbatch=0
for batch in eval_provider
ex.arg_dict[:data][:]=mx.get(eval_provider,batch,:data)
label=mx.get(eval_provider,batch,:softmax_label)
ex.arg_dict[:softmax_label][:]=label
mx.forward(ex,is_train=false)
predicted_probability=ex.outputs[1]
val_acc+=accuracy(predicted_probability,label)
nbatch+=1
end
val_acc/=(nbatch*batch_size)
println("Val accuracy: $val_acc")
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment