Skip to content

Instantly share code, notes, and snippets.

@mberr
Last active February 28, 2022 18:20
Show Gist options
  • Save mberr/7f08a37a56addb083258adfbca12b837 to your computer and use it in GitHub Desktop.
Save mberr/7f08a37a56addb083258adfbca12b837 to your computer and use it in GitHub Desktop.
Several Similarity Matrix Normalization Methods written in PyTorch
"""Several similarity matrix normalization methods."""
import torch
def csls(
sim: torch.FloatTensor,
k: Optional[int] = 1,
) -> torch.FloatTensor:
"""
Apply CSLS normalization to a similarity matrix.
.. math::
csls[i, j] = 2*sim[i, j] - avg(top_k(sim[i, :])) - avg(top_k(sim[:, j]))
:param sim: shape: (d1, ..., dk)
Similarity matrix.
:param k:
The number of top-k elements to use for correction.
:return:
The normalized similarity matrix.
"""
if k is None:
return sim
# Empty similarity matrix
if sim.numel() < 1:
return sim
old_sim = sim
# compensate for subtraction
sim = sim.ndimension() * sim
# Subtract average over top-k similarities for each mode of the tensors.
for dim, size in enumerate(sim.size()):
sim = sim - old_sim.topk(k=min(k, size), dim=dim, largest=True, sorted=False).values.mean(dim=dim, keepdim=True)
return sim
def sinkhorn_knopp(
similarities: torch.FloatTensor,
eps: float = 1.0e-04,
max_iter: int = 1000,
) -> torch.FloatTensor:
"""
Normalize similarities to be double stochastic using the Sinkhorn-Knopp algorithm.
:param similarities: shape: (n, n)
The similarities.
:param eps:
A tolerance for convergence check.
:param max_iter:
A maximum number of iterations.
:return:
The normalized similarities (in log space!).
.. seealso ::
http://www.cerfacs.fr/algor/reports/2006/TR_PA_06_42.pdf
https://github.com/HeddaCohenIndelman/Learning-Gumbel-Sinkhorn-Permutations-w-Pytorch/blob/8fbc8cf4b97f5bafd18776b5497e3f724d60cc0a/my_sinkhorn_ops.py#L36
"""
# input verification
n = similarities.shape[0]
if similarities.ndimension() != 2 or similarities.shape[1] != n:
raise ValueError(f'similarities have to be a square matrix, but have shape: {similarities.shape}')
# fix-point iteration
for _ in range(max_iter):
old_similarities = similarities
# update
similarities = similarities - similarities.logsumexp(dim=-1, keepdim=True)
similarities = similarities - similarities.logsumexp(dim=-2, keepdim=True)
# convergence check
if (old_similarities - similarities).norm() < eps:
break
return similarities
def bidirectional_alignment(
similarities: torch.FloatTensor,
normalize: bool = False,
) -> torch.FloatTensor:
"""
Compute bi-directional alignment scores.
.. note ::
This operation is non-differentiable.
.. seealso ::
https://www.aclweb.org/anthology/D19-1075.pdf
:param similarities: shape: (n, m)
The similarity scores.
:param normalize:
Use the normalized rank instead of the rank; also use mean instead of sum over both directions. Guarantees
that the output has value range [0, 1].
:return: shape: (n, m)
The new similarity scores.
"""
left_to_right, right_to_left = [similarities.argsort(dim=dim).float() for dim in (0, 1)]
if normalize:
left_to_right = 0.5 * left_to_right / similarities.shape[0]
right_to_left = 0.5 * right_to_left / similarities.shape[1]
return left_to_right + right_to_left
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment