Skip to content

Instantly share code, notes, and snippets.

@nwunderly
Last active January 31, 2021 01:04
Show Gist options
  • Save nwunderly/2b22072164bd078b6ade9f70171bfed3 to your computer and use it in GitHub Desktop.
Save nwunderly/2b22072164bd078b6ade9f70171bfed3 to your computer and use it in GitHub Desktop.
python two-body problem solver
"""
Two-body problem integrator.
"""
import argparse
import numpy as np
from scipy.integrate import solve_ivp
from matplotlib import pyplot as plt
class Earth:
R = 6378
MU = 3.986004418e5
TOLERANCE = 10**-13
def parse():
def float_tuple(s):
# x, y, z = s.split(',')
x, y, z = eval(s)
# return float(x), float(y), float(z)
return x, y, z
parser = argparse.ArgumentParser()
parser.add_argument("--r0", type=float_tuple, required=True)
parser.add_argument("--v0", type=float_tuple, required=True)
parser.add_argument("--tmax", type=int, required=False, default=30000)
return parser.parse_args()
def eqm_2bp(r, mu=Earth.MU, array=False):
"""the actual equation of motion.
takes in r, returns r_dot_dot, both vectors.
INPUT
r -> (x, y, z)
OUTPUT
r_dot_dot -> (ax, ay, az)
= f(r)
"""
if array:
return np.array([eqm_2bp(i) for i in r])
else:
n = norm(r)
return -mu / n ** 3 * r
def system_2bp(t, y):
"""the system of first-order equations representing the 2bp.
takes in y=(r, r_dot), returns y_dot=(r_dot, r_dot_dot)
INPUT
y -> (r..., r_dot...)
r -> (x, y, z)
r_dot -> (vx, vy, vz)
OUTPUT
y_dot -> (r_dot..., r_dot_dot...)
r_dot -> (vx, vy, vz)
= y[1]
r_dot_dot -> (ax, ay, az)
= f(y[0])
"""
r = y[:3]
r_dot = y[3:]
r_dot_dot = eqm_2bp(r)
y_dot = np.concatenate((r_dot, r_dot_dot))
return y_dot
def integrate_2bp(x0, y0, z0, vx0, vy0, vz0, t_max):
"""Integrates 2bp based on initial conditions.
r0: (x0, y0, z0)
r_dot0: (vx0, vy0, vz0)
Returns an array of shape (n, n_points)
"""
t_span = (0, t_max)
y0 = np.array([x0, y0, z0, vx0, vy0, vz0])
soln = solve_ivp(system_2bp, t_span, y0, atol=TOLERANCE, rtol=TOLERANCE, first_step=1)
y = soln.y
t = soln.t
r = y[:3]
r_dot = y[3:]
return t, r, r_dot
def norm(vec, array=False):
if array:
return np.array([norm(i) for i in vec])
else:
return np.linalg.norm(vec)
def plot_magnitudes(t, **data):
fig = plt.figure()
i = 0
n = len(data)
for title, vecs in data.items():
i += 1
ax = fig.add_subplot(1, n, i, title=title)
# [[x0, ...], [y0, ...], [z0, ...]] -> [[x0, y0, z0], ...]
mags = norm(vecs.transpose(), array=True)
ax.plot(t, mags)
def plot_3d(x, y, z, title):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d', title=title)
ax.plot(x, y, z)
ax.plot(0, 0, 0, 'ro')
def main():
args = parse()
x0, y0, z0 = args.r0
vx0, vy0, vz0 = args.v0
t_max = args.tmax
t, r, r_dot = integrate_2bp(x0, y0, z0, vx0, vy0, vz0, t_max)
data = {'Position (km)': r, 'Velocity (km/s)': r_dot, 'Acceleration (km/s^2)': eqm_2bp(r, array=True)}
plot_magnitudes(t, **data)
plot_3d(*r, 'Orbit path')
print(
f"Final state:\n"
f" x = {r[0][-1]}\n"
f" y = {r[1][-1]}\n"
f" z = {r[2][-1]}\n"
f" vx = {r_dot[0][-1]}\n"
f" vy = {r_dot[1][-1]}\n"
f" vz = {r_dot[2][-1]}"
)
plt.show()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment