Skip to content

Instantly share code, notes, and snippets.

@christophernhill
Created April 9, 2019 01:03
Show Gist options
  • Save christophernhill/63b40a28a95c1179d947544d79d37536 to your computer and use it in GitHub Desktop.
Save christophernhill/63b40a28a95c1179d947544d79d37536 to your computer and use it in GitHub Desktop.
# Julia code for Helmholtz solve with tridiag Neumann in inner dimension and
# periodic in outer up to 2 dimensions.
# Set environment
using Pkg
Pkg.add("LinearAlgebra")
using LinearAlgebra
Pkg.add("FFTW")
using FFTW
Pkg.add("LaTeXStrings")
using LaTeXStrings
Pkg.add("PyCall")
using PyCall
Pkg.add("PyPlot")
using PyPlot
# Pkg.add("Interact")
# using Interact
Pkg.add("SparseArrays")
using SparseArrays
# tdsolve - solve a tridiagonal system
# ultimately we want to pass in just dz and RHS and loop
# over each levels eigenvalue modes.
function tdsolve(ld,md,ud,rhs)
# Tridoagonal solve per Numerical Recipes, Press et. al 1992 (Sec 2.4 )
# ld[2:N ] - lower diagonal
# md[1:N ] - main diagonal
# ud[1:N-1] - upper diagonal
# phi - solution vector
# rhs - right hand side
# Get length and allocate memory
# return rhs
N=length(rhs)
phi=rhs.*typeof(rhs[1])(0)
gamma=rhs.*typeof(rhs[1])(0)
#
beta=md[1]
phi[1]=rhs[1]/beta
for j=2:N
gamma[j]=ud[j-1]/beta
beta=md[j]-ld[j]*gamma[j]
phi[j]=(rhs[j]-ld[j]*phi[j-1])/beta
end
for j=1:N-1
k=N-j
phi[k]=phi[k]-gamma[k+1]*phi[k+1]
end
return phi
end
# mkwaves - generate eigenvalues
function mkwaves(N,L)
scyc=zeros(N,1); sneu=zeros(N,1);
for i in 1:N
scyc[i]=(2*sin((i-1)*π/N)/(L/N)).^2
sneu[i]=(2*sin((i-1)*π/(2*(N)))/(L/N)).^2
end
return scyc, sneu
end
# Make 1-d N operator with variable grid spacing in Z specified
function mkA_N(dzArr::Array)
N=length(dzArr)
A=zeros(N,N)
for k=1:N
if k == 1
upTerm=0
else
dK=dzArr[k]
dKm1=dzArr[k-1]
upTerm=2.0/(dK*dKm1+dK*dK);
end
if k == N
dnTerm=0
else
dK=dzArr[k]
dKp1=dzArr[k+1]
dnTerm=2.0/(dK*dKp1+dK*dK);
end
if k == 1
A[k ,k ]=-dnTerm
A[k ,k+1]= dnTerm
elseif k == N
A[k ,k-1]= upTerm
A[k ,k ]=-upTerm
else
A[k ,k-1]= upTerm
A[k ,k ]=-dnTerm-upTerm
A[k ,k+1]= dnTerm
end
end
return A
end
function mkA_PPN(Nx,Ny,Nz)
return mkA_PPN(Nx,Ny,Nz,Nx,Ny,Nz)
end
# Make 3-d PPN operator with domain lengths specified
function mkA_PPN(Nx,Ny,Nz,Lx,Ly,Lz)
NN=Nx*Ny*Nz;
A=zeros(NN,NN)
dx=Lx/Nx;rdx2=1. / (dx^2.)
dy=Ly/Ny;rdy2=1. / (dy^2.)
dz=Lz/Nz;rdz2=1. / (dz^2.)
# Modulo and offset for 1 based index
MOD(i,n)=mod(i-1,n)+1
OFF(i,j,k,ni,nj,nk)= (k-1)*ni*nj + (j-1)*ni + (i-1) + 1
for k=1:Nz
for j=1:Ny
for i=1:Nx
ic=i ; iw=MOD(i-1,Nx); ie=MOD(i+1,Nx)
jc=j ; js=MOD(j-1,Ny); jn=MOD(j+1,Ny)
kc=k ; ku=MOD(k-1,Nz); kd=MOD(k+1,Nz)
offc=OFF(i , j, k, Nx, Ny, Nz)
offw=OFF(iw, j, k, Nx, Ny, Nz)
offe=OFF(ie, j, k, Nx, Ny, Nz)
offs=OFF( i,js, k, Nx, Ny, Nz)
offn=OFF( i,jn, k, Nx, Ny, Nz)
offu=OFF( i, j,ku, Nx, Ny, Nz)
offd=OFF( i, j,kd, Nx, Ny, Nz)
A[offc,offc]=-1. * ( 2. * rdz2 +
2. * rdy2 +
2. * rdz2
)
A[offc,offw]= A[offc,offw]+rdx2
A[offc,offe]= A[offc,offe]+rdx2
A[offc,offs]= A[offc,offs]+rdy2
A[offc,offn]= A[offc,offn]+rdy2
if k == 1
A[offc,offu]= A[offc,offu]+0
A[offc,offc]= A[offc,offc]+rdz2
A[offc,offd]= A[offc,offd]+rdz2
elseif k == Nz
A[offc,offu]= A[offc,offu]+rdz2
A[offc,offc]= A[offc,offc]+rdz2
A[offc,offd]= A[offc,offd]+0
else
A[offc,offu]= A[offc,offu]+rdz2
A[offc,offd]= A[offc,offd]+rdz2
end
end
end
end
# show(IOContext(stdout), "text/plain", Matrix(A))
return A, Nx, Ny, Nz, NN
end
dz=ones(5,1)
dz=[1.,2.,3.]
# dz=ones(50,1)
Az=mkA_N(dz);
Nz=size(Az,1)
# Periodic in X
Ah,Nx,Ny,Nzz,NN=mkA_PPN(3,1,1);
Nh=size(Ah,1)
# Add Helmholtz term if we want
hh=0.1
vv=zeros(Nz,1).+hh
Azplus=diagm(0=>vv[:])
Az=Az.+Azplus;
show(stdout,"text/plain",Az);println()
show(stdout,"text/plain",Ah);println()
# Now create a full A matrix for direct solve
AhExp=kron(Ah,Matrix{Float32}(I,Nz,Nz))
AzExp=kron(Matrix{Float32}(I,Nh,Nh),Az)
A=AzExp+AhExp;
f=rand(Nh*Nz,1);f=f.-sum(f)/(Nh*Nz);
AA=copy(A);
alpha=-1*2/(dz[1]*dz[1]+dz[1]*dz[2])*0.
AA[1]=AA[1]-alpha
println(alpha)
# AA[1]=AA[1]-1
# show(IOContext(stdout), "text/plain", Matrix(AA))
phi=AA\f
Lx=Nx;
Ly=Ny;
Lz=Nz;
# fz = FFTW.r2r(f,FFTW.REDFT10,3)
# fxyz= FFTW.fft(fz,[1,2])
# fF1=FFTW.fft(reshape(f,Nx,Ny,Nz),[1,2]);
# fF2= FFTW.r2r(real.(fF1),FFTW.REDFT10)
fF1=FFTW.r2r(reshape(f,Nz,Nx,Ny),FFTW.REDFT10,1)
fF2= FFTW.fft(fF1,[2,3])
####
fF22=FFTW.fft(reshape(f,Nz,Nx,Ny),[2,3])
####
# display(fF22[:,2,1])
sxcyc, sxneu=mkwaves(Nx,Lx);sx=-sxcyc;
sycyc, syneu=mkwaves(Ny,Ly);sy=-sycyc;
szcyc, szneu=mkwaves(Nz,Lz);sz=-szneu;
# sx=sxneu;
# sy=syneu;
fFi=fF2;
for i=1:Nx
for j=1:Ny
for k=1:Nz
s=sx[i]+sy[j]+sz[k]+hh
if abs(s) < 1.e-12
fFi[k,i,j]=0;
else
fFi[k,i,j]=fFi[k,i,j]./s
end
end
end
end
#####
ld=[0 diag(Az,-1)']';
md=diag(Az,0);
ud=[diag(Az,1)' 0]';
xx=complex.(zeros(Nz,Nx,Ny))
s=0
for i=1:Nx
for j=1:Ny
s=sx[i]+sy[j]
if abs(s) < 1.e-12
# xx[:,i,j].=2
mdd=copy(md)
# mdd[1]=mdd[1]-0.6666666
# println(sum(fF22[:,i,j]))
xx[:,i,j].=tdsolve(complex.(ld),complex.(mdd),complex.(ud),complex.(fF22[:,i,j]))
# println(fF22[:,i,j],reshape(f,Nz,Nx,Ny)[:,i,j])
# xx[:,i,j].=complex.(f[:,1,1])
else
# ss=[s,s,s,s,s]
xx[:,i,j]=tdsolve(complex.(ld),complex.(md.+s),complex.(ud),complex.(fF22[:,i,j]))
end
end
end
#####
fiF1=FFTW.ifft(fFi,[2,3])
fiF2=FFTW.r2r(real.(fiF1),FFTW.REDFT01,1)/(2*Nz)
####
fiF11=real.(FFTW.ifft(xx,[2,3]))
####
fiF=reshape(fiF2,Nx*Ny*Nz);
println("f ",f)
println("A*fiF ",A*fiF)
println("A*fiF11 ",A*reshape(fiF11,Nx*Ny*Nz))
println("fiF ",fiF)
println("fiF11 ",fiF11)
println("phi. ",phi)
df=fiF-reshape(fiF11,Nx*Ny*Nz)
println("fiF-fiF11 ",df)
println("f./(A*fiF) ",f./(A*fiF))
println("f./(A*fiF11) ",f./(A*reshape(fiF11,Nx*Ny*Nz)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment