Skip to content

Instantly share code, notes, and snippets.

@louity
Created October 13, 2022 08:55
Show Gist options
  • Save louity/c4cdfbf66a7d0c4c64e265466b424533 to your computer and use it in GitHub Desktop.
Save louity/c4cdfbf66a7d0c4c64e265466b424533 to your computer and use it in GitHub Desktop.
Solve poisson eq. on double periodic domain
import matplotlib.pyplot as plt
import torch
def laplacian_per(f, dx, dy):
f_per = torch.cat([f[...,[-1]], f, f[...,[0]]], dim=-1)
f_per = torch.cat([f_per[...,[-1],:], f_per, f_per[...,[0],:]], dim=-2)
return ((f_per[...,2:,1:-1] + f_per[...,:-2,1:-1] - 2*f_per[...,1:-1,1:-1]) / dx**2 \
+ (f_per[...,1:-1,2:] + f_per[...,1:-1,:-2]- 2*f_per[...,1:-1,1:-1]) / dy**2)
xmin = 0.0
xmax = 1.0
ymin = 0.0
ymax = 1.0
Nx, Ny = 16, 16
dx = (xmax - xmin)/Nx
dy = (ymax - ymin)/Ny
# create the RHS
f = torch.DoubleTensor(Nx, Ny).normal_()
f -= f.mean()
lap_f = laplacian_per(f, dx, dy)
# FFT of RHS
lap_h_hat = torch.fft.fft2(lap_f)
x, y = torch.meshgrid(torch.arange(Nx, dtype=torch.float64), torch.arange(Ny, dtype=torch.float64), indexing='ij')
# laplacian kernel in FFT
k = 2*( (torch.cos(2.0*torch.pi*x/Nx) - 1.0)/dx**2 +
(torch.cos(2.0*torch.pi*y/Ny) - 1.0)/dy**2)
# ignore frequency zero
k[0,0] = 1
f_hat = lap_h_hat / k
# transform back to real space
f_rec = torch.fft.ifft2(f_hat).real
f_rec -= f_rec.mean()
print(f'Max absolute error: {torch.max(torch.abs(f - f_rec)).item():.3E}')
plt.ion()
fig, a = plt.subplots(1,3)
a[0].imshow(f.T, origin='lower')
a[1].imshow(f_rec.T, origin='lower')
fig.colorbar(a[2].imshow((f-f_rec).T, origin='lower'), ax=a[2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment