Last active
August 9, 2022 14:54
-
-
Save wassname/e47ecca09fe6c2ec0daac2d3236a1d6a to your computer and use it in GitHub Desktop.
DILATE_cuda
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
DILATE_fast | |
DILATE cuda implementation | |
DILATE: DIstortion Loss with shApe and tImE | |
WARNING: | |
- does NOT work for larger batch sizes | |
- if you're dumpster diving for loss functions in other peoples dirty gists, then you deserve what you get | |
- look at the TODO's | |
- I'm not responsible for broken devices, dead experiment, thermonuclear war, or you getting fired because | |
- YOU are choosing to use this code, and if you point the finger at me for messing up, I will laugh at you | |
by github.com/thompsonn | |
paper: DILATE: DIstortion Loss with shApe and tImE | |
url: https://gist.github.com/wassname/e47ecca09fe6c2ec0daac2d3236a1d6a/edit | |
""" | |
from numba import cuda, njit, prange | |
from torch.autograd import Function | |
import torch | |
from torch import nn | |
import numpy as np | |
import torch.nn.functional as F | |
import logging | |
# TODO move these around | |
def pairwise_distances(x, y, norm=2): | |
assert x.size(2) == y.size(2) == 1 | |
return (x.transpose(1, 2) - y).pow(norm) | |
@njit(parallel=True) | |
def compute_dtw_cpu(D, w): | |
B, N, M = D.shape | |
R = np.full((B, N + 1, M + 1), np.inf) | |
R[:, 0, 0] = 0 | |
A = np.zeros((B, N + 1, M + 1)) | |
for k in prange(B): | |
Dk, Rk, Ak = D[k], R[k], A[k] | |
for j in range(1, M + 1): | |
for i in range(max(1, j - w), min(N, j + w) + 1): | |
Rk[i, j] = Dk[i - 1, j - 1] + min(Rk[i - 1, j - 1], Rk[i - 1, j], Rk[i, j - 1]) | |
i, j = N, M | |
for s in range(2*max(N, M) - 2, 0, -1): | |
Ak[i, j] = 1 | |
a, b, c = Rk[i - 1, j - 1], Rk[i - 1, j], Rk[i, j - 1] | |
if b > a < c: | |
i, j = i - 1, j - 1 | |
elif b < c: | |
i, j = i - 1, j | |
else: | |
i, j = i, j - 1 | |
return R[:, 1:, 1:], A[:, 1:, 1:] | |
@cuda.jit | |
def compute_dtw_cuda_kernel(D, w, R, A): | |
_, N, M = D.shape | |
# TODO: Support larger sizes? | |
tx = cuda.threadIdx.x | |
ty = cuda.blockIdx.x | |
Dk, Rk, Ak = D[ty], R[ty], A[ty] | |
i = tx + 1 | |
for s in range(2*max(N, M) - 1): | |
j = s - tx + 1 | |
# TODO: Check bandwidth | |
if min(i, j) >= 0 and i <= N and j <= M and abs(i - j) <= w: | |
Rk[i, j] = Dk[i - 1, j - 1] + min(Rk[i - 1, j - 1], Rk[i - 1, j], Rk[i, j - 1]) | |
cuda.syncthreads() | |
i, j = N, M | |
for s in range(2*max(N, M) - 2, 0, -1): | |
# TODO: Check this | |
if tx == 0: | |
Ak[i, j] = 1 | |
a, b, c = Rk[i - 1, j - 1], Rk[i - 1, j], Rk[i, j - 1] | |
if b > a < c: | |
i, j = i - 1, j - 1 | |
elif b < c: | |
i, j = i - 1, j | |
else: | |
i, j = i, j - 1 | |
def compute_dtw_cuda(D, w): | |
B, N, M = D.shape | |
R = torch.full((B, N + 1, M + 1), np.inf, device=D.device) | |
R[:, 0, 0] = 0 | |
A = torch.zeros((B, N + 1, M + 1), device=D.device) | |
compute_dtw_cuda_kernel[B, max(N, M)](cuda.as_cuda_array(D.detach()), w, cuda.as_cuda_array(R), cuda.as_cuda_array(A)) | |
return R[:, 1:, 1:], A[:, 1:, 1:] | |
class ComputeDTW(Function): | |
@staticmethod | |
def forward(ctx, D, w, λ): | |
ctx.w, ctx.λ = w, λ | |
D = D.detach() | |
_, A = compute_dtw_cuda(D, w) | |
ctx.save_for_backward(D, A) | |
return A | |
@staticmethod | |
def backward(ctx, grad_output): | |
grad_output = grad_output.detach() | |
D, A = ctx.saved_tensors | |
Dʹ = torch.clamp(D + ctx.λ*grad_output, min=0) | |
# print(f'z={(grad_output < 0).sum()}') | |
_, Aλ = compute_dtw_cuda(Dʹ, ctx.w) | |
# breakpoint() | |
gradient = -(A - Aλ)/ctx.λ | |
# print(f'nz={(gradient.abs().sum((1, 2)) > 0).sum()}') | |
# print(f'pct={gradient.abs().sum((1, 2))/(4*A.size(1))}') | |
return gradient, None, None | |
def diff(x): | |
return x[:, 1:] - x[:, :-1] | |
class Dilate(object): | |
""" | |
Shape and time distortion loss. | |
- α: weightning between time and shape loss (0,1). | |
- w: bandwidth, max deviation in time steps. | |
- When applying dynamic time warping how many time steps can DTW adjust by. For example you may have data with time increments of 1 hour. You want it to match to the current day, but not other days, so you set it to 6 (hours). | |
- λ: from blackbox backprop paper, geneally same magnitude as D. Set to False to automatically set it using exponential weighted averaging | |
- β: weighted for exponential averaging amount for λ, if enabled | |
Usage: | |
loss_fn = Dilate( | |
w=6, α=0.5 | |
) | |
pred = torch.rand((32, 24, 1)) # Batch, Time, Channels | |
true = torch.rand((32, 24, 1)) | |
loss = loss_fn(pred, true) | |
loss | |
""" | |
def __init__(self, w:float, α:float=0.5, λ=False, β:float=0.99): | |
self.w = w | |
self.α = α | |
self.auto_λ = (λ is None) or (λ is False) | |
self.λ = 20.0 if self.auto_λ else λ | |
print('auto_λ', self.auto_λ, self.λ, λ) | |
self.β = β | |
def __call__(self, x, y): | |
# D = pairwise_distances(diff(x), diff(y))/(x.size(0) - 1) | |
D = pairwise_distances(x, y)/x.size(0) | |
idx = torch.linspace(0, 1, D.size(1), device=x.device).view(1, -1, 1) | |
Ω = pairwise_distances(idx, idx, norm=2).repeat(x.size(0), 1, 1) | |
# D = torch.sigmoid(8*(D - 1/2)) | |
A = ComputeDTW.apply(D, self.w, self.λ) | |
if self.auto_λ: | |
with torch.no_grad(): | |
# Idea here is that lambda is a hyperparam that should be approx same magnitude as D | |
# Here we try exp weighted mean to set it automatically. But this is experimental; | |
# May make more send in log domain | |
denom = ((1 - self.α) * A * Ω).sum() + 1e-5 | |
λʹ = (self.α * A * D).sum() / denom | |
# logger.info(f"λʹ={λʹ:2.2f} λ={self.λ:2.2f} {denom:2.2f}") | |
if not torch.isfinite(λʹ).all() or λʹ > self.λ*100: | |
logging.warning(f"λʹ unstable, auto_λ unstable") | |
λʹ = λʹ.cpu().numpy().clip(1e-5, self.λ*100) | |
self.λ = self.β*self.λ + (1 - self.β)*λʹ | |
L = (A*(self.α*D + (1 - self.α)*Ω)).sum((1, 2)) | |
# L = (A.sum((1, 2)) - D.size(1))/D.size(1) | |
# I = torch.diag(torch.ones(D.size(1))).unsqueeze(0).repeat(x.size(0), 1, 1).to(x.device) | |
# L = (A - I).abs().mean((1, 2)) | |
# L = (A*Ω).sum((1, 2)) | |
return L.mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@thomsonn might as well share this, feel free to move it to your account