Skip to content

Instantly share code, notes, and snippets.

@tejasvaidhyadev
Last active August 13, 2020 10:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tejasvaidhyadev/6c10bdda1f60c3e42472d356ecf3721a to your computer and use it in GitHub Desktop.
Save tejasvaidhyadev/6c10bdda1f60c3e42472d356ecf3721a to your computer and use it in GitHub Desktop.
tfcp2bson helps in converting checkpoints
using JSON
using ZipFile
using Flux: loadparams!
"""
tfckpt2bsonforalbert(path;
raw=false,
saveto="./",
confname = "albert_config.json",
ckptname = "model.ckpt-best",
vocabname = "30k-clean.vocab")
turn google released albert format into BSON file. Set `raw` to `true` to remain the origin data format in bson.
"""
function tfckpt2bsonforalbert(path; saveto="./", confname = "albert_config.json", ckptname = "model.ckpt-best", vocabname = "30k-clean.vocab")
if iszip(path)
data = ZipFile.Reader(path)
else
data = path
end
config, weights, vocab = readckptfolder(data; confname=confname, ckptname=ckptname, vocabname=vocabname)
iszip(path) && close(data)
#saveto tfbson (raw julia data)
bsonname = normpath(joinpath(saveto, config["filename"] * ".tfbson"))
BSON.@save bsonname config weights vocab
bsonname
end
"loading tensorflow checkpoint file into julia Dict"
readckpt(path) = error("readckpt require TensorFlow.jl installed. run `Pkg.add(\"TensorFlow\"); using TensorFlow`")
@init @require TensorFlow="1d978283-2c37-5f34-9a8e-e9c0ece82495" begin
import .TensorFlow
#should be changed to use c api once the patch is included
function readckpt(path)
weights = Dict{String, Array}()
TensorFlow.init()
ckpt = TensorFlow.pywrap_tensorflow.x.NewCheckpointReader(path)
shapes = ckpt.get_variable_to_shape_map()
#shapes = ckpt.get_variable_to_dtype_map()
weights["bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/bias"))')
weights["bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/bias"))')
weights["bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/gamma"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/gamma"))')
weights["bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/beta"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/beta"))')
weights["cls/predictions/transform/dense/kernel"]=collect((ckpt.get_tensor("cls/predictions/transform/dense/kernel"))')
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/bias"))')
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/kernel"))')
weights["bert/embeddings/word_embeddings"]=collect((ckpt.get_tensor("bert/embeddings/word_embeddings"))')
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/bias"))')
weights["bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/kernel"))')
weights["bert/pooler/dense/kernel"]=collect((ckpt.get_tensor("bert/pooler/dense/kernel"))')
weights["cls/predictions/output_bias"]=collect((ckpt.get_tensor("cls/predictions/output_bias"))')
weights["cls/predictions/transform/LayerNorm/beta"]=collect((ckpt.get_tensor("cls/predictions/transform/LayerNorm/beta"))')
weights["cls/seq_relationship/output_bias"]=collect((ckpt.get_tensor("cls/seq_relationship/output_bias"))')
weights["bert/embeddings/LayerNorm/gamma"]=collect((ckpt.get_tensor("bert/embeddings/LayerNorm/gamma"))')
weights["global_step"]=collect((ckpt.get_tensor("global_step"))')
weights["bert/embeddings/LayerNorm/beta"]=collect((ckpt.get_tensor("bert/embeddings/LayerNorm/beta"))')
weights["cls/predictions/transform/LayerNorm/gamma"]=collect((ckpt.get_tensor("cls/predictions/transform/LayerNorm/gamma"))')
weights["bert/encoder/embedding_hidden_mapping_in/bias"]=collect((ckpt.get_tensor("bert/encoder/embedding_hidden_mapping_in/bias"))')
weights["bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/kernel"))')
weights["cls/seq_relationship/output_weights"]=collect((ckpt.get_tensor("cls/seq_relationship/output_weights"))')
weights["cls/predictions/transform/dense/bias"]=collect((ckpt.get_tensor("cls/predictions/transform/dense/bias"))')
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/bias"))')
weights["bert/encoder/transformer/group_0/inner_group_0/LayerNorm/beta"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/LayerNorm/beta"))')
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/kernel"))')
weights["bert/encoder/embedding_hidden_mapping_in/kernel"]=collect((ckpt.get_tensor("bert/encoder/embedding_hidden_mapping_in/kernel"))')
weights["bert/embeddings/token_type_embeddings"]=collect((ckpt.get_tensor("bert/embeddings/token_type_embeddings"))')
weights["bert/encoder/transformer/group_0/inner_group_0/LayerNorm/gamma"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/LayerNorm/gamma"))')
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel"))')
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/kernel"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/kernel"))')
weights["bert/embeddings/position_embeddings"]=collect((ckpt.get_tensor("bert/embeddings/position_embeddings"))')
weights["bert/pooler/dense/bias"]=collect((ckpt.get_tensor("bert/pooler/dense/bias"))')
weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias"]=collect((ckpt.get_tensor("bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias"))')
return(weights)
#print((weights["bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias"]))
end
end
function readckptfolder(z::ZipFile.Reader; confname = "albert_config.json", ckptname = "model.ckpt-best", vocabname = "30k-clean.vocab")
(confile = findfile(z, confname)) === nothing && error("config file $confname not found")
findfile(z, ckptname*".meta") === nothing && error("ckpt file $ckptname not found")
(vocabfile = findfile(z, vocabname)) === nothing && error("vocab file $vocabname not found")
dir = zipname(z)
filename = basename(isdirpath(dir) ? dir[1:end-1] : dir)
config = JSON.parse(confile)
config["filename"] = filename
vocab = readlines(vocabfile)
weights = mktempdir(
dir -> begin
#dump ckpt to tmp
for fidx ∈ findall(zf->startswith(zf.name, joinpath(zipname(z), ckptname)), z.files)
zf = z.files[fidx]
zfn = basename(zf.name)
f = open(joinpath(dir, zfn), "w+")
buffer = Vector{UInt8}(undef, zf.uncompressedsize)
write(f, read!(zf, buffer))
close(f)
end
readckpt(joinpath(dir, ckptname))
end
)
config, weights, vocab
end
function readckptfolder(dir; confname = "albert_config.json", ckptname = "model.ckpt-best", vocabname = "30k-clean.vocab")
files = readdir(dir)
confname ∉ files && error("config file $confname not found")
ckptname*".meta" ∉ files && error("ckpt file $ckptname not found")
vocabname ∉ files && error("vocab file $vocabname not found")
filename = basename(isdirpath(dir) ? dir[1:end-1] : dir)
config = JSON.parsefile(joinpath(dir, confname))
config["filename"] = filename
vocab = readlines(open(joinpath(dir, vocabname)))
weights = readckpt(joinpath(dir, ckptname))
config, weights, vocab
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment