Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save max-kuk/463163cee2530e87a9109781feacc581 to your computer and use it in GitHub Desktop.
Save max-kuk/463163cee2530e87a9109781feacc581 to your computer and use it in GitHub Desktop.
import torch
from collections import OrderedDict
from typing import List
checkpoints_weights_paths: List[str] = ... # sorted in descending order by score
model: torch.nn.Module = ...
def average_weights(state_dicts: List[dict]):
everage_dict = OrderedDict()
for k in state_dicts[0].keys():
everage_dict[k] = sum([state_dict[k] for state_dict in state_dicts]) / len(state_dicts)
return everage_dict
all_weights = [torch.load(path) for path in checkpoints_weights_paths]
best_score = 0
best_weights = []
for w in all_weights:
current_weights = best_weights + [w]
average_dict = average_weights(current_weights)
model.load_state_dict(average_dict)
score = evaluate_model(model, ...)
if score > best_score:
best_score = score
best_weights.append(w)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment