Last active June 24, 2020 17:44
very scruffy script to show case siren networks
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from ranger import Ranger
import numpy as np
import random
import cv2
import math
import kornia
from functools import partial
laplace_filter = partial(kornia.filters.laplacian, kernel_size=3)
gradient_filter = kornia.filters.sobel
def make_dataset(filname=None, centered=False):
img = cv2.imread('dali.jpg')
img = cv2.pyrDown(img)
height, width, c = img.shape
xv, yv = torch.meshgrid([torch.linspace(0,1,height), torch.linspace(0,1,width)])
xv = xv.contiguous()
yv = yv.contiguous()
y = torch.from_numpy(img)/255.0
g = kornia.spatial_gradient(y.permute(2,0,1)[None])
g = g[0].permute(2,3,0,1).view(-1,2,3).permute(0,2,1)
#g = g.view(-1, c)
y = y.view(-1, c)
# centering
# y = (y*2)-1
x =[xv.view(-1)[:,None], yv.view(-1)[:,None]], dim=1)
return x, y, g, height, width
def siren_init(tensor, use_this_fan_in=None):
Siren initalization of a tensor. To initialize a nn.Module use 'apply_siren_init'.
It's equivalent to torch.nn.init.kaiming_uniform_ with mode = 'fan_in'
and the same gain as the 'ReLU' nonlinearity
if use_this_fan_in is not None:
fan_in = use_this_fan_in
fan_in = nn.init._calculate_correct_fan(tensor, "fan_in")
bound = math.sqrt(6.0 / fan_in)
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def apply_siren_init(layer: nn.Module):
Applies siren initialization to a layer
if layer.bias is not None:
fan_in = nn.init._calculate_correct_fan(layer.weight, "fan_in")
siren_init(layer.bias, use_this_fan_in=fan_in)
class SinusLayer(nn.Module):
def __init__(self, cin, cout, w0=1):
super(SinusLayer, self).__init__()
self.linear = nn.Linear(cin, cout)
#special init
#wi ∼ U(−c/√n, c/√n)
# apply_siren_init(self.linear)
fan_in = nn.init._calculate_correct_fan(self.linear.weight, "fan_in")
bound = math.sqrt(6.0 / fan_in)
self.w0 = torch.nn.Parameter(torch.tensor(w0).float())
torch.nn.init.uniform_(self.linear.weight, -bound, bound)
def forward(self, x):
return torch.sin(self.linear(self.w0*x))
class Siren(nn.Module):
def __init__(self, cin=2, cout=3, hiddens=[128,64,64,32,16]):
super(Siren, self).__init__()
self.prepare = SinusLayer(cin, hiddens[0], w0=30)
self.residuals = nn.ModuleList()
last = hiddens[0]
for i in range(1,len(hiddens)):
v = hiddens[i]
self.residuals.append(SinusLayer(last, v))
last = v
self.out = nn.Linear(last, cout)
# self.out = SinusLayer(last, 3)
def forward(self, x):
x = self.prepare(x)
for res in self.residuals:
x = res(x)
return torch.sigmoid(self.out(x))
def show_pred(pred, height, width, centered=False):
pred = pred.view(height, width, 3)
img =
if centered:
img += 1
img /= 2
img = (img-img.min())/(img.max()-img.min())
img = (img*255).astype(np.uint8)
return img
def jacobian_in_batch(y, x):
Compute the Jacobian matrix in batch form.
Return (B, D_y, D_x)
batch = y.shape[0]
single_y_size =[1:])
y = y.view(batch, -1)
vector = torch.ones(batch).to(y)
# Compute Jacobian row by row.
# dy_i / dx -> dy / dx
# (B, D) -> (B, 1, D) -> (B, D, D)
jac = [torch.autograd.grad(y[:, i], x,
create_graph=True)[0].view(batch, -1)
for i in range(single_y_size)]
jac = torch.stack(jac, dim=1)
return jac
batch_size = 1024
x, y, gradient, height, width = make_dataset()
img_dataset = show_pred(y.clone(), height, width)
N = len(x)
cuda = 1
net = Siren(cout=3,hiddens=[256,128,64,64,32])
# criterion = torch.nn.MSELoss()
criterion = torch.nn.SmoothL1Loss()
if cuda:
x = x.cuda()
y = y.cuda()
gradient = gradient.cuda()
#gradient = (gradient-gradient.mean())/(gradient.std()+1e-5)
#gradient = (gradient-gradient.min())/(gradient.max()-gradient.min())
gradient *= 7
opt = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-4)
# opt = Ranger(net.parameters(), lr=0.001)
idx = np.arange(0, len(y))
probas = [1./N] * N
probas = np.array(probas)
for epoch in range(100):
for i in range(0, N, batch_size):
jdx = np.random.choice(np.arange(0, len(y)), size=batch_size, p=probas)
#low = i
#high = (i+batch_size)%N
#jdx = idx[low:high]
#jdx = torch.from_numpy(jdx)
bx = x[jdx]
by = y[jdx]
bg = gradient[jdx] #ground truth gradient!
bx = Variable(bx, requires_grad=True)
out = net(bx)
# bx = Variable(x[jdx], requires_grad=True)
# numerical gradient (not analytical, really easy to bp through, but not exact)
# sizes = [width, height]
# grad = [None,None]
# for r in [-1,1]:
# for dim in [0,1]:
# bx2 = bx.clone()
# bx2[:,dim] += r/100
# o = net(bx2)
# if grad[dim] is None:
# grad[dim] = torch.zeros_like(o)
# grad[dim] += r * o
# grad_errors = ((bg[:,0] - grad[0])**2 + (bg[:,1] - grad[1])**2).mean(dim=1)
# jacob = jacobian_in_batch(out, bx)
# grad_errors = (jacob - bg)**2
# grad_errors = grad_errors.mean(dim=[1,2])
errors = (out-by)**2
errors = errors.mean(dim=1)
loss = errors.mean()
probas[jdx] = ( len(errors)
probas /= probas.sum()
if i%10 == 0:
print('loss1: ', loss.item())
pred = net(x)
img = show_pred(pred, height, width)
cv2.imshow('ground_truth', img_dataset)
cv2.imshow('prediction', img)
with torch.no_grad():
height2, width2 = height*2, width*2
xv, yv = torch.meshgrid([torch.linspace(0,1,height2), torch.linspace(0,1,width2)])
xv = xv.contiguous()
yv = yv.contiguous()
x2 =[xv.view(-1)[:,None], yv.view(-1)[:,None]], dim=1)
x2 = x2.cuda()
pred = net(x2)
img2 = show_pred(pred, height2, width2)
cv2.imshow('ground_truth', img_dataset)
cv2.imshow('ground_truth_x2_bicubic', cv2.pyrUp(img_dataset))
cv2.imshow('prediction_x2', img2)
cv2.imshow('prediction_x1', img)
