Last active
October 1, 2016 11:59
-
-
Save jliszka/7460817 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Data.List | |
import Test.QuickCheck | |
import qualified Data.Map as Map | |
nats = [1..] | |
divides a b = b `mod` a == 0 | |
sieve (n:t) m = case Map.lookup n m of | |
Nothing -> n : sieve t (Map.insert (n*n) [n] m) | |
Just ps -> sieve t (foldl reinsert (Map.delete n m) ps) | |
where | |
reinsert m p = Map.insertWith (++) (n+p) [p] m | |
primes = 2 : sieve (map (\n -> n*2+1) nats) Map.empty | |
factor n = | |
let | |
factor' 1 _ = [] | |
factor' n ps@(p:pt) = if p `divides` n then p : factor' (n `div` p) ps else factor' n pt | |
in factor' n primes | |
modexp a 0 p = 1 | |
modexp a n p = | |
let a2 = modexp (a*a `mod` p) (n `div` 2) p | |
in if n `mod` 2 == 0 then a2 else a2 * a `mod` p | |
modexp1 a 0 p b = b | |
modexp1 a n p b = modexp1 a (n-1) p ((b * a) `mod` p) | |
test0 a n p = n >= 0 && p > 0 ==> modexp a n p == modexp1 a n p 1 | |
combine (a, b, c, d) (a', b', c', d') = (w, x, y, z) | |
where | |
w = a*a' + b*b' + c*c' + d*d' | |
x = a*b' - b*a' - c*d' + d*c' | |
y = a*c' + b*d' - c*a' - d*b' | |
z = a*d' - b*c' + c*b' - d*a' | |
sumOfSquares (a, b, c, d) = a*a + b*b + c*c + d*d | |
test1 a b = (sumOfSquares a) * (sumOfSquares b) == sumOfSquares (combine a b) | |
a |> f = f a | |
sortWith f = sortBy (\a b -> compare (f a) (f b)) | |
map4 f (a, b, c, d) = (f a, f b, f c, f d) | |
foursquare 0 = (0, 0, 0, 0) | |
foursquare 1 = (1, 0, 0, 0) | |
foursquare n = factor n |> map foursquare' |> foldl1 combine | |
where | |
foursquare' p | |
| p == 2 = (1, 1, 0, 0) | |
| p `mod` 4 == 1 = | |
let | |
findSqrtMinus1 (a:ps) = | |
let b = modexp a ((p-1) `div` 4) p | |
in if b*b `mod` p == p-1 then b else findSqrtMinus1 ps | |
a = findSqrtMinus1 primes | |
in (a, 1, 0, 0) |> reduce | |
| p `mod` 4 == 3 = | |
let | |
findNumberWithSqrt (a:ns) = | |
let | |
x = (-1 - a*a) `mod` p | |
b = modexp x ((p+1) `div` 4) p | |
in if b*b `mod` p == x then (a, b) else findNumberWithSqrt ns | |
(a, b) = findNumberWithSqrt [1..p] | |
in (a, b, 1, 0) |> reduce | |
where | |
reduce abcd | sumOfSquares abcd == p = abcd | |
reduce abcd@(a, b, c, d) | otherwise = | |
let k = (sumOfSquares abcd) `div` p | |
in case k `mod` 2 of | |
0 -> | |
let [a', b', c', d'] = sortWith (`mod` 2) [a, b, c, d] | |
in (a'+b', a'-b', c'+d', c'-d') |> map4 (`div` 2) |> reduce | |
1 -> | |
abcd |> map4 (absoluteLeastResidue k) |> combine abcd |> map4 (`div` k) |> reduce | |
absoluteLeastResidue k m = if n <= k `div` 2 then n else n - k | |
where n = m `mod` k | |
test2 n = n >= 0 ==> sumOfSquares (foursquare n) == n |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment