Skip to content

Instantly share code, notes, and snippets.

@chowells79
Last active August 30, 2022 01:31
Show Gist options
  • Save chowells79/cff4f2aff34ed6fa359e62365190a118 to your computer and use it in GitHub Desktop.
Save chowells79/cff4f2aff34ed6fa359e62365190a118 to your computer and use it in GitHub Desktop.
Chinese Remainder Theorem, documented
-- Chinese Remainder Theorem
--
-- Generalized to work with non-coprime moduli when a solution exists
--
-- Inputs and output are (remainder, modulus) pairs
-- Preconditions:
-- 1. r1 `mod` gcd m1 m2 == r2 `mod` gcd m1 m2
-- 2. m1 > 0
-- 3. m2 > 0
--
-- Precondition 1 is trivially true when m1 and m2 are coprime
--
-- crt (r1, m1) (r2, m2) == (r3, m3), such that
-- 1. r3 `mod` m1 == r1 `mod` m1
-- 2. r3 `mod` m2 == r2 `mod` m2
-- 3 m3 == lcm m1 m2
-- 4. 0 <= r3
-- 5. r3 < m3
crt :: (Integer, Integer) -> (Integer, Integer) -> (Integer, Integer)
crt (r1, m1) (r2, m2)
| m1 <= 0 = error $ "crt: " ++ show m1 ++ " is not greater than 0"
| m2 <= 0 = error $ "crt: " ++ show m2 ++ " is not greater than 0"
| r1 `mod` g /= r2 `mod` g =
error $ "crt: " ++ show m1 ++ " and " ++ show m2 ++
" are not coprime (gcd=" ++ show g ++ ") and " ++
show r1 ++ " and " ++ show r2 ++ " are not " ++
"congruent modulo " ++ show g
| otherwise = r3 `seq` (r3, m3)
where
(g, b1, b2) = egcd m1 m2
m1' = m1 `div` g
m2' = m2 `div` g
m3 = m1 * m2' -- lcm m1 m2 - g should be divided out once, not twice
r3 = (r1 * m2' * b2 + r2 * m1' * b1) `mod` m3
-- Extended Euclidean Algorithm
--
-- egcd x y == (g, s, t), such that
-- 1. x * s + y * t == g
-- 2. g == gcd x y
egcd :: Integer -> Integer -> (Integer, Integer, Integer)
egcd a b
| a < b = case egcd b a of (g, t, s) -> (g, s, t)
| b < 0 = case egcd a (negate b) of (g, s, t) -> (g, s, negate t)
| otherwise = go a b 1 0 0 1
where
-- loop invariants:
-- 1. x >= y >= 0
-- 2. gcd x y == gcd a b
-- 3. a * s0 + b * t0 == x
-- 4. a * s1 + b * t1 == y
go x y s0 s1 t0 t1
| y == 0 = (x, s0, t0)
| otherwise = case x `divMod` y of
(q, r) -> s2 `seq` t2 `seq` go y r s1 s2 t1 t2
where
s2 = s0 - q * s1
t2 = t0 - q * t1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment