Skip to content

Instantly share code, notes, and snippets.

@Alexander-Barth
Created February 7, 2022 15:40
Show Gist options
  • Save Alexander-Barth/3f17e6f59daebe94f548833fa3d01f0e to your computer and use it in GitHub Desktop.
Save Alexander-Barth/3f17e6f59daebe94f548833fa3d01f0e to your computer and use it in GitHub Desktop.
PyTorch-resnet-for-flux
# needs julia 1.6 as long as this issue is open:
# https://github.com/JuliaIO/BSON.jl/issues/107
# License MIT
using Flux
import Metalhead
using DataStructures
using Statistics
using BSON
using PyCall
using Images
using Test
function block((in_channels,intermediate_channels), identity_downsample = identity;
stride = 1)
expansion = 4
c = Parallel(
(a,b) -> relu.(a + b),
Chain(
Conv((1,1),in_channels => intermediate_channels,stride=1,pad = 0,bias=false),
BatchNorm(intermediate_channels,relu),
Conv((3,3),intermediate_channels => intermediate_channels,stride=stride,pad = 1,bias=false),
BatchNorm(intermediate_channels,relu),
Conv((1,1),intermediate_channels => expansion*intermediate_channels,stride=1,pad = 0,bias=false),
BatchNorm(expansion*intermediate_channels),
),
identity_downsample
)
return c
end
function make_layer(in_channels, block, num_residual_blocks, intermediate_channels; stride=1)
identity_downsample = identity
layers = []
if (stride != 1) || (in_channels != intermediate_channels * 4)
identity_downsample = Chain(
Conv((1,1),in_channels => intermediate_channels * 4,stride = stride,bias=false),
BatchNorm(4intermediate_channels))
end
push!(layers,block(in_channels => intermediate_channels, identity_downsample, stride=stride))
in_channels = 4 * intermediate_channels
for i = 1:(num_residual_blocks-1)
# channels: 256 -> 64, then 4*64 again
push!(
layers,
block(in_channels => intermediate_channels)
)
end
return Chain(layers...)
end
function resnet(block, layers, image_channels, num_classes)
in_channels = 64
c = Chain(
Conv((7,7),image_channels => in_channels, pad = 3, stride = 2,bias=false),
BatchNorm(64,relu),
MaxPool((3,3),stride=2,pad=1)
)
_intermediate_channels = [64, 128, 256, 512]
_in_channels = vcat([in_channels],_intermediate_channels[1:end-1]*4)
return Chain(
c,
make_layer(_in_channels[1],block, layers[1], 64, stride=1),
make_layer(_in_channels[2],block, layers[2], 128, stride=2),
make_layer(_in_channels[3],block, layers[3], 256, stride=2),
make_layer(_in_channels[4],block, layers[4], 512, stride=2),
AdaptiveMeanPool((1,1)),
flatten,
Dense(512*4, num_classes),
)
end
function ResNet50(;img_channels = 3, num_classes = 1000)
return resnet(block, [3, 4, 6, 3], img_channels, num_classes)
end
function ResNet101(;img_channels = 3, num_classes = 1000)
return resnet(block, [3, 4, 23, 3], img_channels, num_classes)
end
function ResNet152(;img_channels = 3, num_classes = 1000)
return resnet(block, [3, 8, 36, 3], img_channels, num_classes)
end
torchvision = pyimport("torchvision")
torch = pyimport("torch")
modellib = [
(ResNet50,torchvision.models.resnet50),
(ResNet101,torchvision.models.resnet101),
(ResNet152,torchvision.models.resnet152),
]
jlmodel,pymodel = modellib[3]
pytorchmodel = pymodel(pretrained=true)
model = jlmodel()
tr(tmp) = permutedims(tmp,ndims(tmp):-1:1)
function _list_state(node::Flux.BatchNorm,channel,prefix)
# use the same order of parameters than PyTorch
put!(channel, (prefix * ".γ", node.γ)) # weigth (learnable)
put!(channel, (prefix * ".β", node.β)) # bias (learnable)
put!(channel, (prefix * ".μ", node.μ)) # running mean
put!(channel, (prefix * ".σ²", node.σ²)) # running variance
end
function _list_state(node::Union{Flux.Conv,Flux.Dense},channel,prefix)
put!(channel, (prefix * ".weight", node.weight))
if node.bias !== Flux.Zeros()
put!(channel, (prefix * ".bias", node.bias))
end
end
_list_state(node,channel,prefix) = nothing
function _list_state(node::Union{Flux.Chain,Flux.Parallel},channel,prefix)
for (i,n) in enumerate(node.layers)
_list_state(n,channel,prefix * ".layers[$i]")
end
end
function list_state(node; prefix = "model")
Channel() do channel
_list_state(node,channel,prefix)
end
end
state = OrderedDict(list_state(model))
# pytorchmodel.state_dict() looses the order
state_dict = OrderedDict(pycall(pytorchmodel.state_dict,PyObject).items())
pytorch_pp = OrderedDict((k,v.numpy()) for (k,v) in state_dict if !occursin("num_batches_tracked",k))
# loop over all parameters
for ((flux_key,flux_param),(pytorch_key,pytorch_param)) in zip(state,pytorch_pp)
if size(flux_param) == size(pytorch_param)
# Dense weight and vectors
flux_param .= pytorch_param
elseif size(flux_param) == reverse(size(pytorch_param))
tmp = pytorch_param
tmp = permutedims(tmp,ndims(tmp):-1:1)
if ndims(flux_param) == 4
# convolutional weights
flux_param .= reverse(tmp,dims=(1,2))
else
flux_param .= tmp
end
else
@debug begin
@show size(flux_param), size(pytorch_param)
end
error("incompatible shape $flux_key $pytorch_key")
end
end
function normalize(data)
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1))
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1))
return (data .- cmean) ./ cstd
end
Flux.testmode!(model)
@info "saving model"
BSON.@save "$(lowercase(String(Symbol(jlmodel)))).bson" model
guitar_path = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg")
sz = (224, 224)
img = Images.load(guitar_path);
img = imresize(img, sz);
# CHW -> WHC
data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1))
data = normalize(data[:,:,:,1:1])
out = model(data) |> softmax;
out = out[:,1]
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))
println("Flux:")
for i in sortperm(out,rev=true)[1:5]
println("$(labels[i]): $(out[i])")
end
pytorchmodel.eval()
output = pytorchmodel(torch.Tensor(tr(data)));
probabilities = torch.nn.functional.softmax(output[0], dim=0).detach().numpy();
println("PyTorch:")
for i in sortperm(probabilities[:,1],rev=true)[1:5]
println("$(labels[i]): $(probabilities[i])")
end
@test maximum(out) ≈ maximum(probabilities)
@test argmax(out) ≈ argmax(probabilities)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment