Skip to content

Instantly share code, notes, and snippets.

@cako
Last active June 25, 2023 22:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cako/2838f56f7e63c2164057acdeac500e55 to your computer and use it in GitHub Desktop.
Save cako/2838f56f7e63c2164057acdeac500e55 to your computer and use it in GitHub Desktop.
import numpy as np
import pytest
# Test multiple dtypes at once
@pytest.mark.parametrize(
"dtype", ["int8", "int16", "float16", "float32", "float64", "float128"]
)
def test_zero(dtype):
# Set random seed
seed = int(np.random.rand() * (2**32 - 1))
np.random.seed(seed)
# Create a 2D array of random shape and fill with zeros
nx, ny = np.random.randint(3, 100, size=(2,))
arr = np.zeros((nx, ny), dtype=dtype)
# Apply sobel function
arr_sob = sobel(arr)
# `assert_array_equal` should fail most of the times.
# It will only work when `arr_sob` is identically zero,
# which is usually not the case. Don't use it!
# np.testing.assert_array_equal(arr_sob, 0.0, err_msg=f"{seed=} {nx=}, {ny=}")
# `assert_almost_equal` can fail when used with high decimals.
# It also relies on float64 checking, which might fail for
# float 128 types. I would avoid using it.
# np.testing.assert_almost_equal(
# arr_sob, np.zeros_like(arr), err_msg=f"{seed=} {nx=}, {ny=}", decimal=40
# )
# `assert_allclose` with custom tolerance is my preferred method
# The 10 is arbitrary and depends on the problem. If a method
# which you know to be correct does not pass, increase to 100, etc.
# If the tolerance needed to make the tests pass is too high, make
# sure the method is actually correct.
tol = 10 * np.finfo(arr.dtype).eps
err_msg = f"{seed=} {nx=}, {ny=} {tol=}" # Log seeds and other info
np.testing.assert_allclose(
arr_sob,
np.zeros_like(arr),
err_msg=err_msg,
atol=tol, # rtol is useless for desired=zeros
)
@pytest.mark.parametrize(
"dtype", ["int8", "int16", "float16", "float32", "float64", "float128"]
)
def test_constant(dtype):
seed = int(np.random.rand() * (2**32 - 1))
np.random.seed(seed)
nx, ny = np.random.randint(3, 100, size=(2,))
constant = np.random.randn(1).item()
arr = np.full((nx, ny), fill_value=constant, dtype=dtype)
arr_sob = sobel(arr)
tol = 10 * np.finfo(arr.dtype).eps
err_msg = f"{seed=} {nx=}, {ny=} {tol=}"
np.testing.assert_allclose(
arr_sob,
np.zeros_like(arr),
err_msg=err_msg,
atol=tol, # rtol is useless for desired=zeros
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment