Skip to content

Instantly share code, notes, and snippets.

@youben11
Created September 28, 2021 11:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save youben11/9fe7ae3a1829e0cae3f3171856d079f6 to your computer and use it in GitHub Desktop.
Save youben11/9fe7ae3a1829e0cae3f3171856d079f6 to your computer and use it in GitHub Desktop.
class Inferer:
def __init__(self, model):
parameters = list(model.lstm.parameters())
W_ii, W_if, W_ig, W_io = parameters[0].split(HIDDEN_SIZE)
W_hi, W_hf, W_hg, W_ho = parameters[1].split(HIDDEN_SIZE)
b_ii, b_if, b_ig, b_io = parameters[2].split(HIDDEN_SIZE)
b_hi, b_hf, b_hg, b_ho = parameters[3].split(HIDDEN_SIZE)
self.W_ii = W_ii.detach().numpy()
self.b_ii = b_ii.detach().numpy()
self.W_hi = W_hi.detach().numpy()
self.b_hi = b_hi.detach().numpy()
self.W_if = W_if.detach().numpy()
self.b_if = b_if.detach().numpy()
self.W_hf = W_hf.detach().numpy()
self.b_hf = b_hf.detach().numpy()
self.W_ig = W_ig.detach().numpy()
self.b_ig = b_ig.detach().numpy()
self.W_hg = W_hg.detach().numpy()
self.b_hg = b_hg.detach().numpy()
self.W_io = W_io.detach().numpy()
self.b_io = b_io.detach().numpy()
self.W_ho = W_ho.detach().numpy()
self.b_ho = b_ho.detach().numpy()
self.W = model.fc.weight.detach().numpy().T
self.b = model.fc.bias.detach().numpy()
def infer(self, x):
x_t, h_t, c_t = None, np.zeros(HIDDEN_SIZE), np.zeros(HIDDEN_SIZE)
for i in range(x.shape[0]):
x_t = x[i]
_, h_t, c_t = self.lstm_cell(x_t, h_t, c_t)
r = np.dot(h_t, self.W) + self.b
return self.sigmoid(r)
def lstm_cell(self, x_t, h_tm1, c_tm1):
i_t = self.sigmoid(
np.dot(self.W_ii, x_t) + self.b_ii + np.dot(self.W_hi, h_tm1) + self.b_hi
)
f_t = self.sigmoid(
np.dot(self.W_if, x_t) + self.b_if + np.dot(self.W_hf, h_tm1) + self.b_hf
)
g_t = np.tanh(
np.dot(self.W_ig, x_t) + self.b_ig + np.dot(self.W_hg, h_tm1) + self.b_hg
)
o_t = self.sigmoid(
np.dot(self.W_io, x_t) + self.b_io + np.dot(self.W_ho, h_tm1) + self.b_ho
)
c_t = f_t * c_tm1 + i_t * g_t
h_t = o_t * np.tanh(c_t)
return o_t, h_t, c_t
@staticmethod
def sigmoid(x):
return 1 / (1 + np.exp(-x))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment