Skip to content

Instantly share code, notes, and snippets.

@cheind
Created November 10, 2021 08:41
Show Gist options
  • Save cheind/8f795e49192bf1da394d099a287c74ff to your computer and use it in GitHub Desktop.
Save cheind/8f795e49192bf1da394d099a287c74ff to your computer and use it in GitHub Desktop.
(Batched) Sample Entropy in PyTorch for measuring complexities of time-series (see https://en.wikipedia.org/wiki/Sample_entropy)
import torch
def sample_entropy(
x: torch.Tensor, m: int = 2, r: float = None, stride: int = 1, subsample: int = 1
):
"""Returns the (batched) sample entropy of the given time series.
Sample entropy is a measure of complexity of sequences that can be related
to predictability. Sample entropy (SE) is defined as the negative logarithm of
the following ratio:
SE(X,m,r) = -ln(C(X, m+1, r) / C(X, m, r))
where C(X,m,r) is the number of partial vectors of length m in sequence X whose
Chebyshev distance is less than r.
Note, `0 <= SE >= -ln(2/[(T-m-1)(T-m)])`, where T is the sequence length
Based on
Richman, J. S., & Moorman, J. R. (2000). Physiological time-series analysis
using approximate entropy and sample entropy.
Params
------
x: (B,T) tensor
Batched time-series
m: int
Embedding length
r: float
Distance threshold, if None then will be computed as `0.2std(x)`
stride: int
Step between embedding vectors
subsample: int
Reduce the number of possible vectors of length m.
Returns
-------
SE: (B,) tensor
Sample entropy for each sequence
"""
x = torch.atleast_2d(x)
B, T = x.shape
if r is None:
r = torch.std(x) * 0.2
def _num_close(elen: int):
unf = x.unfold(1, elen, stride) # B,N,elen
if subsample > 1:
unf = unf[:, ::subsample, :]
N = unf.shape[1]
d = torch.cdist(unf, unf, p=float("inf")) # B,N,N
idx = torch.triu_indices(N, N, 1) # take pairwise distances excl. diagonal
C = (d[:, idx[0], idx[1]] < r).sum(-1) # B
return C
A = _num_close(m + 1)
B = _num_close(m)
# Exception handling, return upper bound. No regularities found
mask = torch.logical_or(A == 0, B == 0)
A[mask] = 2.0
B[mask] = (T - m - 1) * (T - m)
return -torch.log(A / B)
import torch
import complexity
def test_sample_entropy():
# Uniform random
x = torch.rand(10, 1024)
se = complexity.sample_entropy(x)
assert se.mean() >= 2.0
# Straight lines
x = torch.arange(2 ** 12).float()
se = complexity.sample_entropy(x).mean()
assert abs(se) < 1e-3
# Sine
x = torch.sin(torch.linspace(0, 10 * 3.145, 2 ** 12))
se = complexity.sample_entropy(x).mean()
assert se < 0.2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment