Created
March 6, 2022 23:41
-
-
Save GiggleLiu/62c4d4a0c54855fdf4a0456ad82fd6f5 to your computer and use it in GitHub Desktop.
Compute matrix multiplication `C = A * B` using the strassen algorithm.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra: mul! | |
""" | |
strassen!(C, A, B, s0=8) | |
Compute matrix multiplication `C = A * B` using the strassen algorithm. | |
`s0` is the critical size to stop Strassen recursion, it can be slow if it recurse too deep, | |
then one needs to set this value larger. | |
!!!note | |
I write this function because it might be useful | |
for some special generic element types that multiplication can be expensive. | |
It does not have speed up for regular element types! | |
""" | |
function strassen!(C::AbstractMatrix{T}, A, B, s0=8) where T | |
# size check and determine the stop condition | |
m, n = size(C) | |
p = size(A, 2) | |
@assert m == size(A, 1) && n == size(B, 2) && size(B, 1) == p | |
if p <= s0 || m <= s0 || n <= s0 | |
return mul!(C, A, B) | |
end | |
# handle odd shapes and rectangular shapes | |
s = min(m, n, p) | |
s -= s & 1 | |
hs = s >> 1 | |
rm = m - s # not even | |
rn = n - s | |
rp = p - s | |
@inbounds begin | |
if rm != 0 || rn !=0 || rp != 0 | |
strassen!(view(C,1:s,1:s), view(A,1:s,1:s), view(B,1:s,1:s), s0) | |
rp != 0 && (view(C,1:s,1:s) .+= view(A,1:s,s+1:p) * B[s+1:p,1:s]) | |
view(C,s+1:m,1:s) .= Ref(zero(T)) | |
view(C,:,s+1:n) .= Ref(zero(T)) | |
rm != 0 && (view(C,s+1:m,1:s) .+= view(A,s+1:m,:) * view(B,:,1:s)) | |
rn != 0 && (view(C,1:s,s+1:n) .+= view(A,1:s,:) * view(B,:,s+1:n)) | |
(rm != 0 && rn != 0) && (view(C,s+1:m,s+1:n) .+= view(A,s+1:m,:) * view(B,:,s+1:n)) | |
return C | |
end | |
# main block | |
A11, A21, A12, A22 = devide22(A) | |
B11, B21, B12, B22 = devide22(B) | |
M = similar(C, hs, hs, 7) | |
strassen!(view(M, :, :, 1), A11 .+ A22, B11 .+ B22, s0) | |
strassen!(view(M, :, :, 2), A21 .+ A22, B11, s0) | |
strassen!(view(M, :, :, 3), A11, B12 .- B22, s0) | |
strassen!(view(M, :, :, 4), A22, B21 .- B11, s0) | |
strassen!(view(M, :, :, 5), A11 .+ A12, B22, s0) | |
strassen!(view(M, :, :, 6), A21 .- A11, B11 .+ B12, s0) | |
strassen!(view(M, :, :, 7), A12 .- A22, B21 .+ B22, s0) | |
# fill the results back | |
view(C,1:hs, 1:hs) .= view(M,:,:,1) .+ view(M,:,:,4) .- view(M,:,:,5) .+ view(M,:,:,7) | |
view(C,1:hs, hs+1:s) .= view(M,:,:,3) .+ view(M,:,:,5) | |
view(C,hs+1:s, 1:hs) .= view(M,:,:,2) .+ view(M,:,:,4) | |
view(C,hs+1:s, hs+1:s) .= view(M,:,:,1) .- view(M,:,:,2) .+ view(M,:,:,3) .+ view(M,:,:,6) | |
end | |
return C | |
end | |
# devide a matrix to four blocks | |
function devide22(A) | |
s = size(A, 1) | |
hs = s >> 1 | |
view(A, 1:hs, 1:hs), view(A, hs+1:s, 1:hs), view(A, 1:hs, hs+1:s), view(A, hs+1:s, hs+1:s) | |
end | |
using Test | |
@testset "strassen" begin | |
for (m, p, n) in [(13, 15, 9), (13, 15, 10), (13,16,9), (14,15,9), | |
(13, 16, 10), (14, 15, 10), (13, 15, 9), (14, 16, 10), (128, 128, 128) | |
] | |
A, B = randn(m, p), randn(p, n) | |
C = similar(A, m, n) | |
r, q = strassen!(C, A, B, 1), A * B | |
@test isapprox(r, q, atol=1e-6) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment