Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sumanmichael/9f7f8e0ab3ca3b03ffa3b08e2869892f to your computer and use it in GitHub Desktop.
Save sumanmichael/9f7f8e0ab3ca3b03ffa3b08e2869892f to your computer and use it in GitHub Desktop.
Get pytorch LSTM weights (w_ih, w_hh, b_ih, b_hh) from tensorflow LSTM weights (kernel, bias)
import torch
from torch import nn
import numpy as np
# Get pytorch LSTM weights (w_ih, w_hh, b_ih, b_hh) from tensorflow LSTM weights (kernel, bias)
def get_pytorch_lstm_weights_from_tensorflow(kernel, bias, INPUT_SIZE, HIDDEN_SIZE):
def reorder_lstm_gates(w):
# The split order of gates are different in pytorch and tensorflow
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = np.split(w, 4, 0)
return np.concatenate((i, f, j, o), axis=0).transpose((1, 0))
param_tensor_perm = lambda x: nn.Parameter(torch.tensor(reorder_lstm_gates(x)))
w_ih = param_tensor_perm(kernel[:INPUT_SIZE])
w_hh = param_tensor_perm(kernel[INPUT_SIZE:])
if bias:
b_ih = param_tensor_perm(bias[:INPUT_SIZE])
b_hh = param_tensor_perm(bias[INPUT_SIZE:])
else:
b_ih = nn.Parameter(torch.zeros(4 * HIDDEN_SIZE))
b_hh = nn.Parameter(torch.zeros(4 * HIDDEN_SIZE))
return w_ih, w_hh, b_ih, b_hh
# Example
rnn = nn.LSTMCell(INPUT_SIZE, HIDDEN_SIZE)
rnn.weight_ih, rnn.weight_hh, rnn.bias_ih, rnn.bias_hh = get_ptwb_from_tfkb(kw, None, INPUT_SIZE, HIDDEN_SIZE)
h, c = rnn(inp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment