Skip to content

Instantly share code, notes, and snippets.

@louity
Created April 20, 2022 08:04
Show Gist options
  • Save louity/a3c19bb927120a3c5a6f31df6572f5fd to your computer and use it in GitHub Desktop.
Save louity/a3c19bb927120a3c5a6f31df6572f5fd to your computer and use it in GitHub Desktop.
Solve poisson with homogeneous dirichlet BC using Discrete Sine Transform and PyTorch
import torch
import torch.nn.functional as F
def compute_laplace_dst(nx, ny, dx, dy, arr_kwargs):
"""Discrete sine transform of the 2D centered discrete laplacian
operator."""
x, y = torch.meshgrid(torch.arange(1,nx-1, **arr_kwargs),
torch.arange(1,ny-1, **arr_kwargs),
indexing='ij')
return 2*(torch.cos(torch.pi/(nx-1)*x) - 1)/dx**2 + 2*(torch.cos(torch.pi/(ny-1)*y) - 1)/dy**2
def dstI1D(x, norm='ortho'):
"""1D type-I discrete sine transform."""
return torch.fft.irfft(-1j*F.pad(x, (1,1)), dim=-1, norm=norm)[...,1:x.shape[-1]+1]
def dstI2D(x, norm='ortho'):
"""2D type-I discrete sine transform."""
return dstI1D(dstI1D(x, norm=norm).transpose(-1,-2), norm=norm).transpose(-1,-2)
def laplacian_h_nobc(f):
return (f[...,2:,1:-1] + f[...,:-2,1:-1] + f[...,1:-1,2:] + f[...,1:-1,:-2]
- 4*f[...,1:-1,1:-1])
def inverse_elliptic_dst(f, operator_dst):
"""Inverse elliptic operator (e.g. Laplace, Helmoltz)
using discrete sine transform."""
return dstI2D(dstI2D(f) / operator_dst)
def inverse_elliptic_dst_f64(f, operator_dst):
"""Inverse elliptic operator (e.g. Laplace, Helmoltz)
using discrete sine transform."""
return dstI2D(dstI2D(f.type(torch.float64)) / operator_dst)
def inverse_elliptic_dst_f32_64(f, operator_dst):
"""Inverse elliptic operator (e.g. Laplace, Helmoltz)
using discrete sine transform."""
return dstI2D(dstI2D(f.type(torch.float32)).type(torch.float64) / operator_dst)
def inverse_elliptic_dst_f64_32(f, operator_dst):
"""Inverse elliptic operator (e.g. Laplace, Helmoltz)
using discrete sine transform."""
return dstI2D((dstI2D(f) / operator_dst).type(torch.float32))
def inverse_elliptic_dst_f32(f, operator_dst):
"""Inverse elliptic operator (e.g. Laplace, Helmoltz)
using discrete sine transform."""
return dstI2D(dstI2D(f.type(torch.float32)) / operator_dst).type(torch.float64)
import os
os.environ['OMP_NUM_THREADS']='1'
os.environ['MKL_NUM_THREADS']='1'
torch.set_num_threads(1)
nx, ny = 257, 257
dx, dy = 1, 1
laplace_dst = compute_laplace_dst(nx, ny, dx, dy, {'dtype': torch.float64})
torch.manual_seed(0)
p = torch.zeros(nx, ny, dtype=torch.float64)
p[1:-1,1:-1].uniform_(-1,1)
delta_p = laplacian_h_nobc(p)
p1 = torch.zeros_like(p)
p1[...,1:-1,1:-1] = inverse_elliptic_dst_f64(delta_p, laplace_dst)
d1 = torch.mean(torch.abs(p - p1))
p2 = torch.zeros_like(p)
p2[...,1:-1,1:-1] = inverse_elliptic_dst_f32_64(delta_p, laplace_dst)
d2 = torch.mean(torch.abs(p - p2))
p3 = torch.zeros_like(p)
p3[...,1:-1,1:-1] = inverse_elliptic_dst_f64_32(delta_p, laplace_dst)
d3 = torch.mean(torch.abs(p - p3))
p4 = torch.zeros_like(p)
p4[...,1:-1,1:-1] = inverse_elliptic_dst_f32(delta_p, laplace_dst)
d4 = torch.mean(torch.abs(p - p4))
print(f'Solve poisson eq: \n - diff with DST f64 :{d1.item():.3E}\n - diff with DST f32 forward, f64 backward :{d2.item():.3E}\n - diff with DST f64 forward, f32 backward :{d3.item():.3E}\n - diff with DST f32 :{d4.item():.3E}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment