Skip to content

Instantly share code, notes, and snippets.

@jinpan
Created June 18, 2020 19:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jinpan/7f7281b5d72438b4ea2f0d441afc90f7 to your computer and use it in GitHub Desktop.
Save jinpan/7f7281b5d72438b4ea2f0d441afc90f7 to your computer and use it in GitHub Desktop.
def my_adam_step(p, lr, mom, mom_damp, step, sqr_mom, sqr_damp, grad_avg, sqr_avg, eps, wd, **kwargs):
debias1 = debias(mom, mom_damp, step)
debias2 = debias(sqr_mom, sqr_damp, step)
grad_avg_debiased = grad_avg / debias1
grad_var_debiased = sqr_avg / debias2
p.data = p.data - lr * ((grad_avg_debiased) / (grad_var_debiased.sqrt() + eps) + wd * p.data)
return p
my_adam_step._defaults = dict(eps=1e-5, wd=0.)
def my_adam_opt(xtra_step=None, **kwargs):
return partial(StatefulOptimizer, steppers=[my_adam_step]+listify(xtra_step),
stats=[AverageGrad(dampening=True), AverageSqrGrad(), StepCount()], **kwargs)
for seed in (7, 42, 10914):
fix_seeds(seed)
learn,run = get_learn_run(nfs, data, 0.001, conv_layer, cbs=cbfs, opt_func=my_adam_opt(wd=1e-2))
run.fit(10, learn)
print('-'*80)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment