Skip to content

Instantly share code, notes, and snippets.

@zackmdavis
Created April 17, 2015 18:56
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zackmdavis/ebc08cf7913fcd9f6796 to your computer and use it in GitHub Desktop.
Save zackmdavis/ebc08cf7913fcd9f6796 to your computer and use it in GitHub Desktop.
extremely simplified Reed-Solomon coding in Python 3
from string import ascii_uppercase
ALPHABET = " "+ascii_uppercase
CHAR_TO_INT = dict(zip(ALPHABET, range(27)))
INT_TO_CHAR = dict(zip(range(27), ALPHABET))
def pad(message, chunk_size):
return message + ' '*(chunk_size - len(message) % chunk_size)
def unpad(message):
return message.rstrip()
def chunkify(message, chunk_size):
return [message[i:i+chunk_size]
for i in range(0, len(message), chunk_size)]
def unchunkify(chunks):
return ''.join(chunks)
def convert(string):
return [CHAR_TO_INT[c] for c in string]
def deconvert(sequence):
return ''.join(INT_TO_CHAR[i] for i in sequence)
def evaluate_polynomial(coefficients, x):
return sum(coefficients[i] * x**i for i in range(len(coefficients)))
def encode(chunk, n):
return [evaluate_polynomial(chunk, i) for i in range(n)]
def get_coefficient(P, i):
if 0 <= i < len(P):
return P[i]
else:
return 0
def add_polynomials(P, Q):
n = max(len(P), len(Q))
return [get_coefficient(P, i) + get_coefficient(Q, i) for i in range(n)]
def scale_polynomial(P, a):
return [a*c for c in P]
def multiply_polynomials(P, Q):
maximum_terms = len(P) + len(Q)
R = [0 for _ in range(maximum_terms)]
for i, c in enumerate(P):
for j, d in enumerate(Q):
R[i+j] += c * d
return R
def lagrange_basis_denominator(xs, j):
denominator = 1
for i, x in enumerate(xs):
if i == j:
continue
denominator *= xs[j] - xs[i]
return denominator
def lagrange_basis_element(xs, j):
element = [1]
for i in range(len(xs)):
if i == j:
continue
element = multiply_polynomials(element, [-xs[i], 1])
scaling_factor = 1/lagrange_basis_denominator(xs, j)
return scale_polynomial(element, scaling_factor)
def interpolate(points):
result = [0]
xs, ys = zip(*points)
for j in range(len(points)):
result = add_polynomials(
result,
scale_polynomial(lagrange_basis_element(xs, j), ys[j])
)
return [round(i) for i in result]
def erasure_code(message, chunk_size, encoded_chunk_size):
chunks = chunkify(pad(message, chunk_size), chunk_size)
converted_chunks = [convert(chunk) for chunk in chunks]
return [[(i, evaluate_polynomial(chunk, i))
for i in range(encoded_chunk_size)]
for chunk in converted_chunks]
def erasure_decode(encoded_chunks, chunk_size, encoded_chunk_size):
converted_chunks = [interpolate(chunk[:chunk_size])[:chunk_size]
for chunk in encoded_chunks]
return unpad(unchunkify(deconvert(chunk) for chunk in converted_chunks))
import json
def disperse(encoded_chunks):
node_count = len(encoded_chunks[0])
for i in range(node_count):
with open('node'+str(i), 'w') as partition:
partition.write(json.dumps([chunk[i] for chunk in encoded_chunks]))
def retrieve(*nodes):
responses = []
for node in nodes:
with open(node) as partition:
responses.append(json.loads(partition.read()))
return [[response[i] for response in responses]
for i in range(len(responses[0]))]
import unittest
from random import choice, randrange, sample
from reed_solomon import *
MAX_TESTED_MESSAGE_LENGTH = 20
MAX_TESTED_CHUNK_SIZE = 8
def arbitrary_message(length):
return ''.join(choice(ALPHABET) for _ in range(length)).rstrip()
class TestPadChunkConvert(unittest.TestCase):
def test_pad_invertibility(self):
for length in range(MAX_TESTED_MESSAGE_LENGTH):
for chunk_size in range(1, MAX_TESTED_CHUNK_SIZE):
message = arbitrary_message(length)
self.assertEqual(message,
unpad(pad(message, chunk_size)))
def test_chunk_invertibility(self):
for length in range(MAX_TESTED_MESSAGE_LENGTH):
for chunk_size in range(1, MAX_TESTED_CHUNK_SIZE):
message = arbitrary_message(length)
self.assertEqual(message,
unchunkify(chunkify(message,
chunk_size)))
def test_convert_invertibility(self):
message = arbitrary_message(10)
self.assertEqual(message, deconvert(convert(message)))
def arbitrary_polynomial(degree):
return [randrange(len(ALPHABET)) for _ in range(degree+1)]
def strip_trailing_zeros(polynomial):
return [c for i, c in enumerate(polynomial) if any(polynomial[i:])]
class TestAlgebra(unittest.TestCase):
def test_add_known_polynomials(self):
addition = add_polynomials([1, 0, 2], [1, 3, 2, 5])
self.assertSequenceEqual([2, 3, 4, 5], addition)
def test_scale_known_polynomial(self):
self.assertSequenceEqual([5, 5, 5],
scale_polynomial([1, 1, 1], 5))
def test_multiply_known_polynomials(self):
product = multiply_polynomials([0, 1, 3, 4], [1, 4])
self.assertSequenceEqual([0, 1, 7, 16, 16],
strip_trailing_zeros(product))
def test_known_lagrange_basis_denominator(self):
xs = [0, 1, 2]
expected_denominators = [2, -1, 2]
for j, x, d in zip(range(len(xs)), xs, expected_denominators):
self.assertEqual(d, lagrange_basis_denominator(xs, j))
def test_known_lagrange_basis_element(self):
xs = [0, 1, 2]
expected_elements = [[1, -3/2, 1/2], [0, 2, -1], [0, -1/2, 1/2]]
for j, x, b in zip(range(len(xs)), xs, expected_elements):
result = strip_trailing_zeros(lagrange_basis_element(xs, j))
self.assertSequenceEqual(b, result)
def test_known_lagrange_interpolation(self):
points = [(0, 1), (1, 6), (2, 17)]
interpolated = strip_trailing_zeros(interpolate(points))
self.assertSequenceEqual([1, 2, 3], interpolated)
def test_lagrange_interpolation(self):
for m in range(2, MAX_TESTED_CHUNK_SIZE):
for n in range(m+1, 2 * MAX_TESTED_CHUNK_SIZE):
polynomial = strip_trailing_zeros(
arbitrary_polynomial(m)
)
points = [(x, evaluate_polynomial(polynomial, x))
for x in range(n)]
interpolated = strip_trailing_zeros(
interpolate(sample(points, m+1))
)
self.assertSequenceEqual(polynomial, interpolated)
class TestEncoding(unittest.TestCase):
def test_erasure_coding(self):
for length in range(10, MAX_TESTED_MESSAGE_LENGTH):
for chunk_size in range(1, MAX_TESTED_CHUNK_SIZE):
for encoded_chunk_size in range(chunk_size,
chunk_size + 10):
message = arbitrary_message(length)
retrieved = erasure_decode(
erasure_code(message, chunk_size,
encoded_chunk_size),
chunk_size,
encoded_chunk_size
)
self.assertEqual(message, retrieved)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment