Skip to content

Instantly share code, notes, and snippets.

@AH-dark
Last active July 31, 2023 18:29
Show Gist options
  • Save AH-dark/306d76d001ecd94d1796f64a5ea74df6 to your computer and use it in GitHub Desktop.
Save AH-dark/306d76d001ecd94d1796f64a5ea74df6 to your computer and use it in GitHub Desktop.
Miller–Rabin primality test
#include <cstdlib>
#include <ctime>
// Function to compute (a * b) % mod for large numbers (to prevent overflow)
long long fastMultiply(long long a, long long b, long long mod) {
long long res = 0;
while (b > 0) {
if (b & 1) {
res = (res + a) % mod;
}
a = (a * 2) % mod;
b >>= 1;
}
return res;
}
// Function to compute (base^exponent) % mod using fast exponentiation
long long fastExponentiation(long long base, long long exponent, long long mod) {
long long res = 1;
while (exponent > 0) {
if (exponent & 1) {
res = fastMultiply(res, base, mod);
}
base = fastMultiply(base, base, mod);
exponent >>= 1;
}
return res;
}
// Function to perform the Miller-Rabin primality test
bool isPrime(long long n, int iteration = 5) {
if (n < 4) {
return n == 2 || n == 3;
}
// Write (n - 1) as 2^r * d
// Continuously halve n - 1 until we get an odd number, this is d. The number of times we halve is r.
long long d = n - 1;
while (d % 2 == 0) {
d /= 2;
}
// Seed the random number generator for rand()
srand(time(0));
// Witness loop
for (int i = 0; i < iteration; i++) {
long long a = 2 + rand() % (n - 3); // Random number in [2, n - 2]
long long x = fastExponentiation(a, d, n);
// If x is not 1 and x is not n - 1, then n is definitely composite
if (x == 1 || x == n - 1) {
continue;
}
// Repeat r - 1 times
for (; d != n - 1; x = fastMultiply(x, x, n), d *= 2) {
// If x becomes 1, then n is definitely composite
if (x == 1) {
return false;
}
// If x becomes n - 1, then n might be prime, exit the loop and continue the witness loop
if (x == n - 1) {
break;
}
}
// If the loop finished normally without breaking, then n is composite
if (x != n - 1) {
return false;
}
}
// If no witness is found after all iterations, then n is probably prime
return true;
}
package main
import (
"math/big"
"math/rand"
"time"
)
// fastExponentiation calculates (base^exponent) % mod using big.Int in Go
func fastExponentiation(base, exponent, mod *big.Int) *big.Int {
var res = big.NewInt(1)
var zero = big.NewInt(0)
var one = big.NewInt(1)
var two = big.NewInt(2)
for exponent.Cmp(zero) > 0 {
if new(big.Int).Mod(exponent, two).Cmp(one) == 0 {
res.Mul(res, base)
res.Mod(res, mod)
}
base.Mul(base, base)
base.Mod(base, mod)
exponent.Div(exponent, two)
}
return res
}
// isPrime performs the Miller-Rabin primality test
func isPrime(n *big.Int, iteration int) bool {
if n.Cmp(big.NewInt(2)) < 0 {
return false
}
if n.Cmp(big.NewInt(3)) < 0 {
return true
}
d := new(big.Int).Sub(n, big.NewInt(1))
r := 0
for new(big.Int).Mod(d, big.NewInt(2)).Cmp(big.NewInt(0)) == 0 {
d.Div(d, big.NewInt(2))
r++
}
rand.Seed(time.Now().UnixNano())
for i := 0; i < iteration; i++ {
a := big.NewInt(2)
max := new(big.Int).Sub(n, big.NewInt(2))
a.Rand(rand.New(rand.NewSource(time.Now().UnixNano())), max)
a.Add(a, big.NewInt(2))
x := fastExponentiation(a, new(big.Int).Set(d), n)
if x.Cmp(big.NewInt(1)) == 0 || x.Cmp(new(big.Int).Sub(n, big.NewInt(1))) == 0 {
continue
}
cont := false
for j := 0; j < r-1; j++ {
x.Exp(x, big.NewInt(2), n)
if x.Cmp(big.NewInt(1)) == 0 {
return false
}
if x.Cmp(new(big.Int).Sub(n, big.NewInt(1))) == 0 {
cont = true
break
}
}
if cont {
continue
}
return false
}
return true
}
import random
def fast_multiply(a, b, mod):
""" Function to compute (a * b) % mod for large numbers (to prevent overflow) """
result = 0
while b > 0:
if b & 1:
result = (result + a) % mod
a = (a * 2) % mod
b >>= 1
return result
def fast_exponentiation(base, exponent, mod):
""" Function to compute (base^exponent) % mod using fast exponentiation """
result = 1
while exponent > 0:
if exponent & 1:
result = fast_multiply(result, base, mod)
base = fast_multiply(base, base, mod)
exponent >>= 1
return result
def is_prime(n, iterations=5):
""" Function to perform the Miller-Rabin primality test """
if n < 4:
return n == 2 or n == 3
# Write (n - 1) as 2^r * d
# Continuously halve n - 1 until we get an odd number, this is d. The number of times we halve is r.
d = n - 1
while d % 2 == 0:
d //= 2
# Witness loop
for _ in range(iterations):
a = random.randint(2, n - 2) # Random number in [2, n - 2]
x = fast_exponentiation(a, d, n)
if x == 1 or x == n - 1:
continue
# Repeat r - 1 times
while d != n - 1:
x = fast_multiply(x, x, n)
d *= 2
if x == n - 1:
break
elif x == 1:
return False
if x != n - 1:
return False
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment