Last active
September 2, 2021 09:21
Star
You must be signed in to star a gist
Complex Graph Colouring model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# souped up GCP | |
import numpy as np | |
from ortools.sat.python import cp_model | |
class GraphColouringCP: | |
def __init__( | |
self, | |
edge_array: np.array, | |
use_symmetry_breaking: bool = False, | |
use_node_ordering: bool = False, | |
use_greedy_bounding: bool=False, | |
verbose: bool = True | |
): | |
self.edge_array = edge_array | |
self.use_symmetry_breaking = use_symmetry_breaking | |
self.use_node_ordering = use_node_ordering | |
self.use_greedy_bounding = use_greedy_bounding | |
self.verbose = verbose | |
self.num_nodes = self.get_num_nodes() | |
def get_num_nodes(self): | |
# get the number of nodes from the edge array | |
return len(np.unique(self.edge_array.ravel())) | |
def order_nodes_by_neighbours_order_descending(self): | |
# order the nodes dependent on the average degree of their neighbouring nodes | |
node_neighbour_dict = {} | |
for edge_index in range(self.edge_array.shape[0]): | |
v0, v1 = int(self.edge_array[edge_index, 0]), int(self.edge_array[edge_index, 1]) | |
if v0 not in node_neighbour_dict: | |
node_neighbour_dict[v0] = [] | |
node_neighbour_dict[v0].append(v1) | |
if v1 not in node_neighbour_dict: | |
node_neighbour_dict[v1] = [] | |
node_neighbour_dict[v1].append(v0) | |
node_neighbours_order_dict = {} | |
for node, neighbours in node_neighbour_dict.items(): | |
node_neighbours_neighbors = 0 | |
for neighbour in neighbours: | |
node_neighbours_neighbors += len(node_neighbour_dict[neighbour]) | |
node_neighbours_order_dict[node] = node_neighbours_neighbors | |
out_list = [tup[0] for tup in sorted(node_neighbours_order_dict.items(), key=lambda x: x[1], reverse=True)] | |
return out_list | |
def initialise_model(self): | |
self.model = cp_model.CpModel() | |
# if we don't input the max number of colours then assume it to be bounded by the number of nodes | |
if self.use_greedy_bounding: | |
self.max_colours = greedy_colouring(self.edge_array) | |
print(f'Greedy algo bound: {self.max_colours}') | |
else: | |
self.max_colours = self.num_nodes | |
self.node_colour_variables = {} | |
for node_index in range(self.num_nodes): | |
self.node_colour_variables[node_index] = self.model.NewIntVar(0, self.max_colours - 1, 'node_%i' % node_index) | |
self.colour_used_variables = {} | |
for colour_index in range(self.max_colours): | |
self.colour_used_variables[colour_index] = self.model.NewBoolVar('colour_%i' % colour_index) | |
# link node colour and colour used variables | |
for node_index in range(self.num_nodes): | |
for colour_index in range(self.max_colours): | |
self.model.Add(self.node_colour_variables[node_index] != colour_index).OnlyEnforceIf(self.colour_used_variables[colour_index].Not()) | |
# ensure no adjacent nodes share a colour | |
for edge_index in range(self.edge_array.shape[0]): | |
self.model.Add(self.node_colour_variables[self.edge_array[edge_index, 0]] != self.node_colour_variables[self.edge_array[edge_index, 1]]) | |
if self.use_symmetry_breaking: | |
for colour_index in range(self.max_colours - 1): | |
self.model.Add(self.colour_used_variables[colour_index] >= self.colour_used_variables[colour_index + 1]) | |
if self.use_node_ordering: | |
self.model.AddDecisionStrategy([self.node_colour_variables[node_index] | |
for node_index in self.order_nodes_by_neighbours_order_descending()], | |
cp_model.CHOOSE_FIRST, cp_model.SELECT_MIN_VALUE) | |
self.obj_val_var = self.model.NewIntVar(0, self.max_colours, 'num_distinct_colours') | |
# add constraint to set the objective function variable | |
self.model.Add(self.obj_val_var == sum([self.colour_used_variables[colour_index] for colour_index in range(self.max_colours)])) | |
self.model.Minimize(self.obj_val_var) | |
def solve_model(self, max_solve_time: int = 60): | |
# solve model using some standard parameters | |
self.solver = cp_model.CpSolver() | |
self.solver.parameters.max_time_in_seconds = max_solve_time | |
self.solver.parameters.search_branching = cp_model.FIXED_SEARCH | |
self.solver.parameters.num_search_workers = 16 | |
self.solver.parameters.randomize_search = True | |
start = time.time() | |
self.status = self.solver.Solve(self.model) | |
end = time.time() | |
if self.verbose: | |
print(f'obj val: {self.solver.Value(self.obj_val_var)}') | |
print(f'Solve time: {end-start} seconds') | |
if self.status == cp_model.OPTIMAL: | |
print('Optimal') | |
elif self.status == cp_model.FEASIBLE: | |
print('Feasible') | |
elif self.status == cp_model.INFEASIBLE: | |
print('Infeasible') | |
else: | |
print('¯\_(ツ)_/¯') | |
def get_node_colours(self): | |
self.node_colour_dict = {} | |
for node_index, node_colour_var in self.node_colour_variables.items(): | |
self.node_colour_dict[node_index] = self.solver.Value(node_colour_var) | |
return self.node_colour_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment