Created
October 7, 2022 09:33
-
-
Save henhuy/2f94df3f80dd5c8b01c739218df9f82e to your computer and use it in GitHub Desktop.
Suggestion for oemof postprocessing refactoring
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import abc | |
from typing import Union, List | |
class CalculationError(Exception): | |
"""Raised if something is wrong in calculation""" | |
class Calculator: | |
def __init__(self, scalar_params, scalar_results, sequences_params, sequences_results): | |
self.calculations = {} | |
self.scalar_params = scalar_params | |
self.scalar_results = scalar_results | |
self.sequences_params = sequences_params | |
self.sequences_results = sequences_results | |
def add(self, calculation): | |
if isinstance(calculation, Calculation): | |
if calculation.__class__.__name__ in self.calculations: | |
raise CalculationError(f"Calculation '{calculation.__class__.__name__}' already exists in calculator") | |
self.calculations[calculation.__class__.__name__] = calculation | |
else: | |
if calculation.__name__ in self.calculations: | |
return | |
if issubclass(calculation, Calculation): | |
self.calculations[calculation.__name__] = calculation(self) | |
return | |
raise CalculationError("Can only add Calculation instances or classes") | |
def get_result(self, dependency_name): | |
return self.calculations[dependency_name].result | |
class Calculation(abc.ABC): | |
depends_on: Union["Calculation", List["Calculation"]] = None | |
def __init__(self, calculator: Calculator): | |
super(Calculation, self).__init__() | |
self.calculator = calculator | |
self.calculator.add(self) | |
self.__add_dependencies() | |
self.__result = None | |
def __add_dependencies(self): | |
if not self.depends_on: | |
return | |
if isinstance(self.depends_on, list): | |
for dependency in self.depends_on: | |
self.calculator.add(dependency) | |
else: | |
self.calculator.add(self.depends_on) | |
def dependency(self, index=None): | |
if isinstance(self.depends_on, list): | |
index = index or 0 | |
dependency_name = self.depends_on[index].__name__ | |
else: | |
dependency_name = self.depends_on.__name__ | |
return self.calculator.get_result(dependency_name) | |
@abc.abstractmethod | |
def calculate_result(self): | |
"""This method must be overwritten in child class""" | |
@property | |
def result(self): | |
if not self.__result: | |
self.__result = self.calculate_result() | |
return self.__result | |
@property | |
def scalar_params(self): | |
return self.calculator.scalar_params | |
@property | |
def scalar_results(self): | |
return self.calculator.scalar_results | |
@property | |
def sequences_params(self): | |
return self.calculator.sequences_params | |
@property | |
def sequences_results(self): | |
return self.calculator.sequences_results | |
# ----------------------------------------------------------------------------- | |
# EXAMPLE CALCULATIONS | |
class SimpleCalculation(Calculation): | |
def calculate_result(self): | |
return self.scalar_params["bus"] * 5 # Only dummy calculation! | |
class DependencyCalculation(Calculation): | |
depends_on = SimpleCalculation | |
def calculate_result(self): | |
return self.dependency() * self.sequences_results # Only dummy calculation! | |
class MultipleDependenciesCalculation(Calculation): | |
depends_on = [SimpleCalculation, DependencyCalculation] | |
def calculate_result(self): | |
return self.scalar_results * self.dependency(0) * self.dependency(1) # Only dummy calculation! | |
# ----------------------------------------------------------------------------- | |
# EXAMPLE USER CALCULATIONS | |
if __name__ == "__main__": | |
c = Calculator({"bus": 5}, 2, 3, 4) | |
calc = MultipleDependenciesCalculation(c) | |
assert calc.result == 5 * 5 * 4 * 5 * 5 * 2 | |
assert len(c.calculations) == 3 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment