Last active
December 11, 2018 16:05
-
-
Save pckujawa/4737136 to your computer and use it in GitHub Desktop.
Runge-Kutta, Euler-Richardson midpoint, and Euler methods compared for a simple harmonic oscillator simulation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#------------------------------------------------------------------------------- | |
# Purpose: | |
# Author: Pat | |
#------------------------------------------------------------------------------- | |
#!/usr/bin/env python | |
from __future__ import division | |
import math | |
import numpy as np | |
import pylab as pl | |
from pprint import pprint, pformat | |
def euler(t, x, f, dt): | |
return x + f(t, x) * dt | |
def euler_richardson_midpoint(t, x, f, dt): | |
halfdt = 0.5 * dt | |
xmid = x + halfdt * f(t, x) | |
xnext = x + f(t + halfdt, xmid)*dt | |
return xnext | |
erm = euler_richardson_midpoint | |
def runge_kutta(t, x, f, dt): | |
k1 = f(t , x )*dt | |
k2 = f(t + 0.5*dt, x + 0.5*k1)*dt | |
k3 = f(t + 0.5*dt, x + 0.5*k2)*dt | |
k4 = f(t + dt , x + k3 )*dt | |
xnext = x + 1/6.0*(k1 + 2*k2 + 2*k3 + k4) | |
return xnext | |
def predictor_corrector(t, x_prev, x, f, dt): | |
# Take first step with RK to init | |
if x_prev is None: | |
x_prev = runge_kutta(t, x, f, dt) | |
x_p = x_prev + 2*f(t, x) * dt # _predictor | |
xnext = x + 0.5 * (f(t, x_p) + f(t, x)) * dt | |
return xnext | |
k = 1.0 | |
m = 1.0 | |
class Sho(object): | |
def step(self, t, state): | |
v_prev = state[1] | |
x_prev = state[0] | |
return np.array([v_prev, -k*x_prev/m]) | |
def run(self, integration_fn, dt, predictor_corrector_used = False): | |
x_0 = 1 # m | |
v_0 = 0 # m/s | |
# Timing | |
# Run for five complete cycles | |
t_end = 5 * (2 * math.pi * k/m) | |
# Let us call the state of the system 'state' | |
states = [np.array([x_0, v_0])] | |
elapsed_time = 0 | |
while elapsed_time < t_end: | |
if predictor_corrector_used: | |
if len(states) == 1: | |
prev_state = None | |
else: | |
prev_state = states[-2] | |
next_state = integration_fn(elapsed_time, prev_state, states[-1], self.step, dt) | |
else: | |
next_state = integration_fn(elapsed_time, states[-1], self.step, dt) | |
states.append(next_state) | |
elapsed_time += dt | |
# The trouble with the above, is that what you really want is all the positions | |
# Turn the list to an array: | |
states = np.array(states) | |
# Now array slices get all positions and velocities: | |
positions = states[:,0] | |
velocities = states[:,1] | |
sim_times = pl.frange(0, elapsed_time, dt) | |
# Find the energy | |
energies = 0.5*k*positions**2 + 0.5*m*velocities**2 | |
an_energy = 0.5 # at t=0, v=0 so E = 0.5*k*x^2=0.5 J | |
starting_energy = energies[0] | |
ending_energy = energies[-1] | |
energy_change_percentage = 100.0 * abs(starting_energy - ending_energy) / starting_energy | |
# store vars in the object for easier inspection | |
l = locals().copy() | |
del l['self'] | |
for key,value in l.iteritems(): | |
setattr(self, key, value) | |
return energy_change_percentage | |
def run_until(self, integration_fn, energy_change_tolerance_percentage, dt_start=1, dt_narrowing_fn=None, dt_expanding_fn=None, predictor_corrector_used=False): | |
dt = dt_start | |
narrower = dt_narrowing_fn or self._narrow_dt | |
expander = dt_expanding_fn or self._expand_dt | |
cnt_runs = 1 | |
while cnt_runs < 12: | |
diff = self.run(integration_fn, dt, predictor_corrector_used) - energy_change_tolerance_percentage | |
if abs(diff) < 1e-3: | |
break | |
elif diff > 0: | |
dt = narrower(dt) | |
else: | |
dt = expander(dt) | |
cnt_runs += 1 | |
self.num_times_dt_modified = cnt_runs | |
return self.dt # set in self.run, so should be accurate | |
def _narrow_dt(self, dt): | |
# Default narrower - go by halves | |
return 0.5 * dt | |
def _expand_dt(self, dt): | |
# Default expander - go by 3/2 | |
return 1.5 * dt | |
def doplot(): | |
global int_fn_name, problem | |
## Analytical | |
an_times = np.linspace(0, problem.elapsed_time, 100) | |
positions_an = problem.x_0 * np.cos(math.sqrt(k/m) * an_times) | |
## sim_value_at_approx_end = problem.positions[ problem.sim_times >= problem.t_end][0] # take first | |
## an_value_at_end = positions_an[-1] | |
## percent_diff = 100.0 * abs(sim_value_at_approx_end - an_value_at_end) / an_value_at_end | |
percent_diff = problem.energy_change_percentage | |
## Plotting | |
pl.title(r'Simple Harmonic Osc. (%s method, $dt=%.5f$, $diff=%.3f\%%$)' % (int_fn_name, problem.dt, percent_diff)) | |
sim_pt_size = 3 | |
pl.scatter(problem.sim_times, problem.positions, sim_pt_size, color='black', label='simulation') | |
pl.plot(problem.sim_times, problem.positions, color='black', alpha=0.5) | |
pl.plot(an_times, positions_an, color='red', label='analytical') | |
pl.legend(loc='lower left', frameon=False) #TODO make smaller/alpha | |
# Move axes to be on the plot | |
ax = pl.gca() # get current axes | |
# Remove right and top axes | |
ax.spines['right'].set_color('none') | |
ax.spines['top'].set_color('none') | |
ax.spines['bottom'].set_position(('data',0)) | |
ax.xaxis.set_ticks_position('bottom') | |
ax.yaxis.set_major_locator(pl.MaxNLocator(nbins=3)) | |
pl.xlabel('Time (s)', bbox=dict(facecolor='white', alpha=0.5)) # needs to be after messing with axes or won't show up | |
pl.ylabel('Position (m)') | |
pl.xlim(-1, problem.sim_times.max()+1) | |
def do_plots(): | |
global int_fn_name, problem | |
plotNum = 1 | |
pl.figure(figsize=(12, 9)) | |
for int_fn, int_fn_name, dt, pc in [ | |
[euler, 'Euler', 3e-5, False], # 1e-4 is about as small as you want or Euler takes forever | |
[predictor_corrector, 'Pred-Corr', 0.0070, True], | |
[euler_richardson_midpoint, 'ER midpoint', 0.0234375, False], | |
[runge_kutta, 'Runge-Kutta', 0.1875, False]]: | |
pl.subplot(4, 1, plotNum) | |
problem.run(int_fn, dt, predictor_corrector_used=pc) | |
doplot() | |
plotNum += 1 | |
## pl.show() | |
pl.savefig("pat_improved_ode_sim.png", dpi=72) | |
def find_tolerances(): | |
'''Usage: tols = find_tolerances(); tols.next() three times.''' | |
global problem | |
for fn, dt_start, pc in [ | |
## [euler, 0.1, False], # takes a while | |
[runge_kutta, 1, False], | |
[euler_richardson_midpoint, 0.5, False], | |
[predictor_corrector, 0.1, True]]: | |
best_dt = problem.run_until(fn, 0.01, dt_start, predictor_corrector_used=pc) | |
yield fn, best_dt | |
problem = Sho() | |
for tol in find_tolerances(): | |
print tol, 'after tweaking dt', problem.num_times_dt_modified, 'times' | |
print 'energy change %', problem.energy_change_percentage | |
do_plots() | |
##int_fn_name = 'PC' | |
##print problem.run(predictor_corrector, 0.1, True) | |
##doplot(); pl.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment