Created
March 2, 2022 14:54
-
-
Save nhamilakis/d0c5cd1d5e04d6262fb49e9c4e66ee8d to your computer and use it in GitHub Desktop.
A sanitize script for InfTrain
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script to sanitize all argument files for infSim model data. | |
- checkpoint_args.json & args.json : files used to save arguments used to train the model. | |
::> contain absolute paths that require to be converted to the current machine after moving the model data. | |
""" | |
import argparse | |
import json | |
from pathlib import Path | |
from typing import Union, Any | |
def apply(obj, fn): | |
""" Apply a function to all items of a structure recursively """ | |
def is_iterable(o): | |
return isinstance(o, (list, dict)) | |
def apply_fn(o): | |
if isinstance(o, dict): | |
for index, value in o.items(): | |
if is_iterable(o): | |
o[index] = apply_fn(value) | |
else: | |
o[index] = fn(value) | |
return o | |
elif isinstance(o, list): | |
for index, value in enumerate(o): | |
if is_iterable(o): | |
o[index] = apply_fn(value) | |
else: | |
o[index] = fn(value) | |
return o | |
else: | |
return fn(o) | |
return apply_fn(obj) | |
def make_path_sanitizer(old_root: Path, new_root: Path): | |
""" Build a closure function usable for sanitizing paths after migration """ | |
def sanitize(path: Union[str, Path], _old: Path = old_root, _new: Path = new_root): | |
""" change path to be relative to new root dir """ | |
path = Path(path) # make sure its correctly wrapped | |
if path.is_absolute(): | |
return str(_new / path.relative_to(_old)) | |
return str(path) | |
def test(item: Any, root=str(old_root)): | |
""" allow sanitize only on path strings that contain the root dir""" | |
return isinstance(item, str) and root in item | |
return test, sanitize | |
def make_serial_sanitizer(sanitizer_list): | |
""" Serial sanitizer closure """ | |
def serial_sanitizer(item): | |
clean_item = item | |
for test, sanitize in sanitizer_list: | |
if test(item): | |
clean_item = sanitize(item) | |
return clean_item | |
return serial_sanitizer | |
def sanitizer(options: argparse.Namespace): | |
""" Create the list of sanitizers """ | |
_sanitizers = [ | |
make_path_sanitizer(options.old_model_dir, options.new_model_dir) | |
] | |
return make_serial_sanitizer(_sanitizers) | |
def clean_json(location: Path, sanitizer_fn, debug: bool = False): | |
""" Call sanitizers to all json files contained in location """ | |
for file in location.rglob("*.json"): | |
# read | |
with file.open() as fp: | |
try: | |
data = json.load(fp) | |
except json.decoder.JSONDecodeError: | |
print(f"Warn File {file} had issues with json encoder !!") | |
# clean | |
data = apply(data, sanitizer_fn) | |
# debug safety | |
if debug: | |
f = file.with_suffix('.new.json') | |
else: | |
f = file | |
# write | |
with f.open('w') as fp: | |
json.dump(data, fp, indent=2) | |
if __name__ == '__main__': | |
cleaners = [ | |
sanitizer(argparse.Namespace( | |
old_model_dir=Path('/scratch1/projects/InfTrain/models'), | |
new_model_dir=Path('/data/infquery/models') | |
)) | |
] | |
where = Path('/data/infquery/models') | |
for cl in cleaners: | |
clean_json(location=where, sanitizer_fn=cl) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment