Skip to content

Instantly share code, notes, and snippets.

@wolf1986
Last active February 18, 2019 17:40
Show Gist options
  • Save wolf1986/1759c9b8bcfb09e92faf7e43797e4aef to your computer and use it in GitHub Desktop.
Save wolf1986/1759c9b8bcfb09e92faf7e43797e4aef to your computer and use it in GitHub Desktop.
Research Benchmark Framework
# Adapted excerpt from: https://gist.github.com/wolf1986/1759c9b8bcfb09e92faf7e43797e4aef
import io
import traceback
from pprint import pformat
from time import perf_counter
from typing import List, Iterable, Callable, Any
from log_utils.data_logger import DataLogger
import numpy as np
class CaseBase:
def __init__(self):
self.info = dict()
self.title = '[Untitled Case]'
def __str__(self, *args, **kwargs):
return self.title
# noinspection PyMethodMayBeStatic
def visualize(self):
return None
class ResultBase:
def __init__(self):
self.info = dict()
def __str__(self):
return pformat(self.info, compact=True)
class CaseResultBase(ResultBase):
def __init__(self):
super().__init__()
self.case = None # type: CaseBase
self.is_successful = False
self.time_total_ms = 0
def __str__(self):
prefix = 'PASS' if self.is_successful else 'FAIL'
return '{} - {}'.format(prefix, str(self.info))
# noinspection PyMethodMayBeStatic
def visualize(self):
return None
class BenchmarkBase:
def __init__(
self,
case_result_factory: Callable[[], CaseResultBase],
case_executor: Callable[[CaseBase, CaseResultBase], Any],
case_result_analyzer: Callable[[CaseResultBase], Any] = None
):
self._case_result_factory = case_result_factory
self._case_result_analyzer = case_result_analyzer
self._case_executor = case_executor
self.logger = DataLogger(self.__class__.__name__)
def run_case(self, case: CaseBase) -> ResultBase:
result = self._case_result_factory()
assert (isinstance(result, CaseResultBase))
result.case = case
self.logger.info('Case: {}'.format(case))
# Visualize case
self.logger.debug('Case ' + str(case), data=case.visualize)
time_start = perf_counter()
is_exception_occurred = False
# noinspection PyBroadException
try:
self._case_executor(case, result)
result.time_total_ms = (perf_counter() - time_start) * 1000
if self._case_result_analyzer:
self._case_result_analyzer(result)
except Exception as e:
is_exception_occurred = True
result.info['exception'] = str(e)
result.info['traceback'] = traceback.format_exc()
if is_exception_occurred:
result.time_total_ms = (perf_counter() - time_start) * 1000
self.logger.info(str(result))
self.logger.debug('* Run-Time: {:.2f} [sec]'.format(result.time_total_ms / 1000))
if not is_exception_occurred:
# Visualize case & results
self.logger.debug('Case Result ' + str(case), data=result.visualize)
return result
def run_cases(self, iter_cases: Iterable[CaseBase]):
self.logger.info('Executing Benchmark...')
# noinspection PyTypeChecker
return map(self.run_case, iter_cases)
class PrintStream:
def __init__(self, stream=None):
if not stream:
stream = io.StringIO()
self.stream = stream
def __call__(self, *args, **kwargs):
print(*args, file=self.stream, **kwargs)
def __str__(self):
return self.stream.getvalue()
def report_case_results(results: List[CaseResultBase], full_stacktrace=False) -> str:
printf = PrintStream(io.StringIO())
printf('Case Results:')
if len(results) == 0:
printf('* No Cases')
for result in results:
if 'exception' in result.info:
str_exc = result.info['traceback'] if full_stacktrace else result.info['exception']
result_str = 'Exception: {}'.format(str_exc)
else:
result_str = str(result)
printf('* {} - {}'.format(result.case.title, result_str))
return printf.stream.getvalue()
def collect_metric(metric_id, case_results) -> np.ndarray:
return np.array([
res.info[metric_id]
for res in case_results
])
def format_statistics(metric_iterable, float_format='.2f') -> str:
metric_array = np.array(list(metric_iterable))
ff = float_format
return 'Min: {}; Max: {}; Mean: {}; Median: {}; Std: {};'.format(
format(metric_array.min(), ff),
format(metric_array.max(), ff),
format(metric_array.mean(), ff),
format(np.median(metric_array), ff),
format(np.std(metric_array), ff),
)
def report_result_statistics(benchmark_results: List[CaseResultBase], metric_ids=None) -> str:
if metric_ids is None:
metric_ids = []
printf = PrintStream(io.StringIO())
printf('Result Statistics:')
if len(benchmark_results) == 0:
printf('* No Cases')
return printf.stream.getvalue()
results_succeeded = [result for result in benchmark_results if result.is_successful]
printf('* Success Rate: {:.1f}% ({} / {})'.format(
len(results_succeeded) / len(benchmark_results) * 100,
len(results_succeeded), len(benchmark_results)
))
run_times = np.array([result.time_total_ms for result in benchmark_results])
printf('* run_time_ms: {}'.format(format_statistics(run_times)))
for metric_id in metric_ids:
printf('* {}: {}'.format(
metric_id,
format_statistics(map(lambda r: r.info[metric_id], benchmark_results)))
)
return printf.stream.getvalue()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment