import theano
import theano.tensor as T
from lasagne import init
from lasagne import nonlinearities
from lasagne.utils import as_tuple
from lasagne.layers.base import Layer, MergeLayer
import numpy as np
# from scipy.linalg.dft, this method prepares a matrix of complex numbers that computes the dft
def dft(n, scale=None):
if scale not in [None, 'sqrtn', 'n']:
raise ValueError("scale must be None, 'sqrtn', or 'n'; "
"%r is not valid." % (scale,))
omegas = np.exp(-2j * np.pi * np.arange(n) / n).reshape(-1, 1)
m = omegas ** np.arange(n)
if scale == 'sqrtn':
m /= math.sqrt(n)
elif scale == 'n':
m /= n
return m
# this computes the real half of the dft using only real numbers
def hdft(n, scale=None):
half = dft(n, scale)[:n//2+1]
return np.vstack([np.real(half), np.imag(half)]).astype(np.float32)
class DFTLayer(Layer):
def __init__(self, incoming, **kwargs):
super(DFTLayer, self).__init__(incoming, **kwargs)
n = incoming.output_shape[-1]
dft_matrix = hdft(n).transpose()
self.num_units = dft_matrix.shape[1]
self.W = self.add_param(dft_matrix, dft_matrix.shape, name="W")
def get_output_shape_for(self, input_shape):
return (input_shape[0], self.num_units)
def get_output_for(self, input, **kwargs):
activation =, self.W)
return activation.reshape((-1, self.num_units))
def real(x):
mid = x.shape[1]//2
return x[:,:mid]
def imag(x):
mid = x.shape[1]//2
return x[:,mid:]
# input a and b are in format [r0, r1, r2.. rn, i0, i1, i2... in]
def dft_error(a, b):
ra = real(a)
rb = real(b)
ia = imag(a)
ib = imag(b)
realdiff = ra - rb
imagdiff = ia - ib
rmsdiff = realdiff**2 + imagdiff**2
amag = T.sqrt(ra**2 + ia**2)
bmag = T.sqrt(rb**2 + ib**2)
magdiff = (amag - bmag)**2
n = rmsdiff.shape[1]
weights = (T.cast(1., 'float32') + theano.tensor.arange(n)) / T.cast(1. + n, 'float32') # linspace (0, 1)
return T.stack([rmsdiff * (1. - weights), magdiff * weights]).dimshuffle(2,1,3,).flatten(2)
from lasagne.layers import get_output, InputLayer
def dft_batch(x):
input_layer = InputLayer((None, 1, x.shape[1]))
dft_layer = DFTLayer(input_layer)
return get_output(dft_layer, x).eval()
