Last active
August 29, 2015 14:26
-
-
Save scheidan/83026b9f4b5f4c7c7079 to your computer and use it in GitHub Desktop.
Problem with DecayOnValidation learning rate policy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use_cuda = false | |
using Mocha | |
srand(12345678) | |
############################################################ | |
# Prepare Random Data | |
############################################################ | |
srand(1234) | |
modeldir = "modeltest" | |
N = 1000 | |
M = 20 | |
P = 10 | |
X = rand(M, N) | |
W = rand(M, P) | |
B = rand(P, 1) | |
Y = (W'*X .+ B) | |
Y = Y + 0.01*randn(size(Y)) | |
############################################################ | |
# Define network | |
############################################################ | |
if use_cuda | |
backend = GPUBackend() | |
else | |
backend = CPUBackend() | |
end | |
init(backend) | |
data_layer = MemoryDataLayer(batch_size=500, data=Array[X,Y]) | |
weight_layer = InnerProductLayer(name="ip",output_dim=P, tops=[:pred], bottoms=[:data]) | |
loss_layer = SquareLossLayer(name="loss", bottoms=[:pred, :label]) | |
net = Net("TEST", backend, [loss_layer, weight_layer, data_layer]) | |
println(net) | |
############################################################ | |
# Solve | |
############################################################ | |
nepochs = 10 | |
lr_policy = LRPolicy.DecayOnValidation(0.001, "loss-square-loss", 0.5) | |
params = SolverParameters(regu_coef=0.0005, mom_policy=MomPolicy.Fixed(0.9), | |
max_iter=N*nepochs, lr_policy=lr_policy, | |
load_from=modeldir) | |
solver = SGD(params) | |
setup_coffee_lounge(solver, save_into="$modeldir/statistics.jld", every_n_iter=1000) | |
val_performance = ValidationPerformance(net) | |
## add_coffee_break(solver, val_performance, every_n_epoch=2) # same at every_n_iter=2 | |
add_coffee_break(solver, val_performance, every_n_iter=2000) | |
## register the listener to get notified on performance validation for lr_policy | |
setup(params.lr_policy, val_performance, solver) | |
## add_coffee_break(solver, Snapshot(modeldir), every_n_epoch=2) # same at every_n_iter=2 | |
add_coffee_break(solver, Snapshot(modeldir), every_n_iter=2000) | |
add_coffee_break(solver, TrainingSummary(show_obj_val=true, show_lr=true), every_n_iter=500) | |
solve(solver, net) | |
shutdown(backend) |
error message disapears with 0bbfc835e349d5729039a145362b612e3e8ef59c
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Results in the following error message:
on Julia
v"0.3.10"
, Manjaro linux