Last active
November 25, 2021 05:38
-
-
Save louity/f0e5de7fee5791db265ef3ab83c2aeb3 to your computer and use it in GitHub Desktop.
Pytorch implementation of two dimensional type-I Discrete Sine Transform
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""DST I using FFT routines, Louis Thiry | |
Method 1 is 'naive' and used FFTs with twice bigger input signal. | |
Method 2 is more sophisticated and used iRFFT with half the input signal size. | |
The naive method 1 seems however to be more efficient, and JIT compilation is not key. | |
""" | |
import numpy as np | |
import scipy.fftpack | |
import torch | |
def dstI2D_M1(x, tmp_arr1, tmp_arr2): | |
M, N = x.shape[-2:] | |
tmp_arr1[...,1:N+1] = -x | |
tmp_arr1[...,N+2:] = torch.flip(x, dims=(-1,)) | |
dst1D_x = torch.fft.rfft(tmp_arr1, dim=-1)[...,1:-1].imag | |
tmp_arr2[...,1:M+1,:] = -dst1D_x | |
tmp_arr2[...,M+2:,:] = torch.flip(dst1D_x, dims=(-2,)) | |
return torch.fft.rfft(tmp_arr2, dim=-2)[...,1:-1,:].imag | |
def dstI2D_M2(x, sin_vec1, sin_vec2, tmp_arr1, tmp_arr2, j_cplx): | |
M, N = x.shape[-2:] | |
tmp_arr1[...,0] = 2*x[...,0] | |
tmp_arr1[...,-1] = -2*x[...,-1] | |
tmp_arr1[...,1:-1] = x[...,2::2] - x[...,:-2:2] - j_cplx*x[...,1:-1:2] | |
d = torch.fft.irfft(tmp_arr1, dim=-1, norm='forward') | |
d_flip = torch.flip(d[...,1:], dims=(-1,)) | |
dst1D_x = 0.5*(d[...,1:] - d_flip) + sin_vec1 * (d[...,1:] + d_flip) | |
tmp_arr2[...,0,:] = 2*dst1D_x[...,0,:] | |
tmp_arr2[...,-1,:] = -2*dst1D_x[...,-1,:] | |
tmp_arr2[...,1:-1,:] = dst1D_x[...,2::2,:] - dst1D_x[...,:-2:2,:] - j_cplx*dst1D_x[...,1:-1:2,:] | |
d = torch.fft.irfft(tmp_arr2, dim=-2, norm='forward') | |
d_flip = torch.flip(d[...,1:,:], dims=(-2,)) | |
dst2D_x = 0.5*(d[...,1:,:] - d_flip) + sin_vec2 * (d[...,1:,:] + d_flip) | |
return dst2D_x | |
if __name__ == '__main__': | |
M, N = 3, 7 | |
x = torch.from_numpy(np.random.rand(M, N)) | |
pi = torch.tensor(np.pi) | |
tmp_arr1 = torch.zeros(M, 2*N+2, dtype=x.dtype) | |
tmp_arr2 = torch.zeros(2*M+2, N, dtype=x.dtype) | |
sin_vec1 = 1 / (4*torch.sin(torch.arange(1,N+1, dtype=x.dtype)*pi/(N+1))).reshape((1,N)) | |
sin_vec2 = 1 / (4*torch.sin(torch.arange(1,M+1, dtype=x.dtype)*pi/(M+1))).reshape((M,1)) | |
cplx_dtype = (torch.complex64 if x.dtype == torch.float32 else torch.complex128) | |
tmp_arr3 = torch.zeros(M, (N+1)//2+1, dtype=cplx_dtype) | |
tmp_arr4 = torch.zeros((M+1)//2+1, N, dtype=cplx_dtype) | |
j_cplx = torch.tensor(1j, dtype=cplx_dtype) | |
sin_vec3 = 1 / (4*torch.sin(torch.arange(1,M+1, dtype=x.dtype)*pi/(M+1))).reshape((1,M)) | |
cplx_dtype = (torch.complex64 if x.dtype == torch.float32 else torch.complex128) | |
tmp_arr5 = torch.zeros(N, (M+1)//2+1, dtype=cplx_dtype) | |
with np.printoptions(suppress=True, precision=4): | |
print(f'input x={x.cpu().numpy()}') | |
print(f'scipy 2D DST-I={scipy.fft.dstn(x.cpu().numpy(), type=1, axes=(0,1))}') | |
print(f' 2x fft 2D DST-I={dstI2D_M1(x, tmp_arr1, tmp_arr2).numpy()}') | |
print(f'.5x ifft 2D DST-I={dstI2D_M2(x, sin_vec1, sin_vec2, tmp_arr3, tmp_arr4, j_cplx).numpy()}') | |
from timeit import Timer | |
import matplotlib.pyplot as plt | |
Ns = [63, 127, 255, 511, 1023, 2047] | |
numbers = [1000, 1000, 500, 250, 100, 100] | |
scipy_dst, M1, M2, M1_jit, M2_jit = [], [], [], [], [] | |
device = 'cpu' # 'cuda' | |
for N, number in zip(Ns, numbers): | |
print(N) | |
M = N | |
x = torch.from_numpy(np.random.rand(M, N)).to(device) | |
pi = torch.tensor(np.pi, dtype=x.dtype, device=x.device) | |
tmp_arr1 = torch.zeros(M, 2*N+2, dtype=x.dtype, device=x.device) | |
tmp_arr2 = torch.zeros(2*M+2, N, dtype=x.dtype, device=x.device) | |
sin_vec1 = 1 / (4*torch.sin(torch.arange(1,N+1, dtype=x.dtype, device=x.device)*pi/(N+1))).reshape((1,N)) | |
sin_vec2 = 1 / (4*torch.sin(torch.arange(1,M+1, dtype=x.dtype, device=x.device)*pi/(M+1))).reshape((M,1)) | |
cplx_dtype = (torch.complex64 if x.dtype == torch.float32 else torch.complex128) | |
tmp_arr3 = torch.zeros(M, (N+1)//2+1, dtype=cplx_dtype, device=x.device) | |
tmp_arr4 = torch.zeros((M+1)//2+1, N, dtype=cplx_dtype, device=x.device) | |
j_cplx = torch.tensor(1j, dtype=cplx_dtype, device=x.device) | |
sin_vec3 = 1 / (4*torch.sin(torch.arange(1,M+1, dtype=x.dtype, device=x.device)*pi/(M+1))).reshape((1,1,M)) | |
tmp_arr5 = torch.zeros(N, (M+1)//2+1, dtype=cplx_dtype, device=x.device) | |
dstI2D_M1_jit = torch.jit.trace(dstI2D_M1, (x, tmp_arr1, tmp_arr2)) | |
dstI2D_M2_jit = torch.jit.trace(dstI2D_M2, (x, sin_vec1, sin_vec2, tmp_arr3, tmp_arr4, j_cplx)) | |
x_np = x.cpu().numpy() | |
scipy_dst.append(Timer(lambda: scipy.fft.dstn(x_np, type=1, axes=(-2,-1))).timeit(number=number) / number) | |
M1.append(Timer(lambda: dstI2D_M1(x, tmp_arr1, tmp_arr2)).timeit(number=number)/number) | |
M2.append(Timer(lambda: dstI2D_M2(x, sin_vec1, sin_vec2, tmp_arr3, tmp_arr4, j_cplx)).timeit(number=number)/number) | |
M1_jit.append(Timer(lambda: dstI2D_M1_jit(x, tmp_arr1, tmp_arr2)).timeit(number=number)/number) | |
M2_jit.append(Timer(lambda: dstI2D_M2_jit(x, sin_vec1, sin_vec2, tmp_arr3, tmp_arr4, j_cplx)).timeit(number=number)/number) | |
plt.figure() | |
plt.plot(Ns, scipy_dst, label='scipy') | |
plt.plot(Ns, M1, label='M1 2x') | |
plt.plot(Ns, M2, label='M2 .5x') | |
plt.plot(Ns, M1_jit, label='M1 2x jit') | |
plt.plot(Ns, M2_jit, label='M2 .5x jit') | |
plt.loglog(basex=2, basey=10) | |
plt.legend() | |
plt.xlabel('input size') | |
plt.ylabel('mean exec. time (sec)') | |
plt.title(f'Pytorch implem 2D of typeI DST exec. time on {device}') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment