Last active October 1, 2016 11:59
import Data.List
import Test.QuickCheck
import qualified Data.Map as Map
nats = [1..]
divides a b = b `mod` a == 0
sieve (n:t) m = case Map.lookup n m of
Nothing -> n : sieve t (Map.insert (n*n) [n] m)
Just ps -> sieve t (foldl reinsert (Map.delete n m) ps)
reinsert m p = Map.insertWith (++) (n+p) [p] m
primes = 2 : sieve (map (\n -> n*2+1) nats) Map.empty
factor n =
factor' 1 _ = []
factor' n ps@(p:pt) = if p `divides` n then p : factor' (n `div` p) ps else factor' n pt
in factor' n primes
modexp a 0 p = 1
modexp a n p =
let a2 = modexp (a*a `mod` p) (n `div` 2) p
in if n `mod` 2 == 0 then a2 else a2 * a `mod` p
modexp1 a 0 p b = b
modexp1 a n p b = modexp1 a (n-1) p ((b * a) `mod` p)
test0 a n p = n >= 0 && p > 0 ==> modexp a n p == modexp1 a n p 1
combine (a, b, c, d) (a', b', c', d') = (w, x, y, z)
w = a*a' + b*b' + c*c' + d*d'
x = a*b' - b*a' - c*d' + d*c'
y = a*c' + b*d' - c*a' - d*b'
z = a*d' - b*c' + c*b' - d*a'
sumOfSquares (a, b, c, d) = a*a + b*b + c*c + d*d
test1 a b = (sumOfSquares a) * (sumOfSquares b) == sumOfSquares (combine a b)
a |> f = f a
sortWith f = sortBy (\a b -> compare (f a) (f b))
map4 f (a, b, c, d) = (f a, f b, f c, f d)
foursquare 0 = (0, 0, 0, 0)
foursquare 1 = (1, 0, 0, 0)
foursquare n = factor n |> map foursquare' |> foldl1 combine
foursquare' p
| p == 2 = (1, 1, 0, 0)
| p `mod` 4 == 1 =
findSqrtMinus1 (a:ps) =
let b = modexp a ((p-1) `div` 4) p
in if b*b `mod` p == p-1 then b else findSqrtMinus1 ps
a = findSqrtMinus1 primes
in (a, 1, 0, 0) |> reduce
| p `mod` 4 == 3 =
findNumberWithSqrt (a:ns) =
x = (-1 - a*a) `mod` p
b = modexp x ((p+1) `div` 4) p
in if b*b `mod` p == x then (a, b) else findNumberWithSqrt ns
(a, b) = findNumberWithSqrt [1..p]
in (a, b, 1, 0) |> reduce
reduce abcd | sumOfSquares abcd == p = abcd
reduce abcd@(a, b, c, d) | otherwise =
let k = (sumOfSquares abcd) `div` p
in case k `mod` 2 of
0 ->
let [a', b', c', d'] = sortWith (`mod` 2) [a, b, c, d]
in (a'+b', a'-b', c'+d', c'-d') |> map4 (`div` 2) |> reduce
1 ->
abcd |> map4 (absoluteLeastResidue k) |> combine abcd |> map4 (`div` k) |> reduce
absoluteLeastResidue k m = if n <= k `div` 2 then n else n - k
where n = m `mod` k
test2 n = n >= 0 ==> sumOfSquares (foursquare n) == n
