Last active
June 25, 2023 22:57
-
-
Save cako/2838f56f7e63c2164057acdeac500e55 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pytest | |
# Test multiple dtypes at once | |
@pytest.mark.parametrize( | |
"dtype", ["int8", "int16", "float16", "float32", "float64", "float128"] | |
) | |
def test_zero(dtype): | |
# Set random seed | |
seed = int(np.random.rand() * (2**32 - 1)) | |
np.random.seed(seed) | |
# Create a 2D array of random shape and fill with zeros | |
nx, ny = np.random.randint(3, 100, size=(2,)) | |
arr = np.zeros((nx, ny), dtype=dtype) | |
# Apply sobel function | |
arr_sob = sobel(arr) | |
# `assert_array_equal` should fail most of the times. | |
# It will only work when `arr_sob` is identically zero, | |
# which is usually not the case. Don't use it! | |
# np.testing.assert_array_equal(arr_sob, 0.0, err_msg=f"{seed=} {nx=}, {ny=}") | |
# `assert_almost_equal` can fail when used with high decimals. | |
# It also relies on float64 checking, which might fail for | |
# float 128 types. I would avoid using it. | |
# np.testing.assert_almost_equal( | |
# arr_sob, np.zeros_like(arr), err_msg=f"{seed=} {nx=}, {ny=}", decimal=40 | |
# ) | |
# `assert_allclose` with custom tolerance is my preferred method | |
# The 10 is arbitrary and depends on the problem. If a method | |
# which you know to be correct does not pass, increase to 100, etc. | |
# If the tolerance needed to make the tests pass is too high, make | |
# sure the method is actually correct. | |
tol = 10 * np.finfo(arr.dtype).eps | |
err_msg = f"{seed=} {nx=}, {ny=} {tol=}" # Log seeds and other info | |
np.testing.assert_allclose( | |
arr_sob, | |
np.zeros_like(arr), | |
err_msg=err_msg, | |
atol=tol, # rtol is useless for desired=zeros | |
) | |
@pytest.mark.parametrize( | |
"dtype", ["int8", "int16", "float16", "float32", "float64", "float128"] | |
) | |
def test_constant(dtype): | |
seed = int(np.random.rand() * (2**32 - 1)) | |
np.random.seed(seed) | |
nx, ny = np.random.randint(3, 100, size=(2,)) | |
constant = np.random.randn(1).item() | |
arr = np.full((nx, ny), fill_value=constant, dtype=dtype) | |
arr_sob = sobel(arr) | |
tol = 10 * np.finfo(arr.dtype).eps | |
err_msg = f"{seed=} {nx=}, {ny=} {tol=}" | |
np.testing.assert_allclose( | |
arr_sob, | |
np.zeros_like(arr), | |
err_msg=err_msg, | |
atol=tol, # rtol is useless for desired=zeros | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment