Skip to content

Instantly share code, notes, and snippets.

@tobywf
Created April 9, 2017 16:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tobywf/b0f67e9e414e41813c1590ed1ea39e2c to your computer and use it in GitHub Desktop.
Save tobywf/b0f67e9e414e41813c1590ed1ea39e2c to your computer and use it in GitHub Desktop.
Decimal rounding to the nearest multiple (plus unit tests)
from decimal import Decimal, ROUND_CEILING, ROUND_FLOOR
ONE = Decimal(1)
def round_up(number, quantum):
"""Round a decimal number up to the nearest multiple of quantum"""
return (number / quantum).quantize(ONE, rounding=ROUND_CEILING) * quantum
def round_down(number, quantum):
"""Round a decimal number down to the nearest multiple of quantum"""
return (number / quantum).quantize(ONE, rounding=ROUND_FLOOR) * quantum
import unittest
from decimal import Decimal, BasicContext, localcontext
ONE = Decimal(1)
class RoundTestCase:
def test_unity(self):
self.assertEqual(self.method(ONE, ONE), ONE)
def test_epsilon(self):
def _range_inclusive(start, end, step=1):
return range(start, end + step, step)
with localcontext(BasicContext) as context:
for precision in range(1, 20):
context.prec = precision + 2 # give the algo some room to work
for exponent in _range_inclusive(-10, 10):
epsilon = Decimal((0, (1, ), exponent - precision))
for digit in _range_inclusive(1, 9):
quantum = Decimal((0, (digit, ), exponent))
for multiple in _range_inclusive(-9, 9):
with self.subTest(
precision=precision,
exponent=exponent,
digit=digit,
multiple=multiple):
under = quantum.fma(multiple, -epsilon)
self.assertEqual(self.method(under, quantum), self.under(multiple) * quantum)
over = quantum.fma(multiple, epsilon)
self.assertEqual(self.method(over, quantum), self.over(multiple) * quantum)
class RoundUpTestCase(unittest.TestCase, RoundTestCase):
method = staticmethod(round_up)
@staticmethod
def under(multiple):
return multiple
@staticmethod
def over(multiple):
return multiple + 1
class RoundDownTestCase(unittest.TestCase, RoundTestCase):
method = staticmethod(round_down)
@staticmethod
def under(multiple):
return multiple - 1
@staticmethod
def over(multiple):
return multiple
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment