Last active
April 4, 2019 18:29
-
-
Save syxolk/f185b8c1f2ea689d969b6190a5b42f45 to your computer and use it in GitHub Desktop.
OR-Tools Indeterministic solution_hint bug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ortools.sat.python import cp_model | |
COUNTS = 10 | |
MIN_VAL = 2 | |
MAX_VAL = 4 | |
class Slot: | |
def __init__(self,model): | |
self.model = model | |
self.start = model.NewIntVar(-1,COUNTS-1,name="")#inclusive | |
self.duration = model.NewIntVar(0,COUNTS,"") | |
self.end = model.NewIntVar(-1,COUNTS,name="") #exclusive | |
self.is_used = model.NewBoolVar("") | |
self.is_used_table = [] | |
for t in range(0,COUNTS): | |
self.is_used_table.append(model.NewBoolVar("")) | |
#add constraints | |
#only want to asign a working duration if a shift start is set | |
model.Add((self.start==-1)).OnlyEnforceIf(self.is_used.Not()) | |
model.Add((self.start!=-1)).OnlyEnforceIf(self.is_used) | |
model.Add((self.duration==0)).OnlyEnforceIf(self.is_used.Not()) | |
model.Add((self.duration>0)).OnlyEnforceIf(self.is_used) | |
model.Add((self.start<self.end)).OnlyEnforceIf(self.is_used) | |
model.Add(self.end == self.start+self.duration) | |
#nothing before start | |
for t in range(0,COUNTS): | |
if t > 0: | |
model.Add(self.start==t).OnlyEnforceIf([k.Not() for k in self.is_used_table[:t]]+[self.is_used_table[t]]) | |
#nothing after end | |
for t in range(1,COUNTS): | |
model.Add(self.end==t).OnlyEnforceIf([k.Not() for k in self.is_used_table[t:]]+[self.is_used_table[t-1]]) | |
#if we work at the shift there has to be at least one slot, otherwise it could let working_table be only zeroes 0 | |
model.Add(sum(self.is_used_table)>=self.is_used) | |
#number of slots marked as 1 and duration has to match | |
model.Add(sum(self.is_used_table)==self.duration) | |
#fix working table to start and End | |
for t in range(0,COUNTS): | |
inersect = self.is_used_table[t]# | |
model.Add(t >= self.start).OnlyEnforceIf([inersect,self.is_used]) | |
model.Add(t < self.end).OnlyEnforceIf([inersect,self.is_used]) | |
class Handler: | |
def __init__(self,model,list_count): | |
self.model=model | |
self.is_used_table = [model.NewBoolVar(name="") for _ in range(COUNTS)] | |
self.used_changes = [] | |
self.possible_slots = [] | |
for s in range(list_count): | |
self.possible_slots.append(Slot(model)) | |
""" | |
Shifts are sorted from the first to the last one. Obsolete shifts are at the end | |
""" | |
if s > 0: | |
model.Add(self.possible_slots[s].is_used==0).OnlyEnforceIf(self.possible_slots[s-1].is_used.Not()) | |
#sort shifts and intervals shouldn't overlap but only if they are working at that shift | |
#> because otherwise shifts | |
model.Add(self.possible_slots[s].start > | |
self.possible_slots[s-1].end).OnlyEnforceIf([self.possible_slots[s-1].is_used,self.possible_slots[s].is_used]) | |
model.Add(self.possible_slots[s].start > | |
self.possible_slots[s-1].start).OnlyEnforceIf([self.possible_slots[s-1].is_used,self.possible_slots[s].is_used]) | |
for t in range(0,COUNTS): | |
#Combine slots with is_used_table | |
overlap_list = [] | |
for s in self.possible_slots: | |
overlap_list.append(s.is_used_table[t]) | |
#We just want exactly one overlap or none | |
model.Add(sum(overlap_list) == self.is_used_table[t]) | |
def __count_used_changes(self): | |
if len(self.used_changes) != 0: | |
raise ValueError("count_used_changes can only be called once") | |
rt = self.is_used_table | |
for i,r in enumerate(rt): | |
if i > 0: | |
is_diff = self.model.NewBoolVar("") | |
self.model.Add(r!=rt[i-1]).OnlyEnforceIf(is_diff) | |
self.model.Add(r==rt[i-1]).OnlyEnforceIf(is_diff.Not()) | |
self.used_changes.append(is_diff) | |
def get_used_changes(self): | |
if len(self.used_changes) == 0: | |
self.__count_used_changes() | |
return self.used_changes | |
def refine_x_days(n_d,force_solution=False,print_to_file=None): | |
model = cp_model.CpModel() | |
handlers = [] | |
for d in range(n_d): | |
wd = Handler(model,2)#,lambda t: has_role_restriction(d,t)) | |
handlers.append(wd) | |
total_changes = [] | |
for wd in handlers: | |
total_changes.extend(wd.get_used_changes()) | |
#dont allow ones at the start/end -> handler adds at least a cost of 2 | |
model.Add(wd.is_used_table[0] == 0) | |
model.Add(wd.is_used_table[-1] == 0) | |
model.AddSumConstraint([ws.duration for ws in wd.possible_slots],MIN_VAL,MAX_VAL) | |
if print_to_file != None: | |
with open(print_to_file+"_before.txt","a") as f: | |
f.write("######################################## \n") | |
f.write(str(model.Proto())) | |
f.write("######################################## \n") | |
solver = cp_model.CpSolver() | |
solver.parameters.max_time_in_seconds = 30 | |
solver.parameters.num_search_workers = 8 | |
result = solver.Solve(model) | |
assert result == cp_model.FEASIBLE | |
model.Proto().solution_hint.Clear() | |
for i,field in enumerate(model.Proto().variables): | |
model.Proto().solution_hint.vars.append(i) | |
model.Proto().solution_hint.values.append(solver.ResponseProto().solution[i]) | |
if force_solution: | |
model.Add(sum(total_changes) == n_d*2) | |
model.Minimize(sum(total_changes)) | |
if print_to_file != None: | |
with open(print_to_file+"_after.txt","a") as f: | |
f.write("######################################## \n") | |
f.write(str(model.Proto())) | |
f.write("######################################## \n") | |
result = solver.Solve(model) | |
print('Problem = ', n_d, '\n', solver.ResponseStats()) | |
assert result == cp_model.FEASIBLE or result == cp_model.OPTIMAL | |
if result == cp_model.OPTIMAL: | |
assert solver.Value(sum(total_changes)) == n_d*2 | |
else: | |
assert solver.Value(sum(total_changes)) <= n_d*2*2 | |
#force right solution | |
def test_f_one(): | |
refine_x_days(1,True) | |
def test_f_two(): | |
refine_x_days(2,True) | |
def test_f_three(): | |
refine_x_days(3,True) | |
def test_f_four(): | |
refine_x_days(4,True) | |
def test_f_five(): | |
refine_x_days(5,True) | |
def test_f_six(): | |
refine_x_days(6,True) | |
def test_f_seven(): | |
refine_x_days(7,True) | |
def test_one(): | |
refine_x_days(1,False) | |
def test_two(): | |
refine_x_days(2,False) | |
def test_three(): | |
refine_x_days(3,False) | |
def test_four(): | |
refine_x_days(4,False) | |
def test_five(): | |
refine_x_days(5,False) | |
def test_six(): | |
refine_x_days(6,False) | |
def test_seven(): | |
refine_x_days(7,False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment