Skip to content

Instantly share code, notes, and snippets.

@kmdouglass
Created August 26, 2022 18:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kmdouglass/ceb6aba85852820831c2f5680cbd73a2 to your computer and use it in GitHub Desktop.
Save kmdouglass/ceb6aba85852820831c2f5680cbd73a2 to your computer and use it in GitHub Desktop.
Round a list of floats to ints so that the sum of the ints is the integer portion of the sum of the floats
def safe_round(array: npt.ArrayLike, total: int) -> np.ndarray:
"""Rounds an array of floats, maintaining their integer sum."""
array = np.asanyarray(array)
# Round the array to the nearest integer
rounded_array: np.ndarray = np.rint(array)
error = total - np.sum(rounded_array)
if error == 0:
return rounded_array
# The number of elements to adjust. For integers, each element after rounding is within 0.5 of
# the desired value, so the maximum adjustment is 1.
n = int(np.abs(error))
# np.argsort() returns an array of indices that would sort an array
sorted_index_array = np.argsort(array - rounded_array, axis=None)
# Add +/- 1 to the elements of the rounded_array with the n largest rounding errors
safe_rounded_array = rounded_array.flatten()
safe_rounded_array[sorted_index_array[0:n]] += np.copysign(1, error)
return safe_rounded_array.reshape(array.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment