-
-
Save danielkelshaw/8911674fbbbbec1676191874eaacc3e3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
============================= test session starts ============================== | |
platform darwin -- Python 3.10.0, pytest-7.4.0, pluggy-1.2.0 | |
rootdir: /Users/djk21/coding/jax-cfd | |
plugins: xdist-3.3.1 | |
created: 16/16 workers | |
16 workers [458 items] | |
........................................................................ [ 15%] | |
........................................F............................... [ 31%] | |
.....................F.................................................. [ 46%] | |
....F...................F....................F.......................F.. [ 62%] | |
.........F......F.......F............................................... [ 78%] | |
........................................................................ [ 94%] | |
....F...F................. [100%] | |
==================================== ERRORS ==================================== | |
________________ ERROR collecting jax_cfd/ml/equations_test.py _________________ | |
ImportError while importing test module '/Users/djk21/coding/jax-cfd/jax_cfd/ml/equations_test.py'. | |
Hint: make sure your test modules/packages have valid Python names. | |
Traceback: | |
../../.pyenv/versions/3.10.0/lib/python3.10/importlib/__init__.py:126: in import_module | |
return _bootstrap._gcd_import(name[level:], package, level) | |
jax_cfd/ml/__init__.py:17: in <module> | |
import jax_cfd.ml.advections | |
jax_cfd/ml/advections.py:8: in <module> | |
from jax_cfd.ml import interpolations | |
jax_cfd/ml/interpolations.py:11: in <module> | |
from jax_cfd.ml import layers | |
jax_cfd/ml/layers.py:8: in <module> | |
import haiku as hk | |
venv/lib/python3.10/site-packages/haiku/__init__.py:20: in <module> | |
from haiku import experimental | |
venv/lib/python3.10/site-packages/haiku/experimental/__init__.py:21: in <module> | |
from haiku._src.base import current_name | |
venv/lib/python3.10/site-packages/haiku/_src/base.py:27: in <module> | |
from haiku._src.typing import ( # pylint: disable=g-multiple-import | |
venv/lib/python3.10/site-packages/haiku/_src/typing.py:22: in <module> | |
from typing_extensions import Protocol, runtime_checkable # pylint: disable=multiple-statements,g-multiple-import | |
E ModuleNotFoundError: No module named 'typing_extensions' | |
__________________ ERROR collecting jax_cfd/ml/layers_test.py __________________ | |
ImportError while importing test module '/Users/djk21/coding/jax-cfd/jax_cfd/ml/layers_test.py'. | |
Hint: make sure your test modules/packages have valid Python names. | |
Traceback: | |
../../.pyenv/versions/3.10.0/lib/python3.10/importlib/__init__.py:126: in import_module | |
return _bootstrap._gcd_import(name[level:], package, level) | |
jax_cfd/ml/__init__.py:17: in <module> | |
import jax_cfd.ml.advections | |
jax_cfd/ml/advections.py:8: in <module> | |
from jax_cfd.ml import interpolations | |
jax_cfd/ml/interpolations.py:11: in <module> | |
from jax_cfd.ml import layers | |
jax_cfd/ml/layers.py:8: in <module> | |
import haiku as hk | |
venv/lib/python3.10/site-packages/haiku/__init__.py:20: in <module> | |
from haiku import experimental | |
venv/lib/python3.10/site-packages/haiku/experimental/__init__.py:21: in <module> | |
from haiku._src.base import current_name | |
venv/lib/python3.10/site-packages/haiku/_src/base.py:27: in <module> | |
from haiku._src.typing import ( # pylint: disable=g-multiple-import | |
venv/lib/python3.10/site-packages/haiku/_src/typing.py:22: in <module> | |
from typing_extensions import Protocol, runtime_checkable # pylint: disable=multiple-statements,g-multiple-import | |
E ModuleNotFoundError: No module named 'typing_extensions' | |
_______________ ERROR collecting jax_cfd/ml/layers_util_test.py ________________ | |
ImportError while importing test module '/Users/djk21/coding/jax-cfd/jax_cfd/ml/layers_util_test.py'. | |
Hint: make sure your test modules/packages have valid Python names. | |
Traceback: | |
../../.pyenv/versions/3.10.0/lib/python3.10/importlib/__init__.py:126: in import_module | |
return _bootstrap._gcd_import(name[level:], package, level) | |
jax_cfd/ml/__init__.py:17: in <module> | |
import jax_cfd.ml.advections | |
jax_cfd/ml/advections.py:8: in <module> | |
from jax_cfd.ml import interpolations | |
jax_cfd/ml/interpolations.py:11: in <module> | |
from jax_cfd.ml import layers | |
jax_cfd/ml/layers.py:8: in <module> | |
import haiku as hk | |
venv/lib/python3.10/site-packages/haiku/__init__.py:20: in <module> | |
from haiku import experimental | |
venv/lib/python3.10/site-packages/haiku/experimental/__init__.py:21: in <module> | |
from haiku._src.base import current_name | |
venv/lib/python3.10/site-packages/haiku/_src/base.py:27: in <module> | |
from haiku._src.typing import ( # pylint: disable=g-multiple-import | |
venv/lib/python3.10/site-packages/haiku/_src/typing.py:22: in <module> | |
from typing_extensions import Protocol, runtime_checkable # pylint: disable=multiple-statements,g-multiple-import | |
E ModuleNotFoundError: No module named 'typing_extensions' | |
__________________ ERROR collecting jax_cfd/ml/towers_test.py __________________ | |
ImportError while importing test module '/Users/djk21/coding/jax-cfd/jax_cfd/ml/towers_test.py'. | |
Hint: make sure your test modules/packages have valid Python names. | |
Traceback: | |
../../.pyenv/versions/3.10.0/lib/python3.10/importlib/__init__.py:126: in import_module | |
return _bootstrap._gcd_import(name[level:], package, level) | |
jax_cfd/ml/__init__.py:17: in <module> | |
import jax_cfd.ml.advections | |
jax_cfd/ml/advections.py:8: in <module> | |
from jax_cfd.ml import interpolations | |
jax_cfd/ml/interpolations.py:11: in <module> | |
from jax_cfd.ml import layers | |
jax_cfd/ml/layers.py:8: in <module> | |
import haiku as hk | |
venv/lib/python3.10/site-packages/haiku/__init__.py:20: in <module> | |
from haiku import experimental | |
venv/lib/python3.10/site-packages/haiku/experimental/__init__.py:21: in <module> | |
from haiku._src.base import current_name | |
venv/lib/python3.10/site-packages/haiku/_src/base.py:27: in <module> | |
from haiku._src.typing import ( # pylint: disable=g-multiple-import | |
venv/lib/python3.10/site-packages/haiku/_src/typing.py:22: in <module> | |
from typing_extensions import Protocol, runtime_checkable # pylint: disable=multiple-statements,g-multiple-import | |
E ModuleNotFoundError: No module named 'typing_extensions' | |
=================================== FAILURES =================================== | |
________ AdvectionTest.test_mass_conservation_dirichlet_dichlet_advect _________ | |
[gw8] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.collocated.advection_test.AdvectionTest testMethod=test_mass_conservation_dirichlet_dichlet_advect> | |
shape = (101,), method = <function _euler_step.<locals>.step at 0x138bd57e0> | |
@parameterized.named_parameters( | |
dict( | |
testcase_name='dichlet_advect', | |
shape=(101,), | |
method=_euler_step(advection.advect_linear)),) | |
def test_mass_conservation_dirichlet(self, shape, method): | |
cfl_number = 0.1 | |
dt = cfl_number / shape[0] | |
num_steps = 1000 | |
grid = grids.Grid(shape, domain=([-1., 1.],)) | |
bc = boundaries.dirichlet_boundary_conditions(grid.ndim) | |
c_bc = boundaries.dirichlet_boundary_conditions(grid.ndim, ((-1., 1.),)) | |
def u(grid): | |
x = grid.mesh((0.5,))[0] | |
return grids.GridArray(-jnp.sin(jnp.pi * x), (0.5,), grid) | |
def c0(grid): | |
x = grid.mesh((0.5,))[0] | |
return grids.GridArray(x, (0.5,), grid) | |
v = (bc.impose_bc(u(grid)),) | |
c = c_bc.impose_bc(c0(grid)) | |
ct = c | |
advect = jax.jit(functools.partial(method, v=v, dt=dt)) | |
initial_mass = np.sum(c.data) | |
for _ in range(num_steps): | |
ct = advect(ct) | |
current_total_mass = np.sum(ct.data) | |
> self.assertAllClose(current_total_mass, initial_mass, atol=1e-6) | |
jax_cfd/collocated/advection_test.py:107: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jax_cfd/base/test_util.py:87: in assertAllClose | |
np.testing.assert_allclose(expected, actual, **kwargs) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x14c604160>, array(-9.536743e-07, dtype=float32), array(-2.861023e-06, dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=1e-06', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=1e-07, atol=1e-06 | |
E | |
E Mismatched elements: 1 / 1 (100%) | |
E Max absolute difference: 1.9073486e-06 | |
E Max relative difference: 0.6666667 | |
E x: array(-9.536743e-07, dtype=float32) | |
E y: array(-2.861023e-06, dtype=float32) | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
___________ AdvectionTest.test_neumann_bc_one_step_linear_1d_neumann ___________ | |
[gw8] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.collocated.advection_test.AdvectionTest testMethod=test_neumann_bc_one_step_linear_1d_neumann> | |
shape = (1000,), method = <function advect_linear at 0x138bd4c10> | |
@parameterized.named_parameters( | |
dict( | |
testcase_name='linear_1d_neumann', | |
shape=(1000,), | |
method=advection.advect_linear),) | |
def test_neumann_bc_one_step(self, shape, method): | |
grid = grids.Grid(shape, domain=([-1., 1.],)) | |
bc = boundaries.neumann_boundary_conditions(grid.ndim) | |
c_bc = boundaries.neumann_boundary_conditions(grid.ndim) | |
def u(grid): | |
x = grid.mesh((0.5,))[0] | |
return grids.GridArray(jnp.cos(jnp.pi * x), (0.5,), grid) | |
def c0(grid): | |
x = grid.mesh((0.5,))[0] | |
return grids.GridArray(jnp.cos(jnp.pi * x), (0.5,), grid) | |
def dcdt(grid): | |
x = grid.mesh((0.5,))[0] | |
return grids.GridArray(jnp.pi * jnp.sin(2 * jnp.pi * x), (0.5,), grid) | |
v = (bc.impose_bc(u(grid)),) | |
c = c_bc.impose_bc(c0(grid)) | |
advect = jax.jit(functools.partial(method, v=v)) | |
ct = advect(c) | |
> self.assertAllClose(ct, dcdt(grid), atol=1e-4) | |
jax_cfd/collocated/advection_test.py:137: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jax_cfd/base/test_util.py:87: in assertAllClose | |
np.testing.assert_allclose(expected, actual, **kwargs) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x14c6a51b0>, array([ 0.01972914, 0.05921721, 0.09869038, 0.13811885...0.25632194, -0.21695682, | |
-0.17755595, -0.13812703, -0.09867631, -0.0592115 , -0.01973734], | |
dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0.0001', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=1e-07, atol=0.0001 | |
E | |
E Mismatched elements: 135 / 1000 (13.5%) | |
E Max absolute difference: 0.00017357 | |
E Max relative difference: 0.00055049 | |
E x: array([ 0.019729, 0.059217, 0.09869 , 0.138119, 0.177547, 0.216946, | |
E 0.256315, 0.295639, 0.334918, 0.374153, 0.413313, 0.452399, | |
E 0.491425, 0.530392, 0.569254, 0.608042, 0.646725, 0.685304,... | |
E y: array([ 0.019739, 0.059214, 0.098679, 0.13813 , 0.177557, 0.216957, | |
E 0.256323, 0.29565 , 0.334929, 0.374154, 0.413322, 0.452422, | |
E 0.491453, 0.530406, 0.569274, 0.608054, 0.646736, 0.685317,... | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
______ TimeSteppingTest.test_implicit_solve_harmonic_oscillator_implicit _______ | |
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_implicit_solve_harmonic_oscillator_implicit> | |
implicit_terms = <function <lambda> at 0x15332e440> | |
implicit_solve = <function <lambda> at 0x15332e4d0> | |
initial_state = array([1., 1.]) | |
@parameterized.named_parameters(ALL_TEST_PROBLEMS) | |
def test_implicit_solve( | |
self, | |
explicit_terms, | |
implicit_terms, | |
implicit_solve, | |
dt, | |
inner_steps, | |
outer_steps, | |
initial_state, | |
closed_form, | |
tolerances, | |
): | |
"""Tests that time integration is accurate for a range of test cases.""" | |
del dt, explicit_terms, inner_steps, outer_steps, closed_form # unused | |
del tolerances # unused | |
# Verifies that `implicit_solve` solves (y - eta * F(y)) = x | |
# This does not test the integrator, but rather verifies that the test | |
# case is valid. | |
eta = 0.3 | |
solved_state = implicit_solve(initial_state, eta) | |
reconstructed_state = solved_state - eta * implicit_terms(solved_state) | |
> np.testing.assert_allclose(reconstructed_state, initial_state) | |
jax_cfd/spectral/time_stepping_test.py:159: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x1534de560>, array([0.9999999 , 0.99999994], dtype=float32), array([1., 1.])) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=1e-07, atol=0 | |
E | |
E Mismatched elements: 1 / 2 (50%) | |
E Max absolute difference: 1.1920929e-07 | |
E Max relative difference: 1.1920929e-07 | |
E x: array([1., 1.], dtype=float32) | |
E y: array([1., 1.]) | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
____________ TimeSteppingTest.test_integration_constant_derivative _____________ | |
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_constant_derivative> | |
explicit_terms = <function <lambda> at 0x15332da20> | |
implicit_terms = <function <lambda> at 0x15332dab0> | |
implicit_solve = <function <lambda> at 0x15332db40>, dt = 0.01, inner_steps = 10 | |
outer_steps = 5, initial_state = array([1., 1., 1.]) | |
closed_form = <function <lambda> at 0x15332dbd0> | |
tolerances = [1e-12, 1e-12, 1e-12, 1e-12, 1e-12] | |
@parameterized.named_parameters(ALL_TEST_PROBLEMS) | |
def test_integration( | |
self, | |
explicit_terms, | |
implicit_terms, | |
implicit_solve, | |
dt, | |
inner_steps, | |
outer_steps, | |
initial_state, | |
closed_form, | |
tolerances, | |
): | |
# Compute closed-form solution. | |
time = dt * inner_steps * (1 + np.arange(outer_steps)) | |
expected = jax.vmap(closed_form, in_axes=(None, 0))( | |
initial_state, time) | |
# Compute trajectory using time-stepper. | |
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS): | |
with self.subTest(time_stepper.__name__): | |
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve) | |
semi_implicit_step = time_stepper(equation, dt) | |
integrator = funcutils.trajectory( | |
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps) | |
_, actual = integrator(initial_state) | |
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0) | |
jax_cfd/spectral/time_stepping_test.py:187: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x1534df250>, array([[1.5, 1.5, 1.5], | |
[2. , 2. , 2. ], | |
[2...99986, 2.4999986], | |
[2.999998 , 2.999998 , 2.999998 ], | |
[3.4999976, 3.4999976, 3.4999976]], dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=1e-12', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=0, atol=1e-12 | |
E | |
E Mismatched elements: 15 / 15 (100%) | |
E Max absolute difference: 2.3841858e-06 | |
E Max relative difference: 6.811964e-07 | |
E x: array([[1.5, 1.5, 1.5], | |
E [2. , 2. , 2. ], | |
E [2.5, 2.5, 2.5],... | |
E y: array([[1.5 , 1.5 , 1.5 ], | |
E [1.999999, 1.999999, 1.999999], | |
E [2.499999, 2.499999, 2.499999],... | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
________ TimeSteppingTest.test_integration_harmonic_oscillator_explicit ________ | |
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_harmonic_oscillator_explicit> | |
explicit_terms = <function <lambda> at 0x15332e320> | |
implicit_terms = <function zeros_like at 0x1155ec5e0> | |
implicit_solve = <function <lambda> at 0x15332e3b0>, dt = 0.01, inner_steps = 20 | |
outer_steps = 5, initial_state = array([1., 1.]) | |
closed_form = <function harmonic_oscillator at 0x15332d6c0> | |
tolerances = [0.01, 3e-05, 6e-08, 5e-11, 6e-08] | |
@parameterized.named_parameters(ALL_TEST_PROBLEMS) | |
def test_integration( | |
self, | |
explicit_terms, | |
implicit_terms, | |
implicit_solve, | |
dt, | |
inner_steps, | |
outer_steps, | |
initial_state, | |
closed_form, | |
tolerances, | |
): | |
# Compute closed-form solution. | |
time = dt * inner_steps * (1 + np.arange(outer_steps)) | |
expected = jax.vmap(closed_form, in_axes=(None, 0))( | |
initial_state, time) | |
# Compute trajectory using time-stepper. | |
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS): | |
with self.subTest(time_stepper.__name__): | |
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve) | |
semi_implicit_step = time_stepper(equation, dt) | |
integrator = funcutils.trajectory( | |
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps) | |
_, actual = integrator(initial_state) | |
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0) | |
jax_cfd/spectral/time_stepping_test.py:187: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x1533a64d0>, array([[ 1.1787359 , 0.7813972 ], | |
[ 1.3104794 , ... [ 1.3899782 , 0.26069334], | |
[ 1.414063 , -0.02064916], | |
[ 1.3817736 , -0.30116856]], dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=6e-08', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=0, atol=6e-08 | |
E | |
E Mismatched elements: 9 / 10 (90%) | |
E Max absolute difference: 3.5762787e-07 | |
E Max relative difference: 1.2809022e-05 | |
E x: array([[ 1.178736, 0.781397], | |
E [ 1.310479, 0.531643], | |
E [ 1.389978, 0.260693],... | |
E y: array([[ 1.178736, 0.781397], | |
E [ 1.310479, 0.531643], | |
E [ 1.389978, 0.260693],... | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
________ TimeSteppingTest.test_integration_harmonic_oscillator_implicit ________ | |
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_harmonic_oscillator_implicit> | |
explicit_terms = <function zeros_like at 0x1155ec5e0> | |
implicit_terms = <function <lambda> at 0x15332e440> | |
implicit_solve = <function <lambda> at 0x15332e4d0>, dt = 0.01, inner_steps = 20 | |
outer_steps = 5, initial_state = array([1., 1.]) | |
closed_form = <function harmonic_oscillator at 0x15332d6c0> | |
tolerances = [0.01, 2e-05, 2e-06, 1e-06, 6e-06] | |
@parameterized.named_parameters(ALL_TEST_PROBLEMS) | |
def test_integration( | |
self, | |
explicit_terms, | |
implicit_terms, | |
implicit_solve, | |
dt, | |
inner_steps, | |
outer_steps, | |
initial_state, | |
closed_form, | |
tolerances, | |
): | |
# Compute closed-form solution. | |
time = dt * inner_steps * (1 + np.arange(outer_steps)) | |
expected = jax.vmap(closed_form, in_axes=(None, 0))( | |
initial_state, time) | |
# Compute trajectory using time-stepper. | |
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS): | |
with self.subTest(time_stepper.__name__): | |
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve) | |
semi_implicit_step = time_stepper(equation, dt) | |
integrator = funcutils.trajectory( | |
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps) | |
_, actual = integrator(initial_state) | |
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0) | |
jax_cfd/spectral/time_stepping_test.py:187: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x1535d6e60>, array([[ 1.1787359 , 0.7813972 ], | |
[ 1.3104794 , ... [ 1.3899857 , 0.26069555], | |
[ 1.4140741 , -0.02064833], | |
[ 1.3817878 , -0.30117026]], dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=2e-06', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=0, atol=2e-06 | |
E | |
E Mismatched elements: 7 / 10 (70%) | |
E Max absolute difference: 1.4543533e-05 | |
E Max relative difference: 5.331295e-05 | |
E x: array([[ 1.178736, 0.781397], | |
E [ 1.310479, 0.531643], | |
E [ 1.389978, 0.260693],... | |
E y: array([[ 1.178739, 0.781399], | |
E [ 1.310485, 0.531645], | |
E [ 1.389986, 0.260696],... | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
_________ TimeSteppingTest.test_integration_linear_derivative_explicit _________ | |
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_linear_derivative_explicit> | |
explicit_terms = <function <lambda> at 0x15332dc60> | |
implicit_terms = <function <lambda> at 0x15332dcf0> | |
implicit_solve = <function <lambda> at 0x15332dd80>, dt = 0.01, inner_steps = 20 | |
outer_steps = 5, initial_state = array([0., 1., 2.]) | |
closed_form = <function <lambda> at 0x15332de10> | |
tolerances = [0.05, 0.0001, 1e-06, 1e-09, 1e-06] | |
@parameterized.named_parameters(ALL_TEST_PROBLEMS) | |
def test_integration( | |
self, | |
explicit_terms, | |
implicit_terms, | |
implicit_solve, | |
dt, | |
inner_steps, | |
outer_steps, | |
initial_state, | |
closed_form, | |
tolerances, | |
): | |
# Compute closed-form solution. | |
time = dt * inner_steps * (1 + np.arange(outer_steps)) | |
expected = jax.vmap(closed_form, in_axes=(None, 0))( | |
initial_state, time) | |
# Compute trajectory using time-stepper. | |
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS): | |
with self.subTest(time_stepper.__name__): | |
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve) | |
semi_implicit_step = time_stepper(equation, dt) | |
integrator = funcutils.trajectory( | |
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps) | |
_, actual = integrator(initial_state) | |
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0) | |
jax_cfd/spectral/time_stepping_test.py:187: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x1534df010>, array([[0. , 1.2214028, 2.4428055], | |
[0. ...21195, 3.644239 ], | |
[0. , 2.225541 , 4.451082 ], | |
[0. , 2.7182808, 5.4365616]], dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=1e-06', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=0, atol=1e-06 | |
E | |
E Mismatched elements: 2 / 15 (13.3%) | |
E Max absolute difference: 1.9073486e-06 | |
E Max relative difference: 3.508373e-07 | |
E x: array([[0. , 1.221403, 2.442806], | |
E [0. , 1.491825, 2.983649], | |
E [0. , 1.822119, 3.644238],... | |
E y: array([[0. , 1.221403, 2.442806], | |
E [0. , 1.491825, 2.98365 ], | |
E [0. , 1.822119, 3.644239],... | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
_________ TimeSteppingTest.test_integration_linear_derivative_implicit _________ | |
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_linear_derivative_implicit> | |
explicit_terms = <function <lambda> at 0x15332dea0> | |
implicit_terms = <function <lambda> at 0x15332df30> | |
implicit_solve = <function <lambda> at 0x15332dfc0>, dt = 0.01, inner_steps = 20 | |
outer_steps = 5, initial_state = array([0., 1., 2.]) | |
closed_form = <function <lambda> at 0x15332e050> | |
tolerances = [0.05, 5e-05, 1e-05, 1e-05, 3e-05] | |
@parameterized.named_parameters(ALL_TEST_PROBLEMS) | |
def test_integration( | |
self, | |
explicit_terms, | |
implicit_terms, | |
implicit_solve, | |
dt, | |
inner_steps, | |
outer_steps, | |
initial_state, | |
closed_form, | |
tolerances, | |
): | |
# Compute closed-form solution. | |
time = dt * inner_steps * (1 + np.arange(outer_steps)) | |
expected = jax.vmap(closed_form, in_axes=(None, 0))( | |
initial_state, time) | |
# Compute trajectory using time-stepper. | |
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS): | |
with self.subTest(time_stepper.__name__): | |
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve) | |
semi_implicit_step = time_stepper(equation, dt) | |
integrator = funcutils.trajectory( | |
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps) | |
_, actual = integrator(initial_state) | |
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0) | |
jax_cfd/spectral/time_stepping_test.py:187: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x1535d6290>, array([[0. , 1.2214028, 2.4428055], | |
[0. ...21304, 3.644261 ], | |
[0. , 2.2255597, 4.4511194], | |
[0. , 2.7183104, 5.4366207]], dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=5e-05', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=0, atol=5e-05 | |
E | |
E Mismatched elements: 1 / 15 (6.67%) | |
E Max absolute difference: 5.722046e-05 | |
E Max relative difference: 1.0525005e-05 | |
E x: array([[0. , 1.221403, 2.442806], | |
E [0. , 1.491825, 2.983649], | |
E [0. , 1.822119, 3.644238],... | |
E y: array([[0. , 1.221406, 2.442811], | |
E [0. , 1.491831, 2.983662], | |
E [0. , 1.82213 , 3.644261],... | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
______ TimeSteppingTest.test_integration_linear_derivative_semi_implicit _______ | |
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_linear_derivative_semi_implicit> | |
explicit_terms = <function <lambda> at 0x15332e0e0> | |
implicit_terms = <function <lambda> at 0x15332e170> | |
implicit_solve = <function <lambda> at 0x15332e200>, dt = 0.01, inner_steps = 20 | |
outer_steps = 5, initial_state = array([0., 1., 2.]) | |
closed_form = <function <lambda> at 0x15332e290> | |
tolerances = [0.0001, 2e-05, 2e-06, 1e-06, 2e-05] | |
@parameterized.named_parameters(ALL_TEST_PROBLEMS) | |
def test_integration( | |
self, | |
explicit_terms, | |
implicit_terms, | |
implicit_solve, | |
dt, | |
inner_steps, | |
outer_steps, | |
initial_state, | |
closed_form, | |
tolerances, | |
): | |
# Compute closed-form solution. | |
time = dt * inner_steps * (1 + np.arange(outer_steps)) | |
expected = jax.vmap(closed_form, in_axes=(None, 0))( | |
initial_state, time) | |
# Compute trajectory using time-stepper. | |
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS): | |
with self.subTest(time_stepper.__name__): | |
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve) | |
semi_implicit_step = time_stepper(equation, dt) | |
integrator = funcutils.trajectory( | |
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps) | |
_, actual = integrator(initial_state) | |
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0) | |
jax_cfd/spectral/time_stepping_test.py:187: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x1534dee60>, array([[0. , 1.2214028, 2.4428055], | |
[0. ...21099, 3.6442199], | |
[0. , 2.2255282, 4.4510565], | |
[0. , 2.7182617, 5.4365234]], dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=2e-06', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=0, atol=2e-06 | |
E | |
E Mismatched elements: 10 / 15 (66.7%) | |
E Max absolute difference: 4.005432e-05 | |
E Max relative difference: 7.367635e-06 | |
E x: array([[0. , 1.221403, 2.442806], | |
E [0. , 1.491825, 2.983649], | |
E [0. , 1.822119, 3.644238],... | |
E y: array([[0. , 1.221401, 2.442801], | |
E [0. , 1.49182 , 2.983639], | |
E [0. , 1.82211 , 3.64422 ],... | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
_ SubgridModelsTest.test_divergence_and_momentum_sinusoidal_velocity_with_subgrid_model _ | |
[gw14] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.base.subgrid_models_test.SubgridModelsTest testMethod=test_divergence_and_momentum_sinusoidal_velocity_with_subgrid_model> | |
cs = 0.12, velocity = <function sinusoidal_velocity_field at 0x13359bb50> | |
forcing = None, shape = (100, 100), step = (1.0, 1.0), density = 1.0 | |
viscosity = 0.0001, convect = <function convect_linear at 0x1116b0ee0> | |
pressure_solve = <function solve_fast_diag at 0x129bb1c60>, dt = 0.001 | |
time_steps = 1000, divergence_atol = 0.001, momentum_atol = 0.001 | |
@parameterized.named_parameters( | |
dict( | |
testcase_name='sinusoidal_velocity_base', | |
cs=0.0, | |
velocity=sinusoidal_velocity_field, | |
forcing=None, | |
shape=(100, 100), | |
step=(1., 1.), | |
density=1., | |
viscosity=1e-4, | |
convect=advection.convect_linear, | |
pressure_solve=pressure.solve_cg, | |
dt=1e-3, | |
time_steps=1000, | |
divergence_atol=1e-3, | |
momentum_atol=2e-3), | |
dict( | |
testcase_name='gaussian_force_upwind_with_subgrid_model', | |
cs=0.12, | |
velocity=zero_velocity_field, | |
forcing=gaussian_forcing, | |
shape=(40, 40, 40), | |
step=(1., 1., 1.), | |
density=1., | |
viscosity=0, | |
convect=_convect_upwind, | |
pressure_solve=pressure.solve_cg, | |
dt=1e-3, | |
time_steps=100, | |
divergence_atol=1e-4, | |
momentum_atol=1e-4), | |
dict( | |
testcase_name='sinusoidal_velocity_with_subgrid_model', | |
cs=0.12, | |
velocity=sinusoidal_velocity_field, | |
forcing=None, | |
shape=(100, 100), | |
step=(1., 1.), | |
density=1., | |
viscosity=1e-4, | |
convect=advection.convect_linear, | |
pressure_solve=pressure.solve_fast_diag, | |
dt=1e-3, | |
time_steps=1000, | |
divergence_atol=1e-3, | |
momentum_atol=1e-3), | |
) | |
def test_divergence_and_momentum( | |
self, | |
cs, | |
velocity, | |
forcing, | |
shape, | |
step, | |
density, | |
viscosity, | |
convect, | |
pressure_solve, | |
dt, | |
time_steps, | |
divergence_atol, | |
momentum_atol, | |
): | |
grid = grids.Grid(shape, step) | |
kwargs = dict( | |
density=density, | |
viscosity=viscosity, | |
cs=cs, | |
dt=dt, | |
grid=grid, | |
convect=convect, | |
pressure_solve=pressure_solve, | |
forcing=forcing) | |
# Explicit and implicit navier-stokes solvers: | |
explicit_eq = subgrid_models.explicit_smagorinsky_navier_stokes(**kwargs) | |
implicit_eq = subgrid_models.implicit_smagorinsky_navier_stokes(**kwargs) | |
v_initial = velocity(grid) | |
v_final = funcutils.repeated(explicit_eq, time_steps)(v_initial) | |
# TODO(dkochkov) consider adding more thorough tests for these models. | |
with self.subTest('divergence free'): | |
divergence = fd.divergence(v_final) | |
self.assertLess(jnp.max(divergence.data), divergence_atol) | |
with self.subTest('conservation of momentum'): | |
initial_momentum = momentum(v_initial, density) | |
final_momentum = momentum(v_final, density) | |
if forcing is not None: | |
expected_change = ( | |
jnp.array([f.data for f in forcing(v_initial)]).sum() * | |
jnp.array(grid.step).prod() * dt * time_steps) | |
else: | |
expected_change = 0 | |
expected_momentum = initial_momentum + expected_change | |
> self.assertAllClose(expected_momentum, final_momentum, atol=momentum_atol) | |
jax_cfd/base/subgrid_models_test.py:211: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jax_cfd/base/test_util.py:87: in assertAllClose | |
np.testing.assert_allclose(expected, actual, **kwargs) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x147687130>, array(-0.00071716, dtype=float32), array(-0.00175476, dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0.001', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=1e-07, atol=0.001 | |
E | |
E Mismatched elements: 1 / 1 (100%) | |
E Max absolute difference: 0.0010376 | |
E Max relative difference: 0.59130436 | |
E x: array(-0.000717, dtype=float32) | |
E y: array(-0.001755, dtype=float32) | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
_______________ AdvectionTest.test_mass_conservation_van_leer_1D _______________ | |
[gw0] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python | |
self = <jax_cfd.base.advection_test.AdvectionTest testMethod=test_mass_conservation_van_leer_1D> | |
shape = (101,), method = <function _euler_step.<locals>.step at 0x129e8f9a0> | |
@parameterized.named_parameters( | |
dict( | |
testcase_name='van_leer_1D', | |
shape=(101,), | |
method=_euler_step(advection.advect_van_leer)), | |
) | |
def test_mass_conservation(self, shape, method): | |
offset = 0.5 | |
cfl_number = 0.1 | |
dt = cfl_number / shape[0] | |
num_steps = 1000 | |
grid = grids.Grid(shape, domain=([-1., 1.],)) | |
bc = boundaries.dirichlet_boundary_conditions(grid.ndim) | |
c_bc = boundaries.dirichlet_boundary_conditions(grid.ndim, ((-1., 1.),)) | |
def u(grid, offset): | |
x = grid.mesh((offset,))[0] | |
return grids.GridArray(-jnp.sin(jnp.pi * x), (offset,), grid) | |
def c0(grid, offset): | |
x = grid.mesh((offset,))[0] | |
return grids.GridArray(x, (offset,), grid) | |
v = (bc.impose_bc(u(grid, 1.)),) | |
c = c_bc.impose_bc(c0(grid, offset)) | |
ct = c | |
advect = jax.jit(functools.partial(method, v=v, dt=dt)) | |
initial_mass = np.sum(c.data) | |
for _ in range(num_steps): | |
ct = advect(ct) | |
current_total_mass = np.sum(ct.data) | |
> self.assertAllClose(current_total_mass, initial_mass, atol=1e-6) | |
jax_cfd/base/advection_test.py:442: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jax_cfd/base/test_util.py:87: in assertAllClose | |
np.testing.assert_allclose(expected, actual, **kwargs) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (<function assert_allclose.<locals>.compare at 0x1468c7010>, array(0., dtype=float32), array(-2.861023e-06, dtype=float32)) | |
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=1e-06', 'verbose': True} | |
@wraps(func) | |
def inner(*args, **kwds): | |
with self._recreate_cm(): | |
> return func(*args, **kwds) | |
E AssertionError: | |
E Not equal to tolerance rtol=1e-07, atol=1e-06 | |
E | |
E Mismatched elements: 1 / 1 (100%) | |
E Max absolute difference: 2.861023e-06 | |
E Max relative difference: 1. | |
E x: array(0., dtype=float32) | |
E y: array(-2.861023e-06, dtype=float32) | |
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError | |
=============================== warnings summary =============================== | |
venv/lib/python3.10/site-packages/jax/_src/pjit.py:288: 16 warnings | |
/Users/djk21/coding/jax-cfd/venv/lib/python3.10/site-packages/jax/_src/pjit.py:288: DeprecationWarning: backend and device argument on jit is deprecated. You can use a `jax.sharding.Mesh` context manager or device_put the arguments before passing them to `jit`. Please see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html for more information. | |
warnings.warn( | |
jax_cfd/base/grids_test.py::GridArrayTest::test_tree_util | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/grids_test.py:32: DeprecationWarning: jax.tree_flatten is deprecated: use jax.tree_util.tree_flatten. | |
flat, treedef = jax.tree_flatten(array) | |
jax_cfd/base/grids_test.py::GridArrayTest::test_tree_util | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/grids_test.py:33: DeprecationWarning: jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten. | |
roundtripped = jax.tree_unflatten(treedef, flat) | |
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_identity_nd3 | |
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_identity_nd4 | |
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_identity_nd5 | |
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_poisson_1d1 | |
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_poisson_2d_fft0 | |
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_random_1d_fft0 | |
/Users/djk21/coding/jax-cfd/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py:511: ComplexWarning: Casting complex values to real discards the imaginary part | |
return _convert_element_type(operand, new_dtype, weak_type=False) | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes0 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes0 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes1 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes1 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes2 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes2 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes3 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes3 | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils.py:118: DeprecationWarning: jax.tree_flatten is deprecated: use jax.tree_util.tree_flatten. | |
arrays, tree_def = jax.tree_flatten(inputs) | |
jax_cfd/base/array_utils_test.py: 74 warnings | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils.py:127: DeprecationWarning: jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten. | |
return tuple(jax.tree_unflatten(tree_def, leaves) for leaves in splits) | |
jax_cfd/base/resize_test.py: 32 warnings | |
jax_cfd/base/array_utils_test.py: 12 warnings | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils.py:60: DeprecationWarning: jax.tree_flatten is deprecated: use jax.tree_util.tree_flatten. | |
arrays, tree_def = jax.tree_flatten(inputs) | |
jax_cfd/base/resize_test.py: 32 warnings | |
jax_cfd/base/array_utils_test.py: 10 warnings | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils.py:71: DeprecationWarning: jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten. | |
return jax.tree_unflatten(tree_def, sliced) | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4 | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:137: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves. | |
self.assertEqual(jax.tree_leaves(split_a)[0].shape[axis], idx) | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4 | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:141: DeprecationWarning: jax.tree_structure is deprecated: use jax.tree_util.tree_structure. | |
actual_tree_def = jax.tree_structure(reconstruction) | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4 | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:142: DeprecationWarning: jax.tree_structure is deprecated: use jax.tree_util.tree_structure. | |
expected_tree_def = jax.tree_structure(pytree) | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4 | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:145: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves. | |
actual_values = jax.tree_leaves(reconstruction) | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4 | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:146: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves. | |
expected_values = jax.tree_leaves(pytree) | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4 | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:160: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves. | |
actual_shape = jax.tree_leaves(double_concat)[0].shape[axis] | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3 | |
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4 | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:161: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves. | |
expected_shape = jax.tree_leaves(pytree)[0].shape[axis] * 2 | |
jax_cfd/base/subgrid_models_test.py: 15 warnings | |
/Users/djk21/coding/jax-cfd/jax_cfd/base/subgrid_models.py:98: DeprecationWarning: jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten. | |
return jax.tree_unflatten(jax.tree_util.tree_structure(s_ij), viscosities) | |
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html | |
=========================== short test summary info ============================ | |
FAILED jax_cfd/collocated/advection_test.py::AdvectionTest::test_mass_conservation_dirichlet_dichlet_advect | |
FAILED jax_cfd/collocated/advection_test.py::AdvectionTest::test_neumann_bc_one_step_linear_1d_neumann | |
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_implicit_solve_harmonic_oscillator_implicit | |
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_constant_derivative | |
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_harmonic_oscillator_explicit | |
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_harmonic_oscillator_implicit | |
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_linear_derivative_explicit | |
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_linear_derivative_implicit | |
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_linear_derivative_semi_implicit | |
FAILED jax_cfd/base/subgrid_models_test.py::SubgridModelsTest::test_divergence_and_momentum_sinusoidal_velocity_with_subgrid_model | |
FAILED jax_cfd/base/advection_test.py::AdvectionTest::test_mass_conservation_van_leer_1D | |
ERROR jax_cfd/ml/equations_test.py | |
ERROR jax_cfd/ml/layers_test.py | |
ERROR jax_cfd/ml/layers_util_test.py | |
ERROR jax_cfd/ml/towers_test.py | |
====== 11 failed, 447 passed, 242 warnings, 4 errors in 118.92s (0:01:58) ====== |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment