Skip to content

Instantly share code, notes, and snippets.

@epwalsh
Last active February 19, 2021 20:06
Show Gist options
  • Save epwalsh/2e592a266af420d5585fecbf4bcd389d to your computer and use it in GitHub Desktop.
Save epwalsh/2e592a266af420d5585fecbf4bcd389d to your computer and use it in GitHub Desktop.
from typing import List, Tuple, Dict, Any
import torch
from allennlp.common.lazy import Lazy
from allennlp.common.params import Params
from allennlp.training.optimizers import Optimizer
ParameterGroupType = List[Tuple[List[str], Dict[str, Any]]]
@Optimizer.register("regex")
class RegexOptimizer(Optimizer):
@classmethod
def from_params(
cls,
model_parameters: List[Tuple[str, torch.nn.Parameter]],
optimizers: List[Tuple[ParameterGroupType, Lazy[Optimizer]]],
):
pass
# Example:
MyModel = None
optimizer = Optimizer.from_params(
Params(
{
"type": "regex",
"optimizers": [
[
[
[["^tag_projection_layer\\..*\\.weight$"], {}],
],
{
"type": "adam",
"lr": 1,
},
],
[
[
[["^text_field_embedder.*"], {"weight_decay": 0.001}],
],
{
"type": "adam",
"lr": 2,
},
],
[
[
[["^encoder.*"], {"weight_decay": 0.001}],
],
{
"type": "adam",
"lr": 3,
},
],
},
}
),
model_parameters=MyModel.named_parameters(),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment