Skip to content

Instantly share code, notes, and snippets.

@niklasf
Last active December 17, 2015 08:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save niklasf/5582042 to your computer and use it in GitHub Desktop.
Save niklasf/5582042 to your computer and use it in GitHub Desktop.
An O(n^2.81) n-x-n-matrix multiplication algorithm.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import unittest
def multiply(A, B):
if len(A) == 1:
return [[ A[0][0] * B[0][0] ]]
elif len(A) % 2 == 1:
return shrink(multiply(expand(A), expand(B)))
else:
A00, A01, A10, A11 = split_into_blocks(A)
B00, B01, B10, B11 = split_into_blocks(B)
M1 = multiply(add(A00, A11), add(B00, B11))
M2 = multiply(add(A10, A11), B00)
M3 = multiply(A00, add(B01, B11, -1))
M4 = multiply(A11, add(B10, B00, -1))
M5 = multiply(add(A00, A01), B11)
M6 = multiply(add(A10, A00, -1), add(B00, B01))
M7 = multiply(add(A01, A11, -1), add(B10, B11))
return join_blocks(add(M1, add(M4, add(M7, M5, -1))), add(M3, M5),
add(M2, M4), add(M1, add(M3, add(M6, M2, -1))))
def add(A, B, l = 1):
C = []
for x in range(len(B)):
row = []
for y in range(len(B)):
row.append(A[x][y] + l * B[x][y])
C.append(row)
return C
def split_into_blocks(A):
n = len(A)
A00 = []
A01 = []
A10 = []
A11 = []
for i in range(n / 2):
A00.append(A[i][0:n/2])
A10.append(A[i + n/2][0:n/2])
A01.append(A[i][n/2:n])
A11.append(A[i + n/2][n/2:n])
return (A00, A01, A10, A11)
def join_blocks(A00, A01, A10, A11):
A = []
for i in range(len(A00)):
A.append(A00[i] + A01[i])
for i in range(len(A10)):
A.append(A10[i] + A11[i])
return A
def expand(A):
result = []
for row in A:
result.append(row[:] + [0])
result.append([0] * (len(row) + 1))
return result
def shrink(A):
result = []
for i in range(len(A) - 1):
result.append(A[i][:-1])
return result
class MatrixTestCase(unittest.TestCase):
def test_add(self):
A = [[1, 2],
[3, 4]]
B = [[4, 3],
[2, 1]]
Five = [[5, 5],
[5, 5]]
self.assertEqual(add(B, A), Five)
self.assertEqual(add(A, B), add(B, A))
self.assertEqual(add(A, B), Five)
def test_joining_and_splitting_blocks(self):
A00 = [[1, 2],
[3, 4]]
A01 = [[-1, -2],
[-3, -4]]
A10 = [[5, 6],
[7, 8]]
A11 = [[-5, -6],
[-7, -8]]
Joined = [[1, 2, -1, -2],
[3, 4, -3, -4],
[5, 6, -5, -6],
[7, 8, -7, -8]]
self.assertEqual(join_blocks(A00, A01, A10, A11), Joined)
B00, B01, B10, B11 = split_into_blocks(Joined)
self.assertEqual(B00, A00)
self.assertEqual(B01, A01)
self.assertEqual(B10, A10)
self.assertEqual(B11, A11)
def test_expanding_and_shrinking(self):
A = [[1, 2, 3],
[3, 2, 1],
[2, 1, 3]]
A_expanded = [[1, 2, 3, 0],
[3, 2, 1, 0],
[2, 1, 3, 0],
[0, 0, 0, 0]]
self.assertEqual(expand(A), A_expanded)
def test_2x2(self):
id_2 = [[1, 0],
[0, 1]]
A = [[1, 2],
[3, 4]]
B = [[3, 2],
[6, 7]]
AxB = [[15, 16],
[33, 34]]
self.assertEqual(multiply(id_2, A), A)
self.assertEqual(multiply(B, id_2), B)
self.assertEqual(multiply(A, B), AxB)
def test_3x3(self):
id_3 = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
A = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
AxA = [[ 30, 36, 42],
[ 66, 81, 96],
[102, 126, 150]]
self.assertEqual(multiply(id_3, A), A)
self.assertEqual(multiply(A, id_3), A)
self.assertEqual(multiply(A, A), AxA)
def test_4x4(self):
id_4 = [[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]]
A = [[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]]
self.assertEqual(multiply(id_4, id_4), id_4)
self.assertEqual(multiply(id_4, A), A)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment