Skip to content

Instantly share code, notes, and snippets.

@DroneBetter
Last active December 21, 2022 18:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DroneBetter/b184fbff57077de3ec720d0a0add9ff2 to your computer and use it in GitHub Desktop.
Save DroneBetter/b184fbff57077de3ec720d0a0add9ff2 to your computer and use it in GitHub Desktop.
Strassen's algorithm in one line
from functools import reduce
from math import isqrt
from random import random
lap=(lambda f,*a: list(map(f,*a))) #just like Python 2 used to make (map objects are iterators)
dims=16 #any power of two
mint=(lambda m: print('\n('+"\n ".join(','.join(str(m[dims*j+i]) for i in range(dims))+',' for j in range(dims))+')')) #matrix print (frankly delicious)
strassen=(lambda a,b: tuple(( lambda f: f(a,b) if len(a)==4 else (lambda m: (lambda m,d: ((m[y*2+x][d*j+i] for y in range(2) for j in range(d) for x in range(2) for i in range(d))))(m,isqrt(len(m[0]))))(f(*map((lambda m: (lambda m,d: tuple((tuple((m[i*d+j] for i in range(d//2*y,d//(2-y)) for j in range(d//2*x,d//(2-x)))) for y in range(2) for x in range(2))))(m,isqrt(len(m)))),(a,b)))))((lambda a,b: (lambda s,n,m: (lambda a,b,c,d,e,f,g: lap(s,((a,d,n(e),g),(c,e),(b,d),(a,n(b),c,f))))(*map(m,*map(lambda x: map(s,x),(((a[0],a[3]),(a[2],a[3]),(a[0],),(a[3],),(a[0],a[1]),(a[2],n(a[0])),(a[1],n(a[3]))),((b[0],b[3]),(b[0],),(b[1],n(b[3])),(b[2],n(b[0])),(b[3],),(b[0],b[1]),(b[2],b[3])))))))(*((sum,int.__neg__,int.__mul__) if type(a[0])==int else ((lambda x: reduce(lambda a,b: lap(int.__add__,a,b),x)),(lambda x: map(int.__neg__,x)),strassen)))))))
(a,b)=(tuple(int(random()*8) for _ in range(dims**2)) for _ in range(2)) #random matrices
lap(mint,(a,b,strassen(a,b)))
#it is comprised of two parts that ping-pong back and forth as it deepens, here is a dissection
strassen=(lambda a,b: tuple(( lambda f: f(a,b) #this part is for generating and recombining the subdivisions, if the matrix is 2*2 it passes directly into the other part
if len(a)==4 else #it operates upon square matrices flatly encoded as tuples
(lambda m: (lambda m,d: ((m[y*2+x][d*j+i] for y in range(2) for j in range(d) #recombines quarters into whole
for x in range(2) for i in range(d))))(m,isqrt(len(m[0])))) #gives size of quarter
(f(*map((lambda m: (lambda m,d: tuple((tuple((m[i*d+j] for i in range(d//2*y,d//(2-y)) for j in range(d//2*x,d//(2-x))))
for y in range(2) for x in range(2))))(m,isqrt(len(m)))),(a,b)))))
((lambda a,b: #this part enacts Strassen's algorithm itself (s,n,m are sum, negate, multiply, it does the same thing for 2*2 matrices containing integers and matrices)
(lambda s,n,m: (lambda a,b,c,d,e,f,g: lap(s,((a,d,n(e),g),(c,e), #the second step, adding up after all of the variable allocations
(b,d),(a,n(b),c,f))))(*map(m,*map(lambda x: map(s,x),(((a[0],a[3]),(a[2],a[3]),(a[0],),(a[3],),(a[0],a[1]),(a[2],n(a[0])),(a[1],n(a[3]))), #the first step with the wacky multiplications
((b[0],b[3]),(b[0],),(b[1],n(b[3])),(b[2],n(b[0])),(b[3],),(b[0],b[1]),(b[2],b[3])))))))(*( (sum,int.__neg__,int.__mul__)
if type(a[0])==int else
((lambda x: reduce(lambda a,b: lap(int.__add__,a,b),x)), #matrices are added and negated elementwise
(lambda x: map(int.__neg__,x)),
strassen))))))) #the function is the matrices' __mul__ method
@DroneBetter
Copy link
Author

It seemed I had initially included an older version as the undissected version accidentally, which had the o function (that I later realised was equivalent to s), sorry, fixed now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment