Created
February 7, 2022 15:40
-
-
Save Alexander-Barth/3f17e6f59daebe94f548833fa3d01f0e to your computer and use it in GitHub Desktop.
PyTorch-resnet-for-flux
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# needs julia 1.6 as long as this issue is open: | |
# https://github.com/JuliaIO/BSON.jl/issues/107 | |
# License MIT | |
using Flux | |
import Metalhead | |
using DataStructures | |
using Statistics | |
using BSON | |
using PyCall | |
using Images | |
using Test | |
function block((in_channels,intermediate_channels), identity_downsample = identity; | |
stride = 1) | |
expansion = 4 | |
c = Parallel( | |
(a,b) -> relu.(a + b), | |
Chain( | |
Conv((1,1),in_channels => intermediate_channels,stride=1,pad = 0,bias=false), | |
BatchNorm(intermediate_channels,relu), | |
Conv((3,3),intermediate_channels => intermediate_channels,stride=stride,pad = 1,bias=false), | |
BatchNorm(intermediate_channels,relu), | |
Conv((1,1),intermediate_channels => expansion*intermediate_channels,stride=1,pad = 0,bias=false), | |
BatchNorm(expansion*intermediate_channels), | |
), | |
identity_downsample | |
) | |
return c | |
end | |
function make_layer(in_channels, block, num_residual_blocks, intermediate_channels; stride=1) | |
identity_downsample = identity | |
layers = [] | |
if (stride != 1) || (in_channels != intermediate_channels * 4) | |
identity_downsample = Chain( | |
Conv((1,1),in_channels => intermediate_channels * 4,stride = stride,bias=false), | |
BatchNorm(4intermediate_channels)) | |
end | |
push!(layers,block(in_channels => intermediate_channels, identity_downsample, stride=stride)) | |
in_channels = 4 * intermediate_channels | |
for i = 1:(num_residual_blocks-1) | |
# channels: 256 -> 64, then 4*64 again | |
push!( | |
layers, | |
block(in_channels => intermediate_channels) | |
) | |
end | |
return Chain(layers...) | |
end | |
function resnet(block, layers, image_channels, num_classes) | |
in_channels = 64 | |
c = Chain( | |
Conv((7,7),image_channels => in_channels, pad = 3, stride = 2,bias=false), | |
BatchNorm(64,relu), | |
MaxPool((3,3),stride=2,pad=1) | |
) | |
_intermediate_channels = [64, 128, 256, 512] | |
_in_channels = vcat([in_channels],_intermediate_channels[1:end-1]*4) | |
return Chain( | |
c, | |
make_layer(_in_channels[1],block, layers[1], 64, stride=1), | |
make_layer(_in_channels[2],block, layers[2], 128, stride=2), | |
make_layer(_in_channels[3],block, layers[3], 256, stride=2), | |
make_layer(_in_channels[4],block, layers[4], 512, stride=2), | |
AdaptiveMeanPool((1,1)), | |
flatten, | |
Dense(512*4, num_classes), | |
) | |
end | |
function ResNet50(;img_channels = 3, num_classes = 1000) | |
return resnet(block, [3, 4, 6, 3], img_channels, num_classes) | |
end | |
function ResNet101(;img_channels = 3, num_classes = 1000) | |
return resnet(block, [3, 4, 23, 3], img_channels, num_classes) | |
end | |
function ResNet152(;img_channels = 3, num_classes = 1000) | |
return resnet(block, [3, 8, 36, 3], img_channels, num_classes) | |
end | |
torchvision = pyimport("torchvision") | |
torch = pyimport("torch") | |
modellib = [ | |
(ResNet50,torchvision.models.resnet50), | |
(ResNet101,torchvision.models.resnet101), | |
(ResNet152,torchvision.models.resnet152), | |
] | |
jlmodel,pymodel = modellib[3] | |
pytorchmodel = pymodel(pretrained=true) | |
model = jlmodel() | |
tr(tmp) = permutedims(tmp,ndims(tmp):-1:1) | |
function _list_state(node::Flux.BatchNorm,channel,prefix) | |
# use the same order of parameters than PyTorch | |
put!(channel, (prefix * ".γ", node.γ)) # weigth (learnable) | |
put!(channel, (prefix * ".β", node.β)) # bias (learnable) | |
put!(channel, (prefix * ".μ", node.μ)) # running mean | |
put!(channel, (prefix * ".σ²", node.σ²)) # running variance | |
end | |
function _list_state(node::Union{Flux.Conv,Flux.Dense},channel,prefix) | |
put!(channel, (prefix * ".weight", node.weight)) | |
if node.bias !== Flux.Zeros() | |
put!(channel, (prefix * ".bias", node.bias)) | |
end | |
end | |
_list_state(node,channel,prefix) = nothing | |
function _list_state(node::Union{Flux.Chain,Flux.Parallel},channel,prefix) | |
for (i,n) in enumerate(node.layers) | |
_list_state(n,channel,prefix * ".layers[$i]") | |
end | |
end | |
function list_state(node; prefix = "model") | |
Channel() do channel | |
_list_state(node,channel,prefix) | |
end | |
end | |
state = OrderedDict(list_state(model)) | |
# pytorchmodel.state_dict() looses the order | |
state_dict = OrderedDict(pycall(pytorchmodel.state_dict,PyObject).items()) | |
pytorch_pp = OrderedDict((k,v.numpy()) for (k,v) in state_dict if !occursin("num_batches_tracked",k)) | |
# loop over all parameters | |
for ((flux_key,flux_param),(pytorch_key,pytorch_param)) in zip(state,pytorch_pp) | |
if size(flux_param) == size(pytorch_param) | |
# Dense weight and vectors | |
flux_param .= pytorch_param | |
elseif size(flux_param) == reverse(size(pytorch_param)) | |
tmp = pytorch_param | |
tmp = permutedims(tmp,ndims(tmp):-1:1) | |
if ndims(flux_param) == 4 | |
# convolutional weights | |
flux_param .= reverse(tmp,dims=(1,2)) | |
else | |
flux_param .= tmp | |
end | |
else | |
@debug begin | |
@show size(flux_param), size(pytorch_param) | |
end | |
error("incompatible shape $flux_key $pytorch_key") | |
end | |
end | |
function normalize(data) | |
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1)) | |
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1)) | |
return (data .- cmean) ./ cstd | |
end | |
Flux.testmode!(model) | |
@info "saving model" | |
BSON.@save "$(lowercase(String(Symbol(jlmodel)))).bson" model | |
guitar_path = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg") | |
sz = (224, 224) | |
img = Images.load(guitar_path); | |
img = imresize(img, sz); | |
# CHW -> WHC | |
data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1)) | |
data = normalize(data[:,:,:,1:1]) | |
out = model(data) |> softmax; | |
out = out[:,1] | |
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) | |
println("Flux:") | |
for i in sortperm(out,rev=true)[1:5] | |
println("$(labels[i]): $(out[i])") | |
end | |
pytorchmodel.eval() | |
output = pytorchmodel(torch.Tensor(tr(data))); | |
probabilities = torch.nn.functional.softmax(output[0], dim=0).detach().numpy(); | |
println("PyTorch:") | |
for i in sortperm(probabilities[:,1],rev=true)[1:5] | |
println("$(labels[i]): $(probabilities[i])") | |
end | |
@test maximum(out) ≈ maximum(probabilities) | |
@test argmax(out) ≈ argmax(probabilities) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment