Skip to content

Instantly share code, notes, and snippets.

@HudsonHuang
Created June 26, 2019 07:56
Show Gist options
  • Save HudsonHuang/12ecce121871362b546271a01d775e8c to your computer and use it in GitHub Desktop.
Save HudsonHuang/12ecce121871362b546271a01d775e8c to your computer and use it in GitHub Desktop.
DFT的梯度
# FFT 的 一阶,二阶和n阶导数:https://math.mit.edu/~stevenj/fft-deriv.pdf
# DFT的导数:https://math.stackexchange.com/a/1658364/684858
import tensorflow as tf
import numpy as np
import torch
from torch.autograd import gradcheck, Variable
# mag loss
def mag(x):
print("mag in:",x.shape)
x = tf.complex(x, tf.zeros(x.shape,dtype=tf.float64))
cl = tf.fft(x)
mag = tf.abs(cl)**2
print("mag out:",mag.shape)
return mag
def tesnorflow_check():
tf.enable_eager_execution()
# Eager execution for tensorflow version mag loss
signal = np.random.random((4,1,256,1))
# Check if gradient can be compute
tfe = tf.contrib.eager
grad = tfe.gradients_function(mag)
k = grad(signal)[0].numpy()
print(k)
def pytorch_check():
inx = Variable(torch.randn(1,128,2).double(), requires_grad=True)
test = gradcheck(torch.fft,(inx,2))
print(test)
if __name__=="__main__":
tesnorflow_check()
pytorch_check()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment