Skip to content

Instantly share code, notes, and snippets.

@NikoOinonen
Created January 15, 2024 15:59
Show Gist options
  • Save NikoOinonen/a3107e4573a455eada0f2be9a7ee6059 to your computer and use it in GitHub Desktop.
Save NikoOinonen/a3107e4573a455eada0f2be9a7ee6059 to your computer and use it in GitHub Desktop.
Convert keras weights to pytorch weights
import torch
from keras_model import create_keras_model
from pt_model import PTModel
translation_table = {
"conv3d_reflective_1": "encoder.0",
"conv3d_reflective_2": "encoder.3",
"conv3d_reflective_3": "encoder.6",
"conv2d_reflective_1": "middle.0",
"conv2d_reflective_2": "middle.2",
"conv2d_reflective_3": "decoders.0.1",
"conv2d_reflective_10": "decoders.1.1",
"conv2d_reflective_17": "decoders.2.1",
"conv2d_reflective_4": "decoders.0.3",
"conv2d_reflective_11": "decoders.1.3",
"conv2d_reflective_18": "decoders.2.3",
"conv2d_reflective_5": "decoders.0.6",
"conv2d_reflective_12": "decoders.1.6",
"conv2d_reflective_19": "decoders.2.6",
"conv2d_reflective_6": "decoders.0.8",
"conv2d_reflective_13": "decoders.1.8",
"conv2d_reflective_20": "decoders.2.8",
"conv2d_reflective_7": "decoders.0.11",
"conv2d_reflective_14": "decoders.1.11",
"conv2d_reflective_21": "decoders.2.11",
"conv2d_reflective_8": "decoders.0.13",
"conv2d_reflective_15": "decoders.1.13",
"conv2d_reflective_22": "decoders.2.13",
"conv2d_reflective_9": "decoders.0.15",
"conv2d_reflective_16": "decoders.1.15",
"conv2d_reflective_23": "decoders.2.15",
}
model_keras = create_keras_model()
model_keras.load_weights("model_keras.h5")
model_pt = PTModel()
state = model_pt.state_dict()
for layer in model_keras.layers:
if layer.name in translation_table:
pt_layer_name = translation_table[layer.name]
for w in layer.weights:
if "kernel" in w.name:
pt_name = pt_layer_name + ".weight"
else:
pt_name = pt_layer_name + ".bias"
w = w.numpy()
if w.ndim == 4:
state[pt_name] = torch.from_numpy(w.transpose(3, 2, 0, 1))
elif w.ndim == 5:
state[pt_name] = torch.from_numpy(w.transpose(4, 3, 0, 1, 2))
else:
state[pt_name] = torch.from_numpy(w)
torch.save(state, "model_pt.pth")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment