Skip to content

Instantly share code, notes, and snippets.

@dongkwan-kim
Created July 15, 2021 13:24
Show Gist options
  • Save dongkwan-kim/d8dd9393b8d78e5708360a316171c08d to your computer and use it in GitHub Desktop.
Save dongkwan-kim/d8dd9393b8d78e5708360a316171c08d to your computer and use it in GitHub Desktop.
from typing import Dict, Any, List, Tuple
import torch
from torch import Tensor
def auto_index_select(value_tensor: Tensor, index_tensor: Tensor):
index_tensor = index_tensor.squeeze()
sizes = value_tensor.size()
for dim, dim_size in enumerate(sizes):
if index_tensor.size(0) == dim_size:
return torch.index_select(value_tensor, dim, index_tensor)
else:
raise IndexError
def auto_index_select_v2(value_tensor: Tensor, index_tensor_list: List[Tensor]):
for index_tensor in index_tensor_list:
if torch.is_tensor(index_tensor):
try:
return auto_index_select(value_tensor, index_tensor)
except IndexError:
pass
else:
raise IndexError
def sort_and_relabel(tensor_to_sort: Tensor,
tensors_to_relabel: List[Tensor],
tensors_to_follow: List[Tensor] = None,
tensors_to_follow_and_relabel: List[Tensor] = None,
index_for_relabel: Tensor = None,
max_num: int = None) -> (Tensor, Tensor, List, List, List):
tensors_to_relabel = tensors_to_relabel or []
tensors_to_follow = tensors_to_follow or []
tensors_to_follow_and_relabel = tensors_to_follow_and_relabel or []
SZ = tensor_to_sort.size()
sorted_tensor, sorted_idx = torch.sort(tensor_to_sort.squeeze())
sorted_tensor = sorted_tensor.view(SZ)
max_val_to_relabel = max_num or (max([t.max().item()
for t in (tensors_to_relabel + tensors_to_follow_and_relabel)]) + 1)
relabel_index = torch.full((max_val_to_relabel,), -1, dtype=torch.long)
# e.g., tensor_to_sort is about edges of [2, E]
if index_for_relabel is not None and index_for_relabel.dim() == 2:
index_for_relabel = auto_index_select(index_for_relabel, sorted_idx)
# e.g., [[2, 0], [1, 2]] --> [2, 1, 0, 2]
flatten_index_for_relabel = index_for_relabel.t().flatten()
# e.g., [2, 1, 0, 2] --> [0, 1, 2, 0]
reversely_stably_ordered_unique_index = torch.unique(flatten_index_for_relabel, sorted=False)
unique_index_size = reversely_stably_ordered_unique_index.size(0)
relabel_index[reversely_stably_ordered_unique_index] = (unique_index_size - 1) - torch.arange(unique_index_size)
if unique_index_size != max_val_to_relabel:
not_used_index = torch.nonzero(relabel_index < 0).flatten()
relabel_index[not_used_index] = torch.arange(not_used_index.size(0)) + unique_index_size
reversely_stably_ordered_unique_index = torch.cat([not_used_index,
reversely_stably_ordered_unique_index])
stably_ordered_unique_index = torch.flip(reversely_stably_ordered_unique_index, dims=[0])
# e.g., tensor_to_sort is about nodes of [N, F]
elif index_for_relabel is None:
relabel_index[sorted_idx] = torch.arange(sorted_idx.size(0))
stably_ordered_unique_index = None
else:
raise ValueError("index_for_relabel.dim() should not be {}".format(index_for_relabel.dim()))
relabeled_tensors = tuple([relabel_index[t] for t in tensors_to_relabel])
followed_tensors = tuple([auto_index_select_v2(t, [sorted_idx, stably_ordered_unique_index])
for t in tensors_to_follow])
followed_and_relabeled_tensors = tuple([relabel_index[auto_index_select(t, sorted_idx)]
for t in tensors_to_follow_and_relabel])
return sorted_tensor, relabel_index, relabeled_tensors, followed_tensors, followed_and_relabeled_tensors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment