Last active
August 13, 2020 10:24
-
-
Save tejasvaidhyadev/6c10bdda1f60c3e42472d356ecf3721a to your computer and use it in GitHub Desktop.
tfcp2bson helps in converting checkpoints
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using JSON | |
using ZipFile | |
using Flux: loadparams! | |
""" | |
tfckpt2bsonforalbert(path; | |
raw=false, | |
saveto="./", | |
confname = "albert_config.json", | |
ckptname = "model.ckpt-best", | |
vocabname = "30k-clean.vocab") | |
turn google released albert format into BSON file. Set `raw` to `true` to remain the origin data format in bson. | |
""" | |
function tfckpt2bsonforalbert(path; saveto="./", confname = "albert_config.json", ckptname = "model.ckpt-best", vocabname = "30k-clean.vocab") | |
if iszip(path) | |
data = ZipFile.Reader(path) | |
else | |
data = path | |
end | |
config, weights, vocab = readckptfolder(data; confname=confname, ckptname=ckptname, vocabname=vocabname) | |
iszip(path) && close(data) | |
#saveto tfbson (raw julia data) | |
bsonname = normpath(joinpath(saveto, config["filename"] * ".tfbson")) | |
BSON.@save bsonname config weights vocab | |
bsonname | |
end | |
"loading tensorflow checkpoint file into julia Dict" | |
readckpt(path) = error("readckpt require TensorFlow.jl installed. run `Pkg.add(\"TensorFlow\"); using TensorFlow`") | |
@init @require TensorFlow="1d978283-2c37-5f34-9a8e-e9c0ece82495" begin | |
import .TensorFlow | |
#should be changed to use c api once the patch is included | |
function readckpt(path) | |
weights = Dict{String, Array}() | |
TensorFlow.init() | |
ckpt = TensorFlow.pywrap_tensorflow.x.NewCheckpointReader(path) | |
shapes = ckpt.get_variable_to_shape_map() | |
#shapes = ckpt.get_variable_to_dtype_map() | |
weights["bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/bias"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/bias"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/gamma"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/gamma"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/beta"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/beta"))') | |
weights["cls/predictions/transform/dense/kernel"]=collect((ckpt.get_tensor("cls/predictions/transform/dense/kernel"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/bias"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/kernel"))') | |
weights["bert/embeddings/word_embeddings"]=collect((ckpt.get_tensor("bert/embeddings/word_embeddings"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/bias"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/kernel"))') | |
weights["bert/pooler/dense/kernel"]=collect((ckpt.get_tensor("bert/pooler/dense/kernel"))') | |
weights["cls/predictions/output_bias"]=collect((ckpt.get_tensor("cls/predictions/output_bias"))') | |
weights["cls/predictions/transform/LayerNorm/beta"]=collect((ckpt.get_tensor("cls/predictions/transform/LayerNorm/beta"))') | |
weights["cls/seq_relationship/output_bias"]=collect((ckpt.get_tensor("cls/seq_relationship/output_bias"))') | |
weights["bert/embeddings/LayerNorm/gamma"]=collect((ckpt.get_tensor("bert/embeddings/LayerNorm/gamma"))') | |
weights["global_step"]=collect((ckpt.get_tensor("global_step"))') | |
weights["bert/embeddings/LayerNorm/beta"]=collect((ckpt.get_tensor("bert/embeddings/LayerNorm/beta"))') | |
weights["cls/predictions/transform/LayerNorm/gamma"]=collect((ckpt.get_tensor("cls/predictions/transform/LayerNorm/gamma"))') | |
weights["bert/encoder/embedding_hidden_mapping_in/bias"]=collect((ckpt.get_tensor("bert/encoder/embedding_hidden_mapping_in/bias"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/kernel"))') | |
weights["cls/seq_relationship/output_weights"]=collect((ckpt.get_tensor("cls/seq_relationship/output_weights"))') | |
weights["cls/predictions/transform/dense/bias"]=collect((ckpt.get_tensor("cls/predictions/transform/dense/bias"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/bias"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/LayerNorm/beta"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/LayerNorm/beta"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/kernel"))') | |
weights["bert/encoder/embedding_hidden_mapping_in/kernel"]=collect((ckpt.get_tensor("bert/encoder/embedding_hidden_mapping_in/kernel"))') | |
weights["bert/embeddings/token_type_embeddings"]=collect((ckpt.get_tensor("bert/embeddings/token_type_embeddings"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/LayerNorm/gamma"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/LayerNorm/gamma"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/kernel"))') | |
weights["bert/embeddings/position_embeddings"]=collect((ckpt.get_tensor("bert/embeddings/position_embeddings"))') | |
weights["bert/pooler/dense/bias"]=collect((ckpt.get_tensor("bert/pooler/dense/bias"))') | |
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias"))') | |
return(weights) | |
#print((weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias"])) | |
end | |
end | |
function readckptfolder(z::ZipFile.Reader; confname = "albert_config.json", ckptname = "model.ckpt-best", vocabname = "30k-clean.vocab") | |
(confile = findfile(z, confname)) === nothing && error("config file $confname not found") | |
findfile(z, ckptname*".meta") === nothing && error("ckpt file $ckptname not found") | |
(vocabfile = findfile(z, vocabname)) === nothing && error("vocab file $vocabname not found") | |
dir = zipname(z) | |
filename = basename(isdirpath(dir) ? dir[1:end-1] : dir) | |
config = JSON.parse(confile) | |
config["filename"] = filename | |
vocab = readlines(vocabfile) | |
weights = mktempdir( | |
dir -> begin | |
#dump ckpt to tmp | |
for fidx ∈ findall(zf->startswith(zf.name, joinpath(zipname(z), ckptname)), z.files) | |
zf = z.files[fidx] | |
zfn = basename(zf.name) | |
f = open(joinpath(dir, zfn), "w+") | |
buffer = Vector{UInt8}(undef, zf.uncompressedsize) | |
write(f, read!(zf, buffer)) | |
close(f) | |
end | |
readckpt(joinpath(dir, ckptname)) | |
end | |
) | |
config, weights, vocab | |
end | |
function readckptfolder(dir; confname = "albert_config.json", ckptname = "model.ckpt-best", vocabname = "30k-clean.vocab") | |
files = readdir(dir) | |
confname ∉ files && error("config file $confname not found") | |
ckptname*".meta" ∉ files && error("ckpt file $ckptname not found") | |
vocabname ∉ files && error("vocab file $vocabname not found") | |
filename = basename(isdirpath(dir) ? dir[1:end-1] : dir) | |
config = JSON.parsefile(joinpath(dir, confname)) | |
config["filename"] = filename | |
vocab = readlines(open(joinpath(dir, vocabname))) | |
weights = readckpt(joinpath(dir, ckptname)) | |
config, weights, vocab | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment