Last active
October 4, 2022 13:08
-
-
Save wassname/91245a26b10c02498f6adfd8f73007f7 to your computer and use it in GitHub Desktop.
ordered quantile loss for machine learning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Sometimes we want to use quantiles loss in machine learning, but the outputs are not ordered. This is sometimes called the quantile crossover problem. | |
Surely it would help to impose the constraint that the quantiles must be ordered? | |
What's the best way to do this? | |
Well it seems me that we should predict differences from the median, | |
and apply a softplus to make sure the differences are only in one direction. | |
Note this will NOT work for very small target values. Because we are using a softplus the model must output very large | |
logits to get very small numbers. This means it will have difficulty with small y values. | |
This is a mixed blessing as it avoid collapses of the quantiles, but it means you must scale your y values. | |
I've also applied smooth l1 to the quantile loss as it seems to improve performance in the presence of large and small errors. | |
There's not much thought on this topic that I can find. Perhaps I didn't use the right keywords, | |
but let me know how it works for you, or if you have suggestions. | |
author: wassname | |
url: https://gist.github.com/wassname/91245a26b10c02498f6adfd8f73007f7 | |
license: WTFPL | |
References: | |
- https://github.com/jdb78/pytorch-forecasting/blob/master/pytorch_forecasting/metrics/quantile.py#L9 | |
- https://github.com/maxmarketit/Auto-PyTorch/blob/develop/examples/quantiles/Quantiles.ipynb | |
- http://ethen8181.github.io/machine-learning/ab_tests/quantile_regression/quantile_regression.html | |
- alternative approach using a penalty https://arxiv.org/pdf/1909.12122.pdf | |
""" | |
import torch | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
from torch import nn | |
import torch.nn.functional as F | |
class QuantileLoss(nn.Module): | |
""" | |
Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calcualted as | |
Defined as ``max(q * (y-y_pred_logits), (1-q) * (y_pred_logits-y))`` | |
Usage: | |
qloss = QuantileLoss() | |
nb_q = len(qloss.quantiles) | |
x = torch.rand((200, nb_q))*0 | |
y = torch.arange(-100, 100, 1)[:, None]/50 | |
l = qloss.smooth_l1_loss(x, y) | |
""" | |
def __init__( | |
self, | |
quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98], | |
**kwargs, | |
): | |
""" | |
Quantile loss | |
Args: | |
quantiles: quantiles for metric | |
""" | |
super().__init__() | |
quantiles = torch.tensor(quantiles) | |
assert 0.5 in quantiles, 'should have median in it' | |
assert len(quantiles)%2==1, 'should be odd number of quantiles' | |
assert (torch.diff(quantiles)>0).all(), 'quantiles should be ordered' | |
self.quantiles = torch.tensor(quantiles) | |
self.med_i = self.quantiles.tolist().index(0.5) | |
# def order(self, y_pred_logits): | |
# """ | |
# Ordered the outputs of a model. | |
# To make them ordered we will treat them as diffs on the central one. We will | |
# use a softplus to make sure they are always in one direction but cannot collapse to 0. | |
# """ | |
# median = y_pred_logits[..., self.med_i][..., None] | |
# top = median + F.softplus(y_pred_logits[..., self.med_i+1:]) | |
# bottom = median - F.softplus(y_pred_logits[..., :self.med_i]) | |
# y_pred = torch.cat([bottom, median, top], -1) | |
# return y_pred | |
def order(self, y_pred): | |
""" | |
Order the outputs of a model. | |
We model the positive (softplus) differences relative to the median. | |
with input from @thomsonn | |
""" | |
median = y_pred[..., self.med_i][..., None] | |
top = median + F.softplus(y_pred[..., self.med_i+1:]).cumsum(-1) | |
# reverse cumsum for the bottomq | |
bottom = F.softplus(y_pred[..., :self.med_i]) | |
bottom = torch.flip(bottom, [-1]).cumsum(-1) | |
bottom = median - torch.flip(bottom, [-1]) | |
return torch.cat([bottom, median, top], -1) | |
def _loss(self, y_pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
""" | |
pinball loss | |
see: http://ethen8181.github.io/machine-learning/ab_tests/quantile_regression/quantile_regression.html | |
""" | |
y_pred = self.order(y_pred_logits) | |
errors = (target - y_pred) | |
return torch.max(self.quantiles * errors, (self.quantiles - 1) * errors) | |
def loss(self, y_pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
""" | |
regular pinball loss, similar profile to MAE | |
""" | |
losses = 2 * self._loss(y_pred_logits, target) | |
return losses.mean() | |
def smooth_pinball_loss(self, x, y, a = 0.2): | |
""" | |
see https://arxiv.org/abs/1909.12122 | |
""" | |
d = y - x | |
return self.quantiles*d + a*torch.log(1+torch.exp(-d/a)) | |
def smooth_l1_loss(self, y_pred_logits, target, beta=1.0): | |
""" | |
pinball loss with smooth l1 scaling applied. | |
Parameters | |
---------- | |
target : 1d ndarray | |
Target value. | |
y_pred_logits : 1d ndarray | |
Predicted value. | |
see https://mmdetection.readthedocs.io/en/latest/_modules/mmdet/models/losses/smooth_l1_loss.html | |
""" | |
diff = 2 * self._loss(y_pred_logits, target) | |
smooth_l1_loss = torch.where(diff < beta, 0.5 * diff * diff / beta, | |
diff - 0.5 * beta) | |
return smooth_l1_loss.mean(-1) | |
def to_median_value(self, y_pred_logits: torch.Tensor) -> torch.Tensor: | |
""" | |
Convert network prediction into apoint prediction using a robust median value. | |
Args: | |
y_pred_logits: prediction output of network | |
Returns: | |
torch.Tensor: point prediction | |
""" | |
# if y_pred_logits.ndim == 3: | |
y_pred_logits = y_pred_logits[..., self.med_i] | |
return y_pred_logits | |
# def to_expected_value(self, y_pred_logits: torch.Tensor) -> torch.Tensor: | |
# """ | |
# Convert quantiles to expected value | |
# """ | |
# y_pred = self.order(y_pred_logits) | |
# prob = torch.stack([1-self.quantiles, self.quantiles]).min(0).values#.expand_as(y_pred_logits) | |
# # Is this the right way to calculate EV when probabilities sum to more than one? | |
# ev = (y_pred*prob/prob.sum()).sum(-1) | |
# return ev | |
def to_device_as(self, x): | |
if not x.device==self.quantiles.device: | |
self.quantiles = self.quantiles.to(x.device) | |
def to_expected_value(self, y_pred_logits: torch.Tensor) -> torch.Tensor: | |
""" | |
Convert quantiles to expected value of the truncated distribution (quantile regression can't fit the 0% and 100% tails). | |
""" | |
y_pred = self.order(y_pred_logits) | |
self.to_device_as(y_pred) | |
# Perform a Riemann sum of the distribution | |
xs = self.quantiles[1:] - self.quantiles[:-1] | |
ys = (y_pred[..., 1:] + y_pred[..., :-1])/2 | |
return (xs*ys).sum(-1) | |
def to_quantiles(self, y_pred_logits: torch.Tensor) -> torch.Tensor: | |
""" | |
Convert network prediction into a quantile prediction. | |
Args: | |
y_pred_logits: prediction output of network | |
Returns: | |
torch.Tensor: prediction quantiles | |
""" | |
y_pred_o = self.order(y_pred_logits) | |
return y_pred_o |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment