Skip to content

Instantly share code, notes, and snippets.

@theptrk
Last active May 29, 2024 16:18
Show Gist options
  • Save theptrk/47bb9bff60ef411f3929cd50a7decdcb to your computer and use it in GitHub Desktop.
Save theptrk/47bb9bff60ef411f3929cd50a7decdcb to your computer and use it in GitHub Desktop.
"""Example of a simple nurse scheduling problem."""
from ortools.sat.python import cp_model
class Nurse:
def __init__(self, name, start):
self.name = name
self.start = start
class NursesPartialSolutionPrinter(cp_model.CpSolverSolutionCallback):
"""Print intermediate solutions."""
def __init__(self, shifts, all_nurses, all_days, all_shift_types, limit):
cp_model.CpSolverSolutionCallback.__init__(self)
self._shifts = shifts
self._all_nurses = all_nurses
self._all_days = all_days
self._all_shift_types = all_shift_types
self._solution_count = 0
self._solution_limit = limit
def on_solution_callback(self):
self._solution_count += 1
print(f"Solution {self._solution_count}")
for d in self._all_days:
print(f"Day {d}")
for n_obj in self._all_nurses:
n = n_obj.name
is_working = False
for s in self._all_shift_types:
if self.value(self._shifts[(n, d, s)]):
is_working = True
print(f" Nurse {n} works shift {s}")
if not is_working:
print(f" Nurse {n} does not work")
if self._solution_count >= self._solution_limit:
print(f"Stop search after {self._solution_limit} solutions")
self.stop_search()
def solutionCount(self):
return self._solution_count
class Modeler:
def __init__(self, all_nurses, all_days, all_shift_types, coverage_needed):
self.model = cp_model.CpModel()
self.all_nurses = all_nurses
self.all_days = all_days
self.all_shift_types = all_shift_types
self.coverage_needed = coverage_needed
self.shifts = {}
self.create_shift_variables()
self.enf_employee_availability()
self.enf_coverage_and_one_nurse()
self.enf_single_shift_assigned_per_day()
self.enf_preventing_eve_to_day()
self.enf_preventing_eve_to_on_unassigned_to_day()
self.maximum_assigned_shifts()
# Creates shift variables.
# shifts[(n, d, s)]: nurse 'n' works shift 's' on day 'd'.
def create_shift_variables(self):
for nurse_obj in self.all_nurses:
n = nurse_obj.name
for d in self.all_days:
for s in self.all_shift_types:
self.shifts[(n, d, s)] = self.model.new_bool_var(
f"shift_n{n}_d{d}_s{s}"
)
# Enforce employee availability
# If not "on" set (n,d,s) to 0
def enf_employee_availability(self):
for nurse_obj in self.all_nurses:
n = nurse_obj.name
for d in self.all_days:
# We should only assign shifts if the employee is available
if nurse_obj.start[d] != "on":
for s in self.all_shift_types:
self.model.add(self.shifts[(n, d, s)] == 0)
# Enforce: shift coverage requirements
# Enforce: only nurse per shift per day
# if shift, day should be covered
# - at most one nurse is assigned
def enf_coverage_and_one_nurse(self):
for d in self.all_days:
for s in self.all_shift_types:
every_shift_for_d_s = [
self.shifts[(nurse_obj.name, d, s)] for nurse_obj in self.all_nurses
]
if self.coverage_needed[s][d] == 1:
# (n (varying), d, s) should be a single 1 in list
# https://or-tools.github.io/docs/pdoc/ortools/sat/python/cp_model.html#CpModel.add_exactly_one
# 🧨 its possible the shift cannot be filled at all
# Note: previous versions assumed perfect fill had this as exactly one
self.model.add_at_most_one(every_shift_for_d_s)
else:
# (n (varying), d, s) should sum to 0
self.model.add(sum(every_shift_for_d_s) == 0)
# Enforce: No double booking. Single shift assigned per n,d
def enf_single_shift_assigned_per_day(self):
for nurse_obj in self.all_nurses:
n = nurse_obj.name
for d in self.all_days:
self.model.add_at_most_one(
self.shifts[(n, d, s)] for s in self.all_shift_types
)
# Enforce: Preventing evening shift -> day shift
def enf_preventing_eve_to_day(self):
for nurse_obj in self.all_nurses:
n = nurse_obj.name
# Avoid the last day as it doesn't have a 'next day'
for d in range(len(self.all_days) - 1):
for s in self.all_shift_types:
if s.startswith("E"):
tomorrows_day_shifts = [
self.shifts[(n, d + 1, s_next)]
for s_next in self.all_shift_types
if not s_next.startswith("E")
]
if tomorrows_day_shifts:
self.model.add(
sum(tomorrows_day_shifts) == 0
).only_enforce_if(self.shifts[(n, d, s)])
# DOES NOT WORK
# Enforce: Preventing evening shift -> day shift
def enf_preventing_eve_to_on_unassigned_to_day(self):
for nurse_obj in self.all_nurses:
n = nurse_obj.name
# Avoid the last day as it doesn't have a 'next day'
for d in range(len(self.all_days) - 2):
for s in self.all_shift_types:
if s.startswith("E"):
tomorrows_av_and_assigned = None
next_is_assigned_as_day = None
# tomorrow is AV and unassigned
if nurse_obj.start[d] == "on":
tomorrows_av_and_assigned = [
self.shifts[(n, d + 1, s_tom)]
for s_tom in self.all_shift_types
]
next_is_assigned_as_day = [
self.shifts[(n, d + 2, s_next)]
for s_next in self.all_shift_types
if not s_next.startswith("E")
]
if tomorrows_av_and_assigned and next_is_assigned_as_day:
self.model.add(
sum(tomorrows_av_and_assigned) + sum(next_is_assigned_as_day) <= 0
).only_enforce_if(self.shifts[(n, d, s)])
def maximum_assigned_shifts(self):
shifts_filled = 0
for d in self.all_days:
for s in self.all_shift_types:
if self.coverage_needed[s][d] == 1:
every_shift_for_d_s = [
self.shifts[(nurse_obj.name, d, s)]
for nurse_obj in self.all_nurses
]
shifts_filled += sum(every_shift_for_d_s)
self.model.maximize(shifts_filled)
def get_data():
# Data.
num_days = 3
# test case 1:
# all_nurses = [
# Nurse("Alice", ["on", "on", "on"]),
# ]
# coverage_needed = {"A": [0, 1, 1], "EVE-A": [1, 0, 0]}
# test case 2:
all_nurses = [
Nurse("Alice", ["on", "on", "on"]),
Nurse("Bobby", ["on", "off", "on"]),
Nurse("Cindy", ["on","on", "off"]),
]
coverage_needed = {"A": [1, 1, 1], "EVE-A": [1, 1, 1], "B": [1, 1, 1]}
for n_obj in all_nurses:
assert len(n_obj.start) == num_days
all_days = range(num_days)
# Define all shift types in a given day
all_shift_types = list(coverage_needed.keys())
return all_nurses, all_days, all_shift_types, coverage_needed
def main() -> None:
all_nurses, all_days, all_shift_types, coverage_needed = get_data()
modeler = Modeler(all_nurses, all_days, all_shift_types, coverage_needed)
# Creates the solver and solve.
solver = cp_model.CpSolver()
solver.parameters.linearization_level = 0
# Enumerate all solutions.
solver.parameters.enumerate_all_solutions = True
# Display the first five solutions.
solution_limit = 3
solution_printer = NursesPartialSolutionPrinter(
modeler.shifts, all_nurses, all_days, all_shift_types, solution_limit
)
solver.solve(modeler.model, solution_printer)
# Statistics.
print("\nStatistics")
print(f"Maximum of objective function: {solver.objective_value}\n")
print(f" - conflicts : {solver.num_conflicts}")
print(f" - branches : {solver.num_branches}")
print(f" - wall time : {solver.wall_time} s")
print(f" - solutions found: {solution_printer.solutionCount()}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment