Last active
April 7, 2021 03:37
-
-
Save xmodar/60fe0d0e3536adc28778448419908f47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, Optional, Union, List | |
import torch | |
import torch.nn as nn | |
__all__ = [ | |
'dot', 'get_neighbors', 'gather_features', 'point_sparsity', | |
'weighted_sampling' | |
] | |
# @torch.jit.script | |
def dot(inputs: torch.Tensor, other: torch.Tensor) -> torch.Tensor: | |
"""Batched version of torch.dot (the last dimensions must match)""" | |
return (inputs.unsqueeze(-2) @ other.unsqueeze(-1)).flatten(-3) | |
# @torch.jit.script | |
def get_neighbors( | |
num_neighbors: int, | |
features: torch.Tensor, | |
neighbors: Optional[torch.Tensor] = None, | |
p_norm: float = 2, | |
farthest: bool = False, | |
ordered: bool = False, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Get the distances and indices to a fixed number of neighbors | |
Args: | |
num_neighbors: number of neighbors to consider | |
features: query points which we need their neighbors [*, N, F] | |
neighbors: set of neighborhood points (`features` if None) [*, M, F] | |
p_norm: distances are computed based on L_p norm | |
farthest: whether to get the farthest or the nearest neighbors | |
ordered: distance sorted (descending if `farthest` ascending otherwise) | |
Returns: | |
(distances, indices) both of shape [*, N, `num_neighbors`] | |
""" | |
if neighbors is None: | |
neighbors = features | |
distance = torch.cdist(features, neighbors, p_norm) | |
return distance.topk(num_neighbors, | |
dim=-1, | |
largest=farthest, | |
sorted=ordered) | |
# @torch.jit.script | |
def gather_features( | |
features: torch.Tensor, | |
indices: torch.Tensor, | |
transposed: bool = False, | |
# TODO: investigate whether we should set sparse_grad to True by default | |
sparse_grad: bool = False, | |
) -> torch.Tensor: | |
"""Gather the features specified by indices. | |
Args: | |
features: tensor of shape [*, N, F] | |
indices: long tensor of shape [*, N, K] | |
transposed: whether to transpose the output tensor | |
sparse_grad: whether to use a sparse tensor for the gradient | |
Returns: | |
gathered_features [*, N, K, F] (or [*, N, F, K] if transposed) | |
""" | |
if transposed: | |
features, indices = features.unsqueeze(-1), indices.unsqueeze(-2) | |
else: | |
features, indices = features.unsqueeze(-2), indices.unsqueeze(-1) | |
features, indices = torch.broadcast_tensors(features, indices) | |
return features.gather(dim=-3, index=indices, sparse_grad=sparse_grad) | |
# @torch.jit.script | |
def point_sparsity( | |
num_neighbors: int, | |
features: torch.Tensor, | |
neighbors: Optional[torch.Tensor] = None, | |
p_norm: float = 2, | |
) -> torch.Tensor: | |
"""Get the per-point sparsity in a neighborhood of features | |
sparsity is defined here as the average distance to a set of neighbors | |
for every point in `features`, we use the nearest neighbors in `neighbors` | |
Args: | |
num_neighbors: number of neighbors to consider | |
features: query points which we need their neighbors [*, N, F] | |
neighbors: set of neighborhood points (`features` if None) [*, M, F] | |
p_norm: distances are computed based on L_p norm | |
Returns: | |
sparsity [*, N] | |
""" | |
distances, _ = get_neighbors(num_neighbors, features, neighbors, p_norm) | |
return distances.mean(dim=-1) | |
# @torch.jit.script # cannot script this (generator and output is Union) | |
def weighted_sampling( | |
num_samples: int, | |
weights: torch.Tensor, | |
replacement: bool = False, | |
ordered: bool = False, | |
need_weights: bool = False, | |
generator: Optional[Union[int, torch.Generator]] = None, | |
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
"""Get sample indices from multinomial distribution given weights | |
Args: | |
num_samples: number of output samples | |
weights: tensor of non-negative finite weights (non-zero sum) [*, N] | |
replacement: whether to sample with replacement | |
ordered: whether to sort the indices by their weights | |
need_weights: whether to return the weights along with the indices | |
generator: optional random number generator (can provide int seed) | |
Returns: | |
(weights, indices) if `need_weights` else indices [*, `num_samples`] | |
""" | |
shape, weights = weights.shape[:-1], weights.unsqueeze(0).flatten(0, -2) | |
if isinstance(generator, int): | |
generator = torch.Generator(weights.device).manual_seed(generator) | |
# np.random.choice: https://github.com/pytorch/pytorch/pull/18624 | |
indices = weights.multinomial(num_samples, | |
replacement, | |
generator=generator) | |
if ordered or need_weights: | |
weights = weights.gather(dim=-1, index=indices) | |
if ordered: | |
weights, indices = weights.sort(dim=-1, descending=True) | |
if need_weights: | |
return weights.view(*shape, -1), indices.view(*shape, -1) | |
return indices.view(*shape, -1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment