Skip to content

Instantly share code, notes, and snippets.

@fabianp
Created June 10, 2016 06:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fabianp/8b62d905ae254fadf763cf1909904e3a to your computer and use it in GitHub Desktop.
Save fabianp/8b62d905ae254fadf763cf1909904e3a to your computer and use it in GitHub Desktop.
1D total variation (also known as fussed lasso) proximal operator
from numba import njit
@njit
def prox_tv1d(w, stepsize):
"""
Parameters
----------
w: array
vector of coefficieents
stepsize: float
step size (sometimes denoted gamma) in proximal objective function
References
----------
Condat, Laurent. "A direct algorithm for 1D total variation denoising."
IEEE Signal Processing Letters (2013)
"""
width = w.size
# /to avoid invalid memory access to input[0] and invalid lambda values
if width > 0 and stepsize >= 0:
k, k0 = 0, 0 # k: current sample location, k0: beginning of current segment
umin = stepsize # u is the dual variable
umax = - stepsize
vmin = w[0] - stepsize
vmax = w[0] + stepsize # bounds for the segment's value
kplus = 0
kminus = 0 # last positions where umax=-lambda, umin=lambda, respectively
twolambda = 2.0 * stepsize # auxiliary variable
minlambda = -stepsize # auxiliary variable
while True: # simple loop, the exit test is inside
while k == width-1: # we use the right boundary condition
if umin < 0.0: # vmin is too high -> negative jump necessary
while True:
w[k0] = vmin
k0 += 1
if k0 > kminus:
break
k = k0
kminus = k
vmin = w[kminus]
umin = stepsize
umax = vmin + umin - vmax
elif umax > 0.0: # vmax is too low -> positive jump necessary
while True:
w[k0] = vmax
k0 += 1
if k0 > kplus:
break
k = k0
kplus = k
vmax = w[kplus]
umax = minlambda
umin = vmax + umax -vmin
else:
vmin += umin / (k-k0+1)
while True:
w[k0] = vmin
k0 += 1
if k0 > k:
break
return
umin += w[k + 1] - vmin
if umin < minlambda: # /*negative jump necessary*/
while True:
w[k0] = vmin
k0 += 1
if k0 > kminus:
break
k = k0
kminus = k
kplus = kminus
vmin = w[kplus]
vmax = vmin + twolambda
umin = stepsize
umax = minlambda
else:
umax += w[k + 1] - vmax
if umax > stepsize:
while True:
w[k0] = vmax
k0 += 1
if k0 > kplus:
break
k = k0
kminus = k
kplus = kminus
vmax = w[kplus]
vmin = vmax - twolambda
umin = stepsize
umax = minlambda
else: # /*no jump necessary, we continue*/
k += 1
if umin >= stepsize: # update of vmin
kminus = k
vmin += (umin - stepsize) / (kminus - k0 + 1)
umin = stepsize
if umax <= minlambda: # update of vmax
kplus = k
vmax += (umax + stepsize) / (kplus - k0 + 1)
umax = minlambda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment