Skip to content

Instantly share code, notes, and snippets.

@keunwoochoi
Created December 12, 2019 23:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save keunwoochoi/0af90c36651abe6e2c4c5426a54f47fe to your computer and use it in GitHub Desktop.
Save keunwoochoi/0af90c36651abe6e2c4c5426a54f47fe to your computer and use it in GitHub Desktop.
"DIGITAL SIGNAL PROCESSING"
import torch
import pdb
def complex_mul(t1, t2):
"""t1, t2: complex representations of torch tensor
"""
assert (t1.dim() == t2.dim()), "dim mismatch in complex_mul, {} and {}".format(t1.dim(), t2.dim())
if t1.dim() == 2:
r1, i1 = t1[:, 0], t1[:, 1]
r2, i2 = t2[:, 0], t2[:, 1]
elif t1.dim() == 3:
r1, i1 = t1[:, :, 0], t1[:, :, 1]
r2, i2 = t2[:, :, 0], t2[:, :, 1]
elif t1.dim() == 4:
r1, i1 = t1[:, :, :, 0], t1[:, :, :, 1]
r2, i2 = t2[:, :, :, 0], t2[:, :, :, 1]
else:
raise NotImplementedError
return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1)
def fast_conv1d(signal, kernel):
"""assuming filter is shorter than signal
signal: (batch, ch=1, time)
kernel: (time, ), or some dim expension of it
This function is not so much general (shapes of input etc) but
let's just use it for now...
The operation is exactly convolution - the kernel doesn't need to be flipped.
"""
batch, ch, L_sig = signal.shape
assert ch == 1
kernel = kernel.reshape(1, -1)
L_I = kernel.shape[1]
L_F = 2 << (L_I - 1).bit_length()
L_S = L_F - L_I + 1
device_ = signal.device
pad_kernel = L_F - L_I
FDir = torch.rfft(torch.cat((kernel, torch.zeros(1, pad_kernel, device=device_)),
dim=1), signal_ndim=1)
signal_sizes = [L_F]
len_pad = (L_S - L_sig % L_S) % L_S
offsets = range(0, L_sig, L_S)
signal = torch.cat((signal, torch.zeros(batch, ch, len_pad, device=device_)), dim=2)
result = torch.zeros(batch, 1, offsets[-1] + L_F).to(device_)
pad_slice = L_F - L_S
for idx_fr in offsets:
idx_to_in = idx_fr + L_S
idx_to_out = idx_fr + L_F
to_rfft = torch.cat((signal[:, 0, idx_fr:idx_to_in],
torch.zeros(batch, pad_slice, device=device_)), dim=1)
to_mul = torch.rfft(to_rfft, signal_ndim=1,
normalized=True)
to_irfft = complex_mul(to_mul, FDir)
conved_slice = torch.irfft(to_irfft, signal_ndim=1,
signal_sizes=signal_sizes,
normalized=True)
result[:, 0, idx_fr: idx_to_out] += conved_slice
return result[:, :, :L_sig]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment