Last active
February 18, 2019 17:40
-
-
Save wolf1986/1759c9b8bcfb09e92faf7e43797e4aef to your computer and use it in GitHub Desktop.
Research Benchmark Framework
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapted excerpt from: https://gist.github.com/wolf1986/1759c9b8bcfb09e92faf7e43797e4aef | |
import io | |
import traceback | |
from pprint import pformat | |
from time import perf_counter | |
from typing import List, Iterable, Callable, Any | |
from log_utils.data_logger import DataLogger | |
import numpy as np | |
class CaseBase: | |
def __init__(self): | |
self.info = dict() | |
self.title = '[Untitled Case]' | |
def __str__(self, *args, **kwargs): | |
return self.title | |
# noinspection PyMethodMayBeStatic | |
def visualize(self): | |
return None | |
class ResultBase: | |
def __init__(self): | |
self.info = dict() | |
def __str__(self): | |
return pformat(self.info, compact=True) | |
class CaseResultBase(ResultBase): | |
def __init__(self): | |
super().__init__() | |
self.case = None # type: CaseBase | |
self.is_successful = False | |
self.time_total_ms = 0 | |
def __str__(self): | |
prefix = 'PASS' if self.is_successful else 'FAIL' | |
return '{} - {}'.format(prefix, str(self.info)) | |
# noinspection PyMethodMayBeStatic | |
def visualize(self): | |
return None | |
class BenchmarkBase: | |
def __init__( | |
self, | |
case_result_factory: Callable[[], CaseResultBase], | |
case_executor: Callable[[CaseBase, CaseResultBase], Any], | |
case_result_analyzer: Callable[[CaseResultBase], Any] = None | |
): | |
self._case_result_factory = case_result_factory | |
self._case_result_analyzer = case_result_analyzer | |
self._case_executor = case_executor | |
self.logger = DataLogger(self.__class__.__name__) | |
def run_case(self, case: CaseBase) -> ResultBase: | |
result = self._case_result_factory() | |
assert (isinstance(result, CaseResultBase)) | |
result.case = case | |
self.logger.info('Case: {}'.format(case)) | |
# Visualize case | |
self.logger.debug('Case ' + str(case), data=case.visualize) | |
time_start = perf_counter() | |
is_exception_occurred = False | |
# noinspection PyBroadException | |
try: | |
self._case_executor(case, result) | |
result.time_total_ms = (perf_counter() - time_start) * 1000 | |
if self._case_result_analyzer: | |
self._case_result_analyzer(result) | |
except Exception as e: | |
is_exception_occurred = True | |
result.info['exception'] = str(e) | |
result.info['traceback'] = traceback.format_exc() | |
if is_exception_occurred: | |
result.time_total_ms = (perf_counter() - time_start) * 1000 | |
self.logger.info(str(result)) | |
self.logger.debug('* Run-Time: {:.2f} [sec]'.format(result.time_total_ms / 1000)) | |
if not is_exception_occurred: | |
# Visualize case & results | |
self.logger.debug('Case Result ' + str(case), data=result.visualize) | |
return result | |
def run_cases(self, iter_cases: Iterable[CaseBase]): | |
self.logger.info('Executing Benchmark...') | |
# noinspection PyTypeChecker | |
return map(self.run_case, iter_cases) | |
class PrintStream: | |
def __init__(self, stream=None): | |
if not stream: | |
stream = io.StringIO() | |
self.stream = stream | |
def __call__(self, *args, **kwargs): | |
print(*args, file=self.stream, **kwargs) | |
def __str__(self): | |
return self.stream.getvalue() | |
def report_case_results(results: List[CaseResultBase], full_stacktrace=False) -> str: | |
printf = PrintStream(io.StringIO()) | |
printf('Case Results:') | |
if len(results) == 0: | |
printf('* No Cases') | |
for result in results: | |
if 'exception' in result.info: | |
str_exc = result.info['traceback'] if full_stacktrace else result.info['exception'] | |
result_str = 'Exception: {}'.format(str_exc) | |
else: | |
result_str = str(result) | |
printf('* {} - {}'.format(result.case.title, result_str)) | |
return printf.stream.getvalue() | |
def collect_metric(metric_id, case_results) -> np.ndarray: | |
return np.array([ | |
res.info[metric_id] | |
for res in case_results | |
]) | |
def format_statistics(metric_iterable, float_format='.2f') -> str: | |
metric_array = np.array(list(metric_iterable)) | |
ff = float_format | |
return 'Min: {}; Max: {}; Mean: {}; Median: {}; Std: {};'.format( | |
format(metric_array.min(), ff), | |
format(metric_array.max(), ff), | |
format(metric_array.mean(), ff), | |
format(np.median(metric_array), ff), | |
format(np.std(metric_array), ff), | |
) | |
def report_result_statistics(benchmark_results: List[CaseResultBase], metric_ids=None) -> str: | |
if metric_ids is None: | |
metric_ids = [] | |
printf = PrintStream(io.StringIO()) | |
printf('Result Statistics:') | |
if len(benchmark_results) == 0: | |
printf('* No Cases') | |
return printf.stream.getvalue() | |
results_succeeded = [result for result in benchmark_results if result.is_successful] | |
printf('* Success Rate: {:.1f}% ({} / {})'.format( | |
len(results_succeeded) / len(benchmark_results) * 100, | |
len(results_succeeded), len(benchmark_results) | |
)) | |
run_times = np.array([result.time_total_ms for result in benchmark_results]) | |
printf('* run_time_ms: {}'.format(format_statistics(run_times))) | |
for metric_id in metric_ids: | |
printf('* {}: {}'.format( | |
metric_id, | |
format_statistics(map(lambda r: r.info[metric_id], benchmark_results))) | |
) | |
return printf.stream.getvalue() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment