Skip to content

Instantly share code, notes, and snippets.

@USM-F
Last active April 19, 2024 01:45
Show Gist options
  • Save USM-F/1287f512de4ffb2fb852e98be1ac271d to your computer and use it in GitHub Desktop.
Save USM-F/1287f512de4ffb2fb852e98be1ac271d to your computer and use it in GitHub Desktop.
Dynamic programming solution of Multiple-Choice Knapsack Problem (MCKP) in Python
#groups is list of integers in ascending order without gaps
def getRow(lists, row):
res = []
for l in lists:
for i in range(len(l)):
if i==row:
res.append(l[i])
return res
def multipleChoiceKnapsack(W, weights, values, groups):
n = len(values)
K = [[0 for x in range(W+1)] for x in range(n+1)]
for w in range(W+1):
for i in range(n+1):
if i==0 or w==0:
K[i][w] = 0
elif weights[i-1]<=w:
sub_max = 0
prev_group = groups[i-1]-1
sub_K = getRow(K, w-weights[i-1])
for j in range(n+1):
if groups[j-1]==prev_group and sub_K[j]>sub_max:
sub_max = sub_K[j]
K[i][w] = max(sub_max+values[i-1], K[i-1][w])
else:
K[i][w] = K[i-1][w]
return K[n][W]
#Example
values = [60, 100, 120]
weights = [10, 20, 30]
groups = [0, 1, 2]
W = 50
print(multipleChoiceKnapsack(W, weights, values, groups)) #220
@Arsalan-Vosough
Copy link

hey, this code is amazing but also it's toooooo complicated :( I tried so hard but I couldn't understand it at all. please help me. where can I find an explanation about it? any relative paper or thesis? what is the name of the algorithm? thank you a lot <3

@USM-F
Copy link
Author

USM-F commented Dec 24, 2020

Hello, it's just extension of typical algorithm for 0-1 Knapsack Problem. the best effort for this problem gives dynamic programming approach. You should learn firstly about these points and then back to my gist and try to understand.

But if you have questions only about my code, please ask more partly, what lines are tough for you?

@nacitar
Copy link

nacitar commented Apr 26, 2021

Nice gist! However, I'm curious... what if some items within these groups had a property such that it increased or decreased the overall capacity (W) by some percentage? Like, "selecting this item which weighs 3 units also increases the overall capacity by 10%"... so this example item would make your max weight increase from 50 to 55. do you have any ideas how you would incorporate such a thing? I'm working on a similar problem and these "special" items are causing me issues... so if you have any suggestions I'd love to hear them.

@USM-F
Copy link
Author

USM-F commented Jun 15, 2021 via email

@USM-F
Copy link
Author

USM-F commented Jun 15, 2021

Nice gist! However, I'm curious... what if some items within these groups had a property such that it increased or decreased the overall capacity (W) by some percentage? Like, "selecting this item which weighs 3 units also increases the overall capacity by 10%"... so this example item would make your max weight increase from 50 to 55. do you have any ideas how you would incorporate such a thing? I'm working on a similar problem and these "special" items are causing me issues... so if you have any suggestions I'd love to hear them.

It's not clear for me, can I look at each element as additional knapsack (looks like more natural) or I just can add capacity to common wight W? The second extension brakes classic approach DP for this task in my view, first could be implented at first sight.
For clarity, I would recomend you formalize problem in Integer Programming terms and try to implement it in CPLEX, Gurobi, Google OR-Tools and etc.

@jmont96
Copy link

jmont96 commented Dec 13, 2021

Hey this code is great! Do you know if your code picks exactly 1 item from each group? Thanks!

@USM-F
Copy link
Author

USM-F commented Dec 14, 2021

Hey this code is great! Do you know if your code picks exactly 1 item from each group? Thanks!

Hello, thank you. It picks 1 item or none from each group according to max weight constraint.

#Example
values = [60, 100, 120]
weights = [10, 20, 30]
groups = [0, 1, 2]
W = 50

Answer:
one element from group 1
one element from group 2
total weigth=50
value=220

You can check it by bruteforce.

@raphaottoni
Copy link

Great implementation! Thank you for sharing this!

Quick questions: Would it work if have some items with negative weights? I am currently modelling a problem where depending of the chosen item in the group it could count negative towards the total capacity (in other hands, a choice in a group could increase the capacity of the knapsack)

@USM-F
Copy link
Author

USM-F commented Jan 21, 2022

Quick questions: Would it work if have some items with negative weights? I am currently modelling a problem where depending of the chosen item in the group it could count negative towards the total capacity (in other hands, a choice in a group could increase the capacity of the knapsack)

Hello, this code will not work with negative weights, but it's possible to change and generalize it for your case:

  • calculate offset: abs sum of negative weights
  • initialize K by -offset
  • iterate to range(offset + W + 1)
  • add cheking that total weigth <= offset + W

Since it is a popular question, I might implement it later, stay tuned!

@nacitar
Copy link

nacitar commented Jan 21, 2022

I believe the question @raphaottoni asked about can effectively be an equivalent to the situation I described as well. If your capacity is 110 for example, and there's an item that weighs 1 unit but also increases your base capacity by 10%, it would cost 1 unit but increase capacity by 11 units, so it would effectively be equivalent to an item with a -10 unit weight. However the equivalent would vary depending upon your base capacity... if your capacity was 210 for example then the equivalent would be -20 units in weight.

@xRakun
Copy link

xRakun commented May 15, 2022

Hi, first of all thanks for sharing this. Can you explain how getRow function works? I am having hard time while converting this into Java :)

@redmoe
Copy link

redmoe commented Aug 14, 2022

Thanks for the script, it's very tidy! I'm attempting to convert it to JavaScript, but having some trouble.

It seems my conversion is incorrectly able to reuse items, rather than a 0-1, and so is getting 260.

I'm not exactly sure what need to fix, was hoping could get some help. Thanks!

I've marked parts of the data that are possibly the issue as they are getting different data with //!

  function getRow(lists, row) {
      let res = []
      for (let l of lists) {
          for (let i = 0; i < l.length; i++) {
              if (i === row) {
                  res.push(l[i]);
              }
          }
      }
      return res
  }
  function multipleChoiceKnapsack(W, weights, values, groups) {
      let n = values.length;
      let K = new Array(n + 1).fill(new Array(W + 1).fill(0));
      for (let w = 0; w < W + 1; w++) {
          for (let i = 0; i < n + 1; i++) {
              if (i === 0 || w === 0) {
                  K[i][w] = 0;
              }
              else if (weights.at(i - 1) <= w) {
                  let sub_max = 0;
                  let prev_group = groups.at(i - 1) - 1;
                  let sub_K = getRow(
                      K,//!
                      w - weights.at(i - 1)
                  );
                  for (let j = 0; j < n + 1; j++) {
                      if (groups.at(j - 1) === prev_group
                          && sub_K[j] > sub_max//!
                      ) {
                          sub_max = sub_K[j];//!
                      }
                  }
                  K[i][w] = Math.max(
                      sub_max + values.at(i - 1),//!
                      K.at(i - 1)[w]
                  );
              } else {
                  K[i][w] = K.at(i - 1)[w];
              }
          }
      }
      return K[n][W];
  }
let values = [60, 100, 120] 
let weights = [10, 20, 30] 
let groups = [0, 1, 2]
let W = 50
console.log(multipleChoiceKnapsack(W, weights, values, groups)) //260? should be 220

@pabloroldan98
Copy link

Hi, thanks for the script, looks amazing. Do you know by any chance what's its Time Complexity?

@USM-F
Copy link
Author

USM-F commented Jan 12, 2023

Hi, @pabloroldan98, this code is a bit unoptimal, complexity is about O(Wn^3), could be reduced to O(Wn^2) but not more.

@EklipZgit
Copy link

EklipZgit commented Aug 18, 2023

@USM-F incorrect regarding a best case reduction to O(Wn^2). Here is a much, much improved version of this original gist with code readability updated. It accurately predicts the runtime before beginning (on my machine, you'll obviously have to tweak the magic number on your machine for whatever IPC you get, or whatever. Allows you to throw an error and use a less optimal algorithm if you want, in time sensitive applications).

This is O(capacity * n * maxGroupSize).
It has an average runtime of O(capacity * n * sqrt(maxGroupSize)) (honestly, I can't explain why, I would have expected * avgGroupSize. But for whatever reason, this trends towards sqrt(maxGroupSize). If someone figures it out, do explain, because it sure looks like it should be avgGroupSize. But the emprical estimations do not lie, this trends towards sqrt).

It also outputs the actual items that were included in the max solution (this is trickier than 0-1 knapsack, if you try that solution work back it will incorrectly think multiple items from same group are included; a single additional check for skipping to the next group must be added).

There were two key performance improvements.

First, the getRow method in the original here created completely unnecessary sublists. I replaced that with direct indexes into K for a 4x speedup.

Second,

for j in range(n+1):
    if groups[j-1]==prev_group and sub_K[j]>sub_max:

was not necessary and can be optimized into a precomputed set of group start and end indexes so that that n^2 becomes

 if prev_group > -1:
    prevGroupStart, prevGroupEnd = groupStartEnds[prev_group]
    for j in range(prevGroupStart + 1, prevGroupEnd + 1):

which has two upsides; one it causes the groupCount=1 case to become a single scan and run hundreds of times faster than it otherwise would have with the n^2, but also it brings that n^2 down to n*maxGroupSize worst case.

Some runtime comparisons AFTER the getRow 4x deletion but before the n^2 reduction, vs the new runtime:
capacity 750 with 200 items = 570ms (now 80ms)
capacity 750 with 100 items = 138ms (now 27ms)
capacity 75 with 400 items = 138ms (now 20ms)
capacity 75 with 400 items, avg weight 10 = 217ms (now 15ms)
capacity 75 with 400 items, avg weight 2 = 246ms (now 20ms)

Enjoy:

def solve_multiple_choice_knapsack(
        items: typing.List[object],
        capacity: int,
        weights: typing.List[int],
        values: typing.List[int],
        groups: typing.List[int],
        noLog: bool = True,
        forceLongRun: bool = False
) -> typing.Tuple[int, typing.List[object]]:
    """
    Solves knapsack where you need to knapsack a bunch of things, but must pick at most one thing from each group of things
    #Example
    items = ['a', 'b', 'c']
    values = [60, 100, 120]
    weights = [10, 20, 30]
    groups = [0, 1, 1]
    capacity = 50
    maxValue, itemList = solve_multiple_choice_knapsack(items, capacity, weights, values, groups)

    Extensively optimized by Travis Drake / EklipZgit by an order of magnitude, original implementation cloned from: https://gist.github.com/USM-F/1287f512de4ffb2fb852e98be1ac271d

    @param items: list of the items to be maximized in the knapsack. Can be a list of literally anything, just used to return the chosen items back as output.
    @param capacity: the capacity of weights that can be taken.
    @param weights: list of the items weights, in same order as items
    @param values: list of the items values, in same order as items
    @param groups: list of the items group id number, in same order as items. MUST start with 0, and cannot skip group numbers.
    @return: returns a tuple of the maximum value that was found to fit in the knapsack, along with the list of optimal items that reached that max value.
    """

    timeStart = time.perf_counter()
    groupStartEnds: typing.List[typing.Tuple[int, int]] = []
    if groups[0] != 0:
        raise AssertionError('Groups must start with 0 and increment by one for each new group. Items should be ordered by group.')

    lastGroup = -1
    lastGroupIndex = 0
    maxGroupSize = 0
    curGroupSize = 0
    for i, group in enumerate(groups):
        if group > lastGroup:
            if curGroupSize > maxGroupSize:
                maxGroupSize = curGroupSize
            if lastGroup > -1:
                groupStartEnds.append((lastGroupIndex, i))
                curGroupSize = 0
            if group > lastGroup + 1:
                raise AssertionError('Groups must have no gaps. if you have group 0, and 2, group 1 must be included between them.')
            lastGroupIndex = i
            lastGroup = group

        curGroupSize += 1

    groupStartEnds.append((lastGroupIndex, len(groups)))
    if curGroupSize > maxGroupSize:
        maxGroupSize = curGroupSize

    # if BYPASS_TIMEOUTS_FOR_DEBUGGING:
    for value in values:
        if not isinstance(value, int):
            raise AssertionError('values are all required to be ints or this algo will not function')

    n = len(values)
    K = [[0 for x in range(capacity + 1)] for x in range(n + 1)]
    """knapsack max values"""

    maxGrSq = math.sqrt(maxGroupSize)
    estTime = n * capacity * math.sqrt(maxGroupSize) * 0.00000022
    """rough approximation of the time it will take on MY machine, I set an arbitrary warning threshold"""
    if maxGroupSize == n:
        # this is a special case that behaves like 0-1 knapsack and doesn't multiply by max group size at all, due to the -1 check in the loop below.
        estTime = n * capacity * 0.00000022

    if estTime > 0.010 and not forceLongRun:
        raise AssertionError(f"The inputs (n {n} * capacity {capacity} * math.sqrt(maxGroupSize {maxGroupSize}) {maxGrSq}) are going to result in a substantial runtime, maybe try a different algorithm")
    if not noLog:
        logging.info(f'estimated knapsack time: {estTime:.3f} (n {n} * capacity {capacity} * math.sqrt(maxGroupSize {maxGroupSize}) {maxGrSq})')

    for curCapacity in range(capacity + 1):
        for i in range(n + 1):
            if i == 0 or curCapacity == 0:
                K[i][curCapacity] = 0
            elif weights[i - 1] <= curCapacity:
                sub_max = 0
                prev_group = groups[i - 1] - 1
                subKRow = curCapacity - weights[i - 1]
                if prev_group > -1:
                    prevGroupStart, prevGroupEnd = groupStartEnds[prev_group]
                    for j in range(prevGroupStart + 1, prevGroupEnd + 1):
                        if groups[j - 1] == prev_group and K[j][subKRow] > sub_max:
                            sub_max = K[j][subKRow]
                K[i][curCapacity] = max(sub_max + values[i - 1], K[i - 1][curCapacity])
            else:
                K[i][curCapacity] = K[i - 1][curCapacity]

    res = K[n][capacity]
    timeTaken = time.perf_counter() - timeStart
    if not noLog:
        logging.info(f"Value Found {res} in {timeTaken:.3f}")
    includedItems = []
    includedGroups = []
    w = capacity
    lastTakenGroup = -1
    for i in range(n, 0, -1):
        if res <= 0:
            break
        if i == 0:
            raise AssertionError(f"i == 0 in knapsack items determiner?? res {res} i {i} w {w}")
        if w < 0:
            raise AssertionError(f"w < 0 in knapsack items determiner?? res {res} i {i} w {w}")
        # either the result comes from the
        # top (K[i-1][w]) or from (val[i-1]
        # + K[i-1] [w-wt[i-1]]) as in Knapsack
        # table. If it comes from the latter
        # one/ it means the item is included.
        # THIS IS WHY VALUE MUST BE INTS
        if res == K[i - 1][w]:
            continue

        group = groups[i - 1]
        if group == lastTakenGroup:
            continue

        includedGroups.append(group)
        lastTakenGroup = group
        # This item is included.
        if not noLog:
            logging.info(
                f"item at index {i - 1} with value {values[i - 1]} and weight {weights[i - 1]} was included... adding it to output. (Res {res})")
        includedItems.append(items[i - 1])

        # Since this weight is included
        # its value is deducted
        res = res - values[i - 1]
        w = w - weights[i - 1]

    uniqueGroupsIncluded = set(includedGroups)
    if len(uniqueGroupsIncluded) != len(includedGroups):
        raise AssertionError("Yo, the multiple choice knapsacker failed to be distinct by groups")

    if not noLog:
        logging.info(
            f"multiple choice knapsack completed on {n} items for capacity {capacity} finding value {K[n][capacity]} in Duration {time.perf_counter() - timeStart:.3f}")

    return K[n][capacity], includedItems

Here, have free tests, too:

class SearchUtils_MCKP_Tests(TestBase):


    def test_multiple_choice_knapsack_solver__more_capacity_than_items__0_1_base_case__includes_all(self):
        groupItemWeightValues = [
            (0, 'a', 1, 1),
            (0, 'b', 1, 1),
            (0, 'c', 1, 1),
            (0, 'd', 1, 1),
            (0, 'e', 1, 1),
            (0, 'f', 1, 1),
            (0, 'g', 1, 1),
            (0, 'h', 1, 1),
            (0, 'i', 1, 1),
            (0, 'j', 1, 1),
            (0, 'k', 1, 1),
            (0, 'l', 1, 1),
            (0, 'm', 1, 1),
            (0, 'n', 1, 1),
            (0, 'o', 1, 1),
            (0, 'p', 1, 1),
            (0, 'q', 1, 1),
            (0, 'r', 1, 1),
            (0, 's', 1, 1),
            (0, 't', 1, 1),
            (0, 'u', 1, 1),
            (0, 'v', 1, 1),
            (0, 'w', 1, 1),
            (0, 'x', 1, 1),
            (0, 'y', 1, 1),
            (0, 'z', 1, 1)
        ]

        # give each own group for this test
        i = 0
        for groupItemWeightValue in groupItemWeightValues:
            (group, item, weight, value) = groupItemWeightValue
            groupItemWeightValues[i] = i, item, weight, value
            i += 1

        maxValue, items = self.execute_multiple_choice_knapsack_with_tuples(groupItemWeightValues, capacity=50)

        #should have included every letter once, as this boils down to the 0-1 knapsack problem
        self.assertEqual(26, maxValue)
        self.assertEqual(26, len(items))

    def test_multiple_choice_knapsack_solver__respects_groups(self):
        groupItemWeightValues = [
            (0, 'a', 1, 2),
            (0, 'b', 1, 1),
            (0, 'c', 1, 1),
            (0, 'd', 1, 1),
            (0, 'e', 1, 1),
            (0, 'f', 1, 1),
            (0, 'g', 1, 1),
            (0, 'h', 1, 1),
            (0, 'i', 1, 1),
            (0, 'j', 1, 1),
            (0, 'k', 1, 1),
            (0, 'l', 1, 1),
            (0, 'm', 1, 1),
            (0, 'n', 1, 1),
            (0, 'o', 1, 1),
            (0, 'p', 1, 1),
            (0, 'q', 1, 1),
            (0, 'r', 1, 1),
            (0, 's', 1, 1),
            (0, 't', 1, 1),
            (0, 'u', 1, 1),
            (0, 'v', 1, 1),
            (0, 'w', 1, 1),
            (0, 'x', 1, 1),
            (0, 'y', 1, 1),
            (0, 'z', 1, 1)
        ]

        # give first 10 own group, rest are 0
        i = 0
        for groupItemWeightValue in groupItemWeightValues:
            if i >= 10:
                break
            (group, item, weight, value) = groupItemWeightValue
            groupItemWeightValues[i] = i, item, weight, value
            i += 1


        maxValue, items = self.execute_multiple_choice_knapsack_with_tuples(groupItemWeightValues, capacity=50)

        #should have included every letter from first 10 once, and a should be the group '0' entry worth 2, so value 11
        self.assertEqual(11, maxValue)
        self.assertEqual(10, len(items))


    def test_multiple_choice_knapsack_solver__with_constrained_capacity__finds_best_group_subset(self):
        groupItemWeightValues = [
            (0, 'a', 7, 10),
            (0, 'b', 5, 8),
            (0, 'c', 2, 5),
            (1, 'd', 5, 5),
            (1, 'e', 2, 3),
        ]

        maxValue, items = self.execute_multiple_choice_knapsack_with_tuples(groupItemWeightValues, capacity=7)

        # at 7 capacity, we expect it to pick b (5, 8) and e (2, 3) for 11 at weight 7
        self.assertEqual(11, maxValue)
        self.assertEqual(2, len(items))

    def test_multiple_choice_knapsack_solver__should_not_have_insane_time_complexity(self):
        # envision our worst case gather scenario, lets say a depth 75 gather. We might want to run this 50 times against 200 items from maybe 40 groups each time
        # gathers may have a large max value so lets generate values on the scale of 150
        simulatedItemCount = 200
        simulatedGroupCount = 50
        maxValuePerItem = 150
        maxWeightPerItem = 5
        capacity = 75

        groupItemWeightValues = self.generate_item_test_set(simulatedItemCount, simulatedGroupCount, maxWeightPerItem, maxValuePerItem)

        start = time.perf_counter()
        maxValue, items = self.execute_multiple_choice_knapsack_with_tuples(groupItemWeightValues, capacity=capacity)
        self.assertGreater(maxValue, 0)
        # doubt we ever find less than this
        self.assertGreater(len(items), 20)
        endTime = time.perf_counter()
        duration = endTime - start
        self.assertLess(duration, 0.03)

    def test_multiple_choice_knapsack_solver__should_not_have_insane_time_complexity__high_item_count_low_group_count(self):
        # envision our worst case gather scenario, lets say a depth 75 gather. We might want to run this 50 times against 200 items from maybe 40 groups each time
        # gathers may have a large max value so lets generate values on the scale of 150
        simulatedItemCount = 400
        simulatedGroupCount = 21
        maxValuePerItem = 150
        maxWeightPerItem = 20
        capacity = 75

        groupItemWeightValues = self.generate_item_test_set(simulatedItemCount, simulatedGroupCount, maxWeightPerItem, maxValuePerItem)

        start = time.perf_counter()
        maxValue, items = self.execute_multiple_choice_knapsack_with_tuples(groupItemWeightValues, capacity=capacity)
        self.assertGreater(maxValue, 0)
        # doubt we ever find less than this
        self.assertGreater(len(items), 20)
        endTime = time.perf_counter()
        duration = endTime - start
        self.assertLess(duration, 0.03)

    def test_multiple_choice_knapsack_solver__should_not_have_insane_time_complexity__low_item_count_same_group_count(self):
        # TIME PERFORMANCE NOTES:
        # Scales exponentially with item count, 100 items = 15ms, 200 items = 50ms, 400 items = 160ms
        # Lower numbers of groups slightly increase the runtime by like 20% linearly (so, 50 ms 0-1 to 56ms ish with a pure 50/50 split on groups).
        # value does not matter at all to runtime
        # average item weight does slightly change it, see below for how it interacts with capacity
        # capacity is the big one, the less items you can fit in the solution regardless of input size, the faster it runs.
        # capacity 750 with 200 items = 570ms (now 80ms)
        # capacity 750 with 100 items = 138ms (now 27ms)
        # capacity 75 with 400 items = 138ms (now 20ms)
        # capacity 75 with 400 items, avg weight 10 = 217ms (now 15ms)
        # capacity 75 with 400 items, avg weight 2 = 246ms (now 20ms)
        simulatedItemCount = 400
        simulatedGroupCount = 1
        maxValuePerItem = 150
        maxWeightPerItem = 5
        capacity = 750

        groupItemWeightValues = self.generate_item_test_set(simulatedItemCount, simulatedGroupCount, maxWeightPerItem, maxValuePerItem)

        start = time.perf_counter()
        maxValue, items = self.execute_multiple_choice_knapsack_with_tuples(groupItemWeightValues, capacity=capacity)
        self.assertGreater(maxValue, 0)
        # doubt we ever find less than this
        self.assertEqual(len(items), simulatedGroupCount)
        endTime = time.perf_counter()
        duration = endTime - start

        self.assertLess(duration, 0.15)

    def execute_multiple_choice_knapsack_with_tuples(
            self,
            groupItemWeightValues: typing.List[typing.Tuple[int, object, int, int]],
            capacity: int):
        groupItemWeightValues = [t for t in sorted(groupItemWeightValues)]
        items = []
        groups = []
        weights = []
        values = []

        for (group, item, weight, value) in groupItemWeightValues:
            items.append(item)
            groups.append(group)
            weights.append(weight)
            values.append(value)

        return KnapsackUtils.solve_multiple_choice_knapsack(items, capacity, weights, values, groups, noLog=False, forceLongRun=True)

    def generate_item_test_set(self, simulatedItemCount, simulatedGroupCount, maxWeightPerItem, maxValuePerItem):
        groupItemWeightValues = []
        r = random.Random()

        # at least one per group
        for i in range(simulatedGroupCount):
            item = i
            group = i
            value = r.randint(0, maxValuePerItem)
            weight = r.randint(1, maxWeightPerItem)
            groupItemWeightValues.append((group, item, weight, value))

        # then random groups after that
        for i in range(simulatedItemCount - simulatedGroupCount):
            item = i + simulatedGroupCount
            group = r.randint(0, simulatedGroupCount - 1)
            value = r.randint(0, maxValuePerItem)
            weight = r.randint(1, maxWeightPerItem)
            groupItemWeightValues.append((group, item, weight, value))

        return groupItemWeightValues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment