Skip to content

Instantly share code, notes, and snippets.

@jhrmnn
Created September 1, 2022 08:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jhrmnn/86f0b061a9933c271dceab121dfdb0a0 to your computer and use it in GitHub Desktop.
Save jhrmnn/86f0b061a9933c271dceab121dfdb0a0 to your computer and use it in GitHub Desktop.
import torch
__all__ = ()
@torch.no_grad()
def solve_lap(C):
assert len(C.shape) == 3
assert C.shape[2] >= C.shape[1]
assert (C >= 0).all()
# step 1, initialize
bs, NR, NC = C.shape
u, v = (C.new_zeros(bs, n) for n in (NR, NC))
curRow = C.new_zeros(len(C), dtype=torch.long)
col4row, row4col = (C.new_full((bs, n), -1, dtype=torch.long) for n in (NR, NC))
path = C.new_empty(bs, NC, dtype=torch.long)
_ib = torch.arange(bs, device=C.device)
_inf = C.new_tensor(torch.inf)
for _ in range(NR):
# step 2, prepare for augmentation
shortestPathCosts = C.new_full((bs, NC), torch.inf)
SR, SC = (C.new_zeros(bs, n, dtype=torch.bool) for n in (NR, NC))
sink = C.new_full((bs,), -1, dtype=torch.long)
minVal = C.new_zeros(bs)
i = curRow.clone()
# step 3, find the shortest augmenting path
while True:
_bx = sink == -1
if not _bx.any():
break
SR[_ib[_bx], i[_bx]] = True
_SC = SC[_bx]
expand = lambda x: x[:, None].expand_as(_SC)
path_costs = (
expand(minVal[_bx])[~_SC]
+ C[_ib[_bx], i[_bx]][~_SC]
- expand(u[_ib[_bx], i[_bx]])[~_SC]
- v[_bx][~_SC]
)
_jx = path_costs < shortestPathCosts[_bx][~_SC]
shortestPathCosts.masked_scatter_(
_bx[:, None] & ~SC,
torch.where(_jx, path_costs, shortestPathCosts[_bx][~_SC]),
)
path.masked_scatter_(
_bx[:, None] & ~SC,
torch.where(_jx, expand(i[_bx])[~_SC], path[_bx][~_SC]),
)
j = shortestPathCosts[_bx].where(~_SC, _inf).min(dim=-1).indices
assert (shortestPathCosts[_ib[_bx], j] < _inf).all()
SC[_ib[_bx], j] = True
minVal[_bx] = shortestPathCosts[_ib[_bx], j]
sink.masked_scatter_(
_bx, torch.where(row4col[_ib[_bx], j] == -1, j, sink[_bx])
)
i[_bx] = row4col[_ib[_bx], j]
# step 4, update the dual variables
u[_ib, curRow] += minVal
SR[_ib, curRow] = False
u[SR] += (
minVal[:, None].expand_as(u)[SR]
- shortestPathCosts[_ib[:, None].expand_as(u)[SR], col4row[SR]]
)
_i, _j = SC.nonzero().t()
v[SC] += -minVal[:, None].expand_as(v)[SC] + shortestPathCosts[SC]
# step 5, augment the previous solution
j = sink
_bx = C.new_ones(bs, dtype=torch.bool)
while True:
i = path[_ib[_bx], j]
row4col[_ib[_bx], j] = i
temp = col4row[_ib[_bx], i]
col4row[_ib[_bx], i] = j
_bx_temp = _bx.clone()
_bx[_bx_temp] = i != curRow[_bx]
if not _bx.any():
break
j = temp[_bx[_bx_temp]]
# step 6, loop
curRow += 1
return torch.arange(bs)[:, None].expand_as(col4row), col4row
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment