Skip to content

Instantly share code, notes, and snippets.

@tdulcet
Last active June 10, 2024 15:37
Show Gist options
  • Save tdulcet/9644a87211d1f32c25741f240707fee3 to your computer and use it in GitHub Desktop.
Save tdulcet/9644a87211d1f32c25741f240707fee3 to your computer and use it in GitHub Desktop.
LL and PRP Primality Tests.
// Teal Dulcet
// Lucas-Lehmer (LL) Primality Test
// Support for arbitrary-precision integers requires the GNU Multiple Precision (GMP) library
// sudo apt-get update
// sudo apt-get install libgmp3-dev
// Compile: g++ -Wall -g -O3 -flto ll.cpp -o LL -lgmpxx -lgmp
// g++ -Wall -g -fsanitize=undefined ll.cpp -o LL -lgmpxx -lgmp
// Run: ./LL <NUMBER(S)>...
// time ./LL 3 5 7 13 17 19 31 61 89 107 127 521 607 1279 2203 2281 3217 4253 4423 9689 9941 11213 19937 21701 23209 44497 86243 110503 132049 216091 756839 859433 1257787 1398269 2976221 3021377 6972593
#include <iostream>
#include <cmath>
#include <cinttypes>
#include <gmpxx.h>
#include <chrono>
using namespace std;
int jacobi(mpz_class &exp, mpz_class &words)
{
mpz_class w = words - 2;
return mpz_jacobi(w.get_mpz_t(), exp.get_mpz_t());
}
template <class T>
constexpr T rotl(const T &value, const uintmax_t count, const uintmax_t p, const T &n)
{
return ((value << count) & n) | (value >> (p - count));
}
template <class T>
constexpr T rotr(const T &value, const uintmax_t count, const uintmax_t p, const T &n)
{
return (value >> count) | (value << (p - count) & n);
}
void isPrime(const uintmax_t p)
{
if (p < 3)
{
cerr << "Error: Number must be > 2";
return;
}
const uintmax_t iters = p - 2;
const uintmax_t shift = rand() % p;
uintmax_t ashift = shift;
// mpz_class checkNumber = pow(2, p) - 1;
// mpz_class checkNumber;
// mpz_ui_pow_ui(checkNumber.get_mpz_t(), 2, p);
// --checkNumber;
mpz_class checkNumber = (mpz_class(1) << p) - 1;
auto start = chrono::steady_clock::now();
mpz_class nextval = 4;
nextval = rotl(nextval, ashift, p, checkNumber);
for (uintmax_t i = 0; i < iters; ++i)
{
ashift = (ashift << 1) % p;
nextval = (nextval * nextval - (mpz_class(2) << ashift)) % checkNumber;
}
nextval = rotr(nextval, ashift, p, checkNumber);
auto end = chrono::steady_clock::now();
auto totaltime = chrono::duration_cast<chrono::microseconds>(end - start);
mpz_class result;
mpz_tdiv_r_2exp(result.get_mpz_t(), nextval.get_mpz_t(), 64);
// mpz_class result = (nextval & ((mpz_class(1) << 64) - 1));
gmp_printf("%#018ZX", result.get_mpz_t());
cout << "\t";
if (nextval == 0)
cout << "Mersenne prime!";
else
cout << "Composite (Not prime)";
cout << "\tShift " << shift << "\t\t" << (totaltime / iters).count() << " µs/iter" << endl;
start = chrono::steady_clock::now();
const int ajacobi = jacobi(checkNumber, nextval);
end = chrono::steady_clock::now();
totaltime = chrono::duration_cast<chrono::microseconds>(end - start);
cout << "\tJacobi " << ajacobi << " (" << (ajacobi == -1 ? "Passed" : "Failed") << ")\t\t" << totaltime.count() << " µs";
}
int main(int argc, char *argv[])
{
int frombase = 0;
for (int i = 1; i < argc; ++i)
{
const uintmax_t ll = strtoumax(argv[i], NULL, frombase);
if (errno == ERANGE)
{
cerr << "Error: Integer number too large to input: '" << argv[i] << "' (" << strerror(errno) << ").\n";
return 1;
}
cout << "2^" << ll << " - 1:\t";
isPrime(ll);
cout << endl;
}
return 0;
}
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Teal Dulcet
# Lucas-Lehmer (LL) Primality Test
# Run: python3 -OO LL.py <p> [iterations] [shift]
# time for i in 3 5 7 13 17 19 31 61 89 107 127 521 607 1279 2203 2281 3217 4253 4423 9689 9941 11213 19937 21701 23209 44497 86243 110503 132049 216091 756839 859433 1257787 1398269 2976221 3021377 6972593; do python3 -X dev LL.py "$i"; done
from __future__ import division, print_function, unicode_literals
import random
import sys
import timeit
# Adapted from: https://rosettacode.org/wiki/Jacobi_symbol#Python
def jacobi(a, n):
"""Returns the Jacobi symbol (a/n), where n is an odd integer."""
if n <= 0:
msg = "'n' must be a positive integer."
raise ValueError(msg)
if not n & 1:
msg = "'n' must be odd."
raise ValueError(msg)
a %= n
result = 1
while a:
while not a & 1:
a >>= 1
if n & 7 in {3, 5}:
result = -result
a, n = n, a
if a & 3 == n & 3 == 3:
result = -result
a %= n
return result if n == 1 else 0
def rotl(value, count, p, n):
return ((value << count) & n) | (value >> (p - count))
def rotr(value, count, p, n):
return (value >> count) | (value << (p - count) & n)
if not 2 <= len(sys.argv) <= 4:
print("Usage: " + sys.argv[0] + " <p> [iterations] [shift]", file=sys.stderr)
sys.exit(1)
random.seed(0)
p = int(sys.argv[1], 0)
if p < 3:
print("Error: Number must be > 2", file=sys.stderr)
sys.exit(1)
# Iterations
j = int(sys.argv[2], 0) if len(sys.argv) >= 3 and sys.argv[2] else p - 2
if not 0 < j <= p - 2:
sys.exit(1)
# Shift
shift = ashift = int(sys.argv[3], 0) % p if len(sys.argv) == 4 and sys.argv[3] else random.randint(0, p - 1)
# n = 2 ** p - 1
n = (1 << p) - 1
start = timeit.default_timer()
s = rotl(4, ashift, p, n)
for _ in range(j):
ashift = (ashift << 1) % p
s = (s * s - (2 << ashift)) % n
s = rotr(s, ashift, p, n)
end = timeit.default_timer()
totaltime = (end - start) * 1000000
print(
"2^{0} - 1:\t{1:#018X}{2}\tShift {3:n}\t\t{4:.1f} µs/iter".format(
p,
s & 0xFFFFFFFFFFFFFFFF,
"\t" + ("Mersenne prime!" if not s else "Composite (Not prime)") if j == p - 2 else "",
shift,
totaltime / j,
)
)
start = timeit.default_timer()
ajacobi = jacobi(s - 2, n)
end = timeit.default_timer()
totaltime = (end - start) * 1000000
print("\tJacobi {0:n} ({1})\t\t{2:.1f} µs".format(ajacobi, "Passed" if ajacobi == -1 else "Failed", totaltime))
// Teal Dulcet
// Fermat PRobable Prime (PRP) Test
// Support for arbitrary-precision integers requires the GNU Multiple Precision (GMP) library
// sudo apt-get update
// sudo apt-get install libgmp3-dev
// Compile: g++ -Wall -g -O3 -flto prp.cpp -o PRP -lgmpxx -lgmp
// g++ -Wall -g -fsanitize=undefined prp.cpp -o PRP -lgmpxx -lgmp
// Run: ./PRP <NUMBER(S)>...
// time ./PRP 3 5 7 13 17 19 31 61 89 107 127 521 607 1279 2203 2281 3217 4253 4423 9689 9941 11213 19937 21701 23209 44497 86243 110503 132049 216091 756839 859433 1257787 1398269 2976221 3021377 6972593
#include <iostream>
#include <cmath>
#include <cinttypes>
#include <gmpxx.h>
#include <chrono>
// #include <cassert>
using namespace std;
// PRP Base
const int a = 3;
// PRP Residue Type
const int rt = 1;
template <class T>
constexpr T rotl(const T &value, const uintmax_t count, const uintmax_t p, const T &n)
{
return ((value << count) & n) | (value >> (p - count));
}
template <class T>
constexpr T rotr(const T &value, const uintmax_t count, const uintmax_t p, const T &n)
{
return (value >> count) | (value << (p - count) & n);
}
void isPrime(const uintmax_t p)
{
if (p < 3)
{
cerr << "Error: Number must be > 2";
return;
}
const uintmax_t iters = rt == 2 or rt == 4 ? p - 1 : p;
const uintmax_t shift = rand() % p;
uintmax_t ashift = shift;
// mpz_class checkNumber = pow(2, p) - 1;
// mpz_class checkNumber;
// mpz_ui_pow_ui(checkNumber.get_mpz_t(), 2, p);
// --checkNumber;
mpz_class checkNumber = (mpz_class(1) << p) - 1;
const uintmax_t L = sqrt(iters);
mpz_class d = a, prev_d;
auto start = chrono::steady_clock::now();
mpz_class nextval = a;
nextval = rotl(nextval, ashift, p, checkNumber);
for (uintmax_t i = 0; i < iters; ++i)
{
if (i and i % L == 0)
{
prev_d = d;
d = prev_d * rotr(nextval, ashift, p, checkNumber) % checkNumber;
}
ashift = (ashift << 1) % p;
nextval = (nextval * nextval) % checkNumber;
}
nextval = rotr(nextval, ashift, p, checkNumber);
auto end = chrono::steady_clock::now();
auto totaltime = chrono::duration_cast<chrono::microseconds>(end - start);
bool prime = false;
switch (rt)
{
case 1:
case 5:
{
const int a2 = a * a;
mpz_class r = nextval % a2;
if (r != 0)
{
mpz_class temp;
mpz_invert(temp.get_mpz_t(), mpz_class(checkNumber % a2).get_mpz_t(), mpz_class(a2).get_mpz_t());
nextval += (a2 - r * temp % a2) * checkNumber;
}
// assert(nextval % a2 == 0);
nextval /= a2;
// prime = nextval == 1 % checkNumber;
prime = nextval == 1;
break;
}
case 2:
{
mpz_class r = nextval % a;
if (r != 0)
{
mpz_class temp;
mpz_invert(temp.get_mpz_t(), mpz_class(checkNumber % a).get_mpz_t(), mpz_class(a).get_mpz_t());
nextval += (a - r * temp % a) * checkNumber;
}
// assert(nextval % a == 0);
nextval /= a;
// prime = nextval == 1 % checkNumber or nextval == -1 % checkNumber;
prime = checkNumber - nextval == 1;
break;
}
case 3:
prime = nextval == (a * a) % checkNumber;
break;
case 4:
// prime = nextval == a % checkNumber or nextval == -a % checkNumber;
prime = checkNumber - nextval == a;
break;
}
mpz_class result;
mpz_tdiv_r_2exp(result.get_mpz_t(), nextval.get_mpz_t(), 64);
// mpz_class result = (nextval & ((mpz_class(1) << 64) - 1));
gmp_printf("%#018ZX", result.get_mpz_t());
cout << "\t";
if (prime)
cout << "Probable prime!";
else
cout << "Composite (Not prime)";
cout << "\tShift " << shift << "\t\t" << (totaltime / iters).count() << " µs/iter" << endl;
start = chrono::steady_clock::now();
mpz_class temp1 = mpz_class(1) << L;
// mpz_ui_pow_ui(temp1.get_mpz_t(), 2, L);
mpz_class temp2;
mpz_powm(temp2.get_mpz_t(), prev_d.get_mpz_t(), temp1.get_mpz_t(), checkNumber.get_mpz_t());
temp2 = a * temp2 % checkNumber;
const bool gerbicz = d == temp2;
end = chrono::steady_clock::now();
totaltime = chrono::duration_cast<chrono::microseconds>(end - start);
cout << "\tGerbicz " << (gerbicz ? "Passed" : "Failed") << "\tIteration " << L * L << "\t\t" << totaltime.count() << " µs";
}
int main(int argc, char *argv[])
{
int frombase = 0;
for (int i = 1; i < argc; ++i)
{
const uintmax_t ll = strtoumax(argv[i], NULL, frombase);
if (errno == ERANGE)
{
cerr << "Error: Integer number too large to input: '" << argv[i] << "' (" << strerror(errno) << ").\n";
return 1;
}
cout << "2^" << ll << " - 1:\t";
isPrime(ll);
cout << endl;
}
return 0;
}
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Teal Dulcet
# Fermat PRobable Prime (PRP) Test
# Run: python3 -OO PRP.py <p> [iterations] [shift]
# time for i in 3 5 7 13 17 19 31 61 89 107 127 521 607 1279 2203 2281 3217 4253 4423 9689 9941 11213 19937 21701 23209 44497 86243 110503 132049 216091 756839 859433 1257787 1398269 2976221 3021377 6972593; do python3 -X dev PRP.py "$i"; done
from __future__ import division, print_function, unicode_literals
import random
import sys
import timeit
try:
# Python 3.8+
from math import isqrt
except ImportError:
from math import sqrt
def isqrt(x):
return int(sqrt(x))
def rotl(value, count, p, n):
return ((value << count) & n) | (value >> (p - count))
def rotr(value, count, p, n):
return (value >> count) | (value << (p - count) & n)
if not 2 <= len(sys.argv) <= 4:
print("Usage: " + sys.argv[0] + " <p> [iterations] [shift]", file=sys.stderr)
sys.exit(1)
# PRP Base
a = 3
# PRP Residue Type
rt = 1
random.seed(0)
p = int(sys.argv[1], 0)
if p < 3:
print("Error: Number must be > 2", file=sys.stderr)
sys.exit(1)
# Iterations
iters = p - 1 if rt in {2, 4} else p
j = int(sys.argv[2], 0) if len(sys.argv) >= 3 and sys.argv[2] else iters
if not 0 < j <= iters:
sys.exit(1)
# Shift
shift = ashift = int(sys.argv[3], 0) % p if len(sys.argv) == 4 and sys.argv[3] else random.randint(0, p - 1)
# n = 2 ** p - 1
n = (1 << p) - 1
L = isqrt(iters)
L2 = L * L
d = a
start = timeit.default_timer()
s = rotl(a, ashift, p, n)
for i in range(j):
if i and not i % L:
prev_d = d
d = prev_d * rotr(s, ashift, p, n) % n
ashift = (ashift << 1) % p
s = (s * s) % n
s = rotr(s, ashift, p, n)
end = timeit.default_timer()
totaltime = (end - start) * 1000000
if j == iters:
if rt in {1, 5}:
# s = s * pow(a, -2, n) % n
a2 = a * a
r = s % a2
if r:
s += (a2 - r * pow(n % a2, -1, a2) % a2) * n
# assert(s % a2 == 0)
s //= a2
# prime = s == 1 % n
prime = s == 1
elif rt == 2:
# s = s * pow(a, -1, n) % n
r = s % a
if r:
s += (a - r * pow(n % a, -1, a) % a) * n
# assert(s % a == 0)
s //= a
# prime = s == 1 % n or s == -1 % n
prime = n - s == 1
elif rt == 3:
prime = s == (a * a) % n
elif rt == 4:
# prime = s == a % n or s == -a % n
prime = n - s == a
print(
"2^{0} - 1:\t{1:#018X}{2}\tShift {3:n}\t\t{4:.1f} µs/iter".format(
p,
s & 0xFFFFFFFFFFFFFFFF,
"\t" + ("Probable prime!" if prime else "Composite (Not prime)") if j == iters else "",
shift,
totaltime / j,
)
)
if j >= L2:
start = timeit.default_timer()
gerbicz = d == a * pow(prev_d, 1 << L, n) % n
# gerbicz = d == a * prev_d**(2**L) % n
end = timeit.default_timer()
totaltime = (end - start) * 1000000
print("\tGerbicz {0}\tIteration {1:n}\t\t{2:.1f} µs".format("Passed" if gerbicz else "Failed", L2, totaltime))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment