Skip to content

Instantly share code, notes, and snippets.

@chirag1992m
chirag1992m / weight_transfer.py
Created December 1, 2017 21:36
weight_transfer
import numpy as np
import torch
import keras
def pyt_to_keras(pytorch_model, keras_model):
"""
Given a PyTorch model, this method transfers the weight to
a Keras Model (with backend TensorFlow) with the same architecture.
Assumptions:
1. The corresponding layer names in both the models will be the same