Last active
February 10, 2023 17:56
-
-
Save crowsonkb/140b5c433b205edb969e80bfc1b85730 to your computer and use it in GitHub Desktop.
JAX implementation of the 2D DWT and IDWT.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""JAX implementation of the 2D DWT and IDWT.""" | |
from einops import rearrange | |
import jax | |
import jax.numpy as jnp | |
import pywt | |
def get_filter_bank(wavelet, dtype=jnp.float32): | |
"""Get the filter bank for a given pywavelets wavelet name.""" | |
return jnp.array(pywt.Wavelet(wavelet).filter_bank, dtype) | |
def make_kernel(lo, hi): | |
"""Make a 2D convolution kernel from 1D lowpass and highpass filters.""" | |
lo = jnp.flip(lo) | |
hi = jnp.flip(hi) | |
ll = jnp.outer(lo, lo) | |
lh = jnp.outer(hi, lo) | |
hl = jnp.outer(lo, hi) | |
hh = jnp.outer(hi, hi) | |
kernel = jnp.stack([ll, lh, hl, hh], 0) | |
kernel = jnp.expand_dims(kernel, 1) | |
return kernel | |
def wavelet_dec_once(x, filt, channels): | |
"""Do one level of the DWT.""" | |
low, high = x[..., :channels], x[..., channels:] | |
kernel = make_kernel(filt[0], filt[1]) | |
kernel = jnp.tile(kernel, [channels, 1, 1, 1]) | |
n = kernel.shape[-1] - 1 | |
lo, hi = n // 2, n // 2 + n % 2 | |
low = jnp.pad(low, ((0, 0), (lo, hi), (lo, hi), (0, 0)), "wrap") | |
low = jax.lax.conv_general_dilated( | |
lhs=low, | |
rhs=kernel, | |
window_strides=(2, 2), | |
padding=((0, 0), (0, 0)), | |
dimension_numbers=("NHWC", "OIHW", "NHWC"), | |
feature_group_count=channels, | |
) | |
low = rearrange(low, "n h w (c1 c2) -> n h w (c2 c1)", c2=4) | |
high = rearrange( | |
high, "n (h h2) (w w2) (c c2) -> n h w (c h2 w2 c2)", h2=2, w2=2, c2=channels | |
) | |
x = jnp.concatenate([low, high], axis=-1) | |
return x | |
def wavelet_rec_once(x, filt, channels): | |
"""Do one level of the IDWT.""" | |
low, high = x[..., : channels * 4], x[..., channels * 4 :] | |
kernel = make_kernel(filt[2], filt[3]) | |
kernel = jnp.tile(kernel, [1, channels, 1, 1]) | |
n = kernel.shape[-1] | |
lo, hi = n // 2, n // 2 + n % 2 | |
lo_pre, hi_pre = lo // 2, lo // 2 + lo % 2 | |
lo_post, hi_post = lo_pre * 2, hi_pre * 2 | |
low = rearrange(low, "n h w (c1 c2) -> n h w (c2 c1)", c1=4) | |
low = jnp.pad(low, ((0, 0), (lo_pre, hi_pre), (lo_pre, hi_pre), (0, 0)), "wrap") | |
low = jax.lax.conv_general_dilated( | |
lhs=low, | |
rhs=kernel, | |
window_strides=(1, 1), | |
padding=((lo, hi), (lo, hi)), | |
lhs_dilation=(2, 2), | |
dimension_numbers=("NHWC", "IOHW", "NHWC"), | |
feature_group_count=channels, | |
) | |
low = low[:, lo_post:-hi_post, lo_post:-hi_post, :] | |
high = rearrange( | |
high, "n h w (c h2 w2 c2) -> n (h h2) (w w2) (c c2)", h2=2, w2=2, c2=channels | |
) | |
x = jnp.concatenate([low, high], axis=-1) | |
return x | |
def wavelet_dec(x, filt, levels): | |
"""Do the DWT for a given number of levels. | |
Args: | |
x: Input image (NHWC layout). | |
filt: Filter bank. | |
levels: Number of levels. | |
Returns: | |
The DWT coefficients, with shape | |
(N, H // 2 ** levels, W // 2 ** levels, C * 4 ** levels). | |
""" | |
channels = x.shape[-1] | |
for i in range(levels): | |
x = wavelet_dec_once(x, filt, channels) | |
return x | |
def wavelet_rec(x, filt, levels): | |
"""Do the IDWT for a given number of levels. | |
Args: | |
x: Input array of IDWT coefficients. | |
filt: Filter bank. | |
levels: Number of levels. | |
Returns: | |
The IDWT coefficients, with shape | |
(N, H * 2 ** levels, W * 2 ** levels, C // 4 ** levels). | |
""" | |
channels = x.shape[-1] // 4**levels | |
for i in reversed(range(levels)): | |
x = wavelet_rec_once(x, filt, channels) | |
return x | |
def unpack(x, levels): | |
"""Unpack the DWT coefficients into a pywavelets wavedec2() coefficients list. | |
Args: | |
x: Input array of DWT coefficients. | |
levels: Number of levels. | |
Returns: | |
A pywavelets wavedec2() coefficients list. | |
""" | |
channels = x.shape[-1] // 4**levels | |
y = [x[..., :channels]] | |
for i in range(levels): | |
y_cur = x[..., channels * 4**i : channels * 4 ** (i + 1)] | |
for j in range(i): | |
y_cur = rearrange( | |
y_cur, | |
"n h w (c h2 w2 c2) -> n (h h2) (w w2) (c c2)", | |
h2=2, | |
w2=2, | |
c2=channels, | |
) | |
y.append(tuple(jnp.split(y_cur, 3, axis=-1))) | |
return y | |
def pack(x): | |
"""Pack the pywavelets wavedec2() coefficients list into a DWT coefficients array. | |
Args: | |
x: Input pywavelets wavedec2() coefficients list. | |
Returns: | |
A DWT coefficients array. | |
""" | |
y = x[0] | |
for i in range(len(x) - 1): | |
y_cur = jnp.concatenate(x[i + 1], axis=-1) | |
for j in range(i): | |
y_cur = rearrange( | |
y_cur, | |
"n (h h2) (w w2) (c c2) -> n h w (c h2 w2 c2)", | |
h2=2, | |
w2=2, | |
c2=x[0].shape[-1], | |
) | |
y = jnp.concatenate([y, y_cur], axis=-1) | |
return y | |
def main(): | |
# Test case | |
import argparse | |
from functools import partial | |
import math | |
import time | |
import skimage | |
parser = argparse.ArgumentParser( | |
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
) | |
parser.add_argument( | |
"--batch-size", "-bs", type=int, default=1, help="the batch size" | |
) | |
parser.add_argument( | |
"--channels", "-c", type=int, default=3, help="the number of channels" | |
) | |
parser.add_argument( | |
"--dtype", type=jnp.dtype, default=jnp.dtype("float32"), help="the dtype" | |
) | |
parser.add_argument( | |
"--levels", type=int, default=3, help="the number of decomposition levels" | |
) | |
parser.add_argument("-n", type=int, default=100, help="the number of iterations") | |
parser.add_argument( | |
"--size", type=int, nargs=2, default=(512, 512), help="the image size" | |
) | |
parser.add_argument("--wavelet", type=str, default="bior4.4", help="the wavelet") | |
args = parser.parse_args() | |
print(f"Using dtype: {args.dtype}") | |
print(f"Number of decomposition levels: {args.levels}") | |
print(f"Number of iterations: {args.n}") | |
print(f"Using wavelet: {args.wavelet}") | |
filt = get_filter_bank(args.wavelet, args.dtype) | |
print(f"Kernel size: {filt.shape[1]}x{filt.shape[1]}") | |
x = skimage.util.img_as_float32(skimage.data.astronaut())[None] | |
x = jnp.array(x, args.dtype) | |
x = jax.image.resize( | |
x, | |
(x.shape[0], args.size[0], args.size[1], x.shape[3]), | |
jax.image.ResizeMethod.LINEAR, | |
) | |
x = jnp.tile(x, [args.batch_size, 1, 1, math.ceil(args.channels / x.shape[3])]) | |
x = x[..., : args.channels] | |
print(f"Input shape: {x.shape}") | |
# Benchmark DWT forward pass | |
jit_down = jax.jit(partial(wavelet_dec, filt=filt, levels=args.levels)) | |
y = jit_down(x) | |
start = time.time() | |
for i in range(args.n): | |
y = jit_down(x).block_until_ready() | |
time_taken = (time.time() - start) / args.n | |
print(f"Time for DWT forward: {time_taken:g} s/it ({1 / time_taken:g} it/s)") | |
# Benchmark IDWT forward pass | |
jit_up = jax.jit(partial(wavelet_rec, filt=filt, levels=args.levels)) | |
z = jit_up(y) | |
start = time.time() | |
for i in range(args.n): | |
z = jit_up(y).block_until_ready() | |
time_taken = (time.time() - start) / args.n | |
print(f"Time for IDWT forward: {time_taken:g} s/it ({1 / time_taken:g} it/s)") | |
# Benchmark DWT backward pass | |
vjp_down = jax.jit(jax.vjp(jit_down, x)[1]) | |
_ = vjp_down(y) | |
start = time.time() | |
for i in range(args.n): | |
_ = vjp_down(y)[0].block_until_ready() | |
time_taken = (time.time() - start) / args.n | |
print(f"Time for DWT backward: {time_taken:g} s/it ({1 / time_taken:g} it/s)") | |
# Benchmark IDWT backward pass | |
vjp_up = jax.jit(jax.vjp(jit_up, y)[1]) | |
_ = vjp_up(z) | |
start = time.time() | |
for i in range(args.n): | |
_ = vjp_up(z)[0].block_until_ready() | |
time_taken = (time.time() - start) / args.n | |
print(f"Time for IDWT backward: {time_taken:g} s/it ({1 / time_taken:g} it/s)") | |
# Compute reconstruction error | |
mse = jnp.mean(jnp.square(jnp.float32(x - z))) | |
rms = jnp.sqrt(mse) | |
psnr = -10 * jnp.log10(mse) | |
print(f"RMS reconstruction error: {rms.item():g} (PSNR: {psnr.item():g} dB)") | |
# Compare with PyWavelets | |
y_pywt = pywt.wavedec2( | |
x, args.wavelet, level=args.levels, mode="periodization", axes=(1, 2) | |
) | |
y_unpack = unpack(y, args.levels) | |
sq_norms = jax.tree_map(lambda x, y: jnp.sum((x - y) ** 2), y_pywt, y_unpack) | |
mse = jax.tree_util.tree_reduce(jnp.add, sq_norms) / x.size | |
rms = jnp.sqrt(mse) | |
psnr = -10 * jnp.log10(mse) | |
print(f"RMS diff from pywavelets: {rms.item():g} (PSNR: {psnr.item():g} dB)") | |
# Test pack() | |
y_pack = pack(y_unpack) | |
assert jnp.all(y == y_pack) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment