Skip to content

Instantly share code, notes, and snippets.

@18alantom
Last active May 14, 2020 01:07
Show Gist options
  • Save 18alantom/28de9de8397cec3a3fe44098aa7a422a to your computer and use it in GitHub Desktop.
Save 18alantom/28de9de8397cec3a3fe44098aa7a422a to your computer and use it in GitHub Desktop.
A set of functions to help update a pytorch optimizer's param groups wrt learning rate and to unlock last n layers of a pytorch model.
# Imports for Annotations
from torch.nn import Module
from torch.optim import Optimizer
from typing import Optional, Union, List
def get_layers(model:Module, rgrad:bool=False):
"""
Returns all layers of the model
that have no children ie will return the
contents of a sequential but not the Sequential
rgrad : return layers whose parameters `requires_grad`
"""
for l in model.modules():
if len([*l.children()]) == 0:
params = [*l.parameters()]
if len(params) > 0:
if rgrad:
if params[0].requires_grad:
yield l
else:
continue
else:
yield l
def freeze(model:Module)-> None:
# Freeze all layers in the model
for params in model.parameters():
params.requires_grad = False
def unfreeze(model:Module)-> None:
# Unfreeze all layers in the model
for params in model.parameters():
params.requires_grad = True
def unf_last_n(model:Module, n:Optional[int]=None):
"""
Unfreeze last `n` parametric layers of the
model.
if `n is None` then all layers are unfrozen.
"""
# Freeze all the layers
freeze(model)
# Unfreeze only the required layers
if n is None:
unfreeze(model)
else:
layers = [*get_layers(model)][::-1][:n]
for layer in layers:
unfreeze(layer)
def get_lrs(lr:slice, count:Optional[int]=None):
"""
Exponentially increasing lr from
slice.start to slice.stop.
if `count is None` then count = int(stop/start)
"""
lr1 = lr.start
lr2 = lr.stop
if count is None:
count = int(lr2/lr1)
incr = np.exp((np.log(lr2/lr1)/(count-1)))
return [lr1*incr**i for i in range(count)]
def configure_optimizer(model:Module, optimizer:Optimizer,
lr:Optional[Union[List[float], slice, float]]=None,
unlock:Optional[Union[bool, int]]=None):
"""
model : a pytorch nn.Module whose params are to be optimized
optimizer : a pytorch optimizer whose paramgroups have to
be configured.
lr : If lr is a `slice` then spread the lrs exponentially
over all the unlocked layers of the neural networks.
unlock : If unlock is True unlock all the layers
else if unlock is number, unlock the last [unlock] layers
"""
pgdicts = []
param_groups = optimizer.param_groups
for param_group in param_groups:
pgdict = {}
for key in param_group:
if key not in ['lr','initial_lr','params']:
pgdict[key] = param_group[key]
pgdicts.append(pgdict)
# If no learning rate set the same learning rate to
# al unlocked layers.
if lr is None:
lr = param_groups[0]['lr']
optimizer.param_groups.clear()
layers = [*get_layers(model, True)]
for i,layer in enumerate(layers):
if len(layers) != len(param_group):
i = 0
optimizer.add_param_group({
'params':layer.parameters(),
'lr':lr,
'initial_lr':lr,
**pgdicts[i]
})
# If lr is not None apply slice
else:
optimizer.param_groups.clear()
if unlock is not None:
if unlock is True:
# Unfreeze all the layers
unf_last_n(model)
else:
# Unlock only the last n layers
unf_last_n(model, n=unlock)
# Attach learning rate to the unfrozen layers.
layers = [*get_layers(model,True)]
l = len(layers)
if isinstance(lr, slice):
lrs = get_lrs(lr, count=l)
elif isinstance(lr, list):
llay = len(layers)
llrs = len(lr)
if llrs < llay:
print("insufficient lrs")
return
d = llrs - llay
lrs = lr[d:]
else:
lrs = [lr] * l
for i,(lr, layer) in enumerate(zip(lrs, layers)):
if len(layers) != len(param_group):
i = 0
optimizer.add_param_group({
'params':layer.parameters(),
'lr':lr,
'initial_lr':lr,
**pgdicts[i]
})
def print_lr_layer(model:Module, optimizer:Optimizer):
"""
Function to print lrs : layer
"""
layers = [*get_layers(model, True)]
pgroup = optimizer.param_groups
if len(pgroup) != len(layers):
if len(pgroup) > 1:
print('param_group, unfrozen layer length mismatch')
elif len(pgroup) == 1:
lr = pgroup[0]['lr']
print(f"lr: {lr}, for: ")
for l in layers:
print(l)
else:
for pg, layer in zip(pgroup, layers):
lr = pg['lr']
print(f"lr: {lr:0.10f} :: {layer}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment