Skip to content

Instantly share code, notes, and snippets.

@jamesgregson
Created July 10, 2020 03:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesgregson/036dedb4cfc67d49c5fb0f9790050138 to your computer and use it in GitHub Desktop.
Save jamesgregson/036dedb4cfc67d49c5fb0f9790050138 to your computer and use it in GitHub Desktop.
Spatial, Fourier and matrix representations of 2D filtering operations
import unittest
import numpy as np
import scipy.sparse as sparse
class Filter2D:
def __init__( self, offy, offx, vals ):
'''2D filtering operation with spatial, fourier and matrix features
Args:
- offy (integer sequence): vertical offsets of filter coefficients
- offx (integer sequence): horizontal offset of filter coefficients
- vals (float sequence): filter coefficients
'''
self._offy = np.array(offy)
self._offx = np.array(offx)
self._vals = np.array(vals)
@property
def T( self ):
'''Mirror/conjugate/transpose the filter for spatial, fourier & matrix operations respectively'''
return Filter2D( -self._offy.copy(), -self._offx.copy(), self._vals.copy() )
def __matmul__( self, img ):
'''Perform spatial correlation with the input image, no padding is performed'''
result = np.zeros_like(img)
for dy,dx,v in zip(self._offy,self._offx,self._vals):
result += v*np.roll( img, (-dy,-dx), (0,1) )
return result
def image(self, dim, dtype=np.float32):
'''Return the image representation of the filter, not shifted to center'''
K = np.zeros(dim,dtype=dtype)
K[-self._offy,-self._offx] = self._vals
return K
def fft(self, dim, dtype=np.float32):
'''Return the filter spectrum, for Fourier-domain operations'''
return np.fft.fft2( self.image(dim,dtype) )
def matrix(self,dim,dtype=np.float32):
'''Return a sparse-matrix representation of the filter with pixels in numpy array ordering, periodic boundaries'''
N = np.prod(dim)
row,col,val = ([],[],[])
idx = np.arange(N).reshape(dim)
x,y = np.meshgrid( np.arange(dim[1]),np.arange(dim[0]))
for (dy,dx,v) in zip(self._offy,self._offx,self._vals):
cr = idx.copy().flatten()
cc = idx[(y+dy+dim[0])%dim[0],(x+dx+dim[1])%dim[1]].flatten()
cv = v*np.ones(N)
row.append( cr )
col.append( cc )
val.append( cv )
row,col,val = [np.concatenate(arg) for arg in (row,col,val)]
return sparse.coo_matrix((val,(row,col)),shape=(N,N),dtype=dtype)
class FilterTest(unittest.TestCase):
def test_all(self):
# test images for gradient filters and laplacian
x,y = [v.astype(np.float32) for v in np.meshgrid( np.arange(51),np.arange(100) )]
img = np.random.standard_normal((100,51))
# derivative in x & Y and Laplacian
Dx = Filter2D([0,0],[-1,0],[-1.0,1.0])
Dy = Filter2D([-1,0],[0,0],[-1.0,1.0])
L = Filter2D([0,0,-1,1,0],[-1,1,0,0,0],[-0.25,-0.25,-0.25,-0.25,1.0])
F = Filter2D([0,0,-1,1,0],[-1,1,0,0,0],np.random.uniform(size=5))
# x-derivative test
gx_s = Dx@x
gx_f = np.real( np.fft.ifft2( Dx.fft(x.shape)*np.fft.fft2( x ) ) )
gx_m = (Dx.matrix(x.shape)@x.flatten()).reshape(x.shape)
self.assertTrue( np.allclose( gx_s, gx_f ) )
self.assertTrue( np.allclose( gx_s, gx_m ) )
self.assertAlmostEqual( np.median(gx_s), 1.0 )
# y-derivative test
gy_s = Dy@y
gy_f = np.real( np.fft.ifft2( Dy.fft(y.shape)*np.fft.fft2( y ) ) )
gy_m = (Dy.matrix(y.shape)@y.flatten()).reshape(x.shape)
self.assertTrue( np.allclose( gy_s, gy_f ) )
self.assertTrue( np.allclose( gy_s, gy_m ) )
self.assertAlmostEqual( np.median(gy_s), 1.0 )
# laplacian test
l_s = L@img
l_f = np.real( np.fft.ifft2( L.fft(img.shape)*np.fft.fft2(img) ) )
l_m = (L.matrix(img.shape)@img.flatten()).reshape(img.shape)
self.assertTrue( np.allclose( l_s, l_m ) )
self.assertTrue( np.allclose( l_s, l_f ) )
# transpose/conjugation test
ft_s = F.T@img
ft_f = np.real( np.fft.ifft2( F.T.fft(img.shape)*np.fft.fft2(img) ) )
ft_m1 = (F.T.matrix(img.shape)@img.flatten()).reshape(img.shape)
ft_m2 = (F.matrix(img.shape).T@img.flatten()).reshape(img.shape)
self.assertTrue( np.linalg.norm( ft_s-ft_f ) < 1e-5 )
self.assertTrue( np.linalg.norm( ft_s-ft_m1 ) < 1e-5 )
self.assertTrue( np.linalg.norm( ft_m1 - ft_m2 ) < 1e-5 )
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment