Created
June 8, 2022 22:37
-
-
Save darsnack/bfb8594cf5fdc702bdacb66586f518ef to your computer and use it in GitHub Desktop.
torchvision to Metalhead port scripts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Compare Flux model from Metalhead to PyTorch model | |
# for a sample image | |
# PyTorch need to be installed | |
# Tested on ResNet and VGG models | |
using Flux | |
import Metalhead | |
using DataStructures | |
using Statistics | |
using BSON | |
using PyCall | |
using Images | |
using Test | |
using MLUtils | |
using Random | |
torchvision = pyimport("torchvision") | |
torch = pyimport("torch") | |
modellib = [ | |
("vgg11", () -> Metalhead.VGG(11), torchvision.models.vgg11), | |
("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13), | |
("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16), | |
("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19), | |
("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18), | |
("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34), | |
("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50), | |
("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101), | |
("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152), | |
] | |
tr(tmp) = permutedims(tmp,ndims(tmp):-1:1) | |
function normalize(data) | |
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1)) | |
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1)) | |
return (data .- cmean) ./ cstd | |
end | |
# test image | |
guitar_path = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg") | |
# image net labels | |
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) | |
for (modelname,jlmodel,pymodel) in modellib | |
println(modelname) | |
model = jlmodel() | |
saved_model = BSON.load(joinpath(@__DIR__, "$(modelname).bson")) | |
Flux.loadmodel!(model,saved_model[:model]) | |
pytorchmodel = pymodel(pretrained=true) | |
Flux.testmode!(model) | |
sz = (224, 224) | |
img = Images.load(guitar_path); | |
img = imresize(img, sz); | |
# CHW -> WHC | |
data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1)) | |
data = normalize(data[:,:,:,1:1]) | |
out = model(data) |> softmax; | |
out = out[:,1] | |
println(" Flux:") | |
for i in sortperm(out,rev=true)[1:5] | |
println(" $(labels[i]): $(out[i])") | |
end | |
pytorchmodel.eval() | |
output = pytorchmodel(torch.Tensor(tr(data))); | |
probabilities = torch.nn.functional.softmax(output[0], dim=0).detach().numpy(); | |
println(" PyTorch:") | |
for i in sortperm(probabilities[:,1],rev=true)[1:5] | |
println(" $(labels[i]): $(probabilities[i])") | |
end | |
@test maximum(out) ≈ maximum(probabilities) | |
@test argmax(out) ≈ argmax(probabilities) | |
println() | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This file is machine-generated - editing it directly is not advised | |
julia_version = "1.7.2" | |
manifest_format = "2.0" | |
[[deps.AbstractFFTs]] | |
deps = ["ChainRulesCore", "LinearAlgebra"] | |
git-tree-sha1 = "6f1d9bc1c08f9f4a8fa92e3ea3cb50153a1b40d4" | |
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" | |
version = "1.1.0" | |
[[deps.Accessors]] | |
deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "Future", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"] | |
git-tree-sha1 = "2a1a240c2656a3198859e0f8f181ae61f294a3fb" | |
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | |
version = "0.1.13" | |
[[deps.Adapt]] | |
deps = ["LinearAlgebra"] | |
git-tree-sha1 = "af92965fb30777147966f58acb05da51c5616b5f" | |
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | |
version = "3.3.3" | |
[[deps.ArgCheck]] | |
git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" | |
uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" | |
version = "2.3.0" | |
[[deps.ArgTools]] | |
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" | |
[[deps.ArnoldiMethod]] | |
deps = ["LinearAlgebra", "Random", "StaticArrays"] | |
git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae" | |
uuid = "ec485272-7323-5ecc-a04f-4719b315124d" | |
version = "0.2.0" | |
[[deps.ArrayInterface]] | |
deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] | |
git-tree-sha1 = "1ee88c4c76caa995a885dc2f22a5d548dfbbc0ba" | |
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | |
version = "3.2.2" | |
[[deps.Artifacts]] | |
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" | |
[[deps.AxisAlgorithms]] | |
deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] | |
git-tree-sha1 = "66771c8d21c8ff5e3a93379480a2307ac36863f7" | |
uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950" | |
version = "1.0.1" | |
[[deps.AxisArrays]] | |
deps = ["Dates", "IntervalSets", "IterTools", "RangeArrays"] | |
git-tree-sha1 = "1dd4d9f5beebac0c03446918741b1a03dc5e5788" | |
uuid = "39de3d68-74b9-583c-8d2d-e117c070f3a9" | |
version = "0.4.6" | |
[[deps.BFloat16s]] | |
deps = ["LinearAlgebra", "Printf", "Random", "Test"] | |
git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" | |
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" | |
version = "0.2.0" | |
[[deps.BSON]] | |
git-tree-sha1 = "306bb5574b0c1c56d7e1207581516c557d105cad" | |
uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" | |
version = "0.3.5" | |
[[deps.BangBang]] | |
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] | |
git-tree-sha1 = "b15a6bc52594f5e4a3b825858d1089618871bf9d" | |
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" | |
version = "0.3.36" | |
[[deps.Base64]] | |
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" | |
[[deps.Baselet]] | |
git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" | |
uuid = "9718e550-a3fa-408a-8086-8db961cd8217" | |
version = "0.1.1" | |
[[deps.CEnum]] | |
git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" | |
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" | |
version = "0.4.2" | |
[[deps.CUDA]] | |
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] | |
git-tree-sha1 = "925a16b909fdae16920c1319feadecffb6695b9d" | |
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" | |
version = "3.10.1" | |
[[deps.Calculus]] | |
deps = ["LinearAlgebra"] | |
git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" | |
uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" | |
version = "0.5.1" | |
[[deps.CatIndices]] | |
deps = ["CustomUnitRanges", "OffsetArrays"] | |
git-tree-sha1 = "a0f80a09780eed9b1d106a1bf62041c2efc995bc" | |
uuid = "aafaddc9-749c-510e-ac4f-586e18779b91" | |
version = "0.2.2" | |
[[deps.ChainRules]] | |
deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"] | |
git-tree-sha1 = "e9023f88b1655ffc6a4aaef2502878e8116151ef" | |
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" | |
version = "1.35.1" | |
[[deps.ChainRulesCore]] | |
deps = ["Compat", "LinearAlgebra", "SparseArrays"] | |
git-tree-sha1 = "9489214b993cd42d17f44c36e359bf6a7c919abf" | |
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | |
version = "1.15.0" | |
[[deps.ChangesOfVariables]] | |
deps = ["ChainRulesCore", "LinearAlgebra", "Test"] | |
git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53" | |
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" | |
version = "0.1.3" | |
[[deps.Clustering]] | |
deps = ["Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "SparseArrays", "Statistics", "StatsBase"] | |
git-tree-sha1 = "75479b7df4167267d75294d14b58244695beb2ac" | |
uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" | |
version = "0.14.2" | |
[[deps.ColorTypes]] | |
deps = ["FixedPointNumbers", "Random"] | |
git-tree-sha1 = "0f4e115f6f34bbe43c19751c90a38b2f380637b9" | |
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" | |
version = "0.11.3" | |
[[deps.ColorVectorSpace]] | |
deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] | |
git-tree-sha1 = "d08c20eef1f2cbc6e60fd3612ac4340b89fea322" | |
uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" | |
version = "0.9.9" | |
[[deps.Colors]] | |
deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] | |
git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" | |
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" | |
version = "0.12.8" | |
[[deps.CommonSubexpressions]] | |
deps = ["MacroTools", "Test"] | |
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" | |
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" | |
version = "0.3.0" | |
[[deps.Compat]] | |
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] | |
git-tree-sha1 = "9be8be1d8a6f44b96482c8af52238ea7987da3e3" | |
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" | |
version = "3.45.0" | |
[[deps.CompilerSupportLibraries_jll]] | |
deps = ["Artifacts", "Libdl"] | |
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" | |
[[deps.CompositionsBase]] | |
git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" | |
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" | |
version = "0.1.1" | |
[[deps.ComputationalResources]] | |
git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" | |
uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" | |
version = "0.3.2" | |
[[deps.Conda]] | |
deps = ["Downloads", "JSON", "VersionParsing"] | |
git-tree-sha1 = "6e47d11ea2776bc5627421d59cdcc1296c058071" | |
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" | |
version = "1.7.0" | |
[[deps.ConstructionBase]] | |
deps = ["LinearAlgebra"] | |
git-tree-sha1 = "f74e9d5388b8620b4cee35d4c5a618dd4dc547f4" | |
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" | |
version = "1.3.0" | |
[[deps.ContextVariablesX]] | |
deps = ["Compat", "Logging", "UUIDs"] | |
git-tree-sha1 = "8ccaa8c655bc1b83d2da4d569c9b28254ababd6e" | |
uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" | |
version = "0.1.2" | |
[[deps.CoordinateTransformations]] | |
deps = ["LinearAlgebra", "StaticArrays"] | |
git-tree-sha1 = "681ea870b918e7cff7111da58791d7f718067a19" | |
uuid = "150eb455-5306-5404-9cee-2592286d6298" | |
version = "0.6.2" | |
[[deps.CustomUnitRanges]] | |
git-tree-sha1 = "1a3f97f907e6dd8983b744d2642651bb162a3f7a" | |
uuid = "dc8bdbbb-1ca9-579f-8c36-e416f6a65cce" | |
version = "1.0.2" | |
[[deps.DataAPI]] | |
git-tree-sha1 = "fb5f5316dd3fd4c5e7c30a24d50643b73e37cd40" | |
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" | |
version = "1.10.0" | |
[[deps.DataStructures]] | |
deps = ["Compat", "InteractiveUtils", "OrderedCollections"] | |
git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" | |
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" | |
version = "0.18.13" | |
[[deps.DataValueInterfaces]] | |
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" | |
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" | |
version = "1.0.0" | |
[[deps.Dates]] | |
deps = ["Printf"] | |
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" | |
[[deps.DefineSingletons]] | |
git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" | |
uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" | |
version = "0.1.2" | |
[[deps.DelimitedFiles]] | |
deps = ["Mmap"] | |
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" | |
[[deps.DiffResults]] | |
deps = ["StaticArrays"] | |
git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" | |
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" | |
version = "1.0.3" | |
[[deps.DiffRules]] | |
deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] | |
git-tree-sha1 = "28d605d9a0ac17118fe2c5e9ce0fbb76c3ceb120" | |
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" | |
version = "1.11.0" | |
[[deps.Distances]] | |
deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] | |
git-tree-sha1 = "3258d0659f812acde79e8a74b11f17ac06d0ca04" | |
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" | |
version = "0.10.7" | |
[[deps.Distributed]] | |
deps = ["Random", "Serialization", "Sockets"] | |
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" | |
[[deps.DocStringExtensions]] | |
deps = ["LibGit2"] | |
git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" | |
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | |
version = "0.8.6" | |
[[deps.Downloads]] | |
deps = ["ArgTools", "LibCURL", "NetworkOptions"] | |
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" | |
[[deps.DualNumbers]] | |
deps = ["Calculus", "NaNMath", "SpecialFunctions"] | |
git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" | |
uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" | |
version = "0.6.8" | |
[[deps.ExprTools]] | |
git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" | |
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" | |
version = "0.1.8" | |
[[deps.FFTViews]] | |
deps = ["CustomUnitRanges", "FFTW"] | |
git-tree-sha1 = "cbdf14d1e8c7c8aacbe8b19862e0179fd08321c2" | |
uuid = "4f61f5a4-77b1-5117-aa51-3ab5ef4ef0cd" | |
version = "0.3.2" | |
[[deps.FFTW]] | |
deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] | |
git-tree-sha1 = "505876577b5481e50d089c1c68899dfb6faebc62" | |
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" | |
version = "1.4.6" | |
[[deps.FFTW_jll]] | |
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] | |
git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea" | |
uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" | |
version = "3.3.10+0" | |
[[deps.FLoops]] | |
deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] | |
git-tree-sha1 = "4391d3ed58db9dc5a9883b23a0578316b4798b1f" | |
uuid = "cc61a311-1640-44b5-9fba-1b764f453329" | |
version = "0.2.0" | |
[[deps.FLoopsBase]] | |
deps = ["ContextVariablesX"] | |
git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" | |
uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" | |
version = "0.1.1" | |
[[deps.FileIO]] | |
deps = ["Pkg", "Requires", "UUIDs"] | |
git-tree-sha1 = "9267e5f50b0e12fdfd5a2455534345c4cf2c7f7a" | |
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" | |
version = "1.14.0" | |
[[deps.FillArrays]] | |
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] | |
git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" | |
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" | |
version = "0.13.2" | |
[[deps.FixedPointNumbers]] | |
deps = ["Statistics"] | |
git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" | |
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" | |
version = "0.8.4" | |
[[deps.Flux]] | |
deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"] | |
git-tree-sha1 = "62350a872545e1369b1d8f11358a21681aa73929" | |
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" | |
version = "0.13.3" | |
[[deps.FoldsThreads]] | |
deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] | |
git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" | |
uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" | |
version = "0.1.1" | |
[[deps.ForwardDiff]] | |
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] | |
git-tree-sha1 = "2f18915445b248731ec5db4e4a17e451020bf21e" | |
uuid = "f6369f11-7733-5829-9624-2563aa707210" | |
version = "0.10.30" | |
[[deps.FunctionWrappers]] | |
git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc" | |
uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" | |
version = "1.1.2" | |
[[deps.Functors]] | |
git-tree-sha1 = "223fffa49ca0ff9ce4f875be001ffe173b2b7de4" | |
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | |
version = "0.2.8" | |
[[deps.Future]] | |
deps = ["Random"] | |
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" | |
[[deps.GPUArrays]] | |
deps = ["Adapt", "LLVM", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] | |
git-tree-sha1 = "c783e8883028bf26fb05ed4022c450ef44edd875" | |
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" | |
version = "8.3.2" | |
[[deps.GPUCompiler]] | |
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] | |
git-tree-sha1 = "d8c5999631e1dc18d767883f621639c838f8e632" | |
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" | |
version = "0.15.2" | |
[[deps.Ghostscript_jll]] | |
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] | |
git-tree-sha1 = "78e2c69783c9753a91cdae88a8d432be85a2ab5e" | |
uuid = "61579ee1-b43e-5ca0-a5da-69d92c66a64b" | |
version = "9.55.0+0" | |
[[deps.Graphics]] | |
deps = ["Colors", "LinearAlgebra", "NaNMath"] | |
git-tree-sha1 = "d61890399bc535850c4bf08e4e0d3a7ad0f21cbd" | |
uuid = "a2bd30eb-e257-5431-a919-1863eab51364" | |
version = "1.1.2" | |
[[deps.Graphs]] | |
deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] | |
git-tree-sha1 = "4888af84657011a65afc7a564918d281612f983a" | |
uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" | |
version = "1.7.0" | |
[[deps.IRTools]] | |
deps = ["InteractiveUtils", "MacroTools", "Test"] | |
git-tree-sha1 = "af14a478780ca78d5eb9908b263023096c2b9d64" | |
uuid = "7869d1d1-7146-5819-86e3-90919afe41df" | |
version = "0.4.6" | |
[[deps.IfElse]] | |
git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" | |
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" | |
version = "0.1.1" | |
[[deps.ImageAxes]] | |
deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"] | |
git-tree-sha1 = "c54b581a83008dc7f292e205f4c409ab5caa0f04" | |
uuid = "2803e5a7-5153-5ecf-9a86-9b4c37f5f5ac" | |
version = "0.6.10" | |
[[deps.ImageBase]] | |
deps = ["ImageCore", "Reexport"] | |
git-tree-sha1 = "b51bb8cae22c66d0f6357e3bcb6363145ef20835" | |
uuid = "c817782e-172a-44cc-b673-b171935fbb9e" | |
version = "0.1.5" | |
[[deps.ImageContrastAdjustment]] | |
deps = ["ImageCore", "ImageTransformations", "Parameters"] | |
git-tree-sha1 = "0d75cafa80cf22026cea21a8e6cf965295003edc" | |
uuid = "f332f351-ec65-5f6a-b3d1-319c6670881a" | |
version = "0.3.10" | |
[[deps.ImageCore]] | |
deps = ["AbstractFFTs", "ColorVectorSpace", "Colors", "FixedPointNumbers", "Graphics", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "Reexport"] | |
git-tree-sha1 = "9a5c62f231e5bba35695a20988fc7cd6de7eeb5a" | |
uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" | |
version = "0.9.3" | |
[[deps.ImageDistances]] | |
deps = ["Distances", "ImageCore", "ImageMorphology", "LinearAlgebra", "Statistics"] | |
git-tree-sha1 = "7a20463713d239a19cbad3f6991e404aca876bda" | |
uuid = "51556ac3-7006-55f5-8cb3-34580c88182d" | |
version = "0.2.15" | |
[[deps.ImageFiltering]] | |
deps = ["CatIndices", "ComputationalResources", "DataStructures", "FFTViews", "FFTW", "ImageBase", "ImageCore", "LinearAlgebra", "OffsetArrays", "Reexport", "SparseArrays", "StaticArrays", "Statistics", "TiledIteration"] | |
git-tree-sha1 = "15bd05c1c0d5dbb32a9a3d7e0ad2d50dd6167189" | |
uuid = "6a3955dd-da59-5b1f-98d4-e7296123deb5" | |
version = "0.7.1" | |
[[deps.ImageIO]] | |
deps = ["FileIO", "IndirectArrays", "JpegTurbo", "LazyModules", "Netpbm", "OpenEXR", "PNGFiles", "QOI", "Sixel", "TiffImages", "UUIDs"] | |
git-tree-sha1 = "d9a03ffc2f6650bd4c831b285637929d99a4efb5" | |
uuid = "82e4d734-157c-48bb-816b-45c225c6df19" | |
version = "0.6.5" | |
[[deps.ImageMagick]] | |
deps = ["FileIO", "ImageCore", "ImageMagick_jll", "InteractiveUtils", "Libdl", "Pkg", "Random"] | |
git-tree-sha1 = "5bc1cb62e0c5f1005868358db0692c994c3a13c6" | |
uuid = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" | |
version = "1.2.1" | |
[[deps.ImageMagick_jll]] | |
deps = ["Artifacts", "Ghostscript_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pkg", "Zlib_jll", "libpng_jll"] | |
git-tree-sha1 = "f025b79883f361fa1bd80ad132773161d231fd9f" | |
uuid = "c73af94c-d91f-53ed-93a7-00f77d67a9d7" | |
version = "6.9.12+2" | |
[[deps.ImageMetadata]] | |
deps = ["AxisArrays", "ImageAxes", "ImageBase", "ImageCore"] | |
git-tree-sha1 = "36cbaebed194b292590cba2593da27b34763804a" | |
uuid = "bc367c6b-8a6b-528e-b4bd-a4b897500b49" | |
version = "0.9.8" | |
[[deps.ImageMorphology]] | |
deps = ["ImageCore", "LinearAlgebra", "Requires", "TiledIteration"] | |
git-tree-sha1 = "e7c68ab3df4a75511ba33fc5d8d9098007b579a8" | |
uuid = "787d08f9-d448-5407-9aad-5290dd7ab264" | |
version = "0.3.2" | |
[[deps.ImageQualityIndexes]] | |
deps = ["ImageContrastAdjustment", "ImageCore", "ImageDistances", "ImageFiltering", "OffsetArrays", "Statistics"] | |
git-tree-sha1 = "1d2d73b14198d10f7f12bf7f8481fd4b3ff5cd61" | |
uuid = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9" | |
version = "0.3.0" | |
[[deps.ImageSegmentation]] | |
deps = ["Clustering", "DataStructures", "Distances", "Graphs", "ImageCore", "ImageFiltering", "ImageMorphology", "LinearAlgebra", "MetaGraphs", "RegionTrees", "SimpleWeightedGraphs", "StaticArrays", "Statistics"] | |
git-tree-sha1 = "36832067ea220818d105d718527d6ed02385bf22" | |
uuid = "80713f31-8817-5129-9cf8-209ff8fb23e1" | |
version = "1.7.0" | |
[[deps.ImageShow]] | |
deps = ["Base64", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] | |
git-tree-sha1 = "b563cf9ae75a635592fc73d3eb78b86220e55bd8" | |
uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" | |
version = "0.3.6" | |
[[deps.ImageTransformations]] | |
deps = ["AxisAlgorithms", "ColorVectorSpace", "CoordinateTransformations", "ImageBase", "ImageCore", "Interpolations", "OffsetArrays", "Rotations", "StaticArrays"] | |
git-tree-sha1 = "42fe8de1fe1f80dab37a39d391b6301f7aeaa7b8" | |
uuid = "02fcd773-0e25-5acc-982a-7f6622650795" | |
version = "0.9.4" | |
[[deps.Images]] | |
deps = ["Base64", "FileIO", "Graphics", "ImageAxes", "ImageBase", "ImageContrastAdjustment", "ImageCore", "ImageDistances", "ImageFiltering", "ImageIO", "ImageMagick", "ImageMetadata", "ImageMorphology", "ImageQualityIndexes", "ImageSegmentation", "ImageShow", "ImageTransformations", "IndirectArrays", "IntegralArrays", "Random", "Reexport", "SparseArrays", "StaticArrays", "Statistics", "StatsBase", "TiledIteration"] | |
git-tree-sha1 = "03d1301b7ec885b266c0f816f338368c6c0b81bd" | |
uuid = "916415d5-f1e6-5110-898d-aaa5f9f070e0" | |
version = "0.25.2" | |
[[deps.Imath_jll]] | |
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] | |
git-tree-sha1 = "87f7662e03a649cffa2e05bf19c303e168732d3e" | |
uuid = "905a6f67-0a94-5f89-b386-d35d92009cd1" | |
version = "3.1.2+0" | |
[[deps.IndirectArrays]] | |
git-tree-sha1 = "012e604e1c7458645cb8b436f8fba789a51b257f" | |
uuid = "9b13fd28-a010-5f03-acff-a1bbcff69959" | |
version = "1.0.0" | |
[[deps.Inflate]] | |
git-tree-sha1 = "f5fc07d4e706b84f72d54eedcc1c13d92fb0871c" | |
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" | |
version = "0.1.2" | |
[[deps.InitialValues]] | |
git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" | |
uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" | |
version = "0.3.1" | |
[[deps.IntegralArrays]] | |
deps = ["ColorTypes", "FixedPointNumbers", "IntervalSets"] | |
git-tree-sha1 = "be8e690c3973443bec584db3346ddc904d4884eb" | |
uuid = "1d092043-8f09-5a30-832f-7509e371ab51" | |
version = "0.1.5" | |
[[deps.IntelOpenMP_jll]] | |
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] | |
git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c" | |
uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" | |
version = "2018.0.3+2" | |
[[deps.InteractiveUtils]] | |
deps = ["Markdown"] | |
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" | |
[[deps.Interpolations]] | |
deps = ["AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] | |
git-tree-sha1 = "b7bc05649af456efc75d178846f47006c2c4c3c7" | |
uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" | |
version = "0.13.6" | |
[[deps.IntervalSets]] | |
deps = ["Dates", "Random", "Statistics"] | |
git-tree-sha1 = "57af5939800bce15980bddd2426912c4f83012d8" | |
uuid = "8197267c-284f-5f27-9208-e0e47529a953" | |
version = "0.7.1" | |
[[deps.InverseFunctions]] | |
deps = ["Test"] | |
git-tree-sha1 = "b3364212fb5d870f724876ffcd34dd8ec6d98918" | |
uuid = "3587e190-3f89-42d0-90ee-14403ec27112" | |
version = "0.1.7" | |
[[deps.IrrationalConstants]] | |
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" | |
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" | |
version = "0.1.1" | |
[[deps.IterTools]] | |
git-tree-sha1 = "fa6287a4469f5e048d763df38279ee729fbd44e5" | |
uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" | |
version = "1.4.0" | |
[[deps.IteratorInterfaceExtensions]] | |
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" | |
uuid = "82899510-4779-5014-852e-03e436cf321d" | |
version = "1.0.0" | |
[[deps.JLD2]] | |
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "Printf", "Reexport", "TranscodingStreams", "UUIDs"] | |
git-tree-sha1 = "81b9477b49402b47fbe7f7ae0b252077f53e4a08" | |
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" | |
version = "0.4.22" | |
[[deps.JLLWrappers]] | |
deps = ["Preferences"] | |
git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" | |
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" | |
version = "1.4.1" | |
[[deps.JSON]] | |
deps = ["Dates", "Mmap", "Parsers", "Unicode"] | |
git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" | |
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" | |
version = "0.21.3" | |
[[deps.JpegTurbo]] | |
deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] | |
git-tree-sha1 = "a77b273f1ddec645d1b7c4fd5fb98c8f90ad10a5" | |
uuid = "b835a17e-a41a-41e7-81f0-2f016b05efe0" | |
version = "0.1.1" | |
[[deps.JpegTurbo_jll]] | |
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] | |
git-tree-sha1 = "b53380851c6e6664204efb2e62cd24fa5c47e4ba" | |
uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" | |
version = "2.1.2+0" | |
[[deps.JuliaVariables]] | |
deps = ["MLStyle", "NameResolution"] | |
git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" | |
uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" | |
version = "0.2.4" | |
[[deps.LERC_jll]] | |
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] | |
git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" | |
uuid = "88015f11-f218-50d7-93a8-a6af411a945d" | |
version = "3.0.0+1" | |
[[deps.LLVM]] | |
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] | |
git-tree-sha1 = "e7e9184b0bf0158ac4e4aa9daf00041b5909bf1a" | |
uuid = "929cbde3-209d-540e-8aea-75f648917ca0" | |
version = "4.14.0" | |
[[deps.LLVMExtra_jll]] | |
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] | |
git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" | |
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" | |
version = "0.0.16+0" | |
[[deps.LazyArtifacts]] | |
deps = ["Artifacts", "Pkg"] | |
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" | |
[[deps.LazyModules]] | |
git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" | |
uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" | |
version = "0.3.1" | |
[[deps.LibCURL]] | |
deps = ["LibCURL_jll", "MozillaCACerts_jll"] | |
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" | |
[[deps.LibCURL_jll]] | |
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] | |
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" | |
[[deps.LibGit2]] | |
deps = ["Base64", "NetworkOptions", "Printf", "SHA"] | |
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" | |
[[deps.LibSSH2_jll]] | |
deps = ["Artifacts", "Libdl", "MbedTLS_jll"] | |
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" | |
[[deps.Libdl]] | |
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" | |
[[deps.Libtiff_jll]] | |
deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "Pkg", "Zlib_jll", "Zstd_jll"] | |
git-tree-sha1 = "3eb79b0ca5764d4799c06699573fd8f533259713" | |
uuid = "89763e89-9b03-5906-acba-b20f662cd828" | |
version = "4.4.0+0" | |
[[deps.LinearAlgebra]] | |
deps = ["Libdl", "libblastrampoline_jll"] | |
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | |
[[deps.LogExpFunctions]] | |
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] | |
git-tree-sha1 = "09e4b894ce6a976c354a69041a04748180d43637" | |
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" | |
version = "0.3.15" | |
[[deps.Logging]] | |
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" | |
[[deps.MKL_jll]] | |
deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] | |
git-tree-sha1 = "e595b205efd49508358f7dc670a940c790204629" | |
uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" | |
version = "2022.0.0+0" | |
[[deps.MLStyle]] | |
git-tree-sha1 = "2041c1fd6833b3720d363c3ea8140bffaf86d9c4" | |
uuid = "d8e11817-5142-5d16-987a-aa16d5891078" | |
version = "0.4.12" | |
[[deps.MLUtils]] | |
deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"] | |
git-tree-sha1 = "95ab49a8c9afb6a8a0fc81df25617a6798c0fb73" | |
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" | |
version = "0.2.5" | |
[[deps.MacroTools]] | |
deps = ["Markdown", "Random"] | |
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" | |
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" | |
version = "0.5.9" | |
[[deps.MappedArrays]] | |
git-tree-sha1 = "e8b359ef06ec72e8c030463fe02efe5527ee5142" | |
uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" | |
version = "0.4.1" | |
[[deps.Markdown]] | |
deps = ["Base64"] | |
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" | |
[[deps.MbedTLS_jll]] | |
deps = ["Artifacts", "Libdl"] | |
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" | |
[[deps.MetaGraphs]] | |
deps = ["Graphs", "JLD2", "Random"] | |
git-tree-sha1 = "2af69ff3c024d13bde52b34a2a7d6887d4e7b438" | |
uuid = "626554b9-1ddb-594c-aa3c-2596fe9399a5" | |
version = "0.7.1" | |
[[deps.Metalhead]] | |
deps = ["Artifacts", "BSON", "Flux", "Functors", "LazyArtifacts", "MLUtils", "NNlib", "NeuralAttentionlib", "Statistics"] | |
git-tree-sha1 = "5587e8dd00fc53bf11b3947f00d506df1eb76712" | |
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" | |
version = "0.7.1" | |
[[deps.MicroCollections]] | |
deps = ["BangBang", "InitialValues", "Setfield"] | |
git-tree-sha1 = "6bb7786e4f24d44b4e29df03c69add1b63d88f01" | |
uuid = "128add7d-3638-4c79-886c-908ea0c25c34" | |
version = "0.1.2" | |
[[deps.Missings]] | |
deps = ["DataAPI"] | |
git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" | |
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" | |
version = "1.0.2" | |
[[deps.Mmap]] | |
uuid = "a63ad114-7e13-5084-954f-fe012c677804" | |
[[deps.MosaicViews]] | |
deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] | |
git-tree-sha1 = "b34e3bc3ca7c94914418637cb10cc4d1d80d877d" | |
uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" | |
version = "0.3.3" | |
[[deps.MozillaCACerts_jll]] | |
uuid = "14a3606d-f60d-562e-9121-12d972cd8159" | |
[[deps.NNlib]] | |
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] | |
git-tree-sha1 = "f89de462a7bc3243f95834e75751d70b3a33e59d" | |
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | |
version = "0.8.5" | |
[[deps.NNlibCUDA]] | |
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] | |
git-tree-sha1 = "e161b835c6aa9e2339c1e72c3d4e39891eac7a4f" | |
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" | |
version = "0.2.3" | |
[[deps.NaNMath]] | |
git-tree-sha1 = "737a5957f387b17e74d4ad2f440eb330b39a62c5" | |
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" | |
version = "1.0.0" | |
[[deps.NameResolution]] | |
deps = ["PrettyPrint"] | |
git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" | |
uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" | |
version = "0.1.5" | |
[[deps.NearestNeighbors]] | |
deps = ["Distances", "StaticArrays"] | |
git-tree-sha1 = "ded92de95031d4a8c61dfb6ba9adb6f1d8016ddd" | |
uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" | |
version = "0.4.10" | |
[[deps.Netpbm]] | |
deps = ["FileIO", "ImageCore"] | |
git-tree-sha1 = "18efc06f6ec36a8b801b23f076e3c6ac7c3bf153" | |
uuid = "f09324ee-3d7c-5217-9330-fc30815ba969" | |
version = "1.0.2" | |
[[deps.NetworkOptions]] | |
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" | |
[[deps.NeuralAttentionlib]] | |
deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "NNlib", "PartialFunctions", "Requires", "Static"] | |
git-tree-sha1 = "40be841d5510c3cee5e07fd2e0c5391f6a8d94c1" | |
uuid = "12afc1b8-fad6-47e1-9132-84abc478905f" | |
version = "0.0.4" | |
[[deps.OffsetArrays]] | |
deps = ["Adapt"] | |
git-tree-sha1 = "b4975062de00106132d0b01b5962c09f7db7d880" | |
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" | |
version = "1.12.5" | |
[[deps.OpenBLAS_jll]] | |
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] | |
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" | |
[[deps.OpenEXR]] | |
deps = ["Colors", "FileIO", "OpenEXR_jll"] | |
git-tree-sha1 = "327f53360fdb54df7ecd01e96ef1983536d1e633" | |
uuid = "52e1d378-f018-4a11-a4be-720524705ac7" | |
version = "0.3.2" | |
[[deps.OpenEXR_jll]] | |
deps = ["Artifacts", "Imath_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] | |
git-tree-sha1 = "923319661e9a22712f24596ce81c54fc0366f304" | |
uuid = "18a262bb-aa17-5467-a713-aee519bc75cb" | |
version = "3.1.1+0" | |
[[deps.OpenLibm_jll]] | |
deps = ["Artifacts", "Libdl"] | |
uuid = "05823500-19ac-5b8b-9628-191a04bc5112" | |
[[deps.OpenSpecFun_jll]] | |
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] | |
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" | |
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" | |
version = "0.5.5+0" | |
[[deps.Optimisers]] | |
deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] | |
git-tree-sha1 = "013596dcee5e55eb36ff56b8d4df888df01e040d" | |
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" | |
version = "0.2.6" | |
[[deps.OrderedCollections]] | |
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" | |
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" | |
version = "1.4.1" | |
[[deps.PNGFiles]] | |
deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] | |
git-tree-sha1 = "e925a64b8585aa9f4e3047b8d2cdc3f0e79fd4e4" | |
uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" | |
version = "0.3.16" | |
[[deps.PaddedViews]] | |
deps = ["OffsetArrays"] | |
git-tree-sha1 = "03a7a85b76381a3d04c7a1656039197e70eda03d" | |
uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" | |
version = "0.5.11" | |
[[deps.Parameters]] | |
deps = ["OrderedCollections", "UnPack"] | |
git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" | |
uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" | |
version = "0.12.3" | |
[[deps.Parsers]] | |
deps = ["Dates"] | |
git-tree-sha1 = "1285416549ccfcdf0c50d4997a94331e88d68413" | |
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" | |
version = "2.3.1" | |
[[deps.PartialFunctions]] | |
git-tree-sha1 = "b3901ea034cfd8aae57a2fa0dde0b0ea18bad1cb" | |
uuid = "570af359-4316-4cb7-8c74-252c00c2016b" | |
version = "1.1.1" | |
[[deps.Pkg]] | |
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] | |
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | |
[[deps.PkgVersion]] | |
deps = ["Pkg"] | |
git-tree-sha1 = "a7a7e1a88853564e551e4eba8650f8c38df79b37" | |
uuid = "eebad327-c553-4316-9ea0-9fa01ccd7688" | |
version = "0.1.1" | |
[[deps.Preferences]] | |
deps = ["TOML"] | |
git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" | |
uuid = "21216c6a-2e73-6563-6e65-726566657250" | |
version = "1.3.0" | |
[[deps.PrettyPrint]] | |
git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" | |
uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" | |
version = "0.2.0" | |
[[deps.Printf]] | |
deps = ["Unicode"] | |
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" | |
[[deps.ProgressLogging]] | |
deps = ["Logging", "SHA", "UUIDs"] | |
git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" | |
uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" | |
version = "0.1.4" | |
[[deps.ProgressMeter]] | |
deps = ["Distributed", "Printf"] | |
git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9" | |
uuid = "92933f4c-e287-5a05-a399-4b506db050ca" | |
version = "1.7.2" | |
[[deps.PyCall]] | |
deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"] | |
git-tree-sha1 = "1fc929f47d7c151c839c5fc1375929766fb8edcc" | |
uuid = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" | |
version = "1.93.1" | |
[[deps.QOI]] | |
deps = ["ColorTypes", "FileIO", "FixedPointNumbers"] | |
git-tree-sha1 = "18e8f4d1426e965c7b532ddd260599e1510d26ce" | |
uuid = "4b34888f-f399-49d4-9bb3-47ed5cae4e65" | |
version = "1.0.0" | |
[[deps.Quaternions]] | |
deps = ["DualNumbers", "LinearAlgebra", "Random"] | |
git-tree-sha1 = "b327e4db3f2202a4efafe7569fcbe409106a1f75" | |
uuid = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" | |
version = "0.5.6" | |
[[deps.REPL]] | |
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] | |
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" | |
[[deps.Random]] | |
deps = ["SHA", "Serialization"] | |
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | |
[[deps.Random123]] | |
deps = ["Random", "RandomNumbers"] | |
git-tree-sha1 = "afeacaecf4ed1649555a19cb2cad3c141bbc9474" | |
uuid = "74087812-796a-5b5d-8853-05524746bad3" | |
version = "1.5.0" | |
[[deps.RandomNumbers]] | |
deps = ["Random", "Requires"] | |
git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" | |
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" | |
version = "1.5.3" | |
[[deps.RangeArrays]] | |
git-tree-sha1 = "b9039e93773ddcfc828f12aadf7115b4b4d225f5" | |
uuid = "b3c3ace0-ae52-54e7-9d0b-2c1406fd6b9d" | |
version = "0.3.2" | |
[[deps.Ratios]] | |
deps = ["Requires"] | |
git-tree-sha1 = "dc84268fe0e3335a62e315a3a7cf2afa7178a734" | |
uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" | |
version = "0.4.3" | |
[[deps.RealDot]] | |
deps = ["LinearAlgebra"] | |
git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" | |
uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" | |
version = "0.1.0" | |
[[deps.Reexport]] | |
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" | |
uuid = "189a3867-3050-52da-a836-e630ba90ab69" | |
version = "1.2.2" | |
[[deps.RegionTrees]] | |
deps = ["IterTools", "LinearAlgebra", "StaticArrays"] | |
git-tree-sha1 = "4618ed0da7a251c7f92e869ae1a19c74a7d2a7f9" | |
uuid = "dee08c22-ab7f-5625-9660-a9af2021b33f" | |
version = "0.3.2" | |
[[deps.Requires]] | |
deps = ["UUIDs"] | |
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" | |
uuid = "ae029012-a4dd-5104-9daa-d747884805df" | |
version = "1.3.0" | |
[[deps.Rotations]] | |
deps = ["LinearAlgebra", "Quaternions", "Random", "StaticArrays", "Statistics"] | |
git-tree-sha1 = "3177100077c68060d63dd71aec209373c3ec339b" | |
uuid = "6038ab10-8711-5258-84ad-4b1120ba62dc" | |
version = "1.3.1" | |
[[deps.SHA]] | |
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" | |
[[deps.Serialization]] | |
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" | |
[[deps.Setfield]] | |
deps = ["ConstructionBase", "Future", "MacroTools", "Requires"] | |
git-tree-sha1 = "38d88503f695eb0301479bc9b0d4320b378bafe5" | |
uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" | |
version = "0.8.2" | |
[[deps.SharedArrays]] | |
deps = ["Distributed", "Mmap", "Random", "Serialization"] | |
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" | |
[[deps.ShowCases]] | |
git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" | |
uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" | |
version = "0.1.0" | |
[[deps.SimpleTraits]] | |
deps = ["InteractiveUtils", "MacroTools"] | |
git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" | |
uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" | |
version = "0.9.4" | |
[[deps.SimpleWeightedGraphs]] | |
deps = ["Graphs", "LinearAlgebra", "Markdown", "SparseArrays", "Test"] | |
git-tree-sha1 = "a6f404cc44d3d3b28c793ec0eb59af709d827e4e" | |
uuid = "47aef6b3-ad0c-573a-a1e2-d07658019622" | |
version = "1.2.1" | |
[[deps.Sixel]] | |
deps = ["Dates", "FileIO", "ImageCore", "IndirectArrays", "OffsetArrays", "REPL", "libsixel_jll"] | |
git-tree-sha1 = "8fb59825be681d451c246a795117f317ecbcaa28" | |
uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" | |
version = "0.1.2" | |
[[deps.Sockets]] | |
uuid = "6462fe0b-24de-5631-8697-dd941f90decc" | |
[[deps.SortingAlgorithms]] | |
deps = ["DataStructures"] | |
git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" | |
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" | |
version = "1.0.1" | |
[[deps.SparseArrays]] | |
deps = ["LinearAlgebra", "Random"] | |
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | |
[[deps.SpecialFunctions]] | |
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] | |
git-tree-sha1 = "a9e798cae4867e3a41cae2dd9eb60c047f1212db" | |
uuid = "276daf66-3868-5448-9aa4-cd146d93841b" | |
version = "2.1.6" | |
[[deps.SplittablesBase]] | |
deps = ["Setfield", "Test"] | |
git-tree-sha1 = "39c9f91521de844bad65049efd4f9223e7ed43f9" | |
uuid = "171d559e-b47b-412a-8079-5efa626c420e" | |
version = "0.1.14" | |
[[deps.StackViews]] | |
deps = ["OffsetArrays"] | |
git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" | |
uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" | |
version = "0.1.1" | |
[[deps.Static]] | |
deps = ["IfElse"] | |
git-tree-sha1 = "7f5a513baec6f122401abfc8e9c074fdac54f6c1" | |
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" | |
version = "0.4.1" | |
[[deps.StaticArrays]] | |
deps = ["LinearAlgebra", "Random", "Statistics"] | |
git-tree-sha1 = "383a578bdf6e6721f480e749d503ebc8405a0b22" | |
uuid = "90137ffa-7385-5640-81b9-e52037218182" | |
version = "1.4.6" | |
[[deps.Statistics]] | |
deps = ["LinearAlgebra", "SparseArrays"] | |
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | |
[[deps.StatsAPI]] | |
deps = ["LinearAlgebra"] | |
git-tree-sha1 = "2c11d7290036fe7aac9038ff312d3b3a2a5bf89e" | |
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" | |
version = "1.4.0" | |
[[deps.StatsBase]] | |
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] | |
git-tree-sha1 = "8977b17906b0a1cc74ab2e3a05faa16cf08a8291" | |
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | |
version = "0.33.16" | |
[[deps.TOML]] | |
deps = ["Dates"] | |
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" | |
[[deps.TableTraits]] | |
deps = ["IteratorInterfaceExtensions"] | |
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" | |
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" | |
version = "1.0.1" | |
[[deps.Tables]] | |
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] | |
git-tree-sha1 = "5ce79ce186cc678bbb5c5681ca3379d1ddae11a1" | |
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" | |
version = "1.7.0" | |
[[deps.Tar]] | |
deps = ["ArgTools", "SHA"] | |
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" | |
[[deps.TensorCore]] | |
deps = ["LinearAlgebra"] | |
git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" | |
uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" | |
version = "0.1.1" | |
[[deps.Test]] | |
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] | |
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | |
[[deps.TiffImages]] | |
deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] | |
git-tree-sha1 = "f90022b44b7bf97952756a6b6737d1a0024a3233" | |
uuid = "731e570b-9d59-4bfa-96dc-6df516fadf69" | |
version = "0.5.5" | |
[[deps.TiledIteration]] | |
deps = ["OffsetArrays"] | |
git-tree-sha1 = "5683455224ba92ef59db72d10690690f4a8dc297" | |
uuid = "06e1c1a7-607b-532d-9fad-de7d9aa2abac" | |
version = "0.3.1" | |
[[deps.TimerOutputs]] | |
deps = ["ExprTools", "Printf"] | |
git-tree-sha1 = "464d64b2510a25e6efe410e7edab14fffdc333df" | |
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" | |
version = "0.5.20" | |
[[deps.TranscodingStreams]] | |
deps = ["Random", "Test"] | |
git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c" | |
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" | |
version = "0.9.6" | |
[[deps.Transducers]] | |
deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] | |
git-tree-sha1 = "c76399a3bbe6f5a88faa33c8f8a65aa631d95013" | |
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" | |
version = "0.4.73" | |
[[deps.UUIDs]] | |
deps = ["Random", "SHA"] | |
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" | |
[[deps.UnPack]] | |
git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" | |
uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" | |
version = "1.0.2" | |
[[deps.Unicode]] | |
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" | |
[[deps.VersionParsing]] | |
git-tree-sha1 = "58d6e80b4ee071f5efd07fda82cb9fbe17200868" | |
uuid = "81def892-9a0e-5fdd-b105-ffc91e053289" | |
version = "1.3.0" | |
[[deps.WoodburyMatrices]] | |
deps = ["LinearAlgebra", "SparseArrays"] | |
git-tree-sha1 = "de67fa59e33ad156a590055375a30b23c40299d3" | |
uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" | |
version = "0.5.5" | |
[[deps.Zlib_jll]] | |
deps = ["Libdl"] | |
uuid = "83775a58-1f1d-513f-b197-d71354ab007a" | |
[[deps.Zstd_jll]] | |
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] | |
git-tree-sha1 = "e45044cd873ded54b6a5bac0eb5c971392cf1927" | |
uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" | |
version = "1.5.2+0" | |
[[deps.Zygote]] | |
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] | |
git-tree-sha1 = "a49267a2e5f113c7afe93843deea7461c0f6b206" | |
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" | |
version = "0.6.40" | |
[[deps.ZygoteRules]] | |
deps = ["MacroTools"] | |
git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" | |
uuid = "700de1a5-db45-46bc-99cf-38207098b444" | |
version = "0.2.2" | |
[[deps.libblastrampoline_jll]] | |
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] | |
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" | |
[[deps.libpng_jll]] | |
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] | |
git-tree-sha1 = "94d180a6d2b5e55e447e2d27a29ed04fe79eb30c" | |
uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" | |
version = "1.6.38+0" | |
[[deps.libsixel_jll]] | |
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] | |
git-tree-sha1 = "78736dab31ae7a53540a6b752efc61f77b304c5b" | |
uuid = "075b6546-f08a-558a-be8f-8157d0f608a5" | |
version = "1.8.6+1" | |
[[deps.nghttp2_jll]] | |
deps = ["Artifacts", "Libdl"] | |
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" | |
[[deps.p7zip_jll]] | |
deps = ["Artifacts", "Libdl"] | |
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[deps] | |
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" | |
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" | |
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | |
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" | |
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | |
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" | |
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" | |
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Converts the weigths of a PyTorch model to a Flux model from Metalhead | |
# PyTorch need to be installed | |
# Tested on ResNet and VGG models | |
using Flux | |
import Metalhead | |
using DataStructures | |
using Statistics | |
using BSON | |
using PyCall | |
using PyCall: Conda | |
using Images | |
using Test | |
## | |
Conda.add_channel("pytorch") | |
Conda.add("pytorch") | |
Conda.add("torchvision") | |
torch = pyimport("torch") | |
torchvision = pyimport("torchvision") | |
## | |
modellib = [ | |
("vgg11", () -> Metalhead.VGG(11), torchvision.models.vgg11), | |
("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13), | |
("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16), | |
("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19), | |
("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18), | |
("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34), | |
("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50), | |
("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101), | |
("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152), | |
] | |
## | |
function _list_state(node::Flux.BatchNorm,channel,prefix) | |
# use the same order of parameters than PyTorch | |
put!(channel, (prefix * ".γ", node.γ)) # weigth (learnable) | |
put!(channel, (prefix * ".β", node.β)) # bias (learnable) | |
put!(channel, (prefix * ".μ", node.μ)) # running mean | |
put!(channel, (prefix * ".σ²", node.σ²)) # running variance | |
end | |
function _list_state(node::Union{Flux.Conv,Flux.Dense},channel,prefix) | |
put!(channel, (prefix * ".weight", node.weight)) | |
if node.bias !== Flux.Zeros() | |
put!(channel, (prefix * ".bias", node.bias)) | |
end | |
end | |
_list_state(node,channel,prefix) = nothing | |
function _list_state(node::Union{Flux.Chain,Flux.Parallel},channel,prefix) | |
for (i,n) in enumerate(node.layers) | |
_list_state(n,channel,prefix * ".layers[$i]") | |
end | |
end | |
function list_state(node; prefix = "model") | |
Channel() do channel | |
_list_state(node,channel,prefix) | |
end | |
end | |
for (modelname,jlmodel,pymodel) in modellib | |
model = jlmodel() | |
pytorchmodel = pymodel(pretrained=true) | |
state = OrderedDict(list_state(model.layers)) | |
# pytorchmodel.state_dict() looses the order | |
state_dict = OrderedDict(pycall(pytorchmodel.state_dict,PyObject).items()) | |
pytorch_pp = OrderedDict((k,v.numpy()) for (k,v) in state_dict if !occursin("num_batches_tracked",k)) | |
# loop over all parameters | |
for ((flux_key,flux_param),(pytorch_key,pytorch_param)) in zip(state,pytorch_pp) | |
if size(flux_param) == size(pytorch_param) | |
# Dense weight and vectors | |
flux_param .= pytorch_param | |
elseif size(flux_param) == reverse(size(pytorch_param)) | |
tmp = pytorch_param | |
tmp = permutedims(tmp,ndims(tmp):-1:1) | |
if ndims(flux_param) == 4 | |
# convolutional weights | |
flux_param .= reverse(tmp,dims=(1,2)) | |
else | |
flux_param .= tmp | |
end | |
else | |
@debug begin | |
@show size(flux_param), size(pytorch_param) | |
end | |
error("incompatible shape $flux_key $pytorch_key") | |
end | |
end | |
@info "saving model $modelname" | |
BSON.@save joinpath(@__DIR__,"$(modelname).bson") model | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment