Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active February 10, 2023 17:56
Show Gist options
  • Save crowsonkb/140b5c433b205edb969e80bfc1b85730 to your computer and use it in GitHub Desktop.
Save crowsonkb/140b5c433b205edb969e80bfc1b85730 to your computer and use it in GitHub Desktop.
JAX implementation of the 2D DWT and IDWT.
"""JAX implementation of the 2D DWT and IDWT."""
from einops import rearrange
import jax
import jax.numpy as jnp
import pywt
def get_filter_bank(wavelet, dtype=jnp.float32):
"""Get the filter bank for a given pywavelets wavelet name."""
return jnp.array(pywt.Wavelet(wavelet).filter_bank, dtype)
def make_kernel(lo, hi):
"""Make a 2D convolution kernel from 1D lowpass and highpass filters."""
lo = jnp.flip(lo)
hi = jnp.flip(hi)
ll = jnp.outer(lo, lo)
lh = jnp.outer(hi, lo)
hl = jnp.outer(lo, hi)
hh = jnp.outer(hi, hi)
kernel = jnp.stack([ll, lh, hl, hh], 0)
kernel = jnp.expand_dims(kernel, 1)
return kernel
def wavelet_dec_once(x, filt, channels):
"""Do one level of the DWT."""
low, high = x[..., :channels], x[..., channels:]
kernel = make_kernel(filt[0], filt[1])
kernel = jnp.tile(kernel, [channels, 1, 1, 1])
n = kernel.shape[-1] - 1
lo, hi = n // 2, n // 2 + n % 2
low = jnp.pad(low, ((0, 0), (lo, hi), (lo, hi), (0, 0)), "wrap")
low = jax.lax.conv_general_dilated(
lhs=low,
rhs=kernel,
window_strides=(2, 2),
padding=((0, 0), (0, 0)),
dimension_numbers=("NHWC", "OIHW", "NHWC"),
feature_group_count=channels,
)
low = rearrange(low, "n h w (c1 c2) -> n h w (c2 c1)", c2=4)
high = rearrange(
high, "n (h h2) (w w2) (c c2) -> n h w (c h2 w2 c2)", h2=2, w2=2, c2=channels
)
x = jnp.concatenate([low, high], axis=-1)
return x
def wavelet_rec_once(x, filt, channels):
"""Do one level of the IDWT."""
low, high = x[..., : channels * 4], x[..., channels * 4 :]
kernel = make_kernel(filt[2], filt[3])
kernel = jnp.tile(kernel, [1, channels, 1, 1])
n = kernel.shape[-1]
lo, hi = n // 2, n // 2 + n % 2
lo_pre, hi_pre = lo // 2, lo // 2 + lo % 2
lo_post, hi_post = lo_pre * 2, hi_pre * 2
low = rearrange(low, "n h w (c1 c2) -> n h w (c2 c1)", c1=4)
low = jnp.pad(low, ((0, 0), (lo_pre, hi_pre), (lo_pre, hi_pre), (0, 0)), "wrap")
low = jax.lax.conv_general_dilated(
lhs=low,
rhs=kernel,
window_strides=(1, 1),
padding=((lo, hi), (lo, hi)),
lhs_dilation=(2, 2),
dimension_numbers=("NHWC", "IOHW", "NHWC"),
feature_group_count=channels,
)
low = low[:, lo_post:-hi_post, lo_post:-hi_post, :]
high = rearrange(
high, "n h w (c h2 w2 c2) -> n (h h2) (w w2) (c c2)", h2=2, w2=2, c2=channels
)
x = jnp.concatenate([low, high], axis=-1)
return x
def wavelet_dec(x, filt, levels):
"""Do the DWT for a given number of levels.
Args:
x: Input image (NHWC layout).
filt: Filter bank.
levels: Number of levels.
Returns:
The DWT coefficients, with shape
(N, H // 2 ** levels, W // 2 ** levels, C * 4 ** levels).
"""
channels = x.shape[-1]
for i in range(levels):
x = wavelet_dec_once(x, filt, channels)
return x
def wavelet_rec(x, filt, levels):
"""Do the IDWT for a given number of levels.
Args:
x: Input array of IDWT coefficients.
filt: Filter bank.
levels: Number of levels.
Returns:
The IDWT coefficients, with shape
(N, H * 2 ** levels, W * 2 ** levels, C // 4 ** levels).
"""
channels = x.shape[-1] // 4**levels
for i in reversed(range(levels)):
x = wavelet_rec_once(x, filt, channels)
return x
def unpack(x, levels):
"""Unpack the DWT coefficients into a pywavelets wavedec2() coefficients list.
Args:
x: Input array of DWT coefficients.
levels: Number of levels.
Returns:
A pywavelets wavedec2() coefficients list.
"""
channels = x.shape[-1] // 4**levels
y = [x[..., :channels]]
for i in range(levels):
y_cur = x[..., channels * 4**i : channels * 4 ** (i + 1)]
for j in range(i):
y_cur = rearrange(
y_cur,
"n h w (c h2 w2 c2) -> n (h h2) (w w2) (c c2)",
h2=2,
w2=2,
c2=channels,
)
y.append(tuple(jnp.split(y_cur, 3, axis=-1)))
return y
def pack(x):
"""Pack the pywavelets wavedec2() coefficients list into a DWT coefficients array.
Args:
x: Input pywavelets wavedec2() coefficients list.
Returns:
A DWT coefficients array.
"""
y = x[0]
for i in range(len(x) - 1):
y_cur = jnp.concatenate(x[i + 1], axis=-1)
for j in range(i):
y_cur = rearrange(
y_cur,
"n (h h2) (w w2) (c c2) -> n h w (c h2 w2 c2)",
h2=2,
w2=2,
c2=x[0].shape[-1],
)
y = jnp.concatenate([y, y_cur], axis=-1)
return y
def main():
# Test case
import argparse
from functools import partial
import math
import time
import skimage
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--batch-size", "-bs", type=int, default=1, help="the batch size"
)
parser.add_argument(
"--channels", "-c", type=int, default=3, help="the number of channels"
)
parser.add_argument(
"--dtype", type=jnp.dtype, default=jnp.dtype("float32"), help="the dtype"
)
parser.add_argument(
"--levels", type=int, default=3, help="the number of decomposition levels"
)
parser.add_argument("-n", type=int, default=100, help="the number of iterations")
parser.add_argument(
"--size", type=int, nargs=2, default=(512, 512), help="the image size"
)
parser.add_argument("--wavelet", type=str, default="bior4.4", help="the wavelet")
args = parser.parse_args()
print(f"Using dtype: {args.dtype}")
print(f"Number of decomposition levels: {args.levels}")
print(f"Number of iterations: {args.n}")
print(f"Using wavelet: {args.wavelet}")
filt = get_filter_bank(args.wavelet, args.dtype)
print(f"Kernel size: {filt.shape[1]}x{filt.shape[1]}")
x = skimage.util.img_as_float32(skimage.data.astronaut())[None]
x = jnp.array(x, args.dtype)
x = jax.image.resize(
x,
(x.shape[0], args.size[0], args.size[1], x.shape[3]),
jax.image.ResizeMethod.LINEAR,
)
x = jnp.tile(x, [args.batch_size, 1, 1, math.ceil(args.channels / x.shape[3])])
x = x[..., : args.channels]
print(f"Input shape: {x.shape}")
# Benchmark DWT forward pass
jit_down = jax.jit(partial(wavelet_dec, filt=filt, levels=args.levels))
y = jit_down(x)
start = time.time()
for i in range(args.n):
y = jit_down(x).block_until_ready()
time_taken = (time.time() - start) / args.n
print(f"Time for DWT forward: {time_taken:g} s/it ({1 / time_taken:g} it/s)")
# Benchmark IDWT forward pass
jit_up = jax.jit(partial(wavelet_rec, filt=filt, levels=args.levels))
z = jit_up(y)
start = time.time()
for i in range(args.n):
z = jit_up(y).block_until_ready()
time_taken = (time.time() - start) / args.n
print(f"Time for IDWT forward: {time_taken:g} s/it ({1 / time_taken:g} it/s)")
# Benchmark DWT backward pass
vjp_down = jax.jit(jax.vjp(jit_down, x)[1])
_ = vjp_down(y)
start = time.time()
for i in range(args.n):
_ = vjp_down(y)[0].block_until_ready()
time_taken = (time.time() - start) / args.n
print(f"Time for DWT backward: {time_taken:g} s/it ({1 / time_taken:g} it/s)")
# Benchmark IDWT backward pass
vjp_up = jax.jit(jax.vjp(jit_up, y)[1])
_ = vjp_up(z)
start = time.time()
for i in range(args.n):
_ = vjp_up(z)[0].block_until_ready()
time_taken = (time.time() - start) / args.n
print(f"Time for IDWT backward: {time_taken:g} s/it ({1 / time_taken:g} it/s)")
# Compute reconstruction error
mse = jnp.mean(jnp.square(jnp.float32(x - z)))
rms = jnp.sqrt(mse)
psnr = -10 * jnp.log10(mse)
print(f"RMS reconstruction error: {rms.item():g} (PSNR: {psnr.item():g} dB)")
# Compare with PyWavelets
y_pywt = pywt.wavedec2(
x, args.wavelet, level=args.levels, mode="periodization", axes=(1, 2)
)
y_unpack = unpack(y, args.levels)
sq_norms = jax.tree_map(lambda x, y: jnp.sum((x - y) ** 2), y_pywt, y_unpack)
mse = jax.tree_util.tree_reduce(jnp.add, sq_norms) / x.size
rms = jnp.sqrt(mse)
psnr = -10 * jnp.log10(mse)
print(f"RMS diff from pywavelets: {rms.item():g} (PSNR: {psnr.item():g} dB)")
# Test pack()
y_pack = pack(y_unpack)
assert jnp.all(y == y_pack)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment