Last active
September 7, 2021 09:47
-
-
Save fzimmermann89/1958d757c1eb4a7cb7d4e82fa49bf86c to your computer and use it in GitHub Desktop.
lut
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from typing import Tuple, Callable | |
class LUT(nn.Module): | |
def __init__(self, f: Callable, dx: float, xrange: Tuple[float, float], mode: str = "linear"): | |
""" | |
LUT of values of a function | |
f: function to use, does not need to be differentiable | |
dx: resolution of x values | |
xrange: range of x values (xmin,xmax) | |
mode: interpolation mode, linear/cubic | |
TODO: cubic grad behaves badly at edges of range | |
""" | |
super().__init__() | |
with torch.no_grad(): | |
xs = torch.arange(*xrange, dx) | |
self.LUT = torch.as_tensor(f(xs)) | |
self.dx = dx | |
self.xrange = xrange | |
if mode not in ("linear", "cubic"): | |
raise ValueError("mode should be linear or cubic") | |
self.mode = mode | |
def forward(self, x): | |
frac = x / self.dx | |
frac = frac - torch.floor(frac) | |
frac = frac[(...,) + (None,) * (self.LUT.ndim - 1)] | |
if self.mode == "linear": | |
with torch.no_grad(): | |
rounded = torch.floor((x - self.xrange[0]) / self.dx).long() | |
y1, y2 = (self.LUT[torch.clamp(rounded + offset, 0, self.LUT.shape[0] - 1), ...] for offset in (0, 1)) | |
return torch.lerp(y1, y2, frac) | |
elif self.mode == "cubic": | |
with torch.no_grad(): | |
rounded = torch.floor((x - self.xrange[0]) / self.dx).long() | |
y0, y1, y2, y3 = (self.LUT[torch.clamp(rounded + offset, 0, self.LUT.shape[0] - 1), ...] for offset in (-1, 0, 1, 2)) | |
a0 = -0.5 * y0 + 1.5 * y1 - 1.5 * y2 + 0.5 * y3 | |
a1 = y0 - 2.5 * y1 + 2 * y2 - 0.5 * y3 | |
a2 = -0.5 * y0 + 0.5 * y2 | |
a3 = y1 | |
return a0 * frac ** 3 + a1 * frac ** 2 + a2 * frac + a3 | |
if __name__ == "__main__": | |
import matplotlib.pyplot as plt | |
class FunctionWithNoBackward(torch.autograd.Function): | |
""" | |
Just an example of a Function that doesnt have a gradient - this will raise an Exception if backard is called on it | |
""" | |
@staticmethod | |
def forward(ctx, x): | |
result = torch.stack((x.sin(), x.cos())).T | |
return result | |
@staticmethod | |
def backward(ctx, grad_output): | |
raise NotImplementedError | |
f = FunctionWithNoBackward.apply | |
L = LUT(f, dx=0.25, xrange=(-5, 5), mode="linear") | |
x = torch.linspace(-6, 6, 10000).requires_grad_() | |
y = L(x)[:, 0] | |
torch.sum(y).backward() | |
plt.plot(x.detach(), y.detach()) | |
plt.plot(x.detach(), x.grad) | |
plt.plot(x.detach(), torch.sin(x.detach())) | |
plt.plot(x.detach(), torch.cos(x.detach())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment