Skip to content

Instantly share code, notes, and snippets.

@lewtun
Created January 21, 2022 09:27
Show Gist options
  • Save lewtun/d01b7148d8a4bed0b2ebdf0386d89952 to your computer and use it in GitHub Desktop.
Save lewtun/d01b7148d8a4bed0b2ebdf0386d89952 to your computer and use it in GitHub Desktop.
Correction to page 334
def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
params_with_wd, params_without_wd = [], []
for n, p in model.named_parameters():
if any(nd in n for nd in no_decay):
params_without_wd.append(p)
else:
params_with_wd.append(p)
return [{'params': params_with_wd, 'weight_decay': args.weight_decay},
{'params': params_without_wd, 'weight_decay': 0.0}]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment