Skip to content

Instantly share code, notes, and snippets.

@spezold
Last active December 23, 2021 15:33
Show Gist options
  • Save spezold/42a451682422beb42bc43ad0c0967a30 to your computer and use it in GitHub Desktop.
Save spezold/42a451682422beb42bc43ad0c0967a30 to your computer and use it in GitHub Desktop.
*Update (2020-10-28)*: with the introduction of torch.quantile (https://pytorch.org/docs/stable/generated/torch.quantile.html) in PyTorch 1.7 this gist has become largely obsolete – Calculate percentile of a PyTorch tensor's values, similar to numpy.percentile
from typing import Union
import torch
import numpy as np
def percentile(t: torch.tensor, q: float) -> Union[int, float]:
"""
Return the ``q``-th percentile of the flattened input tensor's data.
CAUTION:
* Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
* Values are not interpolated, which corresponds to
``numpy.percentile(..., interpolation="nearest")``.
:param t: Input tensor.
:param q: Percentile to compute, which must be between 0 and 100 inclusive.
:return: Resulting value (scalar).
"""
# Note that ``kthvalue()`` works one-based, i.e. the first sorted value
# indeed corresponds to k=1, not k=0! Use float(q) instead of q directly,
# so that ``round()`` returns an integer, even if q is a np.float32.
k = 1 + round(.01 * float(q) * (t.numel() - 1))
result = t.view(-1).kthvalue(k).values.item()
return result
if __name__ == "__main__":
q_s = np.r_[0, 100 * np.random.uniform(size=8), 100.]
a = np.random.uniform(size=(3, 4, 5))
t = torch.from_numpy(a)
for q in q_s:
p_t = percentile(t, q)
p_a = np.percentile(a, q, interpolation="nearest")
print("q={}, PyTorch result: {}".format(q, p_t))
print("q={}, NumPy result: {}".format(q, p_a))
assert p_t == p_a
@spezold
Copy link
Author

spezold commented May 24, 2019

To get a substitute for numpy.quantile() instead, simply leave out the factor .01 in the calculation of k.

@spezold
Copy link
Author

spezold commented Jun 25, 2020

As one-liners:

percentile = lambda t, q: t.view(-1).kthvalue(1 + round(.01 * float(q) * (t.numel() - 1))).values.item()
quantile = lambda t, q: t.view(-1).kthvalue(1 + round(float(q) * (t.numel() - 1))).values.item()

@DuaneNielsen
Copy link

Thanks for the pointer to torch.quantile... works great!

@deathbymath
Copy link

how to make this code work along a specific axis

@spezold
Copy link
Author

spezold commented Dec 23, 2021

@deathbymath Note that the gist is outdated and you can (and maybe should) use torch.quantile instead.

In any case, both torch.quantile (which is what I would recommend) and torch.kthvalue (which I use in the code above) have the dim parameter to apply code along a certain axis only. To make the gist work this way, you would need to

  1. calculate k along the desired dimension (i.e. replace t.numel() with t.shape[d] for the desired axis/dimension d), and
  2. replace t.view(-1).kthvalue(k) with t.kthvalue(k, dim=d).

I didn't test this, so maybe more adjustments would be necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment