Skip to content

Instantly share code, notes, and snippets.

@innocenat
Created January 25, 2020 19:30
Show Gist options
  • Save innocenat/87e5c24d6be25094301188ea5cdb8bca to your computer and use it in GitHub Desktop.
Save innocenat/87e5c24d6be25094301188ea5cdb8bca to your computer and use it in GitHub Desktop.
Just some helper tools I created to help me save all research data, ever.
import json
import os
import re
import time
from os import path
from typing import TextIO, List, Dict, Callable, AnyStr, Any, Tuple, Pattern, Optional
import matplotlib.pyplot as plt
import numpy as np
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.float32) or isinstance(obj, np.float64):
return float(obj)
return json.JSONEncoder.default(self, obj)
class LRUCache:
def __init__(self):
pass
def cache(self, data_key: str, data: Any, metadata: Any) -> None:
pass
def read(self, data_key: str) -> Any:
return None
class Serializer:
@staticmethod
def write(fp: TextIO, data: Any, metadata: Any) -> None:
if metadata is not None:
raise Exception('Cannot store metadata yet')
json.dump(data, fp, cls=NumpyEncoder)
@staticmethod
def read(fp: TextIO) -> Tuple[Any, Any]:
return json.load(fp), None
class Provider:
TYPE_DATA = 0 # Default
TYPE_GRAPH = 1
function: Optional[Callable[['DataFlowInstanced'], None]]
provides: List[str]
dataset: Pattern[AnyStr]
depends: List[str]
requires: List[str]
consumes: List[str]
type: int
def __init__(self, function: Optional[Callable[['DataFlowInstanced'], None]], provides: List[str],
dataset: str = None, depends: List[str] = None, requires: List[str] = None,
consumes: List[str] = None, provider_type: int = 0):
self.function = function
self.provides = provides
self.dataset = re.compile(dataset) if dataset is not None else None
self.depends = depends if depends is not None else []
self.requires = requires if requires is not None else []
self.consumes = consumes if consumes is not None else []
self.type = provider_type
def can_provide(self, dataset: str, data_name: str) -> bool:
if self.dataset is None:
return data_name in self.provides
return self.dataset.match(dataset) and data_name in self.provides
def do_provide(self, flow: 'DataFlowInstanced') -> None:
self.function(flow)
def required_options_name(self, flow: 'DataFlow') -> List[str]:
options = self.requires.copy()
for data_name in self.depends:
provider = flow.provider(data_name)
if provider is None:
raise Exception('Cannot find provider for {}'.format(data_name))
options.extend(flow.provider(data_name).required_options_name(flow))
for data_name in self.consumes:
if data_name in options:
options.remove(data_name)
return options
def is_depend_on(self, data_name: str) -> bool:
return data_name in self.depends
def has_option(self, option_name: str) -> bool:
return option_name in self.requires
class DataFlow:
_data_directory: str
_dataset: str
providers: List[Provider]
_remap: Dict[str, str]
_reverse_remap: Dict[str, str]
options: Dict[str, any]
_plt_current_fig: Any
_plt_plot_options: Dict[str, any]
def __init__(self, data_directory: str):
self._data_directory = data_directory
self._dataset = '__default__'
self.providers = []
self._remap = {}
self._reverse_remap = {}
self._options = {}
self._plt_current_fig = None
self._plt_plot_options = {}
def add_provider(self, provider: Provider) -> None:
self.providers.append(provider)
def provider(self, data_name: str) -> Optional[Provider]:
data_name = self._actual_data_name(data_name)
for p in self.providers:
if p.can_provide(self._dataset, data_name):
return p
return None
def dataset(self, dataset: str) -> None:
self._dataset = dataset
data_path = "{}/{}".format(self._data_directory, dataset)
if not path.isdir(data_path):
os.makedirs(data_path)
def options(self, options: Dict[str, Any]) -> None:
self._options = options
def remap(self, src: str, dst: str) -> None:
self._remap[dst] = src
self._reverse_remap[src] = dst
def dataset_match(self, regexp: Pattern[AnyStr]) -> bool:
return re.match(regexp, self._dataset) is not None
def request(self, data_name: str, data_options: Dict[str, Any] = None) -> Any:
return DataFlowInstanced(self, None, self._options).request(data_name, data_options)
def plot(self, data_name: str, data_options: Dict[str, Any] = None):
plt.clf()
self._plt_current_fig = None
self.request(data_name, data_options)
fig = self._plt_current_fig
if fig is None:
raise Exception('Plotter {} did not plot'.format(data_name))
fig.savefig(self.filepath(data_name, self._plt_plot_options, 'eps'), dpi=fig.dpi)
plt.show()
def get_plt_axe(self, nrows=1, ncols=1):
fig, ax = plt.subplots(nrows, ncols)
self._plt_current_fig = fig
return fig, ax
def _actual_data_name(self, data_name: str) -> str:
if data_name in self._remap:
return self._remap[data_name]
return data_name
def filepath(self, data_name: str, data_options: Dict[str, Any], ext: str = 'dat') -> str:
options_string = []
for k, v in data_options.items():
options_string.append('{}-{}'.format(k, v))
options_string = list(sorted(options_string))
option_string = ''
if len(options_string) > 0:
option_string = '__' + '__'.join(options_string)
return "{}/{}/{}{}.{}".format(self._data_directory, self._dataset, data_name, option_string, ext)
class DataFlowInstanced:
_flow: DataFlow
_provider: Optional[Provider]
_options: Dict[str, Any]
# Plotting environment
is_graphics: bool
def __init__(self, flow: DataFlow, provider: Optional[Provider], options: Dict[str, Any]):
self._flow = flow
self._provider = provider
self._options = options
self.is_graphics = provider is not None and provider.type == Provider.TYPE_GRAPH
if self.is_graphics:
flow._plt_plot_options = dict(options)
def store(self, data_name: str, data: Any) -> None:
if self._provider is None:
raise Exception('Cannot store data without associated provider')
if data_name not in self._provider.provides:
raise Exception('Provider for {} cannot provide data {}'.format(self._provider.provides, data_name))
if self.is_graphics:
raise Exception('Cannot store graphics data {}'.format(data_name))
filepath = self._flow.filepath(data_name, self._options)
with open(filepath, "w") as fp:
Serializer.write(fp, data, None)
def _load(self, data_name: str, data_options: Dict[str, any]) -> Any:
if self._provider is not None and not self._provider.is_depend_on(data_name):
raise Exception('Data {} cannot have dependency on {}'.format(self._provider.provides, data_name))
filepath = self._flow.filepath(data_name, data_options)
if not path.isfile(filepath):
return None
with open(filepath, "r") as fp:
data, metadata = Serializer.read(fp)
return data
def get_plt(self, nrows=1, ncols=1):
if not self.is_graphics:
raise Exception('Cannot get plotting environment for non-graphics data')
return self._flow.get_plt_axe(nrows, ncols)
def request(self, data_name: str, data_options: Dict[str, Any] = None) -> Any:
if self._provider is not None and not self._provider.is_depend_on(data_name):
raise Exception('Data {} cannot have dependency on {}'.format(self._provider.provides, data_name))
current_options = self._options.copy()
if data_options is not None:
current_options.update(data_options)
provider = self._flow.provider(data_name)
if provider is None:
raise Exception('Cannot find provider for data "{}"'.format(data_name))
target_options_name = provider.required_options_name(self._flow)
target_options = {}
for target_option_name in target_options_name:
if target_option_name not in current_options:
raise Exception("Option {} not provided for data {}".format(target_option_name, data_name))
target_options[target_option_name] = current_options[target_option_name]
if provider.type != Provider.TYPE_GRAPH:
data = self._load(data_name, target_options)
if data is not None:
# TODO Validate last modified chain
return data
sub_instance = DataFlowInstanced(self._flow, provider, target_options)
print('Executing {}...'.format(data_name))
t0 = time.time()
provider.do_provide(sub_instance)
print('Elapsed ({}): {} seconds.'.format(data_name, time.time() - t0))
if provider.type != Provider.TYPE_GRAPH:
return self._load(data_name, target_options)
def requires(self, options_list: List[str]) -> List[Any]:
return [self.option(name) for name in options_list]
def option(self, option_name: str) -> Any:
if self._provider is not None and not self._provider.has_option(option_name):
raise Exception('Data {} cannot have option {}'.format(self._provider.provides, option_name))
if option_name not in self._options:
raise Exception('Option {} not provided.'.format(option_name))
return self._options[option_name]
if __name__ == '__main__':
# Usage guide
def provider1(container: DataFlowInstanced):
container.store('data-1', [1, 2, 3, 4, 5])
def provider2(container: DataFlowInstanced):
data1 = container.request('data-1')
multiple = container.option('multiple')
data2 = [x * multiple for x in data1]
container.store('data-2', data2)
def plotter(container: DataFlowInstanced):
fig, ax = container.get_plt()
data = container.request('data-2')
ax.plot(list(range(len(data))), data)
flow = DataFlow('/tmp/dataset')
flow.dataset('dataset-1')
flow.add_provider(Provider(
provider1, ['data-1']
))
flow.add_provider(Provider(
provider2, ['data-2'],
depends=['data-1'],
requires=['multiple']
))
flow.add_provider(Provider(
plotter, ['plt-data'],
depends=['data-2'],
provider_type=Provider.TYPE_GRAPH
))
flow.options({
'multiple': 3
})
flow.plot('plt-data')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment