Skip to content

Instantly share code, notes, and snippets.

@bckim92
Last active March 24, 2019 11:31
Show Gist options
  • Save bckim92/c7fe090fe943e5da64526afe2a5ba9d5 to your computer and use it in GitHub Desktop.
Save bckim92/c7fe090fe943e5da64526afe2a5ba9d5 to your computer and use it in GitHub Desktop.
from typing import Any, Dict, List, Optional, Tuple, Union
import logging
import shutil
import math
import bisect
import os
logger = logging.getLogger(__name__)
class TopK(object):
"""
Maintain top-k value using min bisect.
Currently, it allows duplicate
"""
def __init__(self,
k: int) -> None:
self.sorted_tuples: List[Tuple[Any, Any]] = []
self.sorted_keys: List[Any] = []
self.k = k
def __str__(self):
from pprint import pformat
return pformat(self.sorted_tuples, indent=4)
def __repr__(self):
return super().__repr__() + "(\n" + self.__str__() + "\n)"
def update(self, keyvalue: Tuple[Any, Any]) -> Tuple[bool, int]:
"""
O(log(n))
"""
key, value = keyvalue
list_size = len(self.sorted_tuples)
insert_location = bisect.bisect_right(self.sorted_keys, key)
if list_size <= self.k or insert_location > 0:
self.sorted_tuples.insert(insert_location, keyvalue)
self.sorted_keys.insert(insert_location, key)
if list_size == self.k:
del self.sorted_tuples[0]
del self.sorted_keys[0]
is_update = True
kth_largest = list_size + 1 - insert_location if insert_location == 0 \
else list_size + 2 - insert_location
else:
is_update = False
kth_largest = -1
return is_update, kth_largest
def kth_largest(self, k: int) -> Tuple[Any, Any]:
"""
O(1)
"""
assert k <= self.k
return self.sorted_tuples[-k]
class CheckpointTracker(object):
"""
This class implements the functionality for maintaing best checkpoints
"""
def __init__(self,
checkpoint_path: str,
model_name: str = 'model',
save_path_name: str = 'best_checkpoints',
max_to_keep: int = 1) -> None:
self._src_path = checkpoint_path
self._tgt_path = os.path.join(checkpoint_path, save_path_name)
self._model_name = model_name
self._max_to_keep = max_to_keep
self._tracker_state: TopK = TopK(max_to_keep)
def update(self, score, step) -> bool:
src_path, tgt_path = self._src_path, self._tgt_path
is_update, kth_largest = self._tracker_state.update((score, step))
os.makedirs(tgt_path, exist_ok=True)
if is_update:
logger.info(f"{kth_largest}-th best score so far. \
Copying weights to '{tgt_path}'.")
src_fnames = self._get_src_ckpt_name(step)
tgt_fnames = self._get_tgt_ckpt_name(kth_largest)
for src_fname, tgt_fname in zip(src_fnames, tgt_fnames):
shutil.copyfile(src_fname, tgt_fname)
with open(os.path.join(tgt_path, f'{kth_largest}th_info.txt'), 'w') as fp:
fp.write(f"Step: {step}, Score: {score}")
return True
else:
return False
def _get_src_ckpt_name(self, step):
model_name = self._model_name
fname_templates = (
f'{model_name}.ckpt-{step}.index',
f'{model_name}.ckpt-{step}.data-00000-of-00001',
f'{model_name}.ckpt-{step}.meta')
return map(lambda x: os.path.join(self._src_path, x), fname_templates)
def _get_tgt_ckpt_name(self, kth_best):
fname_templates = (
f'{kth_best}th_best.ckpt.index',
f'{kth_best}th_best.ckpt.data-00000-of-00001',
f'{kth_best}th_best.ckpt.meta')
return map(lambda x: os.path.join(self._tgt_path, x), fname_templates)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment