Skip to content

Instantly share code, notes, and snippets.

@GiggleLiu
Created March 6, 2022 23:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GiggleLiu/62c4d4a0c54855fdf4a0456ad82fd6f5 to your computer and use it in GitHub Desktop.
Save GiggleLiu/62c4d4a0c54855fdf4a0456ad82fd6f5 to your computer and use it in GitHub Desktop.
Compute matrix multiplication `C = A * B` using the strassen algorithm.
using LinearAlgebra: mul!
"""
strassen!(C, A, B, s0=8)
Compute matrix multiplication `C = A * B` using the strassen algorithm.
`s0` is the critical size to stop Strassen recursion, it can be slow if it recurse too deep,
then one needs to set this value larger.
!!!note
I write this function because it might be useful
for some special generic element types that multiplication can be expensive.
It does not have speed up for regular element types!
"""
function strassen!(C::AbstractMatrix{T}, A, B, s0=8) where T
# size check and determine the stop condition
m, n = size(C)
p = size(A, 2)
@assert m == size(A, 1) && n == size(B, 2) && size(B, 1) == p
if p <= s0 || m <= s0 || n <= s0
return mul!(C, A, B)
end
# handle odd shapes and rectangular shapes
s = min(m, n, p)
s -= s & 1
hs = s >> 1
rm = m - s # not even
rn = n - s
rp = p - s
@inbounds begin
if rm != 0 || rn !=0 || rp != 0
strassen!(view(C,1:s,1:s), view(A,1:s,1:s), view(B,1:s,1:s), s0)
rp != 0 && (view(C,1:s,1:s) .+= view(A,1:s,s+1:p) * B[s+1:p,1:s])
view(C,s+1:m,1:s) .= Ref(zero(T))
view(C,:,s+1:n) .= Ref(zero(T))
rm != 0 && (view(C,s+1:m,1:s) .+= view(A,s+1:m,:) * view(B,:,1:s))
rn != 0 && (view(C,1:s,s+1:n) .+= view(A,1:s,:) * view(B,:,s+1:n))
(rm != 0 && rn != 0) && (view(C,s+1:m,s+1:n) .+= view(A,s+1:m,:) * view(B,:,s+1:n))
return C
end
# main block
A11, A21, A12, A22 = devide22(A)
B11, B21, B12, B22 = devide22(B)
M = similar(C, hs, hs, 7)
strassen!(view(M, :, :, 1), A11 .+ A22, B11 .+ B22, s0)
strassen!(view(M, :, :, 2), A21 .+ A22, B11, s0)
strassen!(view(M, :, :, 3), A11, B12 .- B22, s0)
strassen!(view(M, :, :, 4), A22, B21 .- B11, s0)
strassen!(view(M, :, :, 5), A11 .+ A12, B22, s0)
strassen!(view(M, :, :, 6), A21 .- A11, B11 .+ B12, s0)
strassen!(view(M, :, :, 7), A12 .- A22, B21 .+ B22, s0)
# fill the results back
view(C,1:hs, 1:hs) .= view(M,:,:,1) .+ view(M,:,:,4) .- view(M,:,:,5) .+ view(M,:,:,7)
view(C,1:hs, hs+1:s) .= view(M,:,:,3) .+ view(M,:,:,5)
view(C,hs+1:s, 1:hs) .= view(M,:,:,2) .+ view(M,:,:,4)
view(C,hs+1:s, hs+1:s) .= view(M,:,:,1) .- view(M,:,:,2) .+ view(M,:,:,3) .+ view(M,:,:,6)
end
return C
end
# devide a matrix to four blocks
function devide22(A)
s = size(A, 1)
hs = s >> 1
view(A, 1:hs, 1:hs), view(A, hs+1:s, 1:hs), view(A, 1:hs, hs+1:s), view(A, hs+1:s, hs+1:s)
end
using Test
@testset "strassen" begin
for (m, p, n) in [(13, 15, 9), (13, 15, 10), (13,16,9), (14,15,9),
(13, 16, 10), (14, 15, 10), (13, 15, 9), (14, 16, 10), (128, 128, 128)
]
A, B = randn(m, p), randn(p, n)
C = similar(A, m, n)
r, q = strassen!(C, A, B, 1), A * B
@test isapprox(r, q, atol=1e-6)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment