Skip to content

Instantly share code, notes, and snippets.

@louity
Created May 17, 2022 10:25
Show Gist options
  • Save louity/f51ec9c5ffea7dbf13e6f182e73d29d4 to your computer and use it in GitHub Desktop.
Save louity/f51ec9c5ffea7dbf13e6f182e73d29d4 to your computer and use it in GitHub Desktop.
Type II DCT and DST iwth PyTorch. Note that iDCT-II is DCT-III upt to normalizing constant and t iDST-II is DST-III similarly.
import torch
import scipy.fftpack
import numpy as np
np.set_printoptions(precision=4, linewidth=200)
N = 8
x = torch.DoubleTensor(8).normal_()
exp_vec_1 = 2 * torch.exp(-1j*torch.pi*torch.arange(N)/(2*N))
exp_vec_2 = torch.exp(1j*torch.pi*torch.arange(N)/(2*N))
def dctII_pt(x, exp_vec):
v = torch.cat([x[::2], torch.flip(x, dims=(-1,))[::2]], dim=-1)
V = torch.fft.fft(v)
return (V*exp_vec).real
def dstII_pt(x, exp_vec):
v = torch.cat([x[::2], -torch.flip(x, dims=(-1,))[::2]], dim=-1)
V = torch.fft.fft(v)
return torch.flip((V*exp_vec).real, dims=(-1,))
def idctII_pt(x, exp_vec):
N = x.shape[-1]
x_rev = torch.flip(x, dims=(-1,))[:-1]
v = torch.cat([x[0:1], exp_vec[1:N] * (x[1:N]-1j*x_rev)]) / 2
V = torch.fft.ifft(v)
y = torch.zeros_like(x)
y[::2] = V[:N//2].real;
y[1::2] = torch.flip(V, dims=(-1,))[:N//2].real
return y
def idstII_pt(x, exp_vec):
N = x.shape[-1]
x_ = torch.flip(x, dims=(-1,))
idct_x_ = idctII_pt(x_, exp_vec)
return idct_x_ * (-1)**torch.arange(N)
print(f'pytorch dct-II: {dctII_pt(x, exp_vec_1).cpu().numpy()}')
print(f'scipy dct-II: {scipy.fft.dct(x.cpu().numpy(), type=2)}\n')
print(f'pytorch dst-II: {dstII_pt(x, exp_vec_1).cpu().numpy()}')
print(f'scipy dst-II: {scipy.fft.dst(x.cpu().numpy(), type=2)}\n')
print('x :', x.cpu().numpy())
print('pytorch idctII(dctII(x)):', idctII_pt(dctII_pt(x, exp_vec_1), exp_vec_2).cpu().numpy())
print('pytorch idstII(dstII(x)):', idstII_pt(dstII_pt(x, exp_vec_1), exp_vec_2).cpu().numpy())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment