Skip to content

Instantly share code, notes, and snippets.

@louity
Last active November 25, 2021 05:38
Show Gist options
  • Save louity/f0e5de7fee5791db265ef3ab83c2aeb3 to your computer and use it in GitHub Desktop.
Save louity/f0e5de7fee5791db265ef3ab83c2aeb3 to your computer and use it in GitHub Desktop.
Pytorch implementation of two dimensional type-I Discrete Sine Transform
"""DST I using FFT routines, Louis Thiry
Method 1 is 'naive' and used FFTs with twice bigger input signal.
Method 2 is more sophisticated and used iRFFT with half the input signal size.
The naive method 1 seems however to be more efficient, and JIT compilation is not key.
"""
import numpy as np
import scipy.fftpack
import torch
def dstI2D_M1(x, tmp_arr1, tmp_arr2):
M, N = x.shape[-2:]
tmp_arr1[...,1:N+1] = -x
tmp_arr1[...,N+2:] = torch.flip(x, dims=(-1,))
dst1D_x = torch.fft.rfft(tmp_arr1, dim=-1)[...,1:-1].imag
tmp_arr2[...,1:M+1,:] = -dst1D_x
tmp_arr2[...,M+2:,:] = torch.flip(dst1D_x, dims=(-2,))
return torch.fft.rfft(tmp_arr2, dim=-2)[...,1:-1,:].imag
def dstI2D_M2(x, sin_vec1, sin_vec2, tmp_arr1, tmp_arr2, j_cplx):
M, N = x.shape[-2:]
tmp_arr1[...,0] = 2*x[...,0]
tmp_arr1[...,-1] = -2*x[...,-1]
tmp_arr1[...,1:-1] = x[...,2::2] - x[...,:-2:2] - j_cplx*x[...,1:-1:2]
d = torch.fft.irfft(tmp_arr1, dim=-1, norm='forward')
d_flip = torch.flip(d[...,1:], dims=(-1,))
dst1D_x = 0.5*(d[...,1:] - d_flip) + sin_vec1 * (d[...,1:] + d_flip)
tmp_arr2[...,0,:] = 2*dst1D_x[...,0,:]
tmp_arr2[...,-1,:] = -2*dst1D_x[...,-1,:]
tmp_arr2[...,1:-1,:] = dst1D_x[...,2::2,:] - dst1D_x[...,:-2:2,:] - j_cplx*dst1D_x[...,1:-1:2,:]
d = torch.fft.irfft(tmp_arr2, dim=-2, norm='forward')
d_flip = torch.flip(d[...,1:,:], dims=(-2,))
dst2D_x = 0.5*(d[...,1:,:] - d_flip) + sin_vec2 * (d[...,1:,:] + d_flip)
return dst2D_x
if __name__ == '__main__':
M, N = 3, 7
x = torch.from_numpy(np.random.rand(M, N))
pi = torch.tensor(np.pi)
tmp_arr1 = torch.zeros(M, 2*N+2, dtype=x.dtype)
tmp_arr2 = torch.zeros(2*M+2, N, dtype=x.dtype)
sin_vec1 = 1 / (4*torch.sin(torch.arange(1,N+1, dtype=x.dtype)*pi/(N+1))).reshape((1,N))
sin_vec2 = 1 / (4*torch.sin(torch.arange(1,M+1, dtype=x.dtype)*pi/(M+1))).reshape((M,1))
cplx_dtype = (torch.complex64 if x.dtype == torch.float32 else torch.complex128)
tmp_arr3 = torch.zeros(M, (N+1)//2+1, dtype=cplx_dtype)
tmp_arr4 = torch.zeros((M+1)//2+1, N, dtype=cplx_dtype)
j_cplx = torch.tensor(1j, dtype=cplx_dtype)
sin_vec3 = 1 / (4*torch.sin(torch.arange(1,M+1, dtype=x.dtype)*pi/(M+1))).reshape((1,M))
cplx_dtype = (torch.complex64 if x.dtype == torch.float32 else torch.complex128)
tmp_arr5 = torch.zeros(N, (M+1)//2+1, dtype=cplx_dtype)
with np.printoptions(suppress=True, precision=4):
print(f'input x={x.cpu().numpy()}')
print(f'scipy 2D DST-I={scipy.fft.dstn(x.cpu().numpy(), type=1, axes=(0,1))}')
print(f' 2x fft 2D DST-I={dstI2D_M1(x, tmp_arr1, tmp_arr2).numpy()}')
print(f'.5x ifft 2D DST-I={dstI2D_M2(x, sin_vec1, sin_vec2, tmp_arr3, tmp_arr4, j_cplx).numpy()}')
from timeit import Timer
import matplotlib.pyplot as plt
Ns = [63, 127, 255, 511, 1023, 2047]
numbers = [1000, 1000, 500, 250, 100, 100]
scipy_dst, M1, M2, M1_jit, M2_jit = [], [], [], [], []
device = 'cpu' # 'cuda'
for N, number in zip(Ns, numbers):
print(N)
M = N
x = torch.from_numpy(np.random.rand(M, N)).to(device)
pi = torch.tensor(np.pi, dtype=x.dtype, device=x.device)
tmp_arr1 = torch.zeros(M, 2*N+2, dtype=x.dtype, device=x.device)
tmp_arr2 = torch.zeros(2*M+2, N, dtype=x.dtype, device=x.device)
sin_vec1 = 1 / (4*torch.sin(torch.arange(1,N+1, dtype=x.dtype, device=x.device)*pi/(N+1))).reshape((1,N))
sin_vec2 = 1 / (4*torch.sin(torch.arange(1,M+1, dtype=x.dtype, device=x.device)*pi/(M+1))).reshape((M,1))
cplx_dtype = (torch.complex64 if x.dtype == torch.float32 else torch.complex128)
tmp_arr3 = torch.zeros(M, (N+1)//2+1, dtype=cplx_dtype, device=x.device)
tmp_arr4 = torch.zeros((M+1)//2+1, N, dtype=cplx_dtype, device=x.device)
j_cplx = torch.tensor(1j, dtype=cplx_dtype, device=x.device)
sin_vec3 = 1 / (4*torch.sin(torch.arange(1,M+1, dtype=x.dtype, device=x.device)*pi/(M+1))).reshape((1,1,M))
tmp_arr5 = torch.zeros(N, (M+1)//2+1, dtype=cplx_dtype, device=x.device)
dstI2D_M1_jit = torch.jit.trace(dstI2D_M1, (x, tmp_arr1, tmp_arr2))
dstI2D_M2_jit = torch.jit.trace(dstI2D_M2, (x, sin_vec1, sin_vec2, tmp_arr3, tmp_arr4, j_cplx))
x_np = x.cpu().numpy()
scipy_dst.append(Timer(lambda: scipy.fft.dstn(x_np, type=1, axes=(-2,-1))).timeit(number=number) / number)
M1.append(Timer(lambda: dstI2D_M1(x, tmp_arr1, tmp_arr2)).timeit(number=number)/number)
M2.append(Timer(lambda: dstI2D_M2(x, sin_vec1, sin_vec2, tmp_arr3, tmp_arr4, j_cplx)).timeit(number=number)/number)
M1_jit.append(Timer(lambda: dstI2D_M1_jit(x, tmp_arr1, tmp_arr2)).timeit(number=number)/number)
M2_jit.append(Timer(lambda: dstI2D_M2_jit(x, sin_vec1, sin_vec2, tmp_arr3, tmp_arr4, j_cplx)).timeit(number=number)/number)
plt.figure()
plt.plot(Ns, scipy_dst, label='scipy')
plt.plot(Ns, M1, label='M1 2x')
plt.plot(Ns, M2, label='M2 .5x')
plt.plot(Ns, M1_jit, label='M1 2x jit')
plt.plot(Ns, M2_jit, label='M2 .5x jit')
plt.loglog(basex=2, basey=10)
plt.legend()
plt.xlabel('input size')
plt.ylabel('mean exec. time (sec)')
plt.title(f'Pytorch implem 2D of typeI DST exec. time on {device}')
plt.show()
@louity
Copy link
Author

louity commented Nov 25, 2021

Capture d’écran 2021-11-25 à 06 32 41

@louity
Copy link
Author

louity commented Nov 25, 2021

Capture d’écran 2021-11-25 à 06 36 09

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment