Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active April 7, 2021 03:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/60fe0d0e3536adc28778448419908f47 to your computer and use it in GitHub Desktop.
Save xmodar/60fe0d0e3536adc28778448419908f47 to your computer and use it in GitHub Desktop.
from typing import Tuple, Optional, Union, List
import torch
import torch.nn as nn
__all__ = [
'dot', 'get_neighbors', 'gather_features', 'point_sparsity',
'weighted_sampling'
]
# @torch.jit.script
def dot(inputs: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
"""Batched version of torch.dot (the last dimensions must match)"""
return (inputs.unsqueeze(-2) @ other.unsqueeze(-1)).flatten(-3)
# @torch.jit.script
def get_neighbors(
num_neighbors: int,
features: torch.Tensor,
neighbors: Optional[torch.Tensor] = None,
p_norm: float = 2,
farthest: bool = False,
ordered: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the distances and indices to a fixed number of neighbors
Args:
num_neighbors: number of neighbors to consider
features: query points which we need their neighbors [*, N, F]
neighbors: set of neighborhood points (`features` if None) [*, M, F]
p_norm: distances are computed based on L_p norm
farthest: whether to get the farthest or the nearest neighbors
ordered: distance sorted (descending if `farthest` ascending otherwise)
Returns:
(distances, indices) both of shape [*, N, `num_neighbors`]
"""
if neighbors is None:
neighbors = features
distance = torch.cdist(features, neighbors, p_norm)
return distance.topk(num_neighbors,
dim=-1,
largest=farthest,
sorted=ordered)
# @torch.jit.script
def gather_features(
features: torch.Tensor,
indices: torch.Tensor,
transposed: bool = False,
# TODO: investigate whether we should set sparse_grad to True by default
sparse_grad: bool = False,
) -> torch.Tensor:
"""Gather the features specified by indices.
Args:
features: tensor of shape [*, N, F]
indices: long tensor of shape [*, N, K]
transposed: whether to transpose the output tensor
sparse_grad: whether to use a sparse tensor for the gradient
Returns:
gathered_features [*, N, K, F] (or [*, N, F, K] if transposed)
"""
if transposed:
features, indices = features.unsqueeze(-1), indices.unsqueeze(-2)
else:
features, indices = features.unsqueeze(-2), indices.unsqueeze(-1)
features, indices = torch.broadcast_tensors(features, indices)
return features.gather(dim=-3, index=indices, sparse_grad=sparse_grad)
# @torch.jit.script
def point_sparsity(
num_neighbors: int,
features: torch.Tensor,
neighbors: Optional[torch.Tensor] = None,
p_norm: float = 2,
) -> torch.Tensor:
"""Get the per-point sparsity in a neighborhood of features
sparsity is defined here as the average distance to a set of neighbors
for every point in `features`, we use the nearest neighbors in `neighbors`
Args:
num_neighbors: number of neighbors to consider
features: query points which we need their neighbors [*, N, F]
neighbors: set of neighborhood points (`features` if None) [*, M, F]
p_norm: distances are computed based on L_p norm
Returns:
sparsity [*, N]
"""
distances, _ = get_neighbors(num_neighbors, features, neighbors, p_norm)
return distances.mean(dim=-1)
# @torch.jit.script # cannot script this (generator and output is Union)
def weighted_sampling(
num_samples: int,
weights: torch.Tensor,
replacement: bool = False,
ordered: bool = False,
need_weights: bool = False,
generator: Optional[Union[int, torch.Generator]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Get sample indices from multinomial distribution given weights
Args:
num_samples: number of output samples
weights: tensor of non-negative finite weights (non-zero sum) [*, N]
replacement: whether to sample with replacement
ordered: whether to sort the indices by their weights
need_weights: whether to return the weights along with the indices
generator: optional random number generator (can provide int seed)
Returns:
(weights, indices) if `need_weights` else indices [*, `num_samples`]
"""
shape, weights = weights.shape[:-1], weights.unsqueeze(0).flatten(0, -2)
if isinstance(generator, int):
generator = torch.Generator(weights.device).manual_seed(generator)
# np.random.choice: https://github.com/pytorch/pytorch/pull/18624
indices = weights.multinomial(num_samples,
replacement,
generator=generator)
if ordered or need_weights:
weights = weights.gather(dim=-1, index=indices)
if ordered:
weights, indices = weights.sort(dim=-1, descending=True)
if need_weights:
return weights.view(*shape, -1), indices.view(*shape, -1)
return indices.view(*shape, -1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment