Created
July 18, 2021 13:06
-
-
Save adishavit/04c9808253211156c25c97088fe2aba2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
import time | |
import cvxpy as cp | |
from cvxpy.lin_ops.lin_utils import ID_COUNTER | |
# format seconds as hh:mm:ss.sss (approx. ISO 8601 format) | |
def hhmmss(secs) -> str: | |
return f"{secs // 3600:02.0f}:{secs // 60:02.0f}:{secs % 60:0.1f}" | |
def undeserialize(arg=None): | |
""" | |
serialize arg to file and return the deserialized object from the file. | |
""" | |
if arg: | |
with open('problem.pickle', 'wb') as f: | |
pickle.dump(arg, f, pickle.HIGHEST_PROTOCOL) | |
with open('problem.pickle', 'rb') as f: | |
return pickle.load(f) | |
def main(): | |
def gen_problem_inputs(start_id=0): | |
ID_COUNTER.count = start_id # <=== VERY IMPORTANT HACK | |
n = 5000 | |
x = cp.Variable(n, integer=True, name='x') | |
y = cp.Variable(n, integer=True) | |
q = cp.Parameter(value=n, name='q') # q is an upper bound on the last element of x | |
o = cp.Maximize(sum(x)) | |
cstrs = [x <= y, | |
x[-1] <= q] + \ | |
[x[i] <= x[i + 1] - 1 for i in range(n - 1)] | |
return o, cstrs, x, q, x + y | |
swargs = {'warm_start': True, 'verbose': 0} | |
print('*** Solving full problem') | |
o, cstrs, x, q, e = gen_problem_inputs() | |
original_problem = cp.Problem(o, cstrs) | |
t0 = time.time() | |
original_problem.solve(**swargs) | |
t1 = time.time() | |
print(f"Time: {hhmmss(t1 - t0)}. ") | |
assert q.value == x.value[-1] | |
# print(f"{x.value[-1]}") | |
print('*** Change parameter q and solve as warm problem') | |
q.value = 200 | |
t2 = time.time() | |
original_problem.solve(**swargs) | |
t3 = time.time() | |
print(f"Time: {hhmmss(t3 - t2)}. ") | |
assert q.value == x.value[-1] | |
# print(f"{x.value[-1]}") | |
print('*** Serialize Problem') | |
# Prepare to serialize: | |
keys = ['_cache', | |
'_solver_cache', | |
'_size_metrics', | |
'_compilation_time', | |
'_solve_time', | |
'parameters__cache__', | |
'is_dcp__cache__', | |
'is_dgp__cache__', | |
'is_dqcp__cache__', | |
'is_mixed_integer__cache__', | |
'constants__cache__', | |
'is_dpp__cache__' | |
] | |
cache = {key: val for key, val in vars(original_problem).items() if key in keys} | |
cache = undeserialize(cache) | |
del original_problem, x, q, o, cstrs | |
print("*** Create New Problem - but don't solve it.") | |
# NEW VARIABLES | |
_o, _cstrs, x, _q, _ = gen_problem_inputs() | |
warm_problem = cp.Problem(_o, _cstrs) | |
print("*** Update from restored pickle.") | |
# update from cache | |
vars(warm_problem).update(cache) | |
print("*** Solve restored problem - Should be pre-Warmed.") | |
t4 = time.time() | |
warm_problem.solve(**swargs) | |
t5 = time.time() | |
print(f"Time: {hhmmss(t5 - t4)}. ") | |
# must re-initilaize `q` to the "other" parameter object with the same id, from the warm problem. | |
q = next(prm for prm in warm_problem.parameters() if _q.id == prm.id) # <=== VERY IMPORTANT HACK | |
del _q | |
assert q.value == x.value[-1] | |
# print(f"{x.value[-1]}") | |
print('*** Change parameter q and solve AGAIN as warm problem') | |
q.value = 42 | |
t6 = time.time() | |
warm_problem.solve(**swargs) | |
t7 = time.time() | |
print(f"Time: {hhmmss(t7 - t6)}. ") | |
# print(f"{x.value[-1]}") | |
assert q.value == x.value[-1] | |
return | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment