Skip to content

Instantly share code, notes, and snippets.

@gchenfc
Created January 28, 2022 23:33
Show Gist options
  • Save gchenfc/a0efa92e954a609bf031f7da4cc8dd70 to your computer and use it in GitHub Desktop.
Save gchenfc/a0efa92e954a609bf031f7da4cc8dd70 to your computer and use it in GitHub Desktop.
Simple implementation of the Miller-Rabin primality test.
"""Simple implementation of the Miller-Rabin primality test.
Usage:
```
from primetest import is_prime
print(is_prime(17)) # True
print(is_prime(12345678910987654321)) # True
print(is_prime(46)) # False
```
See: https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test#Testing_against_small_sets_of_bases
Author: Gerry Chen
Date: Jan 28, 2022
"""
import itertools
def _is_composite(n, a, d, s):
"""Returns True if n is definitely composite, returns False if unsure.
Args: n = d*(2^s) + 1 is the number to test, a is the witness prime."""
x = pow(a, d, n)
return (x != 1) and (x != n - 1) and (not any((x := (x**2 % n)) == n - 1 for _ in range(s - 1)))
def is_prime(n):
"""Returns True iff n is prime.
Raises: RuntimeError if n is too big to guarantee correct result. Use `is_probably_prime`."""
if n in _known_primes:
return True
if any((n % p) == 0 for p in _known_primes) or n in (0, 1):
return False
if n >= _thresholds[-1]:
raise RuntimeError('Cannot guarantee primes larger than ' + str(_thresholds[-1]))
prime_witnesses = (p for th, p in zip(_thresholds, _known_primes) if th < n)
# begin Miller-Rabin test
d, s = n - 1, 0
while d % 2 == 0: # bitwise ops are slower for small numbers, not much faster for large numbers
d, s = d // 2, s + 1
return not any(_is_composite(n, a, d, s) for a in prime_witnesses)
# https://oeis.org/A014233
# Smallest odd number false positive for Miller-Rabin primality test using bases <= n-th prime.
_thresholds = (0, 2047, 1373653, 25326001, 3215031751, 2152302898747, 3474749660383,
341550071728321, 341550071728321, 3825123056546413051, 3825123056546413051,
3825123056546413051, 318665857834031151167461, 3317044064679887385961981)
_known_primes = [2, 3]
_known_primes += [x for x in range(5, 1000, 2) if is_prime(x)]
def is_probably_prime(n, num_witnesses=20):
"""Returns False if n is definitely not prime. Returns True if n is probably prime. More
witnesses (bases) decreases the probability of false-positives.
Note: this uses the smallest num_witnesses primes instead of sampling randomly.
Args: num_witnesses (optional): The number of prime witnesses to use (max 168). Defaults to 20.
"""
if n in _known_primes:
return True
if any((n % p) == 0 for p in _known_primes) or n in (0, 1):
return False
primes = [
p
for th, p in itertools.zip_longest(_thresholds, _known_primes[:num_witnesses], fillvalue=0)
if th < n
]
# begin Miller-Rabin test
d, s = n - 1, 0
while d % 2 == 0: # bitwise ops are slower for small numbers, not much faster for large numbers
d, s = d // 2, s + 1
primes = (p for th, p in itertools.zip_longest(_thresholds, primes, fillvalue=0) if th < n)
return not any(_is_composite(n, a, d, s) for a in primes)
if __name__ == '__main__':
# Unit tests and timing
import time
t = time.time()
for _ in range(30):
assert len(_known_primes) == 168, 'Known primes list is incorrect'
assert len([x for x in range(1000) if is_prime(x)]) == 168, 'Some prime < 1000 is incorrect'
assert [x for x in range(901, 1000) if is_prime(x)
] == [907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]
assert is_probably_prime(4547337172376300111955330758342147474062293202868155909489)
assert not is_probably_prime(4547337172376300111955330758342147474062293202868155909393)
assert is_probably_prime(
643808006803554439230129854961492699151386107534013432918073439524138264842370630061369715394739134090922937332590384720397133335969549256322620979036686633213903952966175107096769180017646161851573147596390153
)
assert not is_probably_prime(
743808006803554439230129854961492699151386107534013432918073439524138264842370630061369715394739134090922937332590384720397133335969549256322620979036686633213903952966175107096769180017646161851573147596390153
)
print('ran in {:}s'.format(time.time() - t))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment