Skip to content

Instantly share code, notes, and snippets.

@domdomdom12
Last active September 2, 2021 09:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save domdomdom12/97da39baec1a938a0390cd478b85e369 to your computer and use it in GitHub Desktop.
Save domdomdom12/97da39baec1a938a0390cd478b85e369 to your computer and use it in GitHub Desktop.
Simple Graph Colouring model
# simple GCP
import numpy as np
from ortools.sat.python import cp_model
class GraphColouringCP:
def __init__(
self,
edge_array: np.array,
verbose: bool = True
):
self.edge_array = edge_array
self.verbose = verbose
self.num_nodes = self.get_num_nodes()
def get_num_nodes(self):
# get the number of nodes from the edge array
return len(np.unique(self.edge_array.ravel()))
def initialise_model(self):
self.model = cp_model.CpModel()
# set the max number of colours to be the number of nodes
self.max_colours = self.num_nodes
self.node_colour_variables = {}
for node_index in range(self.num_nodes):
self.node_colour_variables[node_index] = self.model.NewIntVar(0, self.max_colours - 1, 'node_%i' % node_index)
self.colour_used_variables = {}
for colour_index in range(self.max_colours):
self.colour_used_variables[colour_index] = self.model.NewBoolVar('colour_%i' % colour_index)
# link node colour and colour used variables
for node_index in range(self.num_nodes):
for colour_index in range(self.max_colours):
self.model.Add(self.node_colour_variables[node_index] != colour_index).OnlyEnforceIf(self.colour_used_variables[colour_index].Not())
# ensure no adjacent nodes share a colour
for edge_index in range(self.edge_array.shape[0]):
self.model.Add(self.node_colour_variables[self.edge_array[edge_index, 0]] != self.node_colour_variables[self.edge_array[edge_index, 1]])
self.obj_val_var = self.model.NewIntVar(0, self.max_colours, 'num_distinct_colours')
# add constraint to set the objective function variable
self.model.Add(self.obj_val_var == sum([self.colour_used_variables[colour_index] for colour_index in range(self.max_colours)]))
self.model.Minimize(self.obj_val_var)
def solve_model(self, max_solve_time: int = 60):
# solve model using some standard parameters
self.solver = cp_model.CpSolver()
self.solver.parameters.max_time_in_seconds = max_solve_time
self.solver.parameters.search_branching = cp_model.FIXED_SEARCH
self.solver.parameters.num_search_workers = 16
self.solver.parameters.randomize_search = True
start = time.time()
self.status = self.solver.Solve(self.model)
end = time.time()
if self.verbose:
print(f'obj val: {self.solver.Value(self.obj_val_var)}')
print(f'Solve time: {end-start} seconds')
if self.status == cp_model.OPTIMAL:
print('Optimal')
elif self.status == cp_model.FEASIBLE:
print('Feasible')
elif self.status == cp_model.INFEASIBLE:
print('Infeasible')
else:
print('¯\_(ツ)_/¯')
def get_node_colours(self):
self.node_colour_dict = {}
for node_index, node_colour_var in self.node_colour_variables.items():
self.node_colour_dict[node_index] = self.solver.Value(node_colour_var)
return self.node_colour_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment