Last active
July 18, 2023 22:08
-
-
Save pirate/cc8e770eaf1ddb346d72e2e2c406c077 to your computer and use it in GitHub Desktop.
A SafeNumber type for Python that implements the fractions.Fraction interface with guards to prevent implicit operand type casting leading to a loss of precision.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This implements a SafeNumber class which wraps Decimal and Fraction to warn | |
# when infix math or comparison operators may cause dangerous implicit type conversion. | |
# | |
# Implicit type conversion when using operators is sneaky with Decimal/Fraction: | |
# >>> Fraction(10) == 10.0000000000000001 | |
# True | |
# | |
# But with SafeNumber, this throws an error to protect against this scenario: | |
# >>> SafeNumber(10) == 10.0000000000000001 | |
# Traceback (most recent call last): | |
# ... | |
# TypeError: Invalid operand type, operands can only be of type Union[int, str, Fraction, Decimal, SafeNumber] | |
# | |
# For more info on math safety and data integrity in Python see here: | |
# https://github.com/pirate/django-concurrency-talk/blob/master/README.md#-the-correct-approach-with-a-custom-safenumber-type | |
from typing import Union, Any | |
from decimal import Decimal | |
from fractions import Fraction | |
SAFE_NUM_TYPES = (int, str, Fraction, Decimal) | |
SAFE_NUM_TYPES_STR = 'Union[int, str, Fraction, Decimal, SafeNumber]' | |
class SafeNumber: | |
"""A safe number type that protects against implicit type casting""" | |
_value: Fraction | |
def __init__(self, numerator: Union[int, str, Fraction, Decimal, 'SafeNumber'], denominator: Union[int, str, Fraction, Decimal, 'SafeNumber']=1): | |
if isinstance(numerator, SafeNumber): | |
numerator = numerator._value | |
if isinstance(denominator, SafeNumber): | |
denominator = denominator._value | |
if not isinstance(numerator, SAFE_NUM_TYPES) or not isinstance(denominator, SAFE_NUM_TYPES): | |
raise TypeError( | |
f'To ensure correctness, SafeNumbers can only be instantiated ' | |
f'with values of type {SAFE_NUM_TYPES_STR}' | |
) | |
self._value = Fraction(Fraction(numerator), Fraction(denominator)) # type: ignore | |
def __repr__(self): | |
return self._value.__repr__().replace('Fraction', 'SafeNumber') | |
def __str__(self): | |
return self._value.__str__() | |
def as_fraction(self) -> Fraction: | |
return self._value | |
def as_decimal(self) -> Decimal: | |
return Decimal(self._value.numerator) / Decimal(self._value.denominator) | |
def as_float(self) -> float: | |
raise Exception("Hahah nice try. Floats are the devil.") | |
def _safe_param(self, value) -> 'SafeNumber': | |
if isinstance(value, SafeNumber): | |
return value | |
elif isinstance(value, SAFE_NUM_TYPES): | |
return SafeNumber(value) | |
else: | |
raise TypeError( | |
'Invalid operand type, operands can only be of type ' | |
f'{SAFE_NUM_TYPES_STR}' | |
) | |
# Comparisons are only safe with other Fraction/SafeNumbers | |
# >>> Fraction(10) == 10.0000000000000001 | |
# True | |
# >>> Fraction(10) < 10.0000000000000001 | |
# False | |
# The comparison operators are overloaded below to guard against this. | |
def __hash__(self): | |
return self._value.__hash__() | |
def __eq__(self, other: Any): | |
other = self._safe_param(other) | |
return self._value.__eq__(other._value) | |
def __lt__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return self._value.__lt__(other._value) | |
def __le__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return self._value.__le__(other._value) | |
def __gt__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return self._value.__gt__(other._value) | |
def __ge__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return self._value.__ge__(other._value) | |
# Math operations are only safe with other Fraction/SafeNumbers | |
# >>> Decimal(Fraction(10) * 0.3000000000000001) | |
# Decimal('3.00000000000000088817841970012523233890533447265625') | |
# The math operators are overloaded below to guard against this. | |
def __mod__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__mod__(other._value)) | |
def __rmod__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__rmod__(other._value)) | |
def __mul__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__mul__(other._value)) | |
def __rmul__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__rmul__(other._value)) | |
def __truediv__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__truediv__(other._value)) | |
def __rtruediv__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__rtruediv__(other._value)) | |
def __divmod__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__divmod__(other._value)) | |
def __rdivmod__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__rdivmod__(other._value)) | |
def __add__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__add__(other._value)) | |
def __radd__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__radd__(other._value)) | |
def __sub__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__sub__(other._value)) | |
def __rsub__(self, other: 'SafeNumber'): | |
other = self._safe_param(other) | |
return SafeNumber(self._value.__sub__(other._value)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment