Skip to content

Instantly share code, notes, and snippets.

@nmichlo
Created March 14, 2022 13:10
Show Gist options
  • Save nmichlo/ecca073c9b1d23ebd6e5cb5d6a1e49ef to your computer and use it in GitHub Desktop.
Save nmichlo/ecca073c9b1d23ebd6e5cb5d6a1e49ef to your computer and use it in GitHub Desktop.
torchsort but allow specifying the dimension
from functools import lru_cache
from typing import Tuple
from typing import Union
import numpy as np
import torch
# ========================================================================= #
# helper functions for moving dimensions #
# ========================================================================= #
@lru_cache(maxsize=32)
def _get_2d_reshape_info(shape: Tuple[int, ...], dims: Union[int, Tuple[int, ...]] = -1):
if isinstance(dims, int):
dims = (dims,)
# number of dimensions & remove negatives
ndim = len(shape)
dims = tuple((ndim + d) if d < 0 else d for d in dims)
# check that we have at least 2 dims & that all values are valid
assert all(0 <= d < ndim for d in dims)
# return new shape
if ndim == 1:
return [0], (1, *shape)
# check resulting shape
assert ndim >= 2
# get dims
dims_X = set(dims)
dims_B = set(range(ndim)) - dims_X
# sort dims
dims_X = sorted(dims_X)
dims_B = sorted(dims_B)
# compute shape
shape = np.array(shape)
size_B = int(np.prod(shape[dims_B]))
size_X = int(np.prod(shape[dims_X]))
# variables
moved_end_dims = tuple(dims_X[::-1])
reshape_size = (size_B, size_X)
# sort dims
return moved_end_dims, reshape_size
def torch_dims_at_end_2d(tensor: torch.Tensor, dims: Union[int, Tuple[int, ...]] = -1, return_undo_data=True):
# get dim info
moved_end_dims, reshape_size = _get_2d_reshape_info(tensor.shape, dims=dims)
# move all axes
for d in moved_end_dims:
tensor = torch.moveaxis(tensor, d, -1)
moved_shape = tensor.shape
# reshape
tensor = torch.reshape(tensor, reshape_size)
# return all info
if return_undo_data:
return tensor, moved_shape, moved_end_dims
else:
return tensor
def torch_undo_dims_at_end_2d(tensor: torch.Tensor, moved_shape, moved_end_dims):
# reshape
tensor = torch.reshape(tensor, moved_shape)
# undo moving of dims
for d in moved_end_dims[::-1]:
tensor = torch.moveaxis(tensor, -1, d)
# reshape
return tensor
# ========================================================================= #
# Improved differentiable torchsort functions #
# ========================================================================= #
def torch_soft_sort(
tensor: torch.Tensor,
dims: Union[int, Tuple[int, ...]] = -1,
regularization='l2',
regularization_strength=1.0,
leave_dims_at_end=False,
):
# we import it locally so that we don't have to install this
import torchsort
# reorder the dimensions
tensor, moved_shape, moved_end_dims = torch_dims_at_end_2d(tensor, dims=dims, return_undo_data=True)
# sort the last dimension of the 2D tensors
tensor = torchsort.soft_sort(tensor, regularization=regularization, regularization_strength=regularization_strength)
# undo the reorder operation
if leave_dims_at_end:
return tensor
return torch_undo_dims_at_end_2d(tensor, moved_shape=moved_shape, moved_end_dims=moved_end_dims)
def torch_soft_rank(
tensor: torch.Tensor,
dims: Union[int, Tuple[int, ...]] = -1,
regularization='l2',
regularization_strength=1.0,
leave_dims_at_end=False,
):
# we import it locally so that we don't have to install this
import torchsort
# reorder the dimensions
tensor, moved_shape, moved_end_dims = torch_dims_at_end_2d(tensor, dims=dims, return_undo_data=True)
# sort the last dimension of the 2D tensors
tensor = torchsort.soft_rank(tensor, regularization=regularization, regularization_strength=regularization_strength)
# undo the reorder operation
if leave_dims_at_end:
return tensor
return torch_undo_dims_at_end_2d(tensor, moved_shape=moved_shape, moved_end_dims=moved_end_dims)
# ========================================================================= #
# end #
# ========================================================================= #
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment