Skip to content

Instantly share code, notes, and snippets.

@ixxra
Created December 5, 2014 17:08
Show Gist options
  • Save ixxra/875203bbc56b13dc4ced to your computer and use it in GitHub Desktop.
Save ixxra/875203bbc56b13dc4ced to your computer and use it in GitHub Desktop.
See "Nonlinear total variation based noise removal algorithms" by Rudin, Osherand, Fatemi
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 5 07:49:22 2014
@author: isra
"""
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
def minmod(a,b):
return (np.sign(a) + np.sign(b)) * np.minimum(np.abs(a), np.abs(b)) / 2
def delta_plus_x(u):
du = np.zeros(u.shape)
du[1:-1,1:-1] = u[2:,1:-1] - u[1:-1,1:-1]
return du
def delta_minus_x(u):
du = np.zeros(u.shape)
du[1:-1,1:-1] = u[1:-1,1:-1] - u[:-2,1:-1]
return du
def delta_plus_y(u):
du = np.zeros(u.shape)
du[1:-1,1:-1] = u[1:-1,2:] - u[1:-1,1:-1]
return du
def delta_minus_y(u):
du = np.zeros(u.shape)
du[1:-1,1:-1] = u[1:-1,1:-1] - u[1:-1,:-2]
return du
def lamb(u_n, u_0, n, h, sigma):
sq_eps = 1e-10
square_norm = sq_eps + (delta_plus_x(u_n))**2 + (delta_plus_y(u_n))**2
lam = square_norm.copy()
lam -= delta_plus_x(u_0) * delta_plus_x(u_n)
lam -= delta_plus_y(u_0) * delta_plus_y(u_n)
lam /= np.sqrt(square_norm)
return -h / (2 * sigma**2) * np.sum(lam)
def fwd(u_n, u_0, dt, n, h, sigma, lapl):
sq_eps = 1e-10
quotient1 = np.sqrt(sq_eps + delta_plus_x(u_n)**2 + \
minmod(delta_plus_y(u_n), delta_minus_y(u_n))**2)
quotient2 = np.sqrt(sq_eps + delta_plus_y(u_n)**2 + \
minmod(delta_plus_x(u_n), delta_minus_x(u_n))**2)
laplacian = delta_minus_x(delta_plus_x(u_n) / quotient1)
laplacian += delta_minus_y(delta_plus_y(u_n)) / quotient2
laplacian *= dt / h
laplacian -= dt * lamb(u_n, u_0, n, h, sigma) * (u_n - u_0)
u_n += laplacian
u_n[0,:] = u_n[1,:]
u_n[-1,:] = u_n[-2,:]
u_n[:,0] = u_n[:,1]
u_n[:,-1] = u_n[:,-2]
lapl.append(np.std(laplacian))
sigma = 1.0
dt = 1e-15
h = 1e-5
lapl = []
assert dt/h**2 < 1, 'CFL failed'
from scipy import misc
u_0 = misc.imread('image.png')[:,:,0].astype(np.float64)
u_n = u_0.copy()
for n in range(2000):
fwd(u_n, u_0, dt, n, h, sigma, lapl)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment