Skip to content

Instantly share code, notes, and snippets.

@ninegua
Last active November 26, 2019 02:24
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ninegua/4f272ea5e6cd21c8ddcdf524d10b46c1 to your computer and use it in GitHub Desktop.
Save ninegua/4f272ea5e6cd21c8ddcdf524d10b46c1 to your computer and use it in GitHub Desktop.
A Haskell Schnorr Signature Tutorial
-------------------------------------------------------------------------------------
-- Schnorr Signatures - A Haskell Tutorial
--
-- While studying Schnorr Signatures, I find most online materials either
-- imprecise, or inadequate. Often mathematical notations are being quoted
-- without fully explaining conditions, expecations, and variables/functions
-- domain and/or range. They are confusing enough that any attempt to turn them
-- into real programs is doomed, either producing something that is wrong, or
-- even worse, something that you think is correct but is actually wrong.
--
-- Here is my attempt to reconstruct the theory behind Schnorr Signatures (and
-- basic crypto primitives that are used) in a systematic manner, hopefully
-- precise enough to read and understand by a Haskell programmer.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
module Schnorr where
import Data.List (inits, sort)
import Test.QuickCheck
-------------------------------------------------------------------------------------
-- Before we dive into any theory or modulo arithmetics. We introduce the
-- concept of sub-type, and define an embed function to convert values from
-- sub-type to super-type.
-- s <: t means type s is a sub-type of t.
class s <: t where
-- embed is a total function
embed :: s -> t
instance (Show a, Ord a, Num a) => Positive a <: NonNegative a where
embed (Positive x) = NonNegative x
instance (Show a, Ord a, Num a) => Positive a <: a where
embed (Positive x) = x
instance (Show a, Ord a, Num a) => NonNegative a <: a where
embed (NonNegative x) = x
-------------------------------------------------------------------------------------
-- The Positive and NonNegative types are defined in QuickCheck, but they were
-- not made as instances of Num. Here we give the missing definitions.
instance (Ord a, Num a, Show a, Integral a) => Num (Positive a) where
Positive x + Positive y = Positive (x + y)
Positive x - Positive y = fromIntegral (x - y)
Positive x * Positive y = Positive (x * y)
negate x = error $ "negate: " ++ show x
fromInteger x
| x > 0 = Positive (fromInteger x)
| otherwise = error $ "fromInteger: " ++ show x ++ " is not Positive"
abs x = x
signum _ = 1
instance (Ord a, Num a, Show a, Integral a) => Num (NonNegative a) where
NonNegative x + NonNegative y = NonNegative (x + y)
NonNegative x - NonNegative y = fromIntegral (x - y)
NonNegative x * NonNegative y = NonNegative (x * y)
negate x = error $ "negate: " ++ show x
fromInteger x
| x >= 0 = NonNegative (fromInteger x)
| otherwise = error $ "fromInteger: " ++ show x ++ " is not NonNegative"
abs x = x
signum (NonNegative x) = if x == 0 then 0 else 1
-------------------------------------------------------------------------------------
-- Basic Modular Arithmetics
--
-- The mod function from Prelude's Integral class is not precise enough to
-- capture the expected domain and range used in modular arithmetics.
--
-- For this purpose, we define a custom modulo function: x `modulo` y, where x
-- is an integer, y is Positive, and the result is NonNegative.
modulo :: (Ord a, Num a, Integral a) => a -> Positive a -> NonNegative a
modulo x (Positive y) = NonNegative (if r >= 0 then r else r + y)
where r = x `mod` y
-- Two numbers x and y are coprime to each other if and only if gcd(x, y) == 1,
-- where both x and y are either Positive or NonNegative.
coprime :: (a <: Integer, b <: Integer) => a -> b -> Bool
coprime x y = gcd (embed x) (embed y) == 1
-- Extended Euclidean Algorithm is to find x and y such that ax + my = 1, where
-- all numbers are integers.
-- https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
--
-- If ax + my = 1 is given as a condition, it implies ax = 1 (mod m), when m is
-- Positive. Therefore, this algorithm is used to find modular multiplicative
-- inverses.
--
egcd :: Integral a => a -> a -> (a, a, a)
egcd a b = egcd' a b 1 0 0 1
where
egcd' a b u v x y
| a == 0
= (b, x, y)
| b == 0
= (a, u, v)
| a > b
= let (i, j) = a `divMod` b in egcd' j b (u - i * x) (v - i * y) x y
| otherwise
= let (i, j) = b `divMod` a in egcd' a j u v (x - i * u) (y - i * v)
-- Find the multiplicative inverse x of a (mod m), such that ax = 1 (mod m).
-- This function is partial, and gives error when a and m are not coprimes.
invmod :: (Integral a, Show a) => Positive a -> Positive a -> NonNegative a
invmod (Positive a) (Positive m) =
let (r, x, _) = egcd (a `mod` m) m
x' = x `mod` m
in if r == 1
then fromIntegral $ if x' < 0 then x' + m else x'
else error $ "invmod: not coprime " ++ show (a, m)
-- Calculate the modular exponent: g^x (mod m), where g is NonNegative,
-- x is Integer, and m is Positive.
expmod
:: forall a b
. (Integral a, Show a, Ord a)
=> NonNegative a
-> a
-> Positive a
-> NonNegative a
expmod (NonNegative g) x m = case x `compare` 0 of
LT -> invmod (fromIntegral $ g ^ (-x)) m
EQ -> 1
GT -> (g ^ x) `modulo` m
-- A QuickCheck property to test g^x g^(-x) == 1 (mod m).
propExpMod :: Positive Integer -> NonNegative Integer -> Property
propExpMod m x = do
-- Given a value of m, we first find an arbitary coprime g of m.
let gs = filter (coprime m) [2 .. m]
forAll (embed <$> elements gs) $ \a ->
not (null gs) ==> coprime (expmod a (embed x) m * expmod a (-embed x) m) m
-------------------------------------------------------------------------------------
--
-- Multiplicative order
--
-- Given co-prime positive integer g and p, find smallest positive q such that
-- g^q = 1 (mod p).
--
-- Foundational theorems related to multiplicative order:
--
-- a = b (mod q) ==> g^a = g^b (mod p)
-- a = b+c (mod q) ==> g^a = g^(b + c) (mod p)
-- a = bc (mod q) ==> g^a = (g^b)^c (mod p)
orderOf :: Positive Integer -> Positive Integer -> Positive Integer
orderOf g p
| not (coprime g p) = error $ "orderOf: not coprime " ++ show (g, p)
| otherwise = fromIntegral $ toInteger $ length $ takeWhile
(not . coprime p)
gs
where gs = g : map (g *) gs
-------------------------------------------------------------------------------------
--
-- Discrete Logarithm Problem (DLOG).
-- https://en.wikipedia.org/wiki/Discrete_logarithm
--
-- Choose g, p, q such that f(x) = g^x (mod p), where 0 <= x < q, is hard to
-- invert, or in other words, it is hard to guess what x is just from a value
-- of f(x).
--
-- One solution is the multiplicative group of integers modulo prime number p.
-- https://en.wikipedia.org/wiki/Multiplicative_group_of_integers_modulo_n
--
-- 1. Choose q that is a large prime such that p = 2q + 1 is also prime.
--
-- 2. Define G = [g | g <- [0..p-1], g ^ q == 1 (mod p)], which is called
-- the multiplicative group of prime order q modulo p.
--
-- 3. It can be proven that G has q elements.
--
-- 4. Every g in G where g /= 1 is called a generator.
--
-- 5. f(x) = g ^ x (mod p) is a bijective mapping from [0..q-1] to G.
--
-- Essentially this means every element in [f(x) | x <- [0..q-1]] is unique,
-- and is a number in G.
-- A small set of (p, q) pairs used for testing.
allPQ :: Gen (Positive Integer, Positive Integer)
allPQ = elements $ take
100
[ (fromIntegral p, fromIntegral q)
| q <- primesST
, let p = 2 * q + 1
, isPrime p
]
-- One way to calculate generators is by testing all values in [0..p-1].
generators :: Positive Integer -> Positive Integer -> [NonNegative Integer]
generators p q = filter (\g -> expmod g (embed q) p == 1) [0 .. embed p - 1]
-- Calculate the range (co-domain) of f(x) = g^x (mod p).
rangeOfF
:: NonNegative Integer
-> Positive Integer
-> Positive Integer
-> [NonNegative Integer]
rangeOfF g p q = sort [ expmod g x p | x <- [0 .. embed q - 1] ]
-- For any qualifying pair (p, q), let g be a generator of (p, q), we can
-- verify that G is the co-domain of f(x) = g^x (mod p)
propDLOG :: Property
propDLOG = forAll allPQ $ \(p, q) -> do
let gs = generators p q
forAll (elements gs) $ \g -> let gs' = rangeOfF g p q in g /= 1 ==> gs' == gs
-- Instead of iterating through all [0..p-1] to find generators, we can also
-- start with a known generator, and apply f to find all.
generators' :: Positive Integer -> Positive Integer -> [NonNegative Integer]
generators' p q = rangeOfF (generators p q !! 1) p q
-------------------------------------------------------------------------------------
-- Fiat-Shamir heuristic turns an interactive proof of a known secret into a
-- non-interactive one, using random oracle access.
-- https://en.wikipedia.org/wiki/Fiat–Shamir_heuristic
-- (This wiki page likely contains an error)
--
-- Non-interactive proof of knowing a secret x without revealing it, in
-- y = g ^ x (mod p), where y is known as the public key.
-- A triple of (p, g, y), where y is the public key, and p and g are parameters
-- necessary to carry out computation using y.
type PubKey = (Positive Integer, NonNegative Integer, NonNegative Integer)
-- A tuple of (q, x), where x is the secret, and q is a parameter required to
-- perform computation using x, but not necessarily secret.
type SecKey = (Positive Integer, NonNegative Integer)
-- Generate a random key pair given parameter (p, q), by choosing a random
-- x in the range [0, embed q - 1].
--
-- Note that in practice, x = 0 is a trivial case that should be ignored.
keypairs :: Positive Integer -> Positive Integer -> Gen (PubKey, SecKey)
keypairs p q = do
g <- elements (generators p q)
x <- choose (0, embed q - 1)
pure ((p, g, expmod g x p), (q, fromIntegral x))
-- A made-up hash function that is absolutely insecure.
type Hashing = forall a . Show a => a -> NonNegative Integer
hash :: Hashing
hash = fromIntegral . toInteger . sum . map fromEnum . show
-- A proof (t, r) is produced by:
--
-- Choose a random v in [0..q-1]
-- let t = g^v (mod p)
-- let c = hash(g, y, t) (mod q)
-- let r = v - cx (mod q)
--
-- The resulting proof can be checked by anyone with knowledge of PubKey.
type Proof = (NonNegative Integer, NonNegative Integer)
prove :: Hashing -> PubKey -> SecKey -> Gen Proof
prove hash (p, g, y) (q, x) = do
v <- choose (0, embed q - 1)
let t = expmod g v p
c = hash (g, y, t)
r = (v - embed (c * x)) `modulo` q
pure (t, r)
-- Verifying the proof (t, r) by checking if t == g^r y^c (mod p).
--
-- This always holds because (where all exponent values are modulo q):
--
-- g^r y^c (mod p)
-- = g^r (g^x)^c (mod p)
-- = g^(r + xc) (mod p)
-- = g^(v - cx + xc) (mod p)
-- = g^v (mod p)
-- = t
verifyProof :: Hashing -> PubKey -> Proof -> Property
verifyProof hash (p, g, y) (t, r) =
let c = hash (g, y, t)
in t === embed (expmod g (embed r) p * expmod y (embed c) p) `modulo` p
-- For any key pair, we can check that a random proof can always be verified.
propFiatShamir :: Property
propFiatShamir = forAll allPQ $ \(p, q) -> forAll (keypairs p q)
$ \(pub, sec) -> forAll (prove hash pub sec) $ verifyProof hash pub
-------------------------------------------------------------------------------------
-- Schnorr Signature
-- https://en.wikipedia.org/wiki/Schnorr_signature
-- For a given message msg, A Signature (c, r) is produced by:
--
-- 1. Choose a random v in [0..q-1]
-- 2. Compute t = g^v (mod p)
-- 3. Compute c = hash(t, msg)
-- 4. Compute r = v - cx (mod q)
--
-- The resulting proof can be checked by anyone with knowledge of msg and PubKey.
type Signature = (NonNegative Integer, NonNegative Integer)
sign :: Show a => Hashing -> PubKey -> SecKey -> a -> Gen Signature
sign hash (p, g, a) (q, x) msg = do
v <- choose (0, embed q - 1)
let t = expmod g v p
c = hash (t, msg)
r = (v - embed (c * x)) `modulo` q
pure (c, r)
-- Verifying the proof (c, r) by:
--
-- 1. Compute t' = g^r y^c (mod p).
-- 2. Check if hash(t', msg) == c
--
-- This always holds because (where all exponent values are modulo q):
--
-- g^r y^c (mod p)
-- = g^r (g^x)^c (mod p)
-- = g^(r + xc) (mod p)
-- = g^(v - cx + xc) (mod p)
-- = g^v (mod p)
-- = t
--
-- Therefore t' = t, and hash(t', msg) == hash(t, msg) == c
verify :: Show a => Hashing -> PubKey -> a -> Signature -> Property
verify hash (p, g, y) msg (c, r) =
let t = embed (expmod g (embed r) p * expmod y (embed c) p) `modulo` p
in hash (t, msg) === c
-- For any key pair and a random message, we can check that a Schnorr signature
-- always be verified.
propSchnorr :: String -> Property
propSchnorr msg = forAll allPQ $ \(p, q) ->
forAll (keypairs p q) $ \(pub, sec) ->
forAll (sign hash pub sec msg) $ \sig -> verify hash pub msg sig
-------------------------------------------------------------------------------------
-- Prime number functions
-- https://wiki.haskell.org/Prime_numbers
--
primesST :: [Integer]
primesST = 2 : ops
where
ops = sieve 3 9 ops (inits ops) -- odd primes
sieve x q ~(_ : pt) (fs : ft) =
filter ((`all` fs) . ((> 0) .) . rem) [x, x + 2 .. q - 2]
++ sieve (q + 2) (head pt ^ 2) pt ft
noDivs :: Integer -> [Integer] -> Bool
noDivs n = foldr (\d r -> d * d > n || (rem n d > 0 && r)) True
isPrime :: Integer -> Bool
isPrime n = n > 1 && noDivs n primesST
-------------------------------------------------------------------------------------
main :: IO ()
main = do
quickCheck propExpMod
quickCheck propDLOG
quickCheck propFiatShamir
quickCheck propSchnorr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment