Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save stdray/ea378bd9c8362a289ac582f9fa6d3465 to your computer and use it in GitHub Desktop.
Save stdray/ea378bd9c8362a289ac582f9fa6d3465 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from typing import List
T_co = TypeVar('T_co', Employee, Engineer, Manager, covariant=True)
T_contra = TypeVar('T_contra', Employee, Engineer, Manager, contravariant=True)
@dataclass
class Employee:
salary: int
years_in_company: int
@dataclass
class Engineer(Employee):
pass
@dataclass
class Manager(Employee):
has_achieved_kpis: bool
class EmployeeBonusCalculationStrategy(Generic[T_contra]):
def calculate_bonus(self, employee: Manager) -> float:
raise NotImplementedError()
class ManagersBonusCalculationStrategy(EmployeeBonusCalculationStrategy[Manager]):
def calculate_bonus(self, employee: Manager) -> float:
bonus = employee.salary * 1.05
if employee.years_in_company > 5:
bonus += employee.salary * 1.10
if employee.has_achieved_kpis:
bonus += employee.salary * 1.20
return bonus
class EngineerBonusCalculationStrategy(EmployeeBonusCalculationStrategy[Engineer]):
def calculate_bonus(self, employee: Engineer) -> float:
bonus = employee.salary * 1.10
if employee.years_in_company > 5:
bonus += employee.salary * 1.10
return bonus
class SalaryCalculator:
bonus_calculation_strategy = {
'Engineer': EngineerBonusCalculationStrategy(),
'Manager': ManagersBonusCalculationStrategy(),
}
def calculate_salary(self, employee: Generic[T_con]) -> float:
return employee.salary + self.get_employee_bonus(employee)
def get_employee_bonus(self, employee: Generic[T_con]) -> float:
return self.bonus_calculation_strategy[employee.__class__.__name__].calculate_bonus(employee)
def calculate_department_salary(employees: List[Generic[T_con]]) -> float:
salary_calculator = SalaryCalculator()
return sum([salary_calculator.calculate_salary(employee) for employee in employees])
worker1 = Engineer(100, 7)
worker2 = Engineer(120, 4)
worker3 = Engineer(200, 6)
worker4 = Manager(100, 4, False)
worker5 = Manager(120, 6, True)
worker6 = Manager(100, 3, True)
worker7 = Manager(110, 9, True)
department = []
department.append(worker1)
department.append(worker2)
department.append(worker3)
department.append(worker4)
department.append(worker5)
department.append(worker6)
department.append(worker7)
total_salary = calculate_department_salary(department)
print(total_salary)
assert(total_salary == 2742)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment