Skip to content

Instantly share code, notes, and snippets.

@macleginn
Last active April 6, 2020 19:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save macleginn/5760873bb6ef5b118a7c21752c560ab0 to your computer and use it in GitHub Desktop.
Save macleginn/5760873bb6ef5b118a7c21752c560ab0 to your computer and use it in GitHub Desktop.
from itertools import combinations, permutations
from collections import Counter
import gurobipy as gb
from gurobipy import GRB
def get_pairwise_ordering(all_deprels: set, training_set_constraints: Counter):
'''
Solves an integer program and returns a non-loopy ordering
of deprels in an expansion to be converted into indices using an
SMT solver.
'''
try:
m = gb.Model('ordering_ilp')
# Create variables for all possible pairwise orderings.
var_dict = {}
for rel1, rel2 in combinations(all_deprels, 2):
key1 = f'{rel1}->{rel2}' # i,j
key2 = f'{rel2}->{rel1}' # j,i
var_dict[key1] = m.addVar(vtype=GRB.BINARY, name=key1)
var_dict[key2] = m.addVar(vtype=GRB.BINARY, name=key2)
# Limit the edges to one direction
constr_name = f'{key1} + {key2} == 1'
m.addConstr(var_dict[key1] + var_dict[key2] == 1, constr_name)
# Prohibit 3-cycles in the complete graph with directed edges.
for rel_triple in combinations(all_deprels, 3):
for rel1, rel2, rel3 in permutations(rel_triple):
key1 = f'{rel1}->{rel2}' # i,j
key2 = f'{rel2}->{rel3}' # j,k
key3 = f'{rel3}->{rel1}' # k,i
constr_name = f'{key1} + {key2} + {key3} <= 2'
m.addConstr(
var_dict[key1] + var_dict[key2] + var_dict[key3] <= 2,
constr_name
)
# Construct the objective.
linear_expression = gb.LinExpr(0.0)
for rel1, rel2 in combinations(all_deprels, 2):
key1 = f'{rel1}->{rel2}' # i,j
# Add 1 to eliminate 0 weights.
c_ij = 1.0 + training_set_constraints[key1]
linear_expression.add(var_dict[key1], c_ij)
key2 = f'{rel2}->{rel1}' # j,i
c_ji = 1.0 + training_set_constraints[key2]
linear_expression.add(var_dict[key2], c_ji)
m.setObjective(linear_expression, GRB.MAXIMIZE)
m.optimize()
return m.getVars()
except gb.GurobiError as e:
print(f'Error code {e.errno}: {e}.')
except AttributeError as e:
print(f'Encountered an attribute error: {e}.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment