Skip to content

Instantly share code, notes, and snippets.

@fjkz
Created May 11, 2016 10:26
Show Gist options
  • Save fjkz/4e6281eefb777b3d65017fbf5eee3d10 to your computer and use it in GitHub Desktop.
Save fjkz/4e6281eefb777b3d65017fbf5eee3d10 to your computer and use it in GitHub Desktop.
KdV equation solver with pseudo-spectrum method
#!/usr/bin/env python3
"""
KdV equation solver
"""
from numpy import pi, cos, linspace, arange
from numpy.fft import rfft, irfft
# Constant in Zabusky & Kruskal (1965)
DELTA = 0.022 ** 2
TB = 1.0 / pi
N = 512
DT = 0.00001 * TB
NT_SAVE = 1000
NT_END = 10000000
def save_to_file(nt, t, x, u):
fname = "{0:08d}.txt".format(nt)
f = open(fname, "w")
f.write("# t = {0:.6f}\n".format(t))
for i in range(N):
f.write("{0:.6f}\t{1:.6f}\n".format(x[i], u[i]))
f.close()
# Check conservativeness
def validate(nt, u):
i1 = sum(u)
i2 = sum(1/2 * u**2)
print('\t'.join([str(i) for i in [nt, i1, i2]]))
# initial condition
x = linspace(0.0, 2.0, N, endpoint=False)
u = cos(pi*x)
v = rfft(u)
save_to_file(0, 0.0, x, u)
validate(0, u)
k = pi * arange(len(v))
kk = k * k
kkk = kk * k
# Calculate dv/dt with Pseudo-Spectrum method
def time_grad(v):
u = irfft(v)
ux = irfft(1j * k * v)
conv = - rfft(u * ux)
disp = + DELTA * 1j * kkk * v
dv = conv + disp
return dv
# main loop
for nt in range(1, NT_END + 1):
t = DT / TB * nt
# Runge-Kutta method
dv0 = time_grad(v)
v1 = v + 0.5 * DT * dv0
dv1 = time_grad(v1)
v2 = v + 0.5 * DT * dv1
dv2 = time_grad(v2)
v3 = v + DT * dv2
dv3 = time_grad(v3)
v = v + DT / 6.0 * (dv0 + 2.0*dv1 + 2.0*dv2 + dv3)
if nt % NT_SAVE == 0:
u = irfft(v)
validate(nt, u)
save_to_file(nt, t, x, u)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment