Skip to content

Instantly share code, notes, and snippets.

@nlw0
Last active November 30, 2019 16:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nlw0/a047700a15941c8c1dbbf2cb52c07a19 to your computer and use it in GitHub Desktop.
Save nlw0/a047700a15941c8c1dbbf2cb52c07a19 to your computer and use it in GitHub Desktop.
UNet in Flux (work in progress!)
## Process data from https://www.kaggle.com/c/data-science-bowl-2018/
## Merge a bunch of binary masks into a single one
using Glob, Images
root = "/home/user/src/data/data-science-bowl-2018/train/"
for case in glob("*", root)
labels = glob("*png", case * "/masks")
mergedlabels = sum(map(load, labels))
save(case * "/mergedmasks.png", mergedlabels)
end
using Pkg
Pkg.activate(".")
using Glob, Flux, Images, ImageView
# U-Net https://arxiv.org/pdf/1505.04597.pdf
# https://medium.com/analytics-vidhya/semantic-segmentation-using-u-net-data-science-bowl-2018-data-set-ed046c2004a5
# Test data: https://www.kaggle.com/c/data-science-bowl-2018/
convconv(in, out) = Chain(
Conv((3,3), in=>out, pad=(1,1), relu),
Conv((3,3), out=>out, pad=(1,1), relu)
)
"""Downscales a layer by maxpooling, apply a given sub-network and upscales back by transposed convolution."""
downup(subnetwork, inout) = Chain(
MaxPool((2,2)),
subnetwork,
ConvTranspose((2,2), inout; stride=(2,2))
)
function unet(x, ::Val{LayerDepth}; first=true) where LayerDepth
if LayerDepth == 64
convconv(32, 64)(x)
else
ini = convconv(if first 1 else LayerDepth ÷ 2 end, LayerDepth)(x)
sub = downup(x->unet(x, Val(LayerDepth*2); first=false), LayerDepth * 2 => LayerDepth)(ini)
final = cat(ini, sub; dims=3)
convconv(LayerDepth * 2, LayerDepth)(final)
end
end
UNet = Chain(x->unet(x, Val(4)), Conv((1,1), 4=>2), Conv((1,1), 2=>1))
# UNet(rand(572,572,2,1))
UNet(rand(256,256,1,1))
root = "/home/user/src/data/data-science-bowl-2018/train/"
function getdatacrop(imgid)
imgname = "$root/$imgid/images/$imgid.png"
maskname = "$root/$imgid/mergedmasks.png"
imgori = Gray.(load(imgname))
maskori = Gray.(load(maskname))
w,h = size(imgori)
sw = rand(0:w-256)
sh = rand(0:h-256)
reshape(imgori[sw+1:sw+256,sh+1:sh+256], 256, 256, 1,1),
reshape(maskori[sw+1:sw+256,sh+1:sh+256], 256, 256, 1,1)
end
## do the backprop dance
loss(x, y) = Flux.mse(UNet(x), y)
errset = [getdatacrop(split(x, "/")[end]) for x in glob("*", root)[1:20:end]]
dataset = (getdatacrop(split(x, "/")[end]) for x in glob("*", root))
evalcb = () -> @info("err", sum(x->loss(x...), errset))
opt = ADAM()
Flux.train!(loss, params(UNet), dataset, opt, cb = Flux.throttle(evalcb, 10))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment