Skip to content

Instantly share code, notes, and snippets.

@Gurpreetsingh9465
Created March 24, 2019 09:48
Show Gist options
  • Save Gurpreetsingh9465/f76cc9e53107c29fd76515d64c294d3f to your computer and use it in GitHub Desktop.
Save Gurpreetsingh9465/f76cc9e53107c29fd76515d64c294d3f to your computer and use it in GitHub Desktop.
implementation of fftshift and ifftshift
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shift the zero-frequency component to the center of the spectrum"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import roll as _roll
from tensorflow.python.framework import ops
from tensorflow.python.util.tf_export import tf_export
@tf_export("signal.fftshift")
def fftshift(x, axes=None):
"""
Shift the zero-frequency component to the center of the spectrum.
This function swaps half-spaces for all axes listed (defaults to all).
Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even.
Parameters
----------
x : array_like, Tensor
Input array.
axes : int or shape tuple, optional
Axes over which to shift. Default is None, which shifts all axes.
Returns
-------
y : Tensor.
"""
x = ops.convert_to_tensor_v2(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [dim // 2 for dim in x.shape]
elif isinstance(axes, int):
shift = x.shape[axes] // 2
else:
shift = [x.shape[ax] // 2 for ax in axes]
return _roll(x, shift, axes)
@tf_export("signal.ifftshift")
def ifftshift(x, axes=None):
"""
The inverse of `fftshift`. Although identical for even-length `x`, the
functions differ by one sample for odd-length `x`.
Parameters
----------
x : array_like, Tensor.
axes : int or shape tuple, optional
Axes over which to calculate. Defaults to None, which shifts all axes.
Returns
-------
y : Tensor.
"""
x = ops.convert_to_tensor_v2(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [-(dim // 2) for dim in x.shape]
elif isinstance(axes, int):
shift = -(x.shape[axes] // 2)
else:
shift = [-(x.shape[ax] // 2) for ax in axes]
return _roll(x, shift, axes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment