Skip to content

Instantly share code, notes, and snippets.

@shivance
Created December 31, 2022 11:19
Show Gist options
  • Save shivance/9e3c6b8ee5cf77c3e29ec62e71ea49fc to your computer and use it in GitHub Desktop.
Save shivance/9e3c6b8ee5cf77c3e29ec62e71ea49fc to your computer and use it in GitHub Desktop.
Flux implementation of UNet
function (u::UNet)(x::AbstractArray)
enc_out = []
out = x
for i in 1:4
out = u.encoder[i](out)
push!(enc_out, out)
out = u.pool[i](out)
end
out = u.bottleneck(out)
for i in 4:-1:1
out = u.upconv[5 - i](out)
out = cat(out, enc_out[i]; dims = 3)
out = u.decoder[i](out)
end
return σ(u.final_conv(out))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment