Skip to content

Instantly share code, notes, and snippets.

@nhamilakis
Created March 2, 2022 14:54
Show Gist options
  • Save nhamilakis/d0c5cd1d5e04d6262fb49e9c4e66ee8d to your computer and use it in GitHub Desktop.
Save nhamilakis/d0c5cd1d5e04d6262fb49e9c4e66ee8d to your computer and use it in GitHub Desktop.
A sanitize script for InfTrain
"""
Script to sanitize all argument files for infSim model data.
- checkpoint_args.json & args.json : files used to save arguments used to train the model.
::> contain absolute paths that require to be converted to the current machine after moving the model data.
"""
import argparse
import json
from pathlib import Path
from typing import Union, Any
def apply(obj, fn):
""" Apply a function to all items of a structure recursively """
def is_iterable(o):
return isinstance(o, (list, dict))
def apply_fn(o):
if isinstance(o, dict):
for index, value in o.items():
if is_iterable(o):
o[index] = apply_fn(value)
else:
o[index] = fn(value)
return o
elif isinstance(o, list):
for index, value in enumerate(o):
if is_iterable(o):
o[index] = apply_fn(value)
else:
o[index] = fn(value)
return o
else:
return fn(o)
return apply_fn(obj)
def make_path_sanitizer(old_root: Path, new_root: Path):
""" Build a closure function usable for sanitizing paths after migration """
def sanitize(path: Union[str, Path], _old: Path = old_root, _new: Path = new_root):
""" change path to be relative to new root dir """
path = Path(path) # make sure its correctly wrapped
if path.is_absolute():
return str(_new / path.relative_to(_old))
return str(path)
def test(item: Any, root=str(old_root)):
""" allow sanitize only on path strings that contain the root dir"""
return isinstance(item, str) and root in item
return test, sanitize
def make_serial_sanitizer(sanitizer_list):
""" Serial sanitizer closure """
def serial_sanitizer(item):
clean_item = item
for test, sanitize in sanitizer_list:
if test(item):
clean_item = sanitize(item)
return clean_item
return serial_sanitizer
def sanitizer(options: argparse.Namespace):
""" Create the list of sanitizers """
_sanitizers = [
make_path_sanitizer(options.old_model_dir, options.new_model_dir)
]
return make_serial_sanitizer(_sanitizers)
def clean_json(location: Path, sanitizer_fn, debug: bool = False):
""" Call sanitizers to all json files contained in location """
for file in location.rglob("*.json"):
# read
with file.open() as fp:
try:
data = json.load(fp)
except json.decoder.JSONDecodeError:
print(f"Warn File {file} had issues with json encoder !!")
# clean
data = apply(data, sanitizer_fn)
# debug safety
if debug:
f = file.with_suffix('.new.json')
else:
f = file
# write
with f.open('w') as fp:
json.dump(data, fp, indent=2)
if __name__ == '__main__':
cleaners = [
sanitizer(argparse.Namespace(
old_model_dir=Path('/scratch1/projects/InfTrain/models'),
new_model_dir=Path('/data/infquery/models')
))
]
where = Path('/data/infquery/models')
for cl in cleaners:
clean_json(location=where, sanitizer_fn=cl)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment