Created
March 14, 2022 13:10
-
-
Save nmichlo/ecca073c9b1d23ebd6e5cb5d6a1e49ef to your computer and use it in GitHub Desktop.
torchsort but allow specifying the dimension
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import lru_cache | |
from typing import Tuple | |
from typing import Union | |
import numpy as np | |
import torch | |
# ========================================================================= # | |
# helper functions for moving dimensions # | |
# ========================================================================= # | |
@lru_cache(maxsize=32) | |
def _get_2d_reshape_info(shape: Tuple[int, ...], dims: Union[int, Tuple[int, ...]] = -1): | |
if isinstance(dims, int): | |
dims = (dims,) | |
# number of dimensions & remove negatives | |
ndim = len(shape) | |
dims = tuple((ndim + d) if d < 0 else d for d in dims) | |
# check that we have at least 2 dims & that all values are valid | |
assert all(0 <= d < ndim for d in dims) | |
# return new shape | |
if ndim == 1: | |
return [0], (1, *shape) | |
# check resulting shape | |
assert ndim >= 2 | |
# get dims | |
dims_X = set(dims) | |
dims_B = set(range(ndim)) - dims_X | |
# sort dims | |
dims_X = sorted(dims_X) | |
dims_B = sorted(dims_B) | |
# compute shape | |
shape = np.array(shape) | |
size_B = int(np.prod(shape[dims_B])) | |
size_X = int(np.prod(shape[dims_X])) | |
# variables | |
moved_end_dims = tuple(dims_X[::-1]) | |
reshape_size = (size_B, size_X) | |
# sort dims | |
return moved_end_dims, reshape_size | |
def torch_dims_at_end_2d(tensor: torch.Tensor, dims: Union[int, Tuple[int, ...]] = -1, return_undo_data=True): | |
# get dim info | |
moved_end_dims, reshape_size = _get_2d_reshape_info(tensor.shape, dims=dims) | |
# move all axes | |
for d in moved_end_dims: | |
tensor = torch.moveaxis(tensor, d, -1) | |
moved_shape = tensor.shape | |
# reshape | |
tensor = torch.reshape(tensor, reshape_size) | |
# return all info | |
if return_undo_data: | |
return tensor, moved_shape, moved_end_dims | |
else: | |
return tensor | |
def torch_undo_dims_at_end_2d(tensor: torch.Tensor, moved_shape, moved_end_dims): | |
# reshape | |
tensor = torch.reshape(tensor, moved_shape) | |
# undo moving of dims | |
for d in moved_end_dims[::-1]: | |
tensor = torch.moveaxis(tensor, -1, d) | |
# reshape | |
return tensor | |
# ========================================================================= # | |
# Improved differentiable torchsort functions # | |
# ========================================================================= # | |
def torch_soft_sort( | |
tensor: torch.Tensor, | |
dims: Union[int, Tuple[int, ...]] = -1, | |
regularization='l2', | |
regularization_strength=1.0, | |
leave_dims_at_end=False, | |
): | |
# we import it locally so that we don't have to install this | |
import torchsort | |
# reorder the dimensions | |
tensor, moved_shape, moved_end_dims = torch_dims_at_end_2d(tensor, dims=dims, return_undo_data=True) | |
# sort the last dimension of the 2D tensors | |
tensor = torchsort.soft_sort(tensor, regularization=regularization, regularization_strength=regularization_strength) | |
# undo the reorder operation | |
if leave_dims_at_end: | |
return tensor | |
return torch_undo_dims_at_end_2d(tensor, moved_shape=moved_shape, moved_end_dims=moved_end_dims) | |
def torch_soft_rank( | |
tensor: torch.Tensor, | |
dims: Union[int, Tuple[int, ...]] = -1, | |
regularization='l2', | |
regularization_strength=1.0, | |
leave_dims_at_end=False, | |
): | |
# we import it locally so that we don't have to install this | |
import torchsort | |
# reorder the dimensions | |
tensor, moved_shape, moved_end_dims = torch_dims_at_end_2d(tensor, dims=dims, return_undo_data=True) | |
# sort the last dimension of the 2D tensors | |
tensor = torchsort.soft_rank(tensor, regularization=regularization, regularization_strength=regularization_strength) | |
# undo the reorder operation | |
if leave_dims_at_end: | |
return tensor | |
return torch_undo_dims_at_end_2d(tensor, moved_shape=moved_shape, moved_end_dims=moved_end_dims) | |
# ========================================================================= # | |
# end # | |
# ========================================================================= # |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment