Skip to content

Instantly share code, notes, and snippets.

@pshashk
Last active June 10, 2020 21:03
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pshashk/e01bae2df781c23de6d8b13772be3613 to your computer and use it in GitHub Desktop.
Save pshashk/e01bae2df781c23de6d8b13772be3613 to your computer and use it in GitHub Desktop.
Flux unet
using Flux
downsample(c_in, c_out) = Chain(
Conv((4, 4), c_in => c_out, stride = 2, pad = 1),
BatchNorm(c_out, relu)
)
upsample(c_in, c_out) = Chain(
ConvTranspose((4, 4), c_in => c_out, stride = 2, pad = 1),
BatchNorm(c_out, relu)
)
function UNet(initial_channels, depth)
down_path = map(1:depth) do d
c_in = initial_channels * 2 ^ (d - 1)
downsample(c_in, 2 * c_in)
end
up_path = map(depth:-1:1) do d
c_in = initial_channels * 2 ^ (d + 1)
upsample(d == depth ? c_in ÷ 2 : c_in, c_in ÷ 4)
end
input -> begin
intermediate = TrackedArray[]
foreach(down_path) do layer
input = layer(input)
push!(intermediate, input)
end
output = first(up_path)(pop!(intermediate))
foreach(up_path[2:end]) do layer
output = layer(cat(output, pop!(intermediate), dims = 3))
end
output
end
end
initial_channels = 16
depth = 5
model = UNet(initial_channels, depth)
pars = params((model.down_path..., model.up_path...))
input = rand(Float32, 128, 128, initial_channels, 1)
output = model(input)
gs = gradient(pars) do
sum(abs, input - model(input))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment