Skip to content

Instantly share code, notes, and snippets.

@pirate
Last active July 18, 2023 22:08
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pirate/cc8e770eaf1ddb346d72e2e2c406c077 to your computer and use it in GitHub Desktop.
Save pirate/cc8e770eaf1ddb346d72e2e2c406c077 to your computer and use it in GitHub Desktop.
A SafeNumber type for Python that implements the fractions.Fraction interface with guards to prevent implicit operand type casting leading to a loss of precision.
# This implements a SafeNumber class which wraps Decimal and Fraction to warn
# when infix math or comparison operators may cause dangerous implicit type conversion.
#
# Implicit type conversion when using operators is sneaky with Decimal/Fraction:
# >>> Fraction(10) == 10.0000000000000001
# True
#
# But with SafeNumber, this throws an error to protect against this scenario:
# >>> SafeNumber(10) == 10.0000000000000001
# Traceback (most recent call last):
# ...
# TypeError: Invalid operand type, operands can only be of type Union[int, str, Fraction, Decimal, SafeNumber]
#
# For more info on math safety and data integrity in Python see here:
# https://github.com/pirate/django-concurrency-talk/blob/master/README.md#-the-correct-approach-with-a-custom-safenumber-type
from typing import Union, Any
from decimal import Decimal
from fractions import Fraction
SAFE_NUM_TYPES = (int, str, Fraction, Decimal)
SAFE_NUM_TYPES_STR = 'Union[int, str, Fraction, Decimal, SafeNumber]'
class SafeNumber:
"""A safe number type that protects against implicit type casting"""
_value: Fraction
def __init__(self, numerator: Union[int, str, Fraction, Decimal, 'SafeNumber'], denominator: Union[int, str, Fraction, Decimal, 'SafeNumber']=1):
if isinstance(numerator, SafeNumber):
numerator = numerator._value
if isinstance(denominator, SafeNumber):
denominator = denominator._value
if not isinstance(numerator, SAFE_NUM_TYPES) or not isinstance(denominator, SAFE_NUM_TYPES):
raise TypeError(
f'To ensure correctness, SafeNumbers can only be instantiated '
f'with values of type {SAFE_NUM_TYPES_STR}'
)
self._value = Fraction(Fraction(numerator), Fraction(denominator)) # type: ignore
def __repr__(self):
return self._value.__repr__().replace('Fraction', 'SafeNumber')
def __str__(self):
return self._value.__str__()
def as_fraction(self) -> Fraction:
return self._value
def as_decimal(self) -> Decimal:
return Decimal(self._value.numerator) / Decimal(self._value.denominator)
def as_float(self) -> float:
raise Exception("Hahah nice try. Floats are the devil.")
def _safe_param(self, value) -> 'SafeNumber':
if isinstance(value, SafeNumber):
return value
elif isinstance(value, SAFE_NUM_TYPES):
return SafeNumber(value)
else:
raise TypeError(
'Invalid operand type, operands can only be of type '
f'{SAFE_NUM_TYPES_STR}'
)
# Comparisons are only safe with other Fraction/SafeNumbers
# >>> Fraction(10) == 10.0000000000000001
# True
# >>> Fraction(10) < 10.0000000000000001
# False
# The comparison operators are overloaded below to guard against this.
def __hash__(self):
return self._value.__hash__()
def __eq__(self, other: Any):
other = self._safe_param(other)
return self._value.__eq__(other._value)
def __lt__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return self._value.__lt__(other._value)
def __le__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return self._value.__le__(other._value)
def __gt__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return self._value.__gt__(other._value)
def __ge__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return self._value.__ge__(other._value)
# Math operations are only safe with other Fraction/SafeNumbers
# >>> Decimal(Fraction(10) * 0.3000000000000001)
# Decimal('3.00000000000000088817841970012523233890533447265625')
# The math operators are overloaded below to guard against this.
def __mod__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__mod__(other._value))
def __rmod__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__rmod__(other._value))
def __mul__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__mul__(other._value))
def __rmul__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__rmul__(other._value))
def __truediv__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__truediv__(other._value))
def __rtruediv__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__rtruediv__(other._value))
def __divmod__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__divmod__(other._value))
def __rdivmod__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__rdivmod__(other._value))
def __add__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__add__(other._value))
def __radd__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__radd__(other._value))
def __sub__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__sub__(other._value))
def __rsub__(self, other: 'SafeNumber'):
other = self._safe_param(other)
return SafeNumber(self._value.__sub__(other._value))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment