import wave
import numpy as np
import matplotlib.pyplot as plt
import struct

# 周波数シフト関数
def shift_freq(F, s_freq_hz, fs, sec=1):
    N = fs * sec
    s_freq = s_freq_hz * sec
    # プラス方向へのシフト
    if s_freq > 0:
        # 前半部分のシフト
        for i in reversed(range(0,int(N/2))):
            si = i - s_freq
            if si >= 0:
                F.real[i] = F.real[si]
                F.imag[i] = F.imag[si]
            elif si < 0:
                F.imag[i] = 0
        # 後半部分のシフト
        for i in range(int(N/2)+1,N):
            si = i + s_freq
            if si < N:
                F.real[i] = F.real[si]
                F.imag[i] = F.imag[si]
            elif si >= N:
                F.imag[i] = 0
    # マイナス方向へのシフト
    elif s_freq < 0:
        # 前半部分のシフト
        for i in range(0,int(N/2)):
            si = i - s_freq
            if si < int(N/2):
                F.real[i] = F.real[si]
                F.imag[i] = F.imag[si]
            elif si >= int(N/2):
                F.imag[i] = 0
        # 後半部分のシフト
        for i in reversed(range(int(N/2)+1,N)):
            si = i + s_freq
            if si > int(N/2)+1:
                F.real[i] = F.real[si]
                F.imag[i] = F.imag[si]
            elif si < int(N/2)+1:
                F.imag[i] = 0
    return F

if __name__ == '__main__':
    a = 1     #振幅
    fs = 8192 #サンプリング周波数
    f0 = 440  #周波数
    sec = 5   #秒
    N = fs * sec
    swav=[]
    for n in np.arange(fs * sec):
        #サイン波を生成
        s = a * np.sin(2.0 * np.pi * f0 * n / fs)
        swav.append(s)

    # 正弦波プロット
    plt.plot(swav[0:100])
    plt.savefig("sin_440.png")
    plt.gca().clear()

    # FFT
    F = np.fft.fft(swav) # 変換結果

    # Fを弄って周波数を変更
    F = shift_freq(F, 440, fs, sec=5)

    # 逆FFT
    swav_ifft = np.fft.ifft(F)
    plt.plot(swav_ifft.real[0:100])
    plt.savefig("sin_ifft.png")
    plt.gca().clear()

    #サイン波を-32768から32767の整数値に変換(signed 16bit pcmへ)
    swav = [int(x * 32767.0) for x in swav_ifft.real]
    #バイナリ化
    binwave = struct.pack("h" * len(swav), *swav)
    #サイン波をwavファイルとして書き出し
    w = wave.Wave_write("output_ifft.wav")
    p = (1, 2, 8000, len(binwave), 'NONE', 'not compressed')
    w.setparams(p)
    w.writeframes(binwave)
    w.close()