Skip to content

Instantly share code, notes, and snippets.

@marl0ny
Last active February 13, 2024 00:33
Show Gist options
  • Save marl0ny/5059782ce8e9bace8971e8fab2a54279 to your computer and use it in GitHub Desktop.
Save marl0ny/5059782ce8e9bace8971e8fab2a54279 to your computer and use it in GitHub Desktop.
"""
This script numerically solves the linear and nonlinear Schrodinger equation
using the split operator method, and then shows a matplotlib animation of the
results.
References:
Split operator method:
James Schloss. The Split Operator Method - Arcane Algorithm Archive.
https://www.algorithm-archive.org/contents/split-operator_method/
split-operator_method.html
Nonlinear Schrodinger equation:
Xavier Antoine, Weizhu Bao, Christophe Besse.
Computational methods for the dynamics of the nonlinear
Schrodinger/Gross-Pitaevskii equations.
https://arxiv.org/abs/1305.1093
"""
import numpy as np
# from scipy.fft import dstn, idstn
from visualization import animate
NX, NY = 1024, 1024 # Grid dimensions
LX, LY = 1024.0, 1024.0 # Spatial dimensions
X, Y = np.meshgrid(np.linspace(0.0, LX*(1.0 - 1.0/NX), NX),
np.linspace(0.0, LY*(1.0 - 1.0/NY), NY))
DX, DY = X[0, 1] - X[0, 0], Y[1, 0] - Y[0, 0]
# Energies for a free particle with periodic boundary conditions
E = sum(np.meshgrid(2.0*(np.pi*NX*np.fft.fftfreq(NX)/LX)**2,
2.0*(np.pi*NY*np.fft.fftfreq(NY)/LY)**2))
# NL_TIME = 100000.0
NL_TIME = 0.0
def nonlinear(psi, t):
#
if np.abs(t) > NL_TIME:
return 4000.0*psi*np.conj(psi)
return 0.0*psi*np.conj(psi)
def u(psi, t):
"""Free space with periodic boundary condition propagator for psi"""
return np.fft.ifftn(np.exp(-1.0j*E*t)*np.fft.fftn(psi))
def normalize(psi):
return psi/np.sqrt(DX*DY*np.sum(np.abs(psi)**2))
def step(psi, t, dt, phi, should_normalize=False):
"""Advance psi by a single time step dt"""
psi1 = np.exp(-1.0j*(phi + nonlinear(psi, t))*dt/2.0)*psi
psi2 = u(psi1, dt)
psi3 = np.exp(-1.0j*(phi + nonlinear(psi2, t))*dt/2.0)*psi2
return normalize(psi3) if should_normalize else psi3
def make_heart_potential(height, size, edge_sharpness, x_off, y_off):
r = 5.0*np.sqrt((X/LX - x_off)**2 + (Y/LY - y_off)**2)
angle = np.angle(1.0j*(X/LX - x_off) + (Y/LY - y_off))
angle = np.where(angle < 0.0, angle + 2.0*np.pi, angle)
s = size*(np.abs(np.sin(angle)) + 2.0*np.exp(-1.2*np.abs(angle-np.pi)))
return height*(np.tanh(edge_sharpness*(r - s))/2.0 + 0.5)
nx, ny = 20.0, 20.0
sigma_x, sigma_y = 0.07, 0.07
r0x, r0y = 0.35, 0.5
# psi0 = normalize(np.exp(-0.5*((X/LX - r0x)/sigma_x)**2
# -0.5*((Y/LY - r0y)/sigma_y)**2
# )*np.exp(2.0j*np.pi*(nx*X/LX + ny*Y/LY)))
psi0 = normalize(np.ones([NX, NY])*np.exp(2.0j*np.pi*np.random.rand(NX, NY)))
animate(wave_function=psi0, x=X, y=Y,
potential=make_heart_potential(0.5, 1.8, 8.0, 0.5, 0.75),
steps_per_frame=1,
normalize_after_each_step=True,
step_function=step, t=0.0,
# dt=1.0,
# dt=10.0*(1.0 - 0.2j)
dt_func=lambda t: 10.0*(1.0 - 0.2j),
# dt_func=lambda t: 5.0 if np.abs(t) < NL_TIME
# else 5.0*np.exp(-1.0j*np.pi*0.1),
)
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
data = {}
def init_func():
x, y = data['x'], data['y']
psi = np.flip(data['wave_function'], axis=0)
fig = data['fig']
ax = fig.add_subplot(1, 1, 1)
data['ax'] = ax
im = ax.imshow(np.angle(psi),
alpha=np.abs(psi)**2/np.amax(np.abs(psi)**2),
extent=(x[0, 1], x[0, -1], y[0, 0], y[-1, 0]),
interpolation='nearest',
cmap='hsv',
)
im2 = ax.imshow(np.flip(np.zeros([512, 512]), axis=0),
extent=(x[0, 1], x[0, -1], y[0, 0], y[-1, 0]),
interpolation='nearest', cmap='gray')
ax.set_xlabel('x')
ax.set_ylabel('y')
# ax.set_title('Wavefunction')
data['plots'] = im, im2
# return im, im2
def func(*args):
steps_per_frame = data['steps_per_frame']
step = data['step_function']
t = data['t']
# dt = data['dt']
dt_func=data['dt_func']
potential = data['potential']
im, im2 = data['plots']
should_normalize = False
if 'normalize_after_each_step' in data:
should_normalize = data['normalize_after_each_step']
for _ in range(steps_per_frame):
psi = data['wave_function']
data['wave_function'] = step(psi, t, dt_func(t), potential,
should_normalize=should_normalize)
data['t'] += dt_func(t)
psi_view = np.flip(data['wave_function'], axis=0)
im.set_data(np.angle(psi_view))
abs_wavefunc2 = np.abs(psi_view)**2
alpha_map = 5.0*abs_wavefunc2/np.amax(abs_wavefunc2)
im.set_alpha(np.where(alpha_map > 1.0, 1.0, alpha_map))
data['frames'] += 1
if data['frames'] % 60 == 0:
print('frames: ', data['frames'], '\n', 'time: ', data['t'])
return im2, im,
def animate(**kw):
data['frames'] = 0
fig = plt.figure()
data['fig'] = fig
for k in kw:
data[k] = kw[k]
init_func()
data['animation']= animation.FuncAnimation(fig, func,
frames=60,
blit=True,
interval=1000.0/60.0,
)
# plt.show()
data['animation'].save('animation.mp4', writer='ffmpeg', fps=30, bitrate=1800)
# plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment