Created
July 27, 2021 10:58
-
-
Save nmichlo/1e2027a47f96b885a094123558d0d1e4 to your computer and use it in GitHub Desktop.
task handler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses | |
import inspect | |
from functools import lru_cache | |
from typing import Any | |
from typing import Dict | |
from typing import Set | |
from typing import Tuple | |
from typing import Union | |
# ========================================================================= # | |
# Task Builders # | |
# - I know this is overkill... but i was having fun... # | |
# ========================================================================= # | |
IN = object() | |
TASK = object() | |
@dataclasses.dataclass | |
class Task(object): | |
fn: callable | |
name: str | |
# parameters | |
params: Tuple[str, ...] | |
params_inputs: Tuple[str, ...] | |
params_parents: Tuple[str, ...] | |
params_optional: Tuple[str, ...] | |
@lru_cache() | |
def _task_handler_get_task(fn) -> Task: | |
# get name | |
name = fn.__name__ | |
if name.startswith('task__'): | |
name = name[len('task__'):] | |
elif name.startswith('_task__'): | |
name = name[len('_task__'):] | |
if not name: | |
raise ValueError(f'task function has empty name: {repr(fn.__name__)}') | |
# get parameters | |
inputs, parents, optional, params = [], [], [], [] | |
for arg_name, param in inspect.signature(fn).parameters.items(): | |
if param.default is param.empty: | |
raise RuntimeError(f'task {repr(name)} has non-keyword argument: {repr(arg_name)}') | |
elif param.default is TASK: | |
parents.append(arg_name) | |
elif param.default is IN: | |
inputs.append(arg_name) | |
else: | |
optional.append(arg_name) | |
params.append(arg_name) | |
# return task | |
return Task(fn=fn, name=name, params=tuple(params), params_inputs=tuple(inputs), params_parents=tuple(parents), params_optional=tuple(optional)) | |
@lru_cache() | |
def _task_handler_get_parents( | |
task_names: Tuple[str, ...], | |
task_fns: Tuple[callable, ...], | |
) -> Tuple[callable, ...]: | |
if not task_fns: | |
raise ValueError(f'No task functions were given: {task_fns}') | |
# get functions that this depends on | |
task_map = {task.name: task for task in (_task_handler_get_task(fn) for fn in task_fns)} | |
# get all dependencies | |
unprocessed, compute = set(task_names), set() | |
# check they are valid | |
if unprocessed - task_map.keys(): | |
raise KeyError(f'Specified task names do not exist: {sorted(unprocessed)}, valid task names are: {sorted(task_map.keys())}') | |
# add all parents | |
while unprocessed: | |
name = unprocessed.pop() | |
compute.add(name) | |
unprocessed.update(task_map[name].params_parents) | |
# done! | |
task_fns_minimal = tuple(fn for fn in task_fns if _task_handler_get_task(fn).name in compute) | |
compute = tuple(task_map[name].fn for name in compute) | |
return compute, task_fns_minimal | |
@lru_cache() | |
def _task_handler_check_arguments( | |
compute: Tuple[callable, ...], | |
task_fns: Tuple[callable, ...], | |
input_symbol_names: Tuple[str, ...], | |
strict: bool = True, | |
disable_options: bool = True, | |
): | |
# strict mode checks all parameters even if not needed | |
if strict: | |
compute = task_fns | |
# get parameters from functions & parents of tasks marked to compute | |
inputs = {name for fn in compute for name in _task_handler_get_task(fn).params_inputs} | |
parents = {name for fn in compute for name in _task_handler_get_task(fn).params_parents} | |
optional = {name for fn in compute for name in _task_handler_get_task(fn).params_optional} | |
# check that we have no options if they are disabled | |
if disable_options: | |
if optional: | |
raise RuntimeError(f'Optional symbols have been disabled: {sorted(optional)}, set `disable_options=False` to skip this error.') | |
# check dependencies between tasks | |
if inputs & parents: raise RuntimeError(f'An input symbol has the same name as a parent symbol: {sorted(inputs & parents)}') | |
if optional & parents: raise RuntimeError(f'An optional symbol has the same name as a parent symbol: {sorted(optional & parents)}') | |
if inputs & optional: raise RuntimeError(f'An input symbol has the same name as an optional symbol: {sorted(inputs & optional)}') | |
# check against kwargs | |
input_symbol_names = set(input_symbol_names) | |
if input_symbol_names & parents: raise RuntimeError(f'A given argument has the same name as a parent symbol: {sorted(input_symbol_names & parents)}') | |
if inputs - input_symbol_names: raise RuntimeError(f'All the required inputs have not been passed as arguments: {sorted(inputs - input_symbol_names)}') | |
if input_symbol_names - (inputs | optional): raise RuntimeError(f'Invalid arguments were found that are not input or optional symbols: {sorted(input_symbol_names - (inputs | optional))}') | |
# done checks! | |
class TaskHandler(object): | |
def __init__( | |
self, | |
task_names: Union[str, Tuple[str, ...]], | |
task_fns: Tuple[callable, ...], | |
symbols: Dict[str, Any] = None, | |
strict: bool = True, | |
disable_options: bool = True | |
): | |
self._task_names_orig = task_names | |
self._task_names: Tuple[str, ...] = (task_names,) if isinstance(task_names, str) else tuple(task_names) | |
self._task_fns: Tuple[callable, ...] = task_fns | |
# get defaults | |
symbols = {} if (symbols is None) else dict(symbols) | |
# get compute graph | |
compute, self._task_fns_min = _task_handler_get_parents(task_names=self._task_names, task_fns=task_fns) | |
_task_handler_check_arguments(compute=compute, task_fns=task_fns, input_symbol_names=tuple(sorted(symbols.keys())), strict=strict, disable_options=disable_options) | |
# dispatch variables | |
self._compute: Set[callable] = set(compute) | |
self._compute_all: Set[callable] = set(task_fns) | |
self._symbols: Dict[str, Any] = symbols | |
def dispatch_all(self): | |
for task in self._task_fns_min: | |
self.dispatch(task) | |
return self | |
def dispatch(self, fn): | |
if fn in self._compute: | |
task = _task_handler_get_task(fn) | |
kwargs = {name: self._symbols[name] for name in task.params if name in self._symbols} | |
result = task.fn(**kwargs) | |
self._symbols[task.name] = result | |
elif fn not in self._compute_all: | |
raise KeyError(f'tried to dispatch function that has not been registered: {_task_handler_get_task(fn).name}') | |
return self | |
def result(self): | |
if isinstance(self._task_names_orig, str): | |
return self._symbols[self._task_names_orig] | |
else: | |
return tuple(self._symbols[name] for name in self._task_names) | |
@staticmethod | |
def compute(task_names: Union[str, Tuple[str, ...]], task_fns: Tuple[callable, ...], symbols: Dict[str, Any] = None, strict: bool = True, disable_options=True) -> Tuple[Any, ...]: | |
# create the compute graph | |
handler = TaskHandler( | |
task_names=task_names, | |
task_fns=task_fns, | |
symbols=symbols, | |
strict=strict, | |
disable_options=disable_options, | |
) | |
# these may or may not be evaluated! | |
handler.dispatch_all() | |
# return the results | |
return handler.result() | |
# ========================================================================= # | |
# Dataset Visualisation / Traversals -- HELPER # | |
# ========================================================================= # | |
class _TraversalTasks(object): | |
@staticmethod | |
def task__factor_idxs(gt_data=IN, factor_names=IN): | |
return get_factor_idxs(gt_data, factor_names) | |
@staticmethod | |
def task__factors(factor_idxs=TASK, gt_data=IN, seed=IN, base_factors=IN, num=IN, traverse_mode=IN): | |
with TempNumpySeed(seed): | |
return np.stack([ | |
gt_data.sample_random_factor_traversal(f_idx, base_factors=base_factors, num=num, mode=traverse_mode) | |
for f_idx in factor_idxs | |
], axis=0) | |
@staticmethod | |
def task__raw_grid(factors=TASK, gt_data=IN, data_mode=IN): | |
return [gt_data.dataset_batch_from_factors(f, mode=data_mode) for f in factors] | |
@staticmethod | |
def task__aug_grid(raw_grid=TASK, augment_fn=IN): | |
if augment_fn is not None: | |
return [augment_fn(batch) for batch in raw_grid] | |
return raw_grid | |
@staticmethod | |
def task__grid(aug_grid=TASK): | |
return np.stack(aug_grid, axis=0) | |
@staticmethod | |
def task__image(grid=TASK, num=IN, pad=IN, border=IN, bg_color=IN): | |
return make_image_grid(np.concatenate(grid, axis=0), pad=pad, border=border, bg_color=bg_color, num_cols=num) | |
@staticmethod | |
def task__animation(grid=TASK, pad=IN, border=IN, bg_color=IN): | |
return make_animated_image_grid(np.stack(grid, axis=0), pad=pad, border=border, bg_color=bg_color, num_cols=None) | |
@staticmethod | |
def task__image_wandb(image=TASK): | |
import wandb | |
return wandb.Image(image) | |
@staticmethod | |
def task__animation_wandb(animation=TASK): | |
import wandb | |
return wandb.Video(np.transpose(animation, [0, 3, 1, 2]), fps=5, format='mp4') | |
@staticmethod | |
def task__image_plt(image=TASK): | |
return plt_imshow(img=image) | |
# ========================================================================= # | |
# Dataset Visualisation / Traversals # | |
# ========================================================================= # | |
def dataset_traversal_tasks( | |
gt_data: Union[GroundTruthData, GroundTruthDataset], | |
# task settings | |
tasks: Union[str, Tuple[str, ...]] = 'grid', | |
# inputs | |
factor_names: Optional[NonNormalisedFactors] = None, | |
num: int = 9, | |
seed: int = 777, | |
base_factors=None, | |
traverse_mode='cycle', | |
# images & animations | |
pad: int = 4, | |
border: bool = True, | |
bg_color: Number = None, | |
# augment | |
augment_fn: callable = None, | |
data_mode: str = 'raw', | |
): | |
""" | |
Generic function that can return multiple parts of the dataset & factor traversal pipeline. | |
- This only evaluates what is needed to compute the next components. | |
Tasks include: | |
- factor_idxs | |
- factors | |
- grid | |
- image | |
- image_wandb | |
- image_plt | |
- animation | |
- animation_wandb | |
""" | |
# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # | |
# normalise dataset | |
if not isinstance(gt_data, GroundTruthDataset): | |
gt_data = GroundTruthDataset(gt_data) | |
# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # | |
return TaskHandler.compute( | |
task_names=tasks, | |
task_fns=( | |
_TraversalTasks.task__factor_idxs, | |
_TraversalTasks.task__factors, | |
_TraversalTasks.task__raw_grid, | |
_TraversalTasks.task__aug_grid, | |
_TraversalTasks.task__grid, | |
_TraversalTasks.task__image, | |
_TraversalTasks.task__animation, | |
_TraversalTasks.task__image_wandb, | |
_TraversalTasks.task__animation_wandb, | |
_TraversalTasks.task__image_plt, | |
), | |
symbols=dict( | |
gt_data=gt_data, | |
# inputs | |
factor_names=factor_names, | |
num=num, | |
seed=seed, | |
base_factors=base_factors, | |
traverse_mode=traverse_mode, | |
# animation & images | |
pad=pad, | |
border=border, | |
bg_color=bg_color, | |
# augment | |
augment_fn=augment_fn, | |
data_mode=data_mode, | |
), | |
strict=True, | |
disable_options=True, | |
) | |
# ========================================================================= # | |
# END # | |
# ========================================================================= # |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment