Skip to content

Instantly share code, notes, and snippets.

@adamchainz
Created April 21, 2021 09:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adamchainz/025788116abeb3aa9ac392cc5353d7e6 to your computer and use it in GitHub Desktop.
Save adamchainz/025788116abeb3aa9ac392cc5353d7e6 to your computer and use it in GitHub Desktop.
optimus django edition
from __future__ import annotations
import base64
import itertools
import random
from dataclasses import dataclass
from functools import cached_property
from math import ceil, floor, log2, log10
from django.db import models
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.query_utils import DeferredAttribute
from django.utils.deconstruct import deconstructible
from db_buddy.core.numerical import is_prime, modinv
class HashedBigAutoFieldDeferredAttribute(DeferredAttribute):
def __set__(self, instance: object, value: object) -> None:
# Workaround for: https://github.com/django/django/pull/14007
# Won't be needed after Django releases featuring that fix.
if isinstance(value, int):
value = self.field._encode(value)
instance.__dict__[self.field.attname] = value
class HashedBigAutoField(models.BigAutoField):
descriptor_class = HashedBigAutoFieldDeferredAttribute
def __init__(self, *args: object, int_mapper: IntMapper, **kwargs: object) -> None:
super().__init__(*args, **kwargs)
self.int_mapper = int_mapper
def deconstruct(self) -> tuple[str, str, tuple[object, ...], dict[str, object]]:
name: str
path: str
args: tuple[object]
kwargs: dict[str, object]
name, path, args, kwargs = super().deconstruct()
kwargs["int_mapper"] = self.int_mapper
return name, path, args, kwargs
def from_db_value(
self,
value: int | None,
expression: HashedBigAutoField,
connection: BaseDatabaseWrapper,
) -> str | None:
if value is None:
return None
return self._encode(value)
def get_prep_value(self, value: str | int | None) -> int | None:
if value is None:
return None
elif isinstance(value, int):
return value
return self._decode(value)
def to_python(self, value: str | int | None) -> str | None:
"""
In theory this would only be necessary for form values and we’d never
receive an integer, but in practice it seems some code paths in Django
still pass integer values through to_python() so we need to handle
them.
"""
if value is None:
return None
elif isinstance(value, str):
return value
return self._encode(value)
def _encode(self, integer: int) -> str:
encoded = self.int_mapper.encode(integer)
num_bytes = ceil(log2(encoded) / 8)
return base64.urlsafe_b64encode(
encoded.to_bytes(num_bytes, "big"),
).decode()
def _decode(self, string: str) -> int:
encoded = int.from_bytes(
base64.urlsafe_b64decode(string),
"big",
)
return self.int_mapper.decode(encoded)
@deconstructible
@dataclass
class IntMapper:
"""
Maps integers in the range [0, 2**64-1] to a different integer in the same
range, in a 1-to-1, random-seeming manner. Requires three 'cooridinat
"""
LIMIT = 2 ** 64 - 1
prime: int
mask: int
@classmethod
def generate(cls) -> IntMapper:
"""
Generate a new valid integer hasher.
"""
maximum = cls.LIMIT
minimum = 10 ** max(1, floor(log10(maximum)) - 2)
start = random.randint(minimum, maximum)
prime = next( # pragma: no branch
x for x in itertools.count(start) if is_prime(x)
)
mask = random.randint(minimum, maximum)
return cls(prime=prime, mask=mask)
def __post_init__(self) -> None:
if not is_prime(self.prime):
raise ValueError(f"'prime' value {self.prime} is not prime")
@cached_property
def _inverse(self) -> int:
return modinv(self.prime, self.LIMIT + 1)
def encode(self, value: int) -> int:
return ((value * self.prime) & (2 ** 64 - 1)) ^ self.mask
def decode(self, value: int) -> int:
return ((value ^ self.mask) * self._inverse) & (2 ** 64 - 1)
from __future__ import annotations
from math import floor, log
from typing import TypeVar
Number = TypeVar("Number", int, float)
def clamp(*, minimum: Number, value: Number, maximum: Number) -> Number:
return max(minimum, min(value, maximum))
def xgcd(a: int, b: int) -> tuple[int, int, int]:
"""
Extended Euclidean algorithm
Return (g, x, y) such that a*x + b*y = g = gcd(a, b)
Source:
https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm#Python # noqa: B950
"""
x0, x1, y0, y1 = 0, 1, 1, 0
while a != 0:
(q, a), b = divmod(b, a), a
y0, y1 = y1, y0 - q * y1
x0, x1 = x1, x0 - q * x1
return b, x0, y0
def modinv(a: int, b: int) -> int:
"""
Modular multiplicative inverse
Return x such that (x * a) % b == 1
Source:
https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm#Python # noqa: B950
"""
g, x, _ = xgcd(a, b)
if g != 1:
raise ValueError("gcd(a, b) != 1")
return x % b
def is_prime(n: int) -> bool:
"""
Use the Miller Rabin primality checking algorithm with witnesses known to
deterministically cover all integers in the range 3 to 2**64.
Code and witness values from Wikipedia:
https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test
"""
if n <= 1:
raise ValueError("n must be greater than 1")
elif n >= 18_446_744_073_709_551_616:
raise ValueError("n too large")
if n in (2, 3):
return True
elif n % 2 == 0:
return False
# Rewrite n as 2**r * d + 1
r = 0
d = n - 1
while d % 2 == 0:
r += 1
d //= 2
# 'Witness' loop
witnesses = (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37)
max_witness = min(n - 2, floor(2 * log(n) ** 2))
for a in witnesses:
if a > max_witness:
continue
x = pow(base=a, exp=d, mod=n)
if x == 1 or x == n - 1:
continue
composite = True
for _ in range(r - 1):
x = pow(base=x, exp=2, mod=n)
if x == n - 1:
composite = False
break
if composite:
return False
return True
from __future__ import annotations
from unittest import mock
import pytest
from django.db import connection
from db_buddy.core.ext.django.models import (
HashedBigAutoField,
IntMapper,
)
from db_buddy.core.numerical import is_prime
from db_buddy.test import SimpleTestCase
class HashedBigAutoFieldTests(SimpleTestCase):
field = HashedBigAutoField(
int_mapper=IntMapper(prime=18319640022251334367, mask=12979739931064651139)
)
def test_descriptor_set_int(self):
class FakeModel:
_meta = mock.Mock(auto_field=False)
self.field.contribute_to_class(FakeModel, "field")
instance = FakeModel()
instance.field = 1234 # type: ignore[attr-defined]
assert instance.field == "y3DqG1zR620=" # type: ignore[attr-defined]
def test_descriptor_set_str(self):
class FakeModel:
_meta = mock.Mock(auto_field=False)
self.field.contribute_to_class(FakeModel, "field")
instance = FakeModel()
instance.field = "y3DqG1zR620=" # type: ignore[attr-defined]
assert instance.field == "y3DqG1zR620=" # type: ignore[attr-defined]
def test_deconstruct(self):
name, path, args, kwargs = self.field.deconstruct()
assert kwargs["int_mapper"] == self.field.int_mapper
def test_deconstruct_reconstruct(self):
name, path, args, kwargs = self.field.deconstruct()
HashedBigAutoField(*args, **kwargs) # type: ignore[arg-type]
def test_from_db_value_none(self):
result = self.field.from_db_value(
None, expression=self.field, connection=connection
)
assert result is None
def test_from_db_value_number(self):
result = self.field.from_db_value(
1234, expression=self.field, connection=connection
)
assert result == "y3DqG1zR620="
def test_get_prep_value_none(self):
result = self.field.get_prep_value(None)
assert result is None
def test_get_prep_value_integer(self):
result = self.field.get_prep_value(1234)
assert result == 1234
def test_get_prep_value_string(self):
result = self.field.get_prep_value("y3DqG1zR620=")
assert result == 1234
def test_to_python_none(self):
result = self.field.to_python(None)
assert result is None
def test_to_python_string(self):
result = self.field.to_python("y3DqG1zR620=")
assert result == "y3DqG1zR620="
def test_to_python_integer(self):
result = self.field.to_python(1234)
assert result == "y3DqG1zR620="
class IntMapperTests(SimpleTestCase):
def test_generate(self):
im = IntMapper.generate()
assert is_prime(im.prime)
def test_init_non_prime(self):
with pytest.raises(ValueError) as excinfo:
IntMapper(prime=4, mask=1)
assert excinfo.value.args == ("'prime' value 4 is not prime",)
def test_encoding(self):
im = IntMapper(prime=18319640022251334367, mask=12979739931064651139)
for i in range(1_000):
encoded = im.encode(i)
assert im.decode(encoded) == i
def test_deconstruct(self):
im = IntMapper(prime=18319640022251334367, mask=12979739931064651139)
path, args, kwargs = im.deconstruct() # type: ignore[attr-defined]
assert path == "db_buddy.core.ext.django.models.IntMapper"
assert args == ()
assert kwargs == {"prime": 18319640022251334367, "mask": 12979739931064651139}
from __future__ import annotations
import pytest
from db_buddy.core.numerical import clamp, is_prime, modinv
from db_buddy.test import SimpleTestCase
class ClampTests(SimpleTestCase):
def test_below_minimum(self):
result = clamp(minimum=1, value=-1, maximum=3)
assert result == 1
def test_in_range(self):
result = clamp(minimum=1, value=2, maximum=3)
assert result == 2
def test_above_maximum(self):
result = clamp(minimum=1, value=4, maximum=3)
assert result == 3
class ModInvTests(SimpleTestCase):
def test_valid(self):
assert modinv(31, 11) == 5
def test_invalid(self):
with pytest.raises(ValueError) as excinfo:
modinv(32, 10)
assert excinfo.value.args == ("gcd(a, b) != 1",)
class IsPrimeTests(SimpleTestCase):
def test_too_small(self):
with pytest.raises(ValueError) as excinfo:
is_prime(1)
assert excinfo.value.args == ("n must be greater than 1",)
def test_too_large(self):
with pytest.raises(ValueError) as excinfo:
is_prime(2 ** 70)
assert excinfo.value.args == ("n too large",)
def test_2(self):
assert is_prime(2)
def test_3(self):
assert is_prime(3)
def test_even(self):
assert not is_prime(4)
def test_first_100_primes(self):
first_primes = [x for x in range(2, 101) if is_prime(x)]
assert first_primes == [
2,
3,
5,
7,
11,
13,
17,
19,
23,
29,
31,
37,
41,
43,
47,
53,
59,
61,
67,
71,
73,
79,
83,
89,
97,
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment