Skip to content

Instantly share code, notes, and snippets.

@natema
Created August 20, 2020 10:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save natema/a5fe8da406e9fe1e51127bceec4740c4 to your computer and use it in GitHub Desktop.
Save natema/a5fe8da406e9fe1e51127bceec4740c4 to your computer and use it in GitHub Desktop.
Resnet18 on CIFAR10 in Julia
using Flux
using Flux: @functor
function conv3x3(in_planes, out_planes; stride=1)
Conv((3,3), in_planes => out_planes; pad=1, stride=stride)
end
function conv1x1(in_planes, out_planes; stride=1)
Conv((1,1), in_planes => out_planes; pad=0, stride=stride)
end
struct BasicBlock
conv1::Conv
bn1::BatchNorm
conv2::Conv
bn2::BatchNorm
end
@functor BasicBlock
function BasicBlock(inplanes, planes; stride=1, base_width=64, dilation=1)
BasicBlock(
conv3x3(inplanes, planes; stride=stride),
BatchNorm(planes, relu),
conv3x3(planes, planes),
BatchNorm(planes)
)
end
function (m::BasicBlock)(x)
y = x |> m.conv1 |> m.bn1 |> m.conv2 |> m.bn2
relu.(x + y)
end
struct BasicBlockDownsample
conv1::Conv
bn1::BatchNorm
conv2::Conv
bn2::BatchNorm
downsample::Chain
end
@functor BasicBlockDownsample
function BasicBlockDownsample(inplanes, planes; stride=1, base_width=64, dilation=1, downsample)
BasicBlockDownsample(
conv3x3(inplanes, planes; stride=stride),
BatchNorm(planes, relu),
conv3x3(planes, planes),
BatchNorm(planes),
downsample
)
end
function (m::BasicBlockDownsample)(x)
y = x |> m.conv1 |> m.bn1 |> m.conv2 |> m.bn2
relu.(m.downsample(x) + y)
end
# ignore Bottlenect module
function make_layer(inplanes::Int, planes::Int, blocks::Int; base_width=64, stride=1)
downsample=nothing
layers = Any[]
if stride > 1
downsample = Chain(conv1x1(inplanes, planes; stride=stride),
BatchNorm(planes))
push!(layers, BasicBlockDownsample(inplanes, planes; stride=stride, base_width=base_width, downsample=downsample))
else
push!(layers, BasicBlock(inplanes, planes; stride=stride, base_width=base_width))
end
for i in 2:blocks
push!(layers, BasicBlock(planes, planes; base_width=base_width))
end
Chain(layers...)
end
Base.@kwdef struct ResNet
conv1::Conv
bn1::BatchNorm
maxpool::MaxPool
layer1::Chain
layer2::Chain
layer3::Chain
layer4::Chain
avgpool::GlobalMeanPool
fc::Dense
end
@functor ResNet
function ResNet(layers::Vector{Int}, num_classes::Int)
ResNet(;
conv1 = Conv((7,7), 3 => 64; stride=2, pad=3),
bn1 = BatchNorm(64, relu),
maxpool = MaxPool((3, 3); stride=2, pad=1),
layer1 = make_layer(64, 64, layers[1]),
layer2 = make_layer(64, 128, layers[2]; stride=2),
layer3 = make_layer(128, 256, layers[3]; stride=2),
layer4 = make_layer(256, 512, layers[4]; stride=2),
avgpool = GlobalMeanPool(),
fc = Dense(512, num_classes)
)
end
function(m::ResNet)(x)
x |> m.conv1 |> m.bn1 |> m.maxpool |> m.layer1 |> m.layer2 |> m.layer3 |> m.layer4 |>
m.avgpool |> flatten |> m.fc
end
function Flux.testmode!(m::ResNet, mode = true)
for m in [m.bn1, m.layer1, m.layer2, m.layer3, m.layer4]
Flux.testmode!(m, mode)
end
end
function Flux.testmode!(m::BasicBlock, mode = true)
Flux.testmode!(m.bn1, mode)
Flux.testmode!(m.bn2, mode)
end
function Flux.testmode!(m::BasicBlockDownsample, mode = true)
Flux.testmode!(m.bn1, mode)
Flux.testmode!(m.bn2, mode)
Flux.testmode!(m.downsample, mode) # testmode! for a BN layer which is located on Chain
# downsample_bn = m.downsample.layers[2]
# Flux.testmode!(downsample_bn)
end
@info "Loading libraries"
@info " Loading Flux"
using Flux
using Statistics
using Flux: onehotbatch, crossentropy, Momentum, update!, onecold
@info " Loading MLDatasets"
using MLDatasets: CIFAR10
using Base.Iterators: partition
batchsize = 1000
trainsize = 50000 - batchsize
@info "Loading training data"
trainimgs = CIFAR10.traintensor(Float32);
trainlabels = onehotbatch(CIFAR10.trainlabels(Float32) .+ 1, 1:10);
@info "Building the trainset"
trainset = [(trainimgs[:,:,:,i], trainlabels[:,i]) for i in partition(1:trainsize, batchsize)];
batchnum = size(trainset)[1]
@info "Loading validation data"
valset = (trainsize+1):(trainsize+batchsize)
valX = trainimgs[:,:,:,valset] |> gpu;
valY = trainlabels[:, valset] |> gpu;
loss(x, y) = sum(crossentropy(m(x), y))
opt = Momentum(0.01)
max_pred(x) = [findmax(m(x[:,:,:,i:i]))[2][1] for i in 1:(size(x)[4])] |> gpu
max_lab(y) = [findmax(y[:,i])[2] for i in 1:(size(y)[2])] |> gpu
accuracy(x, y) = mean(max_pred(x) .== max_lab(y)) |> gpu
@info "Loading the model"
include("yiyu-resnet.jl")
m = ResNet([2,2,2,2], 10) |> gpu; #ResNet18
epochs = 10
for epoch = 1:epochs
@info "epoch" epoch
for i in 1:batchnum
batch = trainset[i] |> gpu
gs = gradient(params(m)) do
l = loss(batch...)
end
@info "batch fraction" i/batchnum
update!(opt, params(m), gs)
end
@show accuracy(valX, valY)
end
@info "Loading test data"
testimgs = CIFAR10.testtensor(Float32);
testlabels = onehotbatch(CIFAR10.testlabels(Float32) .+ 1, 1:10);
testset = [(testimgs[:,:,:,i], testlabels[:,i]) for i in partition(1:10000, batchsize)] |> gpu;
class_correct = zeros(10)
class_total = zeros(10)
for i in 1:(10000/batchsize)
@info "Evaluating testset batch " i
preds = m(testset[i][1])
lab = testset[i][2]
for j = 1:batchsize
pred_class = findmax(preds[:, j])[2]
actual_class = findmax(lab[:, j])[2]
if pred_class == actual_class
class_correct[pred_class] += 1
end
class_total[actual_class] += 1
end
end
class_correct ./ class_total
@natema
Copy link
Author

natema commented Aug 20, 2020

yiyu-resnet.jl is copied from @yiyuezhuo's resnet.jl.

@natema
Copy link
Author

natema commented Aug 22, 2020

See this discussion on the Julia forum.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment