Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active October 4, 2022 13:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/91245a26b10c02498f6adfd8f73007f7 to your computer and use it in GitHub Desktop.
Save wassname/91245a26b10c02498f6adfd8f73007f7 to your computer and use it in GitHub Desktop.
ordered quantile loss for machine learning
"""
Sometimes we want to use quantiles loss in machine learning, but the outputs are not ordered. This is sometimes called the quantile crossover problem.
Surely it would help to impose the constraint that the quantiles must be ordered?
What's the best way to do this?
Well it seems me that we should predict differences from the median,
and apply a softplus to make sure the differences are only in one direction.
Note this will NOT work for very small target values. Because we are using a softplus the model must output very large
logits to get very small numbers. This means it will have difficulty with small y values.
This is a mixed blessing as it avoid collapses of the quantiles, but it means you must scale your y values.
I've also applied smooth l1 to the quantile loss as it seems to improve performance in the presence of large and small errors.
There's not much thought on this topic that I can find. Perhaps I didn't use the right keywords,
but let me know how it works for you, or if you have suggestions.
author: wassname
url: https://gist.github.com/wassname/91245a26b10c02498f6adfd8f73007f7
license: WTFPL
References:
- https://github.com/jdb78/pytorch-forecasting/blob/master/pytorch_forecasting/metrics/quantile.py#L9
- https://github.com/maxmarketit/Auto-PyTorch/blob/develop/examples/quantiles/Quantiles.ipynb
- http://ethen8181.github.io/machine-learning/ab_tests/quantile_regression/quantile_regression.html
- alternative approach using a penalty https://arxiv.org/pdf/1909.12122.pdf
"""
import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from torch import nn
import torch.nn.functional as F
class QuantileLoss(nn.Module):
"""
Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calcualted as
Defined as ``max(q * (y-y_pred_logits), (1-q) * (y_pred_logits-y))``
Usage:
qloss = QuantileLoss()
nb_q = len(qloss.quantiles)
x = torch.rand((200, nb_q))*0
y = torch.arange(-100, 100, 1)[:, None]/50
l = qloss.smooth_l1_loss(x, y)
"""
def __init__(
self,
quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
**kwargs,
):
"""
Quantile loss
Args:
quantiles: quantiles for metric
"""
super().__init__()
quantiles = torch.tensor(quantiles)
assert 0.5 in quantiles, 'should have median in it'
assert len(quantiles)%2==1, 'should be odd number of quantiles'
assert (torch.diff(quantiles)>0).all(), 'quantiles should be ordered'
self.quantiles = torch.tensor(quantiles)
self.med_i = self.quantiles.tolist().index(0.5)
# def order(self, y_pred_logits):
# """
# Ordered the outputs of a model.
# To make them ordered we will treat them as diffs on the central one. We will
# use a softplus to make sure they are always in one direction but cannot collapse to 0.
# """
# median = y_pred_logits[..., self.med_i][..., None]
# top = median + F.softplus(y_pred_logits[..., self.med_i+1:])
# bottom = median - F.softplus(y_pred_logits[..., :self.med_i])
# y_pred = torch.cat([bottom, median, top], -1)
# return y_pred
def order(self, y_pred):
"""
Order the outputs of a model.
We model the positive (softplus) differences relative to the median.
with input from @thomsonn
"""
median = y_pred[..., self.med_i][..., None]
top = median + F.softplus(y_pred[..., self.med_i+1:]).cumsum(-1)
# reverse cumsum for the bottomq
bottom = F.softplus(y_pred[..., :self.med_i])
bottom = torch.flip(bottom, [-1]).cumsum(-1)
bottom = median - torch.flip(bottom, [-1])
return torch.cat([bottom, median, top], -1)
def _loss(self, y_pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
pinball loss
see: http://ethen8181.github.io/machine-learning/ab_tests/quantile_regression/quantile_regression.html
"""
y_pred = self.order(y_pred_logits)
errors = (target - y_pred)
return torch.max(self.quantiles * errors, (self.quantiles - 1) * errors)
def loss(self, y_pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
regular pinball loss, similar profile to MAE
"""
losses = 2 * self._loss(y_pred_logits, target)
return losses.mean()
def smooth_pinball_loss(self, x, y, a = 0.2):
"""
see https://arxiv.org/abs/1909.12122
"""
d = y - x
return self.quantiles*d + a*torch.log(1+torch.exp(-d/a))
def smooth_l1_loss(self, y_pred_logits, target, beta=1.0):
"""
pinball loss with smooth l1 scaling applied.
Parameters
----------
target : 1d ndarray
Target value.
y_pred_logits : 1d ndarray
Predicted value.
see https://mmdetection.readthedocs.io/en/latest/_modules/mmdet/models/losses/smooth_l1_loss.html
"""
diff = 2 * self._loss(y_pred_logits, target)
smooth_l1_loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
return smooth_l1_loss.mean(-1)
def to_median_value(self, y_pred_logits: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into apoint prediction using a robust median value.
Args:
y_pred_logits: prediction output of network
Returns:
torch.Tensor: point prediction
"""
# if y_pred_logits.ndim == 3:
y_pred_logits = y_pred_logits[..., self.med_i]
return y_pred_logits
# def to_expected_value(self, y_pred_logits: torch.Tensor) -> torch.Tensor:
# """
# Convert quantiles to expected value
# """
# y_pred = self.order(y_pred_logits)
# prob = torch.stack([1-self.quantiles, self.quantiles]).min(0).values#.expand_as(y_pred_logits)
# # Is this the right way to calculate EV when probabilities sum to more than one?
# ev = (y_pred*prob/prob.sum()).sum(-1)
# return ev
def to_device_as(self, x):
if not x.device==self.quantiles.device:
self.quantiles = self.quantiles.to(x.device)
def to_expected_value(self, y_pred_logits: torch.Tensor) -> torch.Tensor:
"""
Convert quantiles to expected value of the truncated distribution (quantile regression can't fit the 0% and 100% tails).
"""
y_pred = self.order(y_pred_logits)
self.to_device_as(y_pred)
# Perform a Riemann sum of the distribution
xs = self.quantiles[1:] - self.quantiles[:-1]
ys = (y_pred[..., 1:] + y_pred[..., :-1])/2
return (xs*ys).sum(-1)
def to_quantiles(self, y_pred_logits: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.
Args:
y_pred_logits: prediction output of network
Returns:
torch.Tensor: prediction quantiles
"""
y_pred_o = self.order(y_pred_logits)
return y_pred_o
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment