Skip to content

Instantly share code, notes, and snippets.

@mikeboers
Created September 18, 2010 04:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mikeboers/585345 to your computer and use it in GitHub Desktop.
Save mikeboers/585345 to your computer and use it in GitHub Desktop.
from __future__ import division
import random
import math
import hashlib
import hmac
def str_to_int(input):
out = 0
for c in input:
out <<= 8
out += ord(c)
return out
def int_to_str(input):
out = []
while input > 0:
out.append(chr(input % 256))
input >>= 8
return str.join('', reversed(out))
def get_coefs(n, size):
return [random.randint(0, size) for i in xrange(n)]
def get_point(val, coefs, n):
for i, c in enumerate(coefs):
val += c * n ** (i + 1)
return (n, val)
def solve(points):
'''Solve the lagrange polynomial for the given points at x=0.'''
# Calculate the constant term of the Lagrange basis polynomials. The
# variable names are lifted straight from Wikipedia. We are calculating
# the numerator and denomenator seperately.
basis_constants = []
for j, (xj, yj) in enumerate(points):
num = yj
den = 1
for f, (xf, yf) in enumerate(points):
if j == f:
continue
num *= - xf
den *= xj - xf
basis_constants.append((num, den))
# Add all the fractions together.
total_num = 0
total_den = 1
for i, (num, den) in enumerate(basis_constants):
total_den *= den
for j, (_, den2) in enumerate(basis_constants):
if i == j:
continue
num *= den2
total_num += num
return total_num // total_den
def split(secret, threshold, num=None, modlen=None):
num = num or threshold
assert num >= threshold
modlen = modlen or 8 * len(secret)
mod = 2 ** modlen
secret = str_to_int(secret)
assert mod > secret
coefs = get_coefs(threshold - 1, mod)
points = [get_point(secret, coefs, i + 1) for i in xrange(num)]
points = [(x, y % mod) for x, y in points]
return points, modlen
def combine(points, modlen):
mod = 2 ** modlen
secret = int_to_str(solve(points) % mod)
return secret
def pad(msg, salt_length=None, hash_length=None, hash_func=None):
hash_func = hash_func or hashlib.sha1
salt_length = salt_length or hash_func().block_size
hash_length = hash_length or hash_func().digest_size
salt = os.urandom(salt_length)
return '\x01' + msg + salt + hmac.new(salt, msg, hash_func).digest()[:hash_length]
def unpad(msg, salt_length=None, hash_length=None, hash_func=None):
hash_func = hash_func or hashlib.sha1
salt_length = salt_length or hash_func().block_size
hash_length = hash_length or hash_func().digest_size
msg = msg[1:]
digest = msg[-(hash_length):]
salt = msg[-(hash_length + salt_length):-(hash_length)]
msg = msg[:-(hash_length + salt_length)]
assert hmac.new(salt, msg, hash_func).digest()[:hash_length] == digest
return msg
if __name__ == '__main__':
import os
for i in xrange(4):
original = 'this is my message'
padded = pad(original, 8)
points, modlen = split(padded, 3)
print padded.encode('hex')
for x, y in points:
print '\t%d-%x' % (x, y)
recovered = combine(points, modlen)
recovered = unpad(recovered, 8)
if recovered != original:
print
print original.encode('hex')
print len(points)
print recovered.encode('hex')
exit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment