Last active
September 2, 2021 09:22
-
-
Save domdomdom12/97da39baec1a938a0390cd478b85e369 to your computer and use it in GitHub Desktop.
Simple Graph Colouring model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# simple GCP | |
import numpy as np | |
from ortools.sat.python import cp_model | |
class GraphColouringCP: | |
def __init__( | |
self, | |
edge_array: np.array, | |
verbose: bool = True | |
): | |
self.edge_array = edge_array | |
self.verbose = verbose | |
self.num_nodes = self.get_num_nodes() | |
def get_num_nodes(self): | |
# get the number of nodes from the edge array | |
return len(np.unique(self.edge_array.ravel())) | |
def initialise_model(self): | |
self.model = cp_model.CpModel() | |
# set the max number of colours to be the number of nodes | |
self.max_colours = self.num_nodes | |
self.node_colour_variables = {} | |
for node_index in range(self.num_nodes): | |
self.node_colour_variables[node_index] = self.model.NewIntVar(0, self.max_colours - 1, 'node_%i' % node_index) | |
self.colour_used_variables = {} | |
for colour_index in range(self.max_colours): | |
self.colour_used_variables[colour_index] = self.model.NewBoolVar('colour_%i' % colour_index) | |
# link node colour and colour used variables | |
for node_index in range(self.num_nodes): | |
for colour_index in range(self.max_colours): | |
self.model.Add(self.node_colour_variables[node_index] != colour_index).OnlyEnforceIf(self.colour_used_variables[colour_index].Not()) | |
# ensure no adjacent nodes share a colour | |
for edge_index in range(self.edge_array.shape[0]): | |
self.model.Add(self.node_colour_variables[self.edge_array[edge_index, 0]] != self.node_colour_variables[self.edge_array[edge_index, 1]]) | |
self.obj_val_var = self.model.NewIntVar(0, self.max_colours, 'num_distinct_colours') | |
# add constraint to set the objective function variable | |
self.model.Add(self.obj_val_var == sum([self.colour_used_variables[colour_index] for colour_index in range(self.max_colours)])) | |
self.model.Minimize(self.obj_val_var) | |
def solve_model(self, max_solve_time: int = 60): | |
# solve model using some standard parameters | |
self.solver = cp_model.CpSolver() | |
self.solver.parameters.max_time_in_seconds = max_solve_time | |
self.solver.parameters.search_branching = cp_model.FIXED_SEARCH | |
self.solver.parameters.num_search_workers = 16 | |
self.solver.parameters.randomize_search = True | |
start = time.time() | |
self.status = self.solver.Solve(self.model) | |
end = time.time() | |
if self.verbose: | |
print(f'obj val: {self.solver.Value(self.obj_val_var)}') | |
print(f'Solve time: {end-start} seconds') | |
if self.status == cp_model.OPTIMAL: | |
print('Optimal') | |
elif self.status == cp_model.FEASIBLE: | |
print('Feasible') | |
elif self.status == cp_model.INFEASIBLE: | |
print('Infeasible') | |
else: | |
print('¯\_(ツ)_/¯') | |
def get_node_colours(self): | |
self.node_colour_dict = {} | |
for node_index, node_colour_var in self.node_colour_variables.items(): | |
self.node_colour_dict[node_index] = self.solver.Value(node_colour_var) | |
return self.node_colour_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment