Skip to content

Instantly share code, notes, and snippets.

@arrieta
Created December 6, 2017 21:48
Show Gist options
  • Save arrieta/2d2c2d99d665e95593258ea7df368cc9 to your computer and use it in GitHub Desktop.
Save arrieta/2d2c2d99d665e95593258ea7df368cc9 to your computer and use it in GitHub Desktop.
Solve small knapsack problems by different methods
"""
knapsack.py
Solve small knapsack problems by different methods.
(C) 2017 J. Arrieta, Nabla Zero Labs
MIT License
"""
def knapsack_01(values, weights, capacity):
qs = range(capacity + 1)
ts = range(len(values))
K = { }
for t in ts:
for q in (qi for qi in qs if qi >= weights[t]):
take_t = K.get((q - weights[t], t - 1), 0) + values[t]
drop_t = K.get((q, t - 1), 0)
K[(q, t)] = max(take_t, drop_t)
return K.get((capacity, len(ts) - 1), 0)
def infinite_knapsack(values, weights, capacity):
root = (0, capacity)
stack = [ root ]
seen = set()
max_value = 0
max_depth = 0
while stack:
max_depth = max(max_depth, len(stack))
vp, qp = stack.pop()
if (vp, qp) in seen: continue
seen.add((vp, qp))
for vc, wc in zip(values, weights):
if qp > wc:
stack.append((vp + vc, qp - wc))
elif qp == wc:
max_value = max(max_value, vp + vc)
else:
max_value = max(max_value, vp)
return max_value
def main(N = 10):
from time import perf_counter
from random import randint
problems = (
([60, 100, 120], [10, 20, 30], 50),
([10, 40, 50, 70], [1, 3, 4, 5], 8),
([randint(1, 100) for _ in range(N)],
[randint(1, 100) for _ in range(N)],
randint(1, 100 * N)),
)
for values, weights, capacity in problems:
print(f"v = {values}",
f"w = {weights}",
f"Q = {capacity}",
sep="\n")
tbeg = perf_counter()
max_value = infinite_knapsack(values, weights, capacity)
tend = perf_counter()
elapsed = tend - tbeg
print(f"inf knapsack => {max_value:6,d}",
f"[elapsed: {elapsed:8.4f} sec]")
tbeg = perf_counter()
max_value = knapsack_01(values, weights, capacity)
tend = perf_counter()
elapsed = tend - tbeg
print(f"0/1 knapsack => {max_value:6,d}",
f"[elapsed: {elapsed:8.4f} sec]")
if __name__ == "__main__":
import sys
N = int(sys.argv[1]) if len(sys.argv) == 2 else 10
sys.exit(main(N))
@arrieta
Copy link
Author

arrieta commented Dec 6, 2017

Sample run:

$ python3 knapsack.py 20
v = [60, 100, 120]
w = [10, 20, 30]
Q = 50
inf knapsack =>    300 [elapsed:   0.0000 sec]
0/1 knapsack =>    220 [elapsed:   0.0001 sec]
v = [10, 40, 50, 70]
w = [1, 3, 4, 5]
Q = 8
inf knapsack =>    110 [elapsed:   0.0000 sec]
0/1 knapsack =>    100 [elapsed:   0.0000 sec]
v = [96, 12, 55, 14, 72, 69, 74, 14, 96, 36, 71, 39, 59, 53, 100, 22, 43, 50, 18, 87]
w = [38, 9, 76, 96, 83, 24, 15, 33, 69, 22, 66, 47, 73, 56, 25, 71, 6, 29, 22, 71]
Q = 537
inf knapsack =>  3,827 [elapsed:   9.5587 sec]
0/1 knapsack =>    861 [elapsed:   0.0091 sec]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment