Skip to content

Instantly share code, notes, and snippets.

@lucasaugustus
Last active January 17, 2025 10:56
The Lagarias-Miller-Odlyzko prime-counting algorithm in Python
#! /usr/bin/env python3
# This is a Python implementation of the Lagarias-Miller-Odlyzko prime-counting algorithm.
# See https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777285-5/S0025-5718-1985-0777285-5.pdf for details.
# The P2 section is entirely my own, while the S1 and S2 parts are translated from Kim Walisch's primecount package:
# https://github.com/kimwalisch/primecount/blob/master/src/lmo/pi_lmo4.cpp.
from sys import argv
from time import time
from math import isqrt
def icbrt(n):
# This function computes the cube root of the integer n, rounded towards zero.
if n < 0: return -icbrt(-n)
if n < 2: return n
lower = upper = 1 << (n.bit_length() // 3)
while lower ** 3 > n: lower >>= 2
while upper ** 3 <= n: upper <<= 2
while lower != upper - 1:
mid = (lower + upper) // 2
m = mid**3
if m == n: return mid
elif m < n: lower = mid
elif m > n: upper = mid
return lower
def pi(x, alpha=2.5, verbose=False):
if x < 20: return (0,0,1,2,2,3,3,4,4,4,4,5,5,6,6,6,6,7,7,8)[x]
if verbose: print("x: ", x)
starttime = time()
# Here beginneth the P2 section of the LMO algorithm.
# I wrote it myself from scratch in accordance with the LMO paper.
x13 = icbrt(x)
x12 = isqrt(x)
x23 = icbrt(x*x)
if verbose: print("x23: ", x23)
if verbose: print("x12: ", x12)
if verbose: print("x13: ", x13)
assert x13**3 <= x < (x13+1)**3
if verbose: print("alpha: ", alpha)
y = int(x13 * alpha)
if verbose: print("y: ", y)
sieve = bytearray([True]) * (y+1)
sieve[0] = sieve[1] = False
for p in range(isqrt(y)+1):
if not sieve[p]: continue
for n in range(p*2, y+1, p):
sieve[n] = False
pi_y = sum(sieve)
primes = [0] * pi_y
k = 0
for n in range(y+1):
if sieve[n]:
primes[k] = n
k += 1
del sieve
if verbose: print("pi_y: ", pi_y)
if verbose: print("pi_y time: ", time() - starttime)
# At this point, primes is a list in increasing order of all primes <= y.
P2 = 0
j = 2
pilist = [0] * y
piprev = pi_y # piprev stores the number of primes found as of the end of the last sieving pass.
while (j-1)*y + 1 <= x23:
# Sieve the interval [(j-1)*y + 1 , j*y]
lo = (j-1) * y + 1
hi = min(j*y, x23)
sl = hi - lo + 1 # "sl" stands for "sieve length".
sieve = bytearray([True]) * sl
# sieve[0] corresponds to lo.
# sieve[sl] corresponds to hi.
for p in primes:
start = (-lo) % p
if lo + start == p: start += p # This line ensures that only proper multiples of p get sieved out.
for n in range(start, sl, p):
sieve[n] = False
pinew = piprev
for n in range(sl):
pinew += sieve[n]
pilist[n] = pinew
# If sl == y, then at this point, we have pilist == [pi(k) for k in range((j-1)*y+1, j*y+1)],
# so that pilist[l] == pi((j-1)*y+1 + l), or pi(k) == pilist[k - (j-1)*y - 1].
# Also, pinew == pilist[-1] == pi(j*y).
# If sl < y, then the obvious modifications can be made to the above statements to make them true.
if lo <= x12 <= hi: pi_12 = pilist[x12 - lo]
Ij_lo = max(x // ( j *y + 1), y ) + 1
Ij_hi = min(x // ((j-1)*y + 1), x12)
sl = Ij_hi - Ij_lo + 1
sieve = bytearray([True]) * sl
for p in primes:
if p**2 > Ij_hi: break
start = (-Ij_lo) % p
if Ij_lo + start == p: start += p # This line ensures that only proper multiples of p get sieved out.
for n in range(start, sl, p):
sieve[n] = False
P2 += sum(pilist[x // (Ij_lo + n) - (j-1)*y - 1] for n in range(sl) if sieve[n])
piprev = pinew
j += 1
del pilist, sieve
P2 += (pi_y * (pi_y-1) - pi_12 * (pi_12-1)) // 2
# Here endeth the P2 section of the LMO algorithm.
if verbose: print("P2 time: ", time() - starttime)
if verbose: print("P2: ", P2)
starttime = time()
# Now we precompute the least-prime-factor and Mobius functions.
lpf = [0] * (y + 1)
mu = [1] * (y + 1)
for p in reversed(primes): # The reversal is important for lpf; for mu, it is irrelevant.
for n in range(p, y+1, p):
lpf[n] = p
mu[n] *= -1
for n in range(p*p, y+1, p*p): mu[n] = 0
primes = [0] + primes # We need the list to be 1-based for the rest of the function.
if verbose: print("LPF & mu: ", time() - starttime)
starttime = time()
# Here beginneth the S1 section of the LMO algorithm, translated from
# https://github.com/kimwalisch/primecount/blob/master/src/S1.cpp.
c = 8 if y >= 20 else (0,0,1,2,2,3,3,4,4,4,4,5,5,6,6,6,6,7,7,8)[y]
if verbose: print("c: ", c)
S1 = primephi(x, c, primes)
for b in range(c+1, len(primes)):
S1 -= primephi(x//primes[b], c, primes)
S1 += primepi_S1(x, y, b, c, primes[b], primes, 1)
# Here endeth the S1 section of the LMO algorithm.
if verbose: print("S1 time: ", time() - starttime)
if verbose: print("S1: ", S1)
starttime = time()
# Here beginneth the S2 section of the LMO algorithm, translated from
# https://github.com/kimwalisch/primecount/blob/master/src/lmo/pi_lmo4.cpp.
limit = x // y
lr = isqrt(limit)
segment_size = lr if (lr & (lr - 1) == 0) else (1 << lr.bit_length()) # the least power of 2 >= lr
S2 = 0
next_ = primes[:]
phi = [0] * len(primes)
for low in range(1, limit, segment_size):
high = min(low + segment_size, limit)
sieve = bytearray([1]) * segment_size
# The current segment of the sieve is [low, high).
for b in range(1, c+1):
k = next_[b]
prime = primes[b]
while k < high:
sieve[k - low] = 0
k += prime
next_[b] = k
# Initialize the Fenwick tree
treesize = len(sieve) // 2
tree = [0] * treesize
for i in range(0, treesize):
tree[i] = sieve[i*2]
k = ((i + 1) & (~i)) >> 1 # (i+1) & (~i) is the number that, when ored into i, would set i's lowest unset bit.
j = i
while k != 0:
tree[i] += tree[j - 1]
j &= j - 1 # clears the lowest set bit
k >>= 1
for b in range(c+1, pi_y):
prime = primes[b]
min_m = max(x // (prime * high), y // prime)
max_m = min(x // (prime * low ), y )
if prime >= max_m: break
for m in range(max_m, min_m, -1):
if mu[m] != 0 and prime < lpf[m]:
n = prime * m
# phi_xn = phi[b] + tree.count
pos = (x//n - low) >> 1
phi_xn = phi[b] + tree[pos]
pos += 1
while True:
pos &= pos - 1 # clears the lowest set bit
if pos == 0: break
phi_xn += tree[pos - 1]
S2 -= mu[m] * phi_xn
# phi[b] += tree.count
pos = (high - 1 - low) >> 1
phi[b] += tree[pos]
pos += 1
while True:
pos &= pos - 1 # clears the lowest set bit
if pos == 0: break
phi[b] += tree[pos - 1]
# cross_off
m = next_[b]
while m < high:
if sieve[m - low]:
pos = m - low
sieve[pos] = 0
pos >>= 1
while True:
tree[pos] -= 1
pos |= pos + 1 # sets the lowest unset bit
if pos >= treesize: break
m += prime * 2
next_[b] = m
# Here endeth the S2 section of the LMO algorithm.
if verbose: print("S2 time: ", time() - starttime)
if verbose: print("S2: ", S2)
if verbose: print("phi: ", S1 + S2)
return pi_y - 1 - P2 + S1 + S2
def primepi_S1(x, y, b, c, square_free, primes, mob):
S1 = 0
b += 1
while b < len(primes):
next_ = square_free * primes[b]
if next_ > y: break
S1 += mob * primephi(x//next_, c, primes)
S1 += primepi_S1(x, y, b, c, next_, primes, -mob)
b += 1
return S1
def primephi(x, a, primes):
# This is the number of positive integers <= x with no prime factor <= the ath prime.
# The list of primes is indexed so that primes[1] == 2.
if a == 0: return x
if a == 1: return (x + 1) // 2
if a == 2: return 2 * (x // 6) + (0,1,1,1,1,2)[x%6]
if a == 3: return 8 * (x // 30) + (0,1,1,1,1,1,1,2,2,2,2,3,3,4,4,4,4,5,5,6,6,6,6,7,7,7,7,7,7,8)[x%30]
pa = primes[a]
if x < pa: return 1
answer = primephi(x, a-1, primes) - primephi(x//pa, a-1, primes)
return answer
x = int(argv[1])
try: alpha = float(argv[2])
except: alpha = 1
print(pi(x, alpha=alpha, verbose=True))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment