Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active August 9, 2022 14:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save wassname/e47ecca09fe6c2ec0daac2d3236a1d6a to your computer and use it in GitHub Desktop.
Save wassname/e47ecca09fe6c2ec0daac2d3236a1d6a to your computer and use it in GitHub Desktop.
DILATE_cuda
"""
DILATE_fast
DILATE cuda implementation
DILATE: DIstortion Loss with shApe and tImE
WARNING:
- does NOT work for larger batch sizes
- if you're dumpster diving for loss functions in other peoples dirty gists, then you deserve what you get
- look at the TODO's
- I'm not responsible for broken devices, dead experiment, thermonuclear war, or you getting fired because
- YOU are choosing to use this code, and if you point the finger at me for messing up, I will laugh at you
by github.com/thompsonn
paper: DILATE: DIstortion Loss with shApe and tImE
url: https://gist.github.com/wassname/e47ecca09fe6c2ec0daac2d3236a1d6a/edit
"""
from numba import cuda, njit, prange
from torch.autograd import Function
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
import logging
# TODO move these around
def pairwise_distances(x, y, norm=2):
assert x.size(2) == y.size(2) == 1
return (x.transpose(1, 2) - y).pow(norm)
@njit(parallel=True)
def compute_dtw_cpu(D, w):
B, N, M = D.shape
R = np.full((B, N + 1, M + 1), np.inf)
R[:, 0, 0] = 0
A = np.zeros((B, N + 1, M + 1))
for k in prange(B):
Dk, Rk, Ak = D[k], R[k], A[k]
for j in range(1, M + 1):
for i in range(max(1, j - w), min(N, j + w) + 1):
Rk[i, j] = Dk[i - 1, j - 1] + min(Rk[i - 1, j - 1], Rk[i - 1, j], Rk[i, j - 1])
i, j = N, M
for s in range(2*max(N, M) - 2, 0, -1):
Ak[i, j] = 1
a, b, c = Rk[i - 1, j - 1], Rk[i - 1, j], Rk[i, j - 1]
if b > a < c:
i, j = i - 1, j - 1
elif b < c:
i, j = i - 1, j
else:
i, j = i, j - 1
return R[:, 1:, 1:], A[:, 1:, 1:]
@cuda.jit
def compute_dtw_cuda_kernel(D, w, R, A):
_, N, M = D.shape
# TODO: Support larger sizes?
tx = cuda.threadIdx.x
ty = cuda.blockIdx.x
Dk, Rk, Ak = D[ty], R[ty], A[ty]
i = tx + 1
for s in range(2*max(N, M) - 1):
j = s - tx + 1
# TODO: Check bandwidth
if min(i, j) >= 0 and i <= N and j <= M and abs(i - j) <= w:
Rk[i, j] = Dk[i - 1, j - 1] + min(Rk[i - 1, j - 1], Rk[i - 1, j], Rk[i, j - 1])
cuda.syncthreads()
i, j = N, M
for s in range(2*max(N, M) - 2, 0, -1):
# TODO: Check this
if tx == 0:
Ak[i, j] = 1
a, b, c = Rk[i - 1, j - 1], Rk[i - 1, j], Rk[i, j - 1]
if b > a < c:
i, j = i - 1, j - 1
elif b < c:
i, j = i - 1, j
else:
i, j = i, j - 1
def compute_dtw_cuda(D, w):
B, N, M = D.shape
R = torch.full((B, N + 1, M + 1), np.inf, device=D.device)
R[:, 0, 0] = 0
A = torch.zeros((B, N + 1, M + 1), device=D.device)
compute_dtw_cuda_kernel[B, max(N, M)](cuda.as_cuda_array(D.detach()), w, cuda.as_cuda_array(R), cuda.as_cuda_array(A))
return R[:, 1:, 1:], A[:, 1:, 1:]
class ComputeDTW(Function):
@staticmethod
def forward(ctx, D, w, λ):
ctx.w, ctx.λ = w, λ
D = D.detach()
_, A = compute_dtw_cuda(D, w)
ctx.save_for_backward(D, A)
return A
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.detach()
D, A = ctx.saved_tensors
Dʹ = torch.clamp(D + ctx.λ*grad_output, min=0)
# print(f'z={(grad_output < 0).sum()}')
_, Aλ = compute_dtw_cuda(Dʹ, ctx.w)
# breakpoint()
gradient = -(A - Aλ)/ctx.λ
# print(f'nz={(gradient.abs().sum((1, 2)) > 0).sum()}')
# print(f'pct={gradient.abs().sum((1, 2))/(4*A.size(1))}')
return gradient, None, None
def diff(x):
return x[:, 1:] - x[:, :-1]
class Dilate(object):
"""
Shape and time distortion loss.
- α: weightning between time and shape loss (0,1).
- w: bandwidth, max deviation in time steps.
- When applying dynamic time warping how many time steps can DTW adjust by. For example you may have data with time increments of 1 hour. You want it to match to the current day, but not other days, so you set it to 6 (hours).
- λ: from blackbox backprop paper, geneally same magnitude as D. Set to False to automatically set it using exponential weighted averaging
- β: weighted for exponential averaging amount for λ, if enabled
Usage:
loss_fn = Dilate(
w=6, α=0.5
)
pred = torch.rand((32, 24, 1)) # Batch, Time, Channels
true = torch.rand((32, 24, 1))
loss = loss_fn(pred, true)
loss
"""
def __init__(self, w:float, α:float=0.5, λ=False, β:float=0.99):
self.w = w
self.α = α
self.auto_λ = (λ is None) or (λ is False)
self.λ = 20.0 if self.auto_λ else λ
print('auto_λ', self.auto_λ, self.λ, λ)
self.β = β
def __call__(self, x, y):
# D = pairwise_distances(diff(x), diff(y))/(x.size(0) - 1)
D = pairwise_distances(x, y)/x.size(0)
idx = torch.linspace(0, 1, D.size(1), device=x.device).view(1, -1, 1)
Ω = pairwise_distances(idx, idx, norm=2).repeat(x.size(0), 1, 1)
# D = torch.sigmoid(8*(D - 1/2))
A = ComputeDTW.apply(D, self.w, self.λ)
if self.auto_λ:
with torch.no_grad():
# Idea here is that lambda is a hyperparam that should be approx same magnitude as D
# Here we try exp weighted mean to set it automatically. But this is experimental;
# May make more send in log domain
denom = ((1 - self.α) * A * Ω).sum() + 1e-5
λʹ = (self.α * A * D).sum() / denom
# logger.info(f"λʹ={λʹ:2.2f} λ={self.λ:2.2f} {denom:2.2f}")
if not torch.isfinite(λʹ).all() or λʹ > self.λ*100:
logging.warning(f"λʹ unstable, auto_λ unstable")
λʹ = λʹ.cpu().numpy().clip(1e-5, self.λ*100)
self.λ = self.β*self.λ + (1 - self.β)*λʹ
L = (A*(self.α*D + (1 - self.α)*Ω)).sum((1, 2))
# L = (A.sum((1, 2)) - D.size(1))/D.size(1)
# I = torch.diag(torch.ones(D.size(1))).unsqueeze(0).repeat(x.size(0), 1, 1).to(x.device)
# L = (A - I).abs().mean((1, 2))
# L = (A*Ω).sum((1, 2))
return L.mean()
@wassname
Copy link
Author

@thomsonn might as well share this, feel free to move it to your account

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment