Last active
January 17, 2025 10:56
The Lagarias-Miller-Odlyzko prime-counting algorithm in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
# This is a Python implementation of the Lagarias-Miller-Odlyzko prime-counting algorithm. | |
# See https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777285-5/S0025-5718-1985-0777285-5.pdf for details. | |
# The P2 section is entirely my own, while the S1 and S2 parts are translated from Kim Walisch's primecount package: | |
# https://github.com/kimwalisch/primecount/blob/master/src/lmo/pi_lmo4.cpp. | |
from sys import argv | |
from time import time | |
from math import isqrt | |
def icbrt(n): | |
# This function computes the cube root of the integer n, rounded towards zero. | |
if n < 0: return -icbrt(-n) | |
if n < 2: return n | |
lower = upper = 1 << (n.bit_length() // 3) | |
while lower ** 3 > n: lower >>= 2 | |
while upper ** 3 <= n: upper <<= 2 | |
while lower != upper - 1: | |
mid = (lower + upper) // 2 | |
m = mid**3 | |
if m == n: return mid | |
elif m < n: lower = mid | |
elif m > n: upper = mid | |
return lower | |
def pi(x, alpha=2.5, verbose=False): | |
if x < 20: return (0,0,1,2,2,3,3,4,4,4,4,5,5,6,6,6,6,7,7,8)[x] | |
if verbose: print("x: ", x) | |
starttime = time() | |
# Here beginneth the P2 section of the LMO algorithm. | |
# I wrote it myself from scratch in accordance with the LMO paper. | |
x13 = icbrt(x) | |
x12 = isqrt(x) | |
x23 = icbrt(x*x) | |
if verbose: print("x23: ", x23) | |
if verbose: print("x12: ", x12) | |
if verbose: print("x13: ", x13) | |
assert x13**3 <= x < (x13+1)**3 | |
if verbose: print("alpha: ", alpha) | |
y = int(x13 * alpha) | |
if verbose: print("y: ", y) | |
sieve = bytearray([True]) * (y+1) | |
sieve[0] = sieve[1] = False | |
for p in range(isqrt(y)+1): | |
if not sieve[p]: continue | |
for n in range(p*2, y+1, p): | |
sieve[n] = False | |
pi_y = sum(sieve) | |
primes = [0] * pi_y | |
k = 0 | |
for n in range(y+1): | |
if sieve[n]: | |
primes[k] = n | |
k += 1 | |
del sieve | |
if verbose: print("pi_y: ", pi_y) | |
if verbose: print("pi_y time: ", time() - starttime) | |
# At this point, primes is a list in increasing order of all primes <= y. | |
P2 = 0 | |
j = 2 | |
pilist = [0] * y | |
piprev = pi_y # piprev stores the number of primes found as of the end of the last sieving pass. | |
while (j-1)*y + 1 <= x23: | |
# Sieve the interval [(j-1)*y + 1 , j*y] | |
lo = (j-1) * y + 1 | |
hi = min(j*y, x23) | |
sl = hi - lo + 1 # "sl" stands for "sieve length". | |
sieve = bytearray([True]) * sl | |
# sieve[0] corresponds to lo. | |
# sieve[sl] corresponds to hi. | |
for p in primes: | |
start = (-lo) % p | |
if lo + start == p: start += p # This line ensures that only proper multiples of p get sieved out. | |
for n in range(start, sl, p): | |
sieve[n] = False | |
pinew = piprev | |
for n in range(sl): | |
pinew += sieve[n] | |
pilist[n] = pinew | |
# If sl == y, then at this point, we have pilist == [pi(k) for k in range((j-1)*y+1, j*y+1)], | |
# so that pilist[l] == pi((j-1)*y+1 + l), or pi(k) == pilist[k - (j-1)*y - 1]. | |
# Also, pinew == pilist[-1] == pi(j*y). | |
# If sl < y, then the obvious modifications can be made to the above statements to make them true. | |
if lo <= x12 <= hi: pi_12 = pilist[x12 - lo] | |
Ij_lo = max(x // ( j *y + 1), y ) + 1 | |
Ij_hi = min(x // ((j-1)*y + 1), x12) | |
sl = Ij_hi - Ij_lo + 1 | |
sieve = bytearray([True]) * sl | |
for p in primes: | |
if p**2 > Ij_hi: break | |
start = (-Ij_lo) % p | |
if Ij_lo + start == p: start += p # This line ensures that only proper multiples of p get sieved out. | |
for n in range(start, sl, p): | |
sieve[n] = False | |
P2 += sum(pilist[x // (Ij_lo + n) - (j-1)*y - 1] for n in range(sl) if sieve[n]) | |
piprev = pinew | |
j += 1 | |
del pilist, sieve | |
P2 += (pi_y * (pi_y-1) - pi_12 * (pi_12-1)) // 2 | |
# Here endeth the P2 section of the LMO algorithm. | |
if verbose: print("P2 time: ", time() - starttime) | |
if verbose: print("P2: ", P2) | |
starttime = time() | |
# Now we precompute the least-prime-factor and Mobius functions. | |
lpf = [0] * (y + 1) | |
mu = [1] * (y + 1) | |
for p in reversed(primes): # The reversal is important for lpf; for mu, it is irrelevant. | |
for n in range(p, y+1, p): | |
lpf[n] = p | |
mu[n] *= -1 | |
for n in range(p*p, y+1, p*p): mu[n] = 0 | |
primes = [0] + primes # We need the list to be 1-based for the rest of the function. | |
if verbose: print("LPF & mu: ", time() - starttime) | |
starttime = time() | |
# Here beginneth the S1 section of the LMO algorithm, translated from | |
# https://github.com/kimwalisch/primecount/blob/master/src/S1.cpp. | |
c = 8 if y >= 20 else (0,0,1,2,2,3,3,4,4,4,4,5,5,6,6,6,6,7,7,8)[y] | |
if verbose: print("c: ", c) | |
S1 = primephi(x, c, primes) | |
for b in range(c+1, len(primes)): | |
S1 -= primephi(x//primes[b], c, primes) | |
S1 += primepi_S1(x, y, b, c, primes[b], primes, 1) | |
# Here endeth the S1 section of the LMO algorithm. | |
if verbose: print("S1 time: ", time() - starttime) | |
if verbose: print("S1: ", S1) | |
starttime = time() | |
# Here beginneth the S2 section of the LMO algorithm, translated from | |
# https://github.com/kimwalisch/primecount/blob/master/src/lmo/pi_lmo4.cpp. | |
limit = x // y | |
lr = isqrt(limit) | |
segment_size = lr if (lr & (lr - 1) == 0) else (1 << lr.bit_length()) # the least power of 2 >= lr | |
S2 = 0 | |
next_ = primes[:] | |
phi = [0] * len(primes) | |
for low in range(1, limit, segment_size): | |
high = min(low + segment_size, limit) | |
sieve = bytearray([1]) * segment_size | |
# The current segment of the sieve is [low, high). | |
for b in range(1, c+1): | |
k = next_[b] | |
prime = primes[b] | |
while k < high: | |
sieve[k - low] = 0 | |
k += prime | |
next_[b] = k | |
# Initialize the Fenwick tree | |
treesize = len(sieve) // 2 | |
tree = [0] * treesize | |
for i in range(0, treesize): | |
tree[i] = sieve[i*2] | |
k = ((i + 1) & (~i)) >> 1 # (i+1) & (~i) is the number that, when ored into i, would set i's lowest unset bit. | |
j = i | |
while k != 0: | |
tree[i] += tree[j - 1] | |
j &= j - 1 # clears the lowest set bit | |
k >>= 1 | |
for b in range(c+1, pi_y): | |
prime = primes[b] | |
min_m = max(x // (prime * high), y // prime) | |
max_m = min(x // (prime * low ), y ) | |
if prime >= max_m: break | |
for m in range(max_m, min_m, -1): | |
if mu[m] != 0 and prime < lpf[m]: | |
n = prime * m | |
# phi_xn = phi[b] + tree.count | |
pos = (x//n - low) >> 1 | |
phi_xn = phi[b] + tree[pos] | |
pos += 1 | |
while True: | |
pos &= pos - 1 # clears the lowest set bit | |
if pos == 0: break | |
phi_xn += tree[pos - 1] | |
S2 -= mu[m] * phi_xn | |
# phi[b] += tree.count | |
pos = (high - 1 - low) >> 1 | |
phi[b] += tree[pos] | |
pos += 1 | |
while True: | |
pos &= pos - 1 # clears the lowest set bit | |
if pos == 0: break | |
phi[b] += tree[pos - 1] | |
# cross_off | |
m = next_[b] | |
while m < high: | |
if sieve[m - low]: | |
pos = m - low | |
sieve[pos] = 0 | |
pos >>= 1 | |
while True: | |
tree[pos] -= 1 | |
pos |= pos + 1 # sets the lowest unset bit | |
if pos >= treesize: break | |
m += prime * 2 | |
next_[b] = m | |
# Here endeth the S2 section of the LMO algorithm. | |
if verbose: print("S2 time: ", time() - starttime) | |
if verbose: print("S2: ", S2) | |
if verbose: print("phi: ", S1 + S2) | |
return pi_y - 1 - P2 + S1 + S2 | |
def primepi_S1(x, y, b, c, square_free, primes, mob): | |
S1 = 0 | |
b += 1 | |
while b < len(primes): | |
next_ = square_free * primes[b] | |
if next_ > y: break | |
S1 += mob * primephi(x//next_, c, primes) | |
S1 += primepi_S1(x, y, b, c, next_, primes, -mob) | |
b += 1 | |
return S1 | |
def primephi(x, a, primes): | |
# This is the number of positive integers <= x with no prime factor <= the ath prime. | |
# The list of primes is indexed so that primes[1] == 2. | |
if a == 0: return x | |
if a == 1: return (x + 1) // 2 | |
if a == 2: return 2 * (x // 6) + (0,1,1,1,1,2)[x%6] | |
if a == 3: return 8 * (x // 30) + (0,1,1,1,1,1,1,2,2,2,2,3,3,4,4,4,4,5,5,6,6,6,6,7,7,7,7,7,7,8)[x%30] | |
pa = primes[a] | |
if x < pa: return 1 | |
answer = primephi(x, a-1, primes) - primephi(x//pa, a-1, primes) | |
return answer | |
x = int(argv[1]) | |
try: alpha = float(argv[2]) | |
except: alpha = 1 | |
print(pi(x, alpha=alpha, verbose=True)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment