Skip to content

Instantly share code, notes, and snippets.

@xvdp
Last active July 27, 2022 16:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xvdp/baf3b1610cedb2b7a2684a187cc2645f to your computer and use it in GitHub Desktop.
Save xvdp/baf3b1610cedb2b7a2684a187cc2645f to your computer and use it in GitHub Desktop.
cumdiv() and cumdif(): reciprocals to torch cumulative functions cumsum() and cumprod()
"""@xvdp
reciprocals for torch.cumsum and torch.cumprod
I noticed that torch has cumsum and cumprod but not their reciprocals
even thought cumdif and cumdiv have meanings in analysis and probability
and are routinely used.
Are these interpretations correct?
> cumsum can be thought of as a discrete integral
> cumdif as discrete derivative
> cumprod is useful as the joint probability of a sequence of events
> cumdiv then is the marginal probability along that sequence
cumprod and cumdiv could also be expressed as exp(cumsum(log(x))) and exp(cumdif(log(x)))
see test_explog_interpretation()
"""
from typing import Optional
import torch
from torch.types import _dtype
import torch.nn.functional as F
from torch import Tensor
# pylint: disable=no-member
# pylint: disable=invalid-name
# pylint: disable=suppressed-message
def cumdiv(x: Tensor,
dim: int = 0,
*,
keepsize: bool = True,
dtype: Optional[_dtype] = None,
**kwargs) -> Tensor:
""" inverse cumprod - returns same size output as input
similar to exp(cumdif(log(x)))
could be thought as undoing the chain rule of probability, ie, marginalizing probabilities
Args
x Tensor
axis | dim int [0]
keepsize bool [True]
True: x.shape == out.shape; in this case cumdiv is the reciprocal of cumprod
False: resuts in a common use of cumdiv: x[1:]/x[:-1] reducing size by 1
"""
denom, front_slice = _cuminv(x, dim, keepsize=keepsize, prod=1, **kwargs)
return x[front_slice].div(denom).to(dtype)
def cumdif(x: Tensor,
dim: int = 0,
*,
keepsize: bool = True,
dtype: Optional[_dtype] = None,
**kwargs) -> Tensor:
""" inverse cumsum - returns same size output as input
could be thought of as the derivative
Args
x Tensor
axis | dim int [0]
keepsize bool [True]
True: x.shape == out.shape; in this case cumdif is the reciprocal of cumsum
False: resuts in a common use of cumdiv: x[1:] - x[:-1] reducing size by 1
"""
prev, front_slice = _cuminv(x, dim, keepsize=keepsize, prod=0, **kwargs)
return x[front_slice].sub(prev).to(dtype)
def _cuminv(x: Tensor,
dim: int = 0,
*,
keepsize: bool = True,
prod: int = 1,
**kwargs) -> Tensor:
""" inverse function base for cumdif and cumdiv
"""
axis = kwargs.get('axis')
dim = dim if axis is None else axis
_back_slice = [slice(0, None, None)] * x.ndim
_back_slice[dim] = slice(0, -1, None)
front_slice = [slice(0, None, None)] * x.ndim
other = x[_back_slice]
if keepsize:
_pads = [0]*(x.ndim*2)
_pads[(x.ndim - 1 - dim)*2] = 1
other = F.pad(other, _pads, value=prod)
else:
front_slice[dim] = slice(1, None, None)
return other, front_slice
###
# tests that validate for 1,2,3 dimensions that these functions
# are reciprocals to the torch functions
#
def test_all():
x = test_cumdiv()
x = test_cumdif()
test_axis_dim(x)
test_explog_interpretation(x)
def test_cumdiv(device=None):
""" test that cumdiv is the reciprocal of cumprod
"""
x = torch.linspace(0.1,1,10).to(device=device)
assert torch.allclose(cumdiv(torch.cumprod(x, 0), 0), x)
x = torch.stack((x, x.flip(0)))
for i in range(x.ndim):
assert torch.allclose(cumdiv(torch.cumprod(x, i), i), x)
x = torch.stack((x, x*2))
for i in range(x.ndim):
assert torch.allclose(cumdiv(torch.cumprod(x, i), i), x)
# check keepsize=False; ~ x[1:] / x[:-1]
for i in range(x.ndim):
y = torch.cumprod(x, i)
z = cumdiv(y, i, keepsize=False)
_slice = [slice(0, None, None)]*i + [slice(1, None, None)]
assert torch.allclose(z, x[_slice])
return x
def test_cumdif(device=None):
""" test that cumdif is the reciprocal of cumsum
"""
x = torch.linspace(0.1,1,10).to(device=device)
assert torch.allclose(cumdif(torch.cumsum(x, 0), 0), x)
x = torch.stack((x, x.flip(0)))
for i in range(x.ndim):
assert torch.allclose(cumdif(torch.cumsum(x, i), i), x)
x = torch.stack((x, x*2))
for i in range(x.ndim):
assert torch.allclose(cumdif(torch.cumsum(x, i), i), x)
# check keepsize=False; ~ x[1:] - x[:-1]
for i in range(x.ndim):
y = torch.cumsum(x, i)
z = cumdif(y, i, keepsize=False)
_slice = [slice(0, None, None)]*i + [slice(1, None, None)]
assert torch.allclose(z, x[_slice])
return x
def test_axis_dim(x):
"""overload hack to match overloads in torch cumsum/cumprod
axis == dim, in this implementation axis overrides dim if both present
"""
for i in range(x.ndim):
assert torch.allclose(cumdif(x, dim=i), cumdif(x, axis=i))
def test_explog_interpretation(x):
""" mul(x) = exp(sum(log(x))"""
for i in range(x.ndim):
assert torch.allclose(torch.cumprod(x, i), torch.exp(torch.cumsum(torch.log(x), i)))
assert torch.allclose(cumdiv(x, i), torch.exp(cumdif(torch.log(x), i)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment