Skip to content

Instantly share code, notes, and snippets.

@zimmerrol
Created January 9, 2020 19:46
Show Gist options
  • Save zimmerrol/5dcc0d730ecf1506cc7b56fc16d3af8a to your computer and use it in GitHub Desktop.
Save zimmerrol/5dcc0d730ecf1506cc7b56fc16d3af8a to your computer and use it in GitHub Desktop.
def ihvp(f, w, v, n, alpha):
# calculate the inverse hessian vector product
# cast to list (this is important if w is a generator)
w = list(w)
p = tuple(list(v).copy())
for j in range(n):
grads = torch.autograd.grad(f, w, grad_outputs=v, retain_graph=True)
# the alpha makes the hessian contractive which is required
# for the Neumann series to converge
v = [vi - alpha * gi for vi, gi in zip(v, grads)]
p = [pi + vi for pi, vi in zip(p, v)]
# undo the scaling by alpha which is done above implicitly to make sure
# that the Hessian is contractive
if n > 0:
p = [alpha * i for i in p]
del v
p = [i.detach() for i in p]
return p
def hg(w, l, loss_v, loss_t, n, alpha):
# w: parameters
# l: hyperparameters
w = list(w)
l = list(l)
del_lv_del_l = torch.autograd.grad(loss_v, l, allow_unused=True,
retain_graph=True)
del_lv_del_w = torch.autograd.grad(loss_v, w, retain_graph=True)
del_lt_del_w = torch.autograd.grad(loss_t, w, retain_graph=True,
create_graph=True)
if n == 0:
ihvp_del_lv_del_w = [i for i in del_lv_del_w]
else:
ihvp_del_lv_del_w = ihvp(del_lt_del_w, w, del_lv_del_w, n, alpha)
del_w_del_l = torch.autograd.grad(del_lt_del_w, l,
grad_outputs=ihvp_del_lv_del_w,
retain_graph=True, allow_unused=True)
if sum([i is None for i in del_w_del_l]):
warnings.warn('del_w_del_l is null; this happens only if loss_t does '
'not depend on parameters l. Please check if this is '
'intended or a mistake.')
del_w_del_l = [i.detach() if i is not None else None for i in del_w_del_l]
del_lv_del_l = [i.detach() if i is not None else None for i in del_lv_del_l]
def sub_(a, b):
if a is None and b is None:
return None
if a is None and b is not None:
return -b
if a is not None and b is None:
return a
return a-b
gradients = [sub_(gi, vi) for gi, vi in zip(del_lv_del_l, del_w_del_l)]
del ihvp_del_lv_del_w
del del_lt_del_w
del del_lv_del_w
del del_lv_del_l
return gradients
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment