Skip to content

Instantly share code, notes, and snippets.

@appgurueu
Created June 28, 2023 10:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save appgurueu/f2192d081cd6e1d7dba4ee5f8a3fe1b6 to your computer and use it in GitHub Desktop.
Save appgurueu/f2192d081cd6e1d7dba4ee5f8a3fe1b6 to your computer and use it in GitHub Desktop.
Count inversions summing up to `x`
-- Use a modified mergesort to count the inversions summing up to `x`
local function count_x_inversions(
list, -- of distinct nums
x -- target sum
)
local function merge(result, left, right)
local inversions = 0
local i, j, k = 1, 1, 1
local left_idx = {}
for idx, v in ipairs(left) do
assert(not left_idx[v], "nums aren't distinct")
left_idx[v] = idx
end
while i <= #left and j <= #right do
-- Compare "head" element, insert "winner"
if right[j] < left[i] then
-- right list came first, this is an inversion with all the larger elements left in the left list
-- This is how you would normally count inversions: `inversions = inversions + (#left - i + 1)`
-- This is the crucial part; note that there can only be *one* inversion with `right[j]` summing up to exactly `x`
-- Note that we could also do a binary search on `left[i:]` here;
-- this would give worst case O(log n) instead of expected O(1) and worst case O(n) for this step,
-- leading to hard O(n (log n)^2) total vs. expected O(n log n) but worst case O(n^2 log n).
if (left_idx[x - right[j]] or -math.huge) >= i then
inversions = inversions + 1
end
result[k] = right[j]
j = j + 1
else
result[k] = left[i]
i = i + 1
end
k = k + 1
end
-- Add remaining elements of either list
for offset = 0, #left - i do
result[k + offset] = left[i + offset]
end
for offset = 0, #right - j do
result[k + offset] = right[j + offset]
end
return inversions
end
local function mergesort(list_to_sort, from, to)
if from == to then
list_to_sort[1] = list[from]
end
if from >= to then
return 0
end
local mid = math.floor((to + from) / 2)
local left = {}
local left_inversions = mergesort(left, from, mid)
local right = {}
local right_inversions = mergesort(right, mid + 1, to)
return merge(list_to_sort, left, right) + left_inversions + right_inversions
end
return mergesort(list, 1, #list)
end
local function count_x_inversions_naive(list, x)
local inversions = 0
for i = 1, #list do
for j = i + 1, #list do
if list[i] > list[j] and list[i] + list[j] == x then
inversions = inversions + 1
end
end
end
return inversions
end
-- Tests
do
assert(count_x_inversions({2, 0, 5, 1}, 3) == count_x_inversions_naive({2, 0, 5, 1}, 3))
-- Fuzzing because I'm lazy
local function shuffle(
list -- list to be shuffled in-place
)
for index = 1, #list - 1 do
local index_2 = math.random(index, #list)
list[index], list[index_2] = list[index_2], list[index]
end
end
for _ = 1, 10 do
local t = {}
local n = math.random(10, 1e3)
for i = 1, n do
t[i] = i
end
shuffle(t)
local x = math.random() < 0.5 and (1 + n) or math.random(math.ceil(n/4), math.floor(3*n/4))
local naive = count_x_inversions_naive(t, x)
--! count_x_inversion sorts `t`, reducing the inversions to 0!
assert(naive == count_x_inversions(t, x))
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment