Skip to content

Instantly share code, notes, and snippets.

@lkluft
Created May 10, 2020 15:02
Show Gist options
  • Save lkluft/24c007fc47ad67c5f9ece93f58e62066 to your computer and use it in GitHub Desktop.
Save lkluft/24c007fc47ad67c5f9ece93f58e62066 to your computer and use it in GitHub Desktop.
Solving an n-body problem numerically using Python
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Classes that provide the general infracstructure"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class Body:\n",
" def __init__(self, xy=(0.0, 0.0), mass=1.0e10, velocity=(0.0, 0.0), static=False):\n",
" self.xy = np.array(xy, dtype=float)\n",
" self.velocity = np.array(velocity, dtype=float)\n",
" self.acceleration = np.array([0.0, 0.0], dtype=float)\n",
" \n",
" self.mass = 2 * mass if static else mass # increase mass of static bodies\n",
" self.is_static = static\n",
" \n",
" def get_vector(self, other):\n",
" \"\"\"Calculate the 2D vectort between two bodies.\"\"\"\n",
" v = other.xy - self.xy\n",
" \n",
" return v\n",
" \n",
" def accelerate(self, others, timestep):\n",
" \"\"\"Calculate the acceleration due to a given list of other bodies.\"\"\"\n",
" # Return early for static bodies to save computation time.\n",
" if self.is_static:\n",
" return\n",
" \n",
" # Accumulate the attraction by every other body.\n",
" G = 6.67430e-11 # gravitational constant\n",
"\n",
" acceleration = np.array([0.0, 0.0])\n",
" for other in others:\n",
" v = self.get_vector(other)\n",
" # Newton's law of universal gravitation\n",
" # F = G * m1 * m2 / r**2\n",
" acceleration += G * self.mass * other.mass / np.linalg.norm(v)**2 * v\n",
" \n",
" # Adjust the current velocity based on acceleration and timestep.\n",
" self.velocity += acceleration / self.mass * timestep\n",
" \n",
" def move(self, timestep):\n",
" \"\"\"Move the body according to its velocity.\"\"\"\n",
" # Do not move static bodies.\n",
" if self.is_static:\n",
" return\n",
" \n",
" self.xy += self.velocity * timestep\n",
" \n",
" def enforce_boundaries(self, width=9, height=16):\n",
" \"\"\"Billiard table like boundary conditions.\"\"\"\n",
" if np.abs(self.xy[0]) >= width / 2:\n",
" self.velocity[0] *= -1\n",
" \n",
" if np.abs(self.xy[1]) >= height / 2:\n",
" self.velocity[1] *= -1\n",
" \n",
" self.xy[0] = np.clip(self.xy[0], -width / 2, width / 2)\n",
" self.xy[1] = np.clip(self.xy[1], -height / 2, height / 2)\n",
" \n",
"\n",
"class Space:\n",
" def __init__(self, bodies, timestep=0.02):\n",
" self.bodies = bodies\n",
" self.timestep = timestep\n",
" \n",
" def move(self):\n",
" for body in self.bodies:\n",
" body.accelerate(\n",
" [b for b in self.bodies if b is not body],\n",
" timestep=self.timestep,\n",
" )\n",
" \n",
" for body in self.bodies:\n",
" body.move(self.timestep)\n",
" body.enforce_boundaries()\n",
" \n",
" def plot(self, ax=None):\n",
" if ax is None:\n",
" ax = plt.gca()\n",
" \n",
" for body in self.bodies:\n",
" if body.is_static:\n",
" ax.scatter(*body.xy, s=6.0**2 * body.mass / 1e10,\n",
" c=\"none\", marker=\"h\", edgecolors=\"grey\", zorder=-1)\n",
" else:\n",
" ax.scatter(*body.xy, s=6.0**2 * body.mass / 1e10)\n",
" \n",
" def run(self, iterations=600):\n",
" \"\"\"Run the model iteratively.\"\"\"\n",
" for n in range(iterations):\n",
" # Create a new figure for each iteration.\n",
" fig, ax = plt.subplots(figsize=(5.4, 9.6))\n",
" ax.set_position([0, 0, 1, 1])\n",
" ax.set_aspect(\"equal\")\n",
" ax.set_xlim(-4.5, 4.5)\n",
" ax.set_ylim(-8, 8)\n",
" ax.axis(\"off\")\n",
"\n",
" # Move every body in the space and plot the new conestllation.\n",
" space.move()\n",
" space.plot()\n",
" \n",
" # Create a PNG for every iteration. I merge them into a movie offline.\n",
" # TODO: Use matplotlib's animation mechanism.\n",
" fig.savefig(f\"plots/{n:04d}.png\", dpi=100)\n",
" plt.close(fig)\n",
"\n",
" \n",
"def random_from_range(rmin=0, rmax=1, scale='lin'):\n",
" \"\"\"Draw a random sample from ginve range range.\"\"\"\n",
" if scale == 'lin':\n",
" return (rmax - rmin) * np.random.rand() + rmin\n",
" elif scale == 'log':\n",
" rmin, rmax = np.log(rmin), np.log(rmax)\n",
" return np.exp((rmax - rmin) * np.random.rand() + rmin)\n",
" \n",
" \n",
"def random_state():\n",
" \"\"\"Generate a random initial state for bodies.\"\"\"\n",
" return dict(\n",
" xy=(random_from_range(-4, 4), random_from_range(-8, 8)),\n",
" mass=random_from_range(0.5e10, 3e10, \"log\"),\n",
" # Uncomment to make 40% of the bodies static.\n",
" # static=np.random.choice([True, False,], p=[0.4, 0.6]),\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Three-body problem"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"triangle = [\n",
" Body((-0.5, 0.), velocity=(0.5, 0.5)),\n",
" Body((0., 0.5), velocity=(0, -0.5)),\n",
" Body((0.5, 0.), velocity=(-0.5, -0.5)),\n",
"]\n",
"\n",
"space = Space(triangle)\n",
"space.run()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Random n-body problem"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"random = [Body(**random_state()) for _ in range(8)]\n",
"\n",
"space = Space(random)\n",
"space.run()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Adding static features"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# We need to randomize the grid slightly to ensure \"uneven\" forces.\n",
"static_field = [\n",
" Body((x + 0.1*np.random.randn(), y + 0.1*np.random.randn()), static=True)\n",
" for x, y in itertools.product(np.linspace(-3, 3, 5), np.linspace(-6, 6, 7))\n",
"]\n",
"\n",
"Space([Body(**random_state()) for _ in range(8)])\n",
"\n",
"space = Space([*random, *static_field])\n",
"space.run()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Concatenating the PNGs using ffmpeg\n",
"```bash\n",
"ffmpeg -framerate 60 \\\n",
" -pattern_type glob \\\n",
" -i plots/*.png \\\n",
" -s:v 1080x1920 \\\n",
" -c:v libx264 \\\n",
" -profile:v high \\\n",
" -crf 20 \\\n",
" -pix_fmt yuv420p \\\n",
" -y plots/test.m4v\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment