Skip to content

Instantly share code, notes, and snippets.

@yzhliu
Created June 29, 2016 01:56
Show Gist options
  • Save yzhliu/4e499fca0345929268c5a7085953db7a to your computer and use it in GitHub Desktop.
Save yzhliu/4e499fca0345929268c5a7085953db7a to your computer and use it in GitHub Desktop.
infer_lstm.py
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx
from lstm import lstm_unroll
from bucket_io import default_build_vocab
from rnn_model import LSTMInferenceModel
def make_input(char, vocab, arr):
idx = vocab[char]
tmp = np.zeros((1,))
tmp[0] = idx
arr[:] = tmp
if __name__ == '__main__':
num_hidden = 200
num_embed = 200
num_lstm_layer = 2
vocab = default_build_vocab("./data/ptb.train.txt")
rvocab = {}
for k, v in vocab.items():
rvocab[v] = k
def sym_gen(seq_len):
return lstm_unroll(num_lstm_layer, seq_len, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab))
symbol = sym_gen
_, arg_params, __ = mx.model.load_checkpoint("model/ptb", 3)
model = LSTMInferenceModel(num_lstm_layer, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab), arg_params=arg_params)
tks = sys.argv[1:]
input_ndarray = mx.nd.zeros((1,))
for k in range(len(tks)):
make_input(tks[k], vocab, input_ndarray)
prob = model.forward(input_ndarray, False)
idx = np.argmax(prob, axis=1)[0]
print prob[0][idx], idx, rvocab[idx]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment