Skip to content

Instantly share code, notes, and snippets.

@willprice
Created February 17, 2021 17:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save willprice/efc1e390c06b50b74e63331b0cf92e80 to your computer and use it in GitHub Desktop.
Save willprice/efc1e390c06b50b74e63331b0cf92e80 to your computer and use it in GitHub Desktop.
A little helper for finetuning networks
from torch import nn
import torch
from typing import List
def filter_parameters_for_finetuning(module: nn.Module) -> List[torch.Tensor]:
"""
Args:
module: A :py:class:`nn.Module` object where some of the children may have
a boolean attribute ``finetune``, which if it exists and is ``False``,
will exclude parameters from this submodule from the result.
Returns:
A list of parameters in the module for finetuning.
"""
params = []
# We're going to look at each direct child of the current module and recurse into
# those that don't have ``finetune=False``.
for child in module.children():
if hasattr(child, 'finetune') and not child.finetune:
# We don't recurse into this part of the module subtree since we want to
# freeze all these parameters
continue
# but if the child is going to be finetuned, we need to add all parameters
# declared in the child
params.extend(child.parameters(recurse=False))
# and all of its children which also have ``finetune=True``
params.extend(filter_parameters_for_finetuning(child))
return params
def demo():
class SubSubSubModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(5, 7)
class SubSubModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(10, 15)
self.m = SubSubSubModule()
self.finetune = False
class SubModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(20, 30)
self.m = SubSubModule()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(40, 60)
self.m = SubModule()
net = Net()
print([
param.shape for param in filter_parameters_for_finetuning(net)
])
# Outputs
# [torch.Size([60, 40]), <- Net.lin.weight
# torch.Size([60]), <- Net.lin.bias
# torch.Size([30, 20]), <- Net.m.lin.weight
# torch.Size([30])] <- Net.m.lin.bias
@willprice
Copy link
Author

Currently this doesn't support overriding finetuning of parts of a submodule which has finetune=False, e.g. if SubSbuSubModule has finetune=True it would still be excluded.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment