Skip to content

Instantly share code, notes, and snippets.

@etienne87
Last active June 24, 2020 17:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save etienne87/e65b6bb2493213f436bf4a5b43b943ca to your computer and use it in GitHub Desktop.
Save etienne87/e65b6bb2493213f436bf4a5b43b943ca to your computer and use it in GitHub Desktop.
very scruffy script to show case siren networks
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from ranger import Ranger
import numpy as np
import random
import cv2
import math
import kornia
from functools import partial
laplace_filter = partial(kornia.filters.laplacian, kernel_size=3)
gradient_filter = kornia.filters.sobel
def make_dataset(filname=None, centered=False):
img = cv2.imread('dali.jpg')
img = cv2.pyrDown(img)
height, width, c = img.shape
xv, yv = torch.meshgrid([torch.linspace(0,1,height), torch.linspace(0,1,width)])
xv = xv.contiguous()
yv = yv.contiguous()
y = torch.from_numpy(img)/255.0
g = kornia.spatial_gradient(y.permute(2,0,1)[None])
g = g[0].permute(2,3,0,1).view(-1,2,3).permute(0,2,1)
#g = g.view(-1, c)
y = y.view(-1, c)
# centering
# y = (y*2)-1
x = torch.cat([xv.view(-1)[:,None], yv.view(-1)[:,None]], dim=1)
return x, y, g, height, width
def siren_init(tensor, use_this_fan_in=None):
"""
Siren initalization of a tensor. To initialize a nn.Module use 'apply_siren_init'.
It's equivalent to torch.nn.init.kaiming_uniform_ with mode = 'fan_in'
and the same gain as the 'ReLU' nonlinearity
"""
if use_this_fan_in is not None:
fan_in = use_this_fan_in
else:
fan_in = nn.init._calculate_correct_fan(tensor, "fan_in")
bound = math.sqrt(6.0 / fan_in)
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def apply_siren_init(layer: nn.Module):
"""
Applies siren initialization to a layer
"""
siren_init(layer.weight)
if layer.bias is not None:
fan_in = nn.init._calculate_correct_fan(layer.weight, "fan_in")
siren_init(layer.bias, use_this_fan_in=fan_in)
class SinusLayer(nn.Module):
def __init__(self, cin, cout, w0=1):
super(SinusLayer, self).__init__()
self.linear = nn.Linear(cin, cout)
#special init
#wi ∼ U(−c/√n, c/√n)
# apply_siren_init(self.linear)
fan_in = nn.init._calculate_correct_fan(self.linear.weight, "fan_in")
bound = math.sqrt(6.0 / fan_in)
self.w0 = torch.nn.Parameter(torch.tensor(w0).float())
torch.nn.init.uniform_(self.linear.weight, -bound, bound)
def forward(self, x):
return torch.sin(self.linear(self.w0*x))
class Siren(nn.Module):
def __init__(self, cin=2, cout=3, hiddens=[128,64,64,32,16]):
super(Siren, self).__init__()
self.prepare = SinusLayer(cin, hiddens[0], w0=30)
self.residuals = nn.ModuleList()
last = hiddens[0]
for i in range(1,len(hiddens)):
v = hiddens[i]
self.residuals.append(SinusLayer(last, v))
last = v
self.out = nn.Linear(last, cout)
# self.out = SinusLayer(last, 3)
def forward(self, x):
x = self.prepare(x)
for res in self.residuals:
x = res(x)
return torch.sigmoid(self.out(x))
def show_pred(pred, height, width, centered=False):
pred = pred.view(height, width, 3)
img = pred.data.cpu().numpy()
if centered:
img += 1
img /= 2
img = (img-img.min())/(img.max()-img.min())
img = (img*255).astype(np.uint8)
return img
def jacobian_in_batch(y, x):
'''
Compute the Jacobian matrix in batch form.
Return (B, D_y, D_x)
'''
batch = y.shape[0]
single_y_size = np.prod(y.shape[1:])
y = y.view(batch, -1)
vector = torch.ones(batch).to(y)
# Compute Jacobian row by row.
# dy_i / dx -> dy / dx
# (B, D) -> (B, 1, D) -> (B, D, D)
jac = [torch.autograd.grad(y[:, i], x,
grad_outputs=vector,
retain_graph=True,
create_graph=True)[0].view(batch, -1)
for i in range(single_y_size)]
jac = torch.stack(jac, dim=1)
return jac
batch_size = 1024
x, y, gradient, height, width = make_dataset()
img_dataset = show_pred(y.clone(), height, width)
N = len(x)
cuda = 1
net = Siren(cout=3,hiddens=[256,128,64,64,32])
# criterion = torch.nn.MSELoss()
criterion = torch.nn.SmoothL1Loss()
if cuda:
x = x.cuda()
y = y.cuda()
gradient = gradient.cuda()
#gradient = (gradient-gradient.mean())/(gradient.std()+1e-5)
#gradient = (gradient-gradient.min())/(gradient.max()-gradient.min())
gradient *= 7
net.cuda()
criterion.cuda()
net.train()
opt = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-4)
# opt = Ranger(net.parameters(), lr=0.001)
idx = np.arange(0, len(y))
random.shuffle(idx)
probas = [1./N] * N
probas = np.array(probas)
for epoch in range(100):
for i in range(0, N, batch_size):
jdx = np.random.choice(np.arange(0, len(y)), size=batch_size, p=probas)
#low = i
#high = (i+batch_size)%N
#jdx = idx[low:high]
#jdx = torch.from_numpy(jdx)
bx = x[jdx]
by = y[jdx]
bg = gradient[jdx] #ground truth gradient!
opt.zero_grad()
bx = Variable(bx, requires_grad=True)
out = net(bx)
# bx = Variable(x[jdx], requires_grad=True)
# numerical gradient (not analytical, really easy to bp through, but not exact)
# sizes = [width, height]
# grad = [None,None]
# for r in [-1,1]:
# for dim in [0,1]:
# bx2 = bx.clone()
# bx2[:,dim] += r/100
# o = net(bx2)
# if grad[dim] is None:
# grad[dim] = torch.zeros_like(o)
# grad[dim] += r * o
# grad_errors = ((bg[:,0] - grad[0])**2 + (bg[:,1] - grad[1])**2).mean(dim=1)
# jacob = jacobian_in_batch(out, bx)
# grad_errors = (jacob - bg)**2
# grad_errors = grad_errors.mean(dim=[1,2])
errors = (out-by)**2
errors = errors.mean(dim=1)
loss = errors.mean()
probas[jdx] = (errors.data.cpu().numpy())/ len(errors)
probas /= probas.sum()
loss.backward()
opt.step()
if i%10 == 0:
print('loss1: ', loss.item())
#showcase
pred = net(x)
img = show_pred(pred, height, width)
cv2.imshow('ground_truth', img_dataset)
cv2.imshow('prediction', img)
cv2.waitKey(5)
net.eval()
with torch.no_grad():
height2, width2 = height*2, width*2
xv, yv = torch.meshgrid([torch.linspace(0,1,height2), torch.linspace(0,1,width2)])
xv = xv.contiguous()
yv = yv.contiguous()
x2 = torch.cat([xv.view(-1)[:,None], yv.view(-1)[:,None]], dim=1)
x2 = x2.cuda()
pred = net(x2)
img2 = show_pred(pred, height2, width2)
cv2.imshow('ground_truth', img_dataset)
cv2.imshow('ground_truth_x2_bicubic', cv2.pyrUp(img_dataset))
cv2.imshow('prediction_x2', img2)
cv2.imshow('prediction_x1', img)
cv2.waitKey(0)
@etienne87
Copy link
Author

mickey_prediction

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment