Skip to content

Instantly share code, notes, and snippets.

@darsnack

darsnack/resnet.jl

Last active Jun 23, 2020
Embed
What would you like to do?
struct OptionAIdentity
pool::MaxPool
end
OptionAIdentity(scale::Integer) = OptionAIdentity(MaxPool((scale, scale)))
OptionAIdentity(scale::Tuple{<:Integer, <:Integer}) = OptionAIdentity(MaxPool(scale))
function (op::OptionAIdentity)(x, y)
z = op.pool(y)
npadchannels = size(x, 3) - size(z, 3)
return (npadchannels > 0) ? cat(z, zeros(Float32, size(z, 1), size(z, 2), npadchannels, size(z, 4)); dims = 3) : z
end
function resnet(width, height, ichannels, nclasses; nlayers = 20)
n = Int((nlayers - 2) / 6)
ksize = (3, 3)
ssize = 2
# channel sizes passing through conv layers
lsizes = vcat([ichannels],
fill(16, 1 + 2 * n),
fill(32, 2 * n),
fill(64, 2 * n))
# output sizes passing through conv layers
osizes = vcat(fill(32, 1 + 2 * n), fill(16, 2 * n), fill(8, 2 * n))
# initialize
convs = []
ffheight = height
ffwidth = width
# determine initial padding (guaranteeing first output is 32x32)
hpad = _calculatepadding(ffheight, osizes[1]; ksize = ksize[1], ssize = ssize)
wpad = _calculatepadding(ffwidth, osizes[1]; ksize = ksize[1], ssize = ssize)
# first convolution
push!(convs, Conv(ksize, lsizes[1] => lsizes[2], stride = ssize, pad = (hpad, wpad)))
push!(convs, BatchNorm(lsizes[2], relu))
ffheight = convsizeout(ffheight; ksize = ksize[1], ssize = ssize, padding = hpad)
ffwidth = convsizeout(ffwidth; ksize = ksize[1], ssize = ssize, padding = wpad)
# println("Height = $ffheight, Width = $ffwidth")
# remaining convolutions (w/ identity skip connections)
for i in 3:2:(length(lsizes) - 1)
hscale = ffheight
wscale = ffwidth
# calculate sizing
hpad₁ = _calculatepadding(ffheight, osizes[i - 1]; ksize = ksize[1], ssize = ssize)
wpad₁ = _calculatepadding(ffwidth, osizes[i - 1]; ksize = ksize[1], ssize = ssize)
ffheight = convsizeout(ffheight; ksize = ksize[1], ssize = ssize, padding = hpad₁)
ffwidth = convsizeout(ffwidth; ksize = ksize[1], ssize = ssize, padding = wpad₁)
# println("Height = $ffheight, Width = $ffwidth")
hpad₂ = _calculatepadding(ffheight, osizes[i]; ksize = ksize[1], ssize = ssize)
wpad₂ = _calculatepadding(ffwidth, osizes[i]; ksize = ksize[1], ssize = ssize)
ffheight = convsizeout(ffheight; ksize = ksize[1], ssize = ssize, padding = hpad₂)
ffwidth = convsizeout(ffwidth; ksize = ksize[1], ssize = ssize, padding = wpad₂)
# println("Height = $ffheight, Width = $ffwidth")
hscale = Int(hscale / ffheight)
wscale = Int(wscale / ffwidth)
resblock = Chain(
Conv(ksize, lsizes[i - 1] => lsizes[i], stride = ssize, pad = (hpad₁, wpad₁)),
BatchNorm(lsizes[i], relu),
Conv(ksize, lsizes[i] => lsizes[i + 1], stride = ssize, pad = (hpad₂, wpad₂)),
BatchNorm(lsizes[i + 1])
)
if i % (2 * n) == 3
push!(convs, SkipConnection(resblock, OptionAIdentity((hscale, wscale))))
else
push!(convs, SkipConnection(resblock, +))
end
push!(convs, x -> relu.(x))
end
# pooling + fc
Chain(
convs...,
MeanPool((osizes[end], osizes[end])), # output map feature size should be 8 x 8
x -> reshape(x, :, size(x, 4)),
Dense(lsizes[end], nclasses)
)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment