Last active
March 24, 2019 11:31
-
-
Save bckim92/c7fe090fe943e5da64526afe2a5ba9d5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Dict, List, Optional, Tuple, Union | |
import logging | |
import shutil | |
import math | |
import bisect | |
import os | |
logger = logging.getLogger(__name__) | |
class TopK(object): | |
""" | |
Maintain top-k value using min bisect. | |
Currently, it allows duplicate | |
""" | |
def __init__(self, | |
k: int) -> None: | |
self.sorted_tuples: List[Tuple[Any, Any]] = [] | |
self.sorted_keys: List[Any] = [] | |
self.k = k | |
def __str__(self): | |
from pprint import pformat | |
return pformat(self.sorted_tuples, indent=4) | |
def __repr__(self): | |
return super().__repr__() + "(\n" + self.__str__() + "\n)" | |
def update(self, keyvalue: Tuple[Any, Any]) -> Tuple[bool, int]: | |
""" | |
O(log(n)) | |
""" | |
key, value = keyvalue | |
list_size = len(self.sorted_tuples) | |
insert_location = bisect.bisect_right(self.sorted_keys, key) | |
if list_size <= self.k or insert_location > 0: | |
self.sorted_tuples.insert(insert_location, keyvalue) | |
self.sorted_keys.insert(insert_location, key) | |
if list_size == self.k: | |
del self.sorted_tuples[0] | |
del self.sorted_keys[0] | |
is_update = True | |
kth_largest = list_size + 1 - insert_location if insert_location == 0 \ | |
else list_size + 2 - insert_location | |
else: | |
is_update = False | |
kth_largest = -1 | |
return is_update, kth_largest | |
def kth_largest(self, k: int) -> Tuple[Any, Any]: | |
""" | |
O(1) | |
""" | |
assert k <= self.k | |
return self.sorted_tuples[-k] | |
class CheckpointTracker(object): | |
""" | |
This class implements the functionality for maintaing best checkpoints | |
""" | |
def __init__(self, | |
checkpoint_path: str, | |
model_name: str = 'model', | |
save_path_name: str = 'best_checkpoints', | |
max_to_keep: int = 1) -> None: | |
self._src_path = checkpoint_path | |
self._tgt_path = os.path.join(checkpoint_path, save_path_name) | |
self._model_name = model_name | |
self._max_to_keep = max_to_keep | |
self._tracker_state: TopK = TopK(max_to_keep) | |
def update(self, score, step) -> bool: | |
src_path, tgt_path = self._src_path, self._tgt_path | |
is_update, kth_largest = self._tracker_state.update((score, step)) | |
os.makedirs(tgt_path, exist_ok=True) | |
if is_update: | |
logger.info(f"{kth_largest}-th best score so far. \ | |
Copying weights to '{tgt_path}'.") | |
src_fnames = self._get_src_ckpt_name(step) | |
tgt_fnames = self._get_tgt_ckpt_name(kth_largest) | |
for src_fname, tgt_fname in zip(src_fnames, tgt_fnames): | |
shutil.copyfile(src_fname, tgt_fname) | |
with open(os.path.join(tgt_path, f'{kth_largest}th_info.txt'), 'w') as fp: | |
fp.write(f"Step: {step}, Score: {score}") | |
return True | |
else: | |
return False | |
def _get_src_ckpt_name(self, step): | |
model_name = self._model_name | |
fname_templates = ( | |
f'{model_name}.ckpt-{step}.index', | |
f'{model_name}.ckpt-{step}.data-00000-of-00001', | |
f'{model_name}.ckpt-{step}.meta') | |
return map(lambda x: os.path.join(self._src_path, x), fname_templates) | |
def _get_tgt_ckpt_name(self, kth_best): | |
fname_templates = ( | |
f'{kth_best}th_best.ckpt.index', | |
f'{kth_best}th_best.ckpt.data-00000-of-00001', | |
f'{kth_best}th_best.ckpt.meta') | |
return map(lambda x: os.path.join(self._tgt_path, x), fname_templates) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment