import numarray as N
import numarray.fft as F

def czt(x, m=None, w=None, a=1.0, axis = -1):
    """
    Copyright (C) 2000 Paul Kienzle

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  US

    usage y=czt(x, m, w, a)

    Chirp z-transform.  Compute the frequency response starting at a and
    stepping by w for m steps.  a is a point in the complex plane, and
    w is the ratio between points in each step (i.e., radius increases
    exponentially, and angle increases linearly).

    To evaluate the frequency response for the range f1 to f2 in a signal
    with sampling frequency Fs, use the following:
    m = 32;                          ## number of points desired
    w = exp(-2i*pi*(f2-f1)/(m*Fs));  ## freq. step of f2-f1/m
    a = exp(2i*pi*f1/Fs);            ## starting at frequency f1
    y = czt(x, m, w, a);

    If you don't specify them, then the parameters default to a Fourier 
    transform:
      m=length(x), w=exp(2i*pi/m), a=1
    Because it is computed with three FFTs, this will be faster than
    computing the Fourier transform directly for large m (which is
    otherwise the best you can do with fft(x,n) for n prime).

    TODO: More testing---particularly when m+N-1 approaches a power of 2
    TODO: Consider treating w,a as f1,f2 expressed in radians if w is real
    """
    # Convenience declarations
    ifft = F.inverse_fft
    fft = F.fft
    do_transpose = (axis != -1) and (x.rank > 1) # transpose data to make it equivalent to axis=-1
    if axis < 0:
        axis += x.rank
    if do_transpose:
        axes = N.arange(x.rank)
        axes[[axis, x.rank-1]] = axes[[x.rank-1, axis]]
        x = N.transpose(x, axes)

    if m is None:
        m = x.shape[-1]
    if w is None:
        w = N.exp(2j*N.pi/m)

    n = x.shape[-1]

    k = N.arange(m, type=N.Float64)
    Nk = N.arange(-(n-1), m-1, type=N.Float64)

    nfft = next2pow(min(m,n) + len(Nk) -1)
    Wk2 = w**(-(Nk**2)/2)           # length = m + len(x)
    AWk2 = a**(-k) * w**((k**2)/2)  # length = m
    y = ifft(fft(Wk2,nfft) * fft(x * N.resize(AWk2, x.shape), nfft));
    y = N.take(y, range(n,m+n), axis=-1)  # [n:m+n]
    y = N.resize(w**((k**2)/2), y.shape) * y
    if do_transpose:
        y.transpose(axes)
    return y

def next2pow(x):
    return 2**int(N.ceil(N.log(float(x))/N.log(2.0)))