Skip to content

Instantly share code, notes, and snippets.

@JuliaMLTools
Created September 22, 2023 13:29
Show Gist options
  • Save JuliaMLTools/ed17c88e6231bcb3b56049c72f2d1438 to your computer and use it in GitHub Desktop.
Save JuliaMLTools/ed17c88e6231bcb3b56049c72f2d1438 to your computer and use it in GitHub Desktop.
Getting inconsistent matrix multiply results from Flux.Dense function
using Flux
using Random
Random.seed!(12345)
function showdiff(f, n)
a = rand(Float32, n, n) # shape (n,n)
b = rand(Float32, n, n) # shape (n,n)
c = cat(a, b; dims=3) # shape (n,n,2)
a_out = f(a) # shape is (n,n)
c_out = f(c)[:,:,1] # shape is (n,n)
diff_coord = findfirst(iszero, a_out .== c_out)
if !isnothing(diff_coord)
return "$(a_out[diff_coord]) !== $(c_out[diff_coord])"
end
nothing
end
n = 8
d = Dense(n, n)
showdiff(d, n)
VERSION
import Pkg
Pkg.installed()["Flux"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment