Skip to content

Instantly share code, notes, and snippets.

@neutrinoceros
Created June 1, 2023 13:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save neutrinoceros/96b8c0f4df9ec56347bc0baf32cdd7ac to your computer and use it in GitHub Desktop.
Save neutrinoceros/96b8c0f4df9ec56347bc0baf32cdd7ac to your computer and use it in GitHub Desktop.
shearing box point advection benchmark
"""
Implementation examples and tests for the shearing-box boundary advection scheme of point-like particles.
"""
import numpy as np
import pytest
# half box size (best avoid integer, or "round-looking" values)
dy2 = 0.55
# box limits
ybeg, yend = -dy2, dy2
# box width
wy = yend - ybeg
def mod(a, b):
"""A helper function to emulate Python's modulus operator `%` on floats"""
return a - (b * np.floor(a / b))
def adv_python(y, shift):
"""A simple, pure Python implementation"""
ys = (y + shift - ybeg) / wy
m = ys % 1
return ybeg + (m * wy)
def adv_mod(y, shift):
"""Same as adv_python, but with `%` replaced by the `mod` helper function"""
ys = (y + shift - ybeg) / wy
m = mod(ys, 1)
return ybeg + (m * wy)
def adv_mod_inlined(y, shift):
"""Same as adv_mod, but with `mod` inlined, and simplifications."""
ys = (y + shift - ybeg) / wy
m = ys - np.floor(ys)
return ybeg + (m * wy)
def adv_mod_inlined_2(y, shift):
"""Same as adv_mod_inlined, with additional simplifications."""
ys = (y + shift - ybeg) / wy
return ybeg + (ys - np.floor(ys)) * wy
# select function to test
adv = adv_mod_inlined_2
# reusable initial positions covering the whole box
yinits = np.linspace(ybeg + 1e-8, yend - 1e-8, 50)
def test_no_nullop():
res = adv(yinits, 0.1)
assert np.abs(res - yinits).min() > 0
def test_no_shift():
res = adv(yinits, 0)
np.testing.assert_allclose(res, yinits, atol=1e-15)
def test_periodicity():
res1 = adv(yinits, 0.1)
res2 = adv(yinits, 0.1 + wy)
np.testing.assert_allclose(res2, res1)
@pytest.mark.parametrize("shift", np.linspace(-dy2/2, dy2/2, 50))
def test_small_shift(shift):
eps = 1e-14
yinits = np.linspace(-dy2/2 + eps, +dy2/2 - eps, 100)
res = adv(yinits, shift)
np.testing.assert_allclose(res, yinits+shift)
@pytest.mark.parametrize("shift", np.linspace(-10, 10, 500))
def test_overshoot(shift):
res = adv(yinits, shift)
assert np.all(res <= yend)
assert np.all(res >= ybeg)
@pytest.mark.parametrize("shift", np.linspace(0, 10, 250))
def test_roundtrip(shift):
np.testing.assert_allclose(adv(adv(yinits, -shift), shift), yinits)
np.testing.assert_allclose(adv(adv(yinits, shift), -shift), yinits)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment