Last active
December 23, 2021 15:33
-
-
Save spezold/42a451682422beb42bc43ad0c0967a30 to your computer and use it in GitHub Desktop.
*Update (2020-10-28)*: with the introduction of torch.quantile (https://pytorch.org/docs/stable/generated/torch.quantile.html) in PyTorch 1.7 this gist has become largely obsolete – Calculate percentile of a PyTorch tensor's values, similar to numpy.percentile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Union | |
import torch | |
import numpy as np | |
def percentile(t: torch.tensor, q: float) -> Union[int, float]: | |
""" | |
Return the ``q``-th percentile of the flattened input tensor's data. | |
CAUTION: | |
* Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. | |
* Values are not interpolated, which corresponds to | |
``numpy.percentile(..., interpolation="nearest")``. | |
:param t: Input tensor. | |
:param q: Percentile to compute, which must be between 0 and 100 inclusive. | |
:return: Resulting value (scalar). | |
""" | |
# Note that ``kthvalue()`` works one-based, i.e. the first sorted value | |
# indeed corresponds to k=1, not k=0! Use float(q) instead of q directly, | |
# so that ``round()`` returns an integer, even if q is a np.float32. | |
k = 1 + round(.01 * float(q) * (t.numel() - 1)) | |
result = t.view(-1).kthvalue(k).values.item() | |
return result | |
if __name__ == "__main__": | |
q_s = np.r_[0, 100 * np.random.uniform(size=8), 100.] | |
a = np.random.uniform(size=(3, 4, 5)) | |
t = torch.from_numpy(a) | |
for q in q_s: | |
p_t = percentile(t, q) | |
p_a = np.percentile(a, q, interpolation="nearest") | |
print("q={}, PyTorch result: {}".format(q, p_t)) | |
print("q={}, NumPy result: {}".format(q, p_a)) | |
assert p_t == p_a |
As one-liners:
percentile = lambda t, q: t.view(-1).kthvalue(1 + round(.01 * float(q) * (t.numel() - 1))).values.item()
quantile = lambda t, q: t.view(-1).kthvalue(1 + round(float(q) * (t.numel() - 1))).values.item()
Thanks for the pointer to torch.quantile... works great!
how to make this code work along a specific axis
@deathbymath Note that the gist is outdated and you can (and maybe should) use torch.quantile
instead.
In any case, both torch.quantile
(which is what I would recommend) and torch.kthvalue
(which I use in the code above) have the dim
parameter to apply code along a certain axis only. To make the gist work this way, you would need to
- calculate
k
along the desired dimension (i.e. replacet.numel()
witht.shape[d]
for the desired axis/dimensiond
), and - replace
t.view(-1).kthvalue(k)
witht.kthvalue(k, dim=d)
.
I didn't test this, so maybe more adjustments would be necessary.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
To get a substitute for
numpy.quantile()
instead, simply leave out the factor.01
in the calculation ofk
.