Created
April 17, 2015 18:56
-
-
Save zackmdavis/ebc08cf7913fcd9f6796 to your computer and use it in GitHub Desktop.
extremely simplified Reed-Solomon coding in Python 3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from string import ascii_uppercase | |
ALPHABET = " "+ascii_uppercase | |
CHAR_TO_INT = dict(zip(ALPHABET, range(27))) | |
INT_TO_CHAR = dict(zip(range(27), ALPHABET)) | |
def pad(message, chunk_size): | |
return message + ' '*(chunk_size - len(message) % chunk_size) | |
def unpad(message): | |
return message.rstrip() | |
def chunkify(message, chunk_size): | |
return [message[i:i+chunk_size] | |
for i in range(0, len(message), chunk_size)] | |
def unchunkify(chunks): | |
return ''.join(chunks) | |
def convert(string): | |
return [CHAR_TO_INT[c] for c in string] | |
def deconvert(sequence): | |
return ''.join(INT_TO_CHAR[i] for i in sequence) | |
def evaluate_polynomial(coefficients, x): | |
return sum(coefficients[i] * x**i for i in range(len(coefficients))) | |
def encode(chunk, n): | |
return [evaluate_polynomial(chunk, i) for i in range(n)] | |
def get_coefficient(P, i): | |
if 0 <= i < len(P): | |
return P[i] | |
else: | |
return 0 | |
def add_polynomials(P, Q): | |
n = max(len(P), len(Q)) | |
return [get_coefficient(P, i) + get_coefficient(Q, i) for i in range(n)] | |
def scale_polynomial(P, a): | |
return [a*c for c in P] | |
def multiply_polynomials(P, Q): | |
maximum_terms = len(P) + len(Q) | |
R = [0 for _ in range(maximum_terms)] | |
for i, c in enumerate(P): | |
for j, d in enumerate(Q): | |
R[i+j] += c * d | |
return R | |
def lagrange_basis_denominator(xs, j): | |
denominator = 1 | |
for i, x in enumerate(xs): | |
if i == j: | |
continue | |
denominator *= xs[j] - xs[i] | |
return denominator | |
def lagrange_basis_element(xs, j): | |
element = [1] | |
for i in range(len(xs)): | |
if i == j: | |
continue | |
element = multiply_polynomials(element, [-xs[i], 1]) | |
scaling_factor = 1/lagrange_basis_denominator(xs, j) | |
return scale_polynomial(element, scaling_factor) | |
def interpolate(points): | |
result = [0] | |
xs, ys = zip(*points) | |
for j in range(len(points)): | |
result = add_polynomials( | |
result, | |
scale_polynomial(lagrange_basis_element(xs, j), ys[j]) | |
) | |
return [round(i) for i in result] | |
def erasure_code(message, chunk_size, encoded_chunk_size): | |
chunks = chunkify(pad(message, chunk_size), chunk_size) | |
converted_chunks = [convert(chunk) for chunk in chunks] | |
return [[(i, evaluate_polynomial(chunk, i)) | |
for i in range(encoded_chunk_size)] | |
for chunk in converted_chunks] | |
def erasure_decode(encoded_chunks, chunk_size, encoded_chunk_size): | |
converted_chunks = [interpolate(chunk[:chunk_size])[:chunk_size] | |
for chunk in encoded_chunks] | |
return unpad(unchunkify(deconvert(chunk) for chunk in converted_chunks)) | |
import json | |
def disperse(encoded_chunks): | |
node_count = len(encoded_chunks[0]) | |
for i in range(node_count): | |
with open('node'+str(i), 'w') as partition: | |
partition.write(json.dumps([chunk[i] for chunk in encoded_chunks])) | |
def retrieve(*nodes): | |
responses = [] | |
for node in nodes: | |
with open(node) as partition: | |
responses.append(json.loads(partition.read())) | |
return [[response[i] for response in responses] | |
for i in range(len(responses[0]))] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from random import choice, randrange, sample | |
from reed_solomon import * | |
MAX_TESTED_MESSAGE_LENGTH = 20 | |
MAX_TESTED_CHUNK_SIZE = 8 | |
def arbitrary_message(length): | |
return ''.join(choice(ALPHABET) for _ in range(length)).rstrip() | |
class TestPadChunkConvert(unittest.TestCase): | |
def test_pad_invertibility(self): | |
for length in range(MAX_TESTED_MESSAGE_LENGTH): | |
for chunk_size in range(1, MAX_TESTED_CHUNK_SIZE): | |
message = arbitrary_message(length) | |
self.assertEqual(message, | |
unpad(pad(message, chunk_size))) | |
def test_chunk_invertibility(self): | |
for length in range(MAX_TESTED_MESSAGE_LENGTH): | |
for chunk_size in range(1, MAX_TESTED_CHUNK_SIZE): | |
message = arbitrary_message(length) | |
self.assertEqual(message, | |
unchunkify(chunkify(message, | |
chunk_size))) | |
def test_convert_invertibility(self): | |
message = arbitrary_message(10) | |
self.assertEqual(message, deconvert(convert(message))) | |
def arbitrary_polynomial(degree): | |
return [randrange(len(ALPHABET)) for _ in range(degree+1)] | |
def strip_trailing_zeros(polynomial): | |
return [c for i, c in enumerate(polynomial) if any(polynomial[i:])] | |
class TestAlgebra(unittest.TestCase): | |
def test_add_known_polynomials(self): | |
addition = add_polynomials([1, 0, 2], [1, 3, 2, 5]) | |
self.assertSequenceEqual([2, 3, 4, 5], addition) | |
def test_scale_known_polynomial(self): | |
self.assertSequenceEqual([5, 5, 5], | |
scale_polynomial([1, 1, 1], 5)) | |
def test_multiply_known_polynomials(self): | |
product = multiply_polynomials([0, 1, 3, 4], [1, 4]) | |
self.assertSequenceEqual([0, 1, 7, 16, 16], | |
strip_trailing_zeros(product)) | |
def test_known_lagrange_basis_denominator(self): | |
xs = [0, 1, 2] | |
expected_denominators = [2, -1, 2] | |
for j, x, d in zip(range(len(xs)), xs, expected_denominators): | |
self.assertEqual(d, lagrange_basis_denominator(xs, j)) | |
def test_known_lagrange_basis_element(self): | |
xs = [0, 1, 2] | |
expected_elements = [[1, -3/2, 1/2], [0, 2, -1], [0, -1/2, 1/2]] | |
for j, x, b in zip(range(len(xs)), xs, expected_elements): | |
result = strip_trailing_zeros(lagrange_basis_element(xs, j)) | |
self.assertSequenceEqual(b, result) | |
def test_known_lagrange_interpolation(self): | |
points = [(0, 1), (1, 6), (2, 17)] | |
interpolated = strip_trailing_zeros(interpolate(points)) | |
self.assertSequenceEqual([1, 2, 3], interpolated) | |
def test_lagrange_interpolation(self): | |
for m in range(2, MAX_TESTED_CHUNK_SIZE): | |
for n in range(m+1, 2 * MAX_TESTED_CHUNK_SIZE): | |
polynomial = strip_trailing_zeros( | |
arbitrary_polynomial(m) | |
) | |
points = [(x, evaluate_polynomial(polynomial, x)) | |
for x in range(n)] | |
interpolated = strip_trailing_zeros( | |
interpolate(sample(points, m+1)) | |
) | |
self.assertSequenceEqual(polynomial, interpolated) | |
class TestEncoding(unittest.TestCase): | |
def test_erasure_coding(self): | |
for length in range(10, MAX_TESTED_MESSAGE_LENGTH): | |
for chunk_size in range(1, MAX_TESTED_CHUNK_SIZE): | |
for encoded_chunk_size in range(chunk_size, | |
chunk_size + 10): | |
message = arbitrary_message(length) | |
retrieved = erasure_decode( | |
erasure_code(message, chunk_size, | |
encoded_chunk_size), | |
chunk_size, | |
encoded_chunk_size | |
) | |
self.assertEqual(message, retrieved) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment