Skip to content

Instantly share code, notes, and snippets.

@vikranth22446
Created June 17, 2021 01:56
Show Gist options
  • Save vikranth22446/f7ab82ae87e0011d432adaec2cf69c0a to your computer and use it in GitHub Desktop.
Save vikranth22446/f7ab82ae87e0011d432adaec2cf69c0a to your computer and use it in GitHub Desktop.
Sample MIP LP solver for shortest path single commodity flow
from mip import *
import networkx as nx
from dataclasses import dataclass
def solve(graph, start_node, end_node):
m = Model(sense=MINIMIZE)
total_flow = len(graph.nodes()) - 1
path_taken = {}
flow_constraints = {}
node_taken = {}
for node in graph.nodes():
if node == start_node:
continue
node_taken[node] = m.add_var("flow_dropped_on_{node}", var_type=BINARY)
for (u, v) in graph.edges():
path_taken[(u,v)] = m.add_var(f"path_taken_{u}_{v}", var_type=BINARY)
flow_constraints[(u,v)] = m.add_var("flow_on_edge_{u}_{v}", var_type=INTEGER)
m += flow_constraints[(u,v)] >= 0
m += flow_constraints[(u,v)] <= total_flow
m += flow_constraints[(u,v)] <= total_flow * path_taken[(u,v)]
# Require start node to have one edge taken
start_outgoing_edges = list(graph.out_edges(start_node))
end_incoming_edges = list(graph.in_edges(end_node))
m += xsum(path_taken[(u,v)] for (u,v) in start_outgoing_edges) == 1
m += xsum(flow_constraints[(u,v)] for (u,v) in start_outgoing_edges) == xsum(node_taken[node] for node in graph.nodes() if node != start_node)
# Require end node to have one edge recieved
m += xsum(path_taken[(u,v)] for (u,v) in end_incoming_edges) == 1
for node in graph.nodes():
if node == start_node:
continue
# remove some flow along the path
incoming_edges = list(graph.in_edges(node))
outgoing_edges = list(graph.out_edges(node))
m += xsum(flow_constraints[(u,v)] for (u,v) in incoming_edges) - xsum(flow_constraints[(u,v)] for (u,v) in outgoing_edges) == node_taken[node]
m += xsum(path_taken[(u,v)] for (u,v) in incoming_edges) == node_taken[node]
m.objective = xsum(path_taken[(u,v)] for (u,v) in graph.edges())
m.max_mip_gap = .05
m.optimize(max_seconds=300)
for k,v in path_taken.items():
if v.x > .99:
print("edges taken", k)
@dataclass
class Function:
name: str
@dataclass
class EdgeConfig:
source: str
end: str
weight: float
function_data = {
"root": Function("rooted"),
"a": Function("at"),
"b": Function("bat"),
"c": Function("cat"),
"d": Function("dog"),
"e": Function("eat")
}
functions = function_data.keys()
function_dependencies = {
"root": ["a"],
"a": ["b", "e"],
"b": ["c", "d"],
"d": ["e"]
}
graph = nx.DiGraph()
for function in functions:
graph.add_node(function)
for function in functions:
if function not in function_dependencies:
continue
for dep_function in function_dependencies[function]:
graph.add_edge(function, dep_function)
solve(graph, "root", "e")
print("Networkx shortest path", nx.shortest_path(graph, "root", "e"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment