Skip to content

Instantly share code, notes, and snippets.

@fzimmermann89
Last active September 7, 2021 09:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fzimmermann89/1958d757c1eb4a7cb7d4e82fa49bf86c to your computer and use it in GitHub Desktop.
Save fzimmermann89/1958d757c1eb4a7cb7d4e82fa49bf86c to your computer and use it in GitHub Desktop.
lut
import torch
from torch import nn
from typing import Tuple, Callable
class LUT(nn.Module):
def __init__(self, f: Callable, dx: float, xrange: Tuple[float, float], mode: str = "linear"):
"""
LUT of values of a function
f: function to use, does not need to be differentiable
dx: resolution of x values
xrange: range of x values (xmin,xmax)
mode: interpolation mode, linear/cubic
TODO: cubic grad behaves badly at edges of range
"""
super().__init__()
with torch.no_grad():
xs = torch.arange(*xrange, dx)
self.LUT = torch.as_tensor(f(xs))
self.dx = dx
self.xrange = xrange
if mode not in ("linear", "cubic"):
raise ValueError("mode should be linear or cubic")
self.mode = mode
def forward(self, x):
frac = x / self.dx
frac = frac - torch.floor(frac)
frac = frac[(...,) + (None,) * (self.LUT.ndim - 1)]
if self.mode == "linear":
with torch.no_grad():
rounded = torch.floor((x - self.xrange[0]) / self.dx).long()
y1, y2 = (self.LUT[torch.clamp(rounded + offset, 0, self.LUT.shape[0] - 1), ...] for offset in (0, 1))
return torch.lerp(y1, y2, frac)
elif self.mode == "cubic":
with torch.no_grad():
rounded = torch.floor((x - self.xrange[0]) / self.dx).long()
y0, y1, y2, y3 = (self.LUT[torch.clamp(rounded + offset, 0, self.LUT.shape[0] - 1), ...] for offset in (-1, 0, 1, 2))
a0 = -0.5 * y0 + 1.5 * y1 - 1.5 * y2 + 0.5 * y3
a1 = y0 - 2.5 * y1 + 2 * y2 - 0.5 * y3
a2 = -0.5 * y0 + 0.5 * y2
a3 = y1
return a0 * frac ** 3 + a1 * frac ** 2 + a2 * frac + a3
if __name__ == "__main__":
import matplotlib.pyplot as plt
class FunctionWithNoBackward(torch.autograd.Function):
"""
Just an example of a Function that doesnt have a gradient - this will raise an Exception if backard is called on it
"""
@staticmethod
def forward(ctx, x):
result = torch.stack((x.sin(), x.cos())).T
return result
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError
f = FunctionWithNoBackward.apply
L = LUT(f, dx=0.25, xrange=(-5, 5), mode="linear")
x = torch.linspace(-6, 6, 10000).requires_grad_()
y = L(x)[:, 0]
torch.sum(y).backward()
plt.plot(x.detach(), y.detach())
plt.plot(x.detach(), x.grad)
plt.plot(x.detach(), torch.sin(x.detach()))
plt.plot(x.detach(), torch.cos(x.detach()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment