Last active
December 21, 2022 18:32
-
-
Save DroneBetter/b184fbff57077de3ec720d0a0add9ff2 to your computer and use it in GitHub Desktop.
Strassen's algorithm in one line
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
from math import isqrt | |
from random import random | |
lap=(lambda f,*a: list(map(f,*a))) #just like Python 2 used to make (map objects are iterators) | |
dims=16 #any power of two | |
mint=(lambda m: print('\n('+"\n ".join(','.join(str(m[dims*j+i]) for i in range(dims))+',' for j in range(dims))+')')) #matrix print (frankly delicious) | |
strassen=(lambda a,b: tuple(( lambda f: f(a,b) if len(a)==4 else (lambda m: (lambda m,d: ((m[y*2+x][d*j+i] for y in range(2) for j in range(d) for x in range(2) for i in range(d))))(m,isqrt(len(m[0]))))(f(*map((lambda m: (lambda m,d: tuple((tuple((m[i*d+j] for i in range(d//2*y,d//(2-y)) for j in range(d//2*x,d//(2-x)))) for y in range(2) for x in range(2))))(m,isqrt(len(m)))),(a,b)))))((lambda a,b: (lambda s,n,m: (lambda a,b,c,d,e,f,g: lap(s,((a,d,n(e),g),(c,e),(b,d),(a,n(b),c,f))))(*map(m,*map(lambda x: map(s,x),(((a[0],a[3]),(a[2],a[3]),(a[0],),(a[3],),(a[0],a[1]),(a[2],n(a[0])),(a[1],n(a[3]))),((b[0],b[3]),(b[0],),(b[1],n(b[3])),(b[2],n(b[0])),(b[3],),(b[0],b[1]),(b[2],b[3])))))))(*((sum,int.__neg__,int.__mul__) if type(a[0])==int else ((lambda x: reduce(lambda a,b: lap(int.__add__,a,b),x)),(lambda x: map(int.__neg__,x)),strassen))))))) | |
(a,b)=(tuple(int(random()*8) for _ in range(dims**2)) for _ in range(2)) #random matrices | |
lap(mint,(a,b,strassen(a,b))) | |
#it is comprised of two parts that ping-pong back and forth as it deepens, here is a dissection | |
strassen=(lambda a,b: tuple(( lambda f: f(a,b) #this part is for generating and recombining the subdivisions, if the matrix is 2*2 it passes directly into the other part | |
if len(a)==4 else #it operates upon square matrices flatly encoded as tuples | |
(lambda m: (lambda m,d: ((m[y*2+x][d*j+i] for y in range(2) for j in range(d) #recombines quarters into whole | |
for x in range(2) for i in range(d))))(m,isqrt(len(m[0])))) #gives size of quarter | |
(f(*map((lambda m: (lambda m,d: tuple((tuple((m[i*d+j] for i in range(d//2*y,d//(2-y)) for j in range(d//2*x,d//(2-x)))) | |
for y in range(2) for x in range(2))))(m,isqrt(len(m)))),(a,b))))) | |
((lambda a,b: #this part enacts Strassen's algorithm itself (s,n,m are sum, negate, multiply, it does the same thing for 2*2 matrices containing integers and matrices) | |
(lambda s,n,m: (lambda a,b,c,d,e,f,g: lap(s,((a,d,n(e),g),(c,e), #the second step, adding up after all of the variable allocations | |
(b,d),(a,n(b),c,f))))(*map(m,*map(lambda x: map(s,x),(((a[0],a[3]),(a[2],a[3]),(a[0],),(a[3],),(a[0],a[1]),(a[2],n(a[0])),(a[1],n(a[3]))), #the first step with the wacky multiplications | |
((b[0],b[3]),(b[0],),(b[1],n(b[3])),(b[2],n(b[0])),(b[3],),(b[0],b[1]),(b[2],b[3])))))))(*( (sum,int.__neg__,int.__mul__) | |
if type(a[0])==int else | |
((lambda x: reduce(lambda a,b: lap(int.__add__,a,b),x)), #matrices are added and negated elementwise | |
(lambda x: map(int.__neg__,x)), | |
strassen))))))) #the function is the matrices' __mul__ method |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It seemed I had initially included an older version as the undissected version accidentally, which had the
o
function (that I later realised was equivalent tos
), sorry, fixed now