Created
July 15, 2021 13:24
-
-
Save dongkwan-kim/d8dd9393b8d78e5708360a316171c08d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Any, List, Tuple | |
import torch | |
from torch import Tensor | |
def auto_index_select(value_tensor: Tensor, index_tensor: Tensor): | |
index_tensor = index_tensor.squeeze() | |
sizes = value_tensor.size() | |
for dim, dim_size in enumerate(sizes): | |
if index_tensor.size(0) == dim_size: | |
return torch.index_select(value_tensor, dim, index_tensor) | |
else: | |
raise IndexError | |
def auto_index_select_v2(value_tensor: Tensor, index_tensor_list: List[Tensor]): | |
for index_tensor in index_tensor_list: | |
if torch.is_tensor(index_tensor): | |
try: | |
return auto_index_select(value_tensor, index_tensor) | |
except IndexError: | |
pass | |
else: | |
raise IndexError | |
def sort_and_relabel(tensor_to_sort: Tensor, | |
tensors_to_relabel: List[Tensor], | |
tensors_to_follow: List[Tensor] = None, | |
tensors_to_follow_and_relabel: List[Tensor] = None, | |
index_for_relabel: Tensor = None, | |
max_num: int = None) -> (Tensor, Tensor, List, List, List): | |
tensors_to_relabel = tensors_to_relabel or [] | |
tensors_to_follow = tensors_to_follow or [] | |
tensors_to_follow_and_relabel = tensors_to_follow_and_relabel or [] | |
SZ = tensor_to_sort.size() | |
sorted_tensor, sorted_idx = torch.sort(tensor_to_sort.squeeze()) | |
sorted_tensor = sorted_tensor.view(SZ) | |
max_val_to_relabel = max_num or (max([t.max().item() | |
for t in (tensors_to_relabel + tensors_to_follow_and_relabel)]) + 1) | |
relabel_index = torch.full((max_val_to_relabel,), -1, dtype=torch.long) | |
# e.g., tensor_to_sort is about edges of [2, E] | |
if index_for_relabel is not None and index_for_relabel.dim() == 2: | |
index_for_relabel = auto_index_select(index_for_relabel, sorted_idx) | |
# e.g., [[2, 0], [1, 2]] --> [2, 1, 0, 2] | |
flatten_index_for_relabel = index_for_relabel.t().flatten() | |
# e.g., [2, 1, 0, 2] --> [0, 1, 2, 0] | |
reversely_stably_ordered_unique_index = torch.unique(flatten_index_for_relabel, sorted=False) | |
unique_index_size = reversely_stably_ordered_unique_index.size(0) | |
relabel_index[reversely_stably_ordered_unique_index] = (unique_index_size - 1) - torch.arange(unique_index_size) | |
if unique_index_size != max_val_to_relabel: | |
not_used_index = torch.nonzero(relabel_index < 0).flatten() | |
relabel_index[not_used_index] = torch.arange(not_used_index.size(0)) + unique_index_size | |
reversely_stably_ordered_unique_index = torch.cat([not_used_index, | |
reversely_stably_ordered_unique_index]) | |
stably_ordered_unique_index = torch.flip(reversely_stably_ordered_unique_index, dims=[0]) | |
# e.g., tensor_to_sort is about nodes of [N, F] | |
elif index_for_relabel is None: | |
relabel_index[sorted_idx] = torch.arange(sorted_idx.size(0)) | |
stably_ordered_unique_index = None | |
else: | |
raise ValueError("index_for_relabel.dim() should not be {}".format(index_for_relabel.dim())) | |
relabeled_tensors = tuple([relabel_index[t] for t in tensors_to_relabel]) | |
followed_tensors = tuple([auto_index_select_v2(t, [sorted_idx, stably_ordered_unique_index]) | |
for t in tensors_to_follow]) | |
followed_and_relabeled_tensors = tuple([relabel_index[auto_index_select(t, sorted_idx)] | |
for t in tensors_to_follow_and_relabel]) | |
return sorted_tensor, relabel_index, relabeled_tensors, followed_tensors, followed_and_relabeled_tensors |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment