Created
April 21, 2021 09:34
-
-
Save adamchainz/025788116abeb3aa9ac392cc5353d7e6 to your computer and use it in GitHub Desktop.
optimus django edition
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import base64 | |
import itertools | |
import random | |
from dataclasses import dataclass | |
from functools import cached_property | |
from math import ceil, floor, log2, log10 | |
from django.db import models | |
from django.db.backends.base.base import BaseDatabaseWrapper | |
from django.db.models.query_utils import DeferredAttribute | |
from django.utils.deconstruct import deconstructible | |
from db_buddy.core.numerical import is_prime, modinv | |
class HashedBigAutoFieldDeferredAttribute(DeferredAttribute): | |
def __set__(self, instance: object, value: object) -> None: | |
# Workaround for: https://github.com/django/django/pull/14007 | |
# Won't be needed after Django releases featuring that fix. | |
if isinstance(value, int): | |
value = self.field._encode(value) | |
instance.__dict__[self.field.attname] = value | |
class HashedBigAutoField(models.BigAutoField): | |
descriptor_class = HashedBigAutoFieldDeferredAttribute | |
def __init__(self, *args: object, int_mapper: IntMapper, **kwargs: object) -> None: | |
super().__init__(*args, **kwargs) | |
self.int_mapper = int_mapper | |
def deconstruct(self) -> tuple[str, str, tuple[object, ...], dict[str, object]]: | |
name: str | |
path: str | |
args: tuple[object] | |
kwargs: dict[str, object] | |
name, path, args, kwargs = super().deconstruct() | |
kwargs["int_mapper"] = self.int_mapper | |
return name, path, args, kwargs | |
def from_db_value( | |
self, | |
value: int | None, | |
expression: HashedBigAutoField, | |
connection: BaseDatabaseWrapper, | |
) -> str | None: | |
if value is None: | |
return None | |
return self._encode(value) | |
def get_prep_value(self, value: str | int | None) -> int | None: | |
if value is None: | |
return None | |
elif isinstance(value, int): | |
return value | |
return self._decode(value) | |
def to_python(self, value: str | int | None) -> str | None: | |
""" | |
In theory this would only be necessary for form values and we’d never | |
receive an integer, but in practice it seems some code paths in Django | |
still pass integer values through to_python() so we need to handle | |
them. | |
""" | |
if value is None: | |
return None | |
elif isinstance(value, str): | |
return value | |
return self._encode(value) | |
def _encode(self, integer: int) -> str: | |
encoded = self.int_mapper.encode(integer) | |
num_bytes = ceil(log2(encoded) / 8) | |
return base64.urlsafe_b64encode( | |
encoded.to_bytes(num_bytes, "big"), | |
).decode() | |
def _decode(self, string: str) -> int: | |
encoded = int.from_bytes( | |
base64.urlsafe_b64decode(string), | |
"big", | |
) | |
return self.int_mapper.decode(encoded) | |
@deconstructible | |
@dataclass | |
class IntMapper: | |
""" | |
Maps integers in the range [0, 2**64-1] to a different integer in the same | |
range, in a 1-to-1, random-seeming manner. Requires three 'cooridinat | |
""" | |
LIMIT = 2 ** 64 - 1 | |
prime: int | |
mask: int | |
@classmethod | |
def generate(cls) -> IntMapper: | |
""" | |
Generate a new valid integer hasher. | |
""" | |
maximum = cls.LIMIT | |
minimum = 10 ** max(1, floor(log10(maximum)) - 2) | |
start = random.randint(minimum, maximum) | |
prime = next( # pragma: no branch | |
x for x in itertools.count(start) if is_prime(x) | |
) | |
mask = random.randint(minimum, maximum) | |
return cls(prime=prime, mask=mask) | |
def __post_init__(self) -> None: | |
if not is_prime(self.prime): | |
raise ValueError(f"'prime' value {self.prime} is not prime") | |
@cached_property | |
def _inverse(self) -> int: | |
return modinv(self.prime, self.LIMIT + 1) | |
def encode(self, value: int) -> int: | |
return ((value * self.prime) & (2 ** 64 - 1)) ^ self.mask | |
def decode(self, value: int) -> int: | |
return ((value ^ self.mask) * self._inverse) & (2 ** 64 - 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from math import floor, log | |
from typing import TypeVar | |
Number = TypeVar("Number", int, float) | |
def clamp(*, minimum: Number, value: Number, maximum: Number) -> Number: | |
return max(minimum, min(value, maximum)) | |
def xgcd(a: int, b: int) -> tuple[int, int, int]: | |
""" | |
Extended Euclidean algorithm | |
Return (g, x, y) such that a*x + b*y = g = gcd(a, b) | |
Source: | |
https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm#Python # noqa: B950 | |
""" | |
x0, x1, y0, y1 = 0, 1, 1, 0 | |
while a != 0: | |
(q, a), b = divmod(b, a), a | |
y0, y1 = y1, y0 - q * y1 | |
x0, x1 = x1, x0 - q * x1 | |
return b, x0, y0 | |
def modinv(a: int, b: int) -> int: | |
""" | |
Modular multiplicative inverse | |
Return x such that (x * a) % b == 1 | |
Source: | |
https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm#Python # noqa: B950 | |
""" | |
g, x, _ = xgcd(a, b) | |
if g != 1: | |
raise ValueError("gcd(a, b) != 1") | |
return x % b | |
def is_prime(n: int) -> bool: | |
""" | |
Use the Miller Rabin primality checking algorithm with witnesses known to | |
deterministically cover all integers in the range 3 to 2**64. | |
Code and witness values from Wikipedia: | |
https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test | |
""" | |
if n <= 1: | |
raise ValueError("n must be greater than 1") | |
elif n >= 18_446_744_073_709_551_616: | |
raise ValueError("n too large") | |
if n in (2, 3): | |
return True | |
elif n % 2 == 0: | |
return False | |
# Rewrite n as 2**r * d + 1 | |
r = 0 | |
d = n - 1 | |
while d % 2 == 0: | |
r += 1 | |
d //= 2 | |
# 'Witness' loop | |
witnesses = (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37) | |
max_witness = min(n - 2, floor(2 * log(n) ** 2)) | |
for a in witnesses: | |
if a > max_witness: | |
continue | |
x = pow(base=a, exp=d, mod=n) | |
if x == 1 or x == n - 1: | |
continue | |
composite = True | |
for _ in range(r - 1): | |
x = pow(base=x, exp=2, mod=n) | |
if x == n - 1: | |
composite = False | |
break | |
if composite: | |
return False | |
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from unittest import mock | |
import pytest | |
from django.db import connection | |
from db_buddy.core.ext.django.models import ( | |
HashedBigAutoField, | |
IntMapper, | |
) | |
from db_buddy.core.numerical import is_prime | |
from db_buddy.test import SimpleTestCase | |
class HashedBigAutoFieldTests(SimpleTestCase): | |
field = HashedBigAutoField( | |
int_mapper=IntMapper(prime=18319640022251334367, mask=12979739931064651139) | |
) | |
def test_descriptor_set_int(self): | |
class FakeModel: | |
_meta = mock.Mock(auto_field=False) | |
self.field.contribute_to_class(FakeModel, "field") | |
instance = FakeModel() | |
instance.field = 1234 # type: ignore[attr-defined] | |
assert instance.field == "y3DqG1zR620=" # type: ignore[attr-defined] | |
def test_descriptor_set_str(self): | |
class FakeModel: | |
_meta = mock.Mock(auto_field=False) | |
self.field.contribute_to_class(FakeModel, "field") | |
instance = FakeModel() | |
instance.field = "y3DqG1zR620=" # type: ignore[attr-defined] | |
assert instance.field == "y3DqG1zR620=" # type: ignore[attr-defined] | |
def test_deconstruct(self): | |
name, path, args, kwargs = self.field.deconstruct() | |
assert kwargs["int_mapper"] == self.field.int_mapper | |
def test_deconstruct_reconstruct(self): | |
name, path, args, kwargs = self.field.deconstruct() | |
HashedBigAutoField(*args, **kwargs) # type: ignore[arg-type] | |
def test_from_db_value_none(self): | |
result = self.field.from_db_value( | |
None, expression=self.field, connection=connection | |
) | |
assert result is None | |
def test_from_db_value_number(self): | |
result = self.field.from_db_value( | |
1234, expression=self.field, connection=connection | |
) | |
assert result == "y3DqG1zR620=" | |
def test_get_prep_value_none(self): | |
result = self.field.get_prep_value(None) | |
assert result is None | |
def test_get_prep_value_integer(self): | |
result = self.field.get_prep_value(1234) | |
assert result == 1234 | |
def test_get_prep_value_string(self): | |
result = self.field.get_prep_value("y3DqG1zR620=") | |
assert result == 1234 | |
def test_to_python_none(self): | |
result = self.field.to_python(None) | |
assert result is None | |
def test_to_python_string(self): | |
result = self.field.to_python("y3DqG1zR620=") | |
assert result == "y3DqG1zR620=" | |
def test_to_python_integer(self): | |
result = self.field.to_python(1234) | |
assert result == "y3DqG1zR620=" | |
class IntMapperTests(SimpleTestCase): | |
def test_generate(self): | |
im = IntMapper.generate() | |
assert is_prime(im.prime) | |
def test_init_non_prime(self): | |
with pytest.raises(ValueError) as excinfo: | |
IntMapper(prime=4, mask=1) | |
assert excinfo.value.args == ("'prime' value 4 is not prime",) | |
def test_encoding(self): | |
im = IntMapper(prime=18319640022251334367, mask=12979739931064651139) | |
for i in range(1_000): | |
encoded = im.encode(i) | |
assert im.decode(encoded) == i | |
def test_deconstruct(self): | |
im = IntMapper(prime=18319640022251334367, mask=12979739931064651139) | |
path, args, kwargs = im.deconstruct() # type: ignore[attr-defined] | |
assert path == "db_buddy.core.ext.django.models.IntMapper" | |
assert args == () | |
assert kwargs == {"prime": 18319640022251334367, "mask": 12979739931064651139} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import pytest | |
from db_buddy.core.numerical import clamp, is_prime, modinv | |
from db_buddy.test import SimpleTestCase | |
class ClampTests(SimpleTestCase): | |
def test_below_minimum(self): | |
result = clamp(minimum=1, value=-1, maximum=3) | |
assert result == 1 | |
def test_in_range(self): | |
result = clamp(minimum=1, value=2, maximum=3) | |
assert result == 2 | |
def test_above_maximum(self): | |
result = clamp(minimum=1, value=4, maximum=3) | |
assert result == 3 | |
class ModInvTests(SimpleTestCase): | |
def test_valid(self): | |
assert modinv(31, 11) == 5 | |
def test_invalid(self): | |
with pytest.raises(ValueError) as excinfo: | |
modinv(32, 10) | |
assert excinfo.value.args == ("gcd(a, b) != 1",) | |
class IsPrimeTests(SimpleTestCase): | |
def test_too_small(self): | |
with pytest.raises(ValueError) as excinfo: | |
is_prime(1) | |
assert excinfo.value.args == ("n must be greater than 1",) | |
def test_too_large(self): | |
with pytest.raises(ValueError) as excinfo: | |
is_prime(2 ** 70) | |
assert excinfo.value.args == ("n too large",) | |
def test_2(self): | |
assert is_prime(2) | |
def test_3(self): | |
assert is_prime(3) | |
def test_even(self): | |
assert not is_prime(4) | |
def test_first_100_primes(self): | |
first_primes = [x for x in range(2, 101) if is_prime(x)] | |
assert first_primes == [ | |
2, | |
3, | |
5, | |
7, | |
11, | |
13, | |
17, | |
19, | |
23, | |
29, | |
31, | |
37, | |
41, | |
43, | |
47, | |
53, | |
59, | |
61, | |
67, | |
71, | |
73, | |
79, | |
83, | |
89, | |
97, | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment