Skip to content

Instantly share code, notes, and snippets.

@nialv7
Created April 16, 2017 16:14
Show Gist options
  • Save nialv7/312936298e641eeb237152579eb56e9b to your computer and use it in GitHub Desktop.
Save nialv7/312936298e641eeb237152579eb56e9b to your computer and use it in GitHub Desktop.
gen.py
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import inspect
import time
import json
import os
from bottle import route, request, HTTPError, run, post, static_file
import numpy as np
import tensorflow as tf
logging = tf.logging
def data_type():
return tf.float32
class PTBSample(object):
"""The PTB model."""
def __init__(self, config):
size = config.hidden_size
vocab_size = config.vocab_size
# Slightly better results can be obtained with forget gate biases
# initialized to 1 but the hyperparameters of the model would need to be
# different than reported in the paper.
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True)
cell = tf.contrib.rnn.MultiRNNCell(
[lstm_cell() for _ in range(config.num_layers)], state_is_tuple=True)
self._initial_state = cell.zero_state(1, data_type())
input_ = tf.placeholder(tf.int32, shape=(1), name="step_input")
with tf.device("/cpu:0"):
embedding = tf.get_variable(
"embedding", [vocab_size, size], dtype=data_type())
inputs = tf.nn.embedding_lookup(embedding, input_)
with tf.variable_scope("RNN"):
(cell_output, nstate) = cell(inputs, self._initial_state)
print(cell_output)
softmax_w = tf.get_variable(
"softmax_w", [size, vocab_size], dtype=data_type())
softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())
logits = tf.matmul(cell_output, softmax_w) + softmax_b
self._step_output = tf.nn.softmax(logits)
self._next_state = nstate
self._step_input = input_
@property
def initial_state(self):
return self._initial_state
def forward(self, session, feed):
return session.run((self._step_output, self._next_state), feed)
@property
def step_input(self):
return self._step_input
class Config(object):
"""Medium config."""
init_scale = 0.05
learning_rate = 1.0
max_grad_norm = 5
num_layers = 2
num_steps = 50
hidden_size = 300
max_epoch = 6
max_max_epoch = 39
keep_prob = 0.5
lr_decay = 0.8
batch_size = 20
vocab_size = 0
def get_sample(k):
r = random.uniform(0, 0.5)
for i, x in enumerate(k):
if x >= r:
return i
r -= x
return len(k)
def ask_(prob, i2w):
cand = []
for i, p in enumerate(prob):
cand.append({"prob": p.item(), "word": i2w[i]})
x = sorted(cand, key=lambda x: -x["prob"])
return list(x[0:100])
import pickle
def main(_):
f = open("w2i1.p", "rb")
w2i = pickle.load(f)
i2w = {}
for k, v in w2i.items():
i2w[v] = k
config = Config()
config.vocab_size = len(w2i)
with tf.Graph().as_default():
initializer = tf.random_uniform_initializer(-config.init_scale,
config.init_scale)
# with tf.name_scope("Train"):
# with tf.variable_scope("Model", reuse=None, initializer=initializer):
# mtrain = PTBSample(config)
with tf.name_scope("Test"):
with tf.variable_scope("Model", reuse=None, initializer=initializer):
model = PTBSample(config)
sv = tf.train.Supervisor(logdir="x")
print(tf.train.latest_checkpoint("mmodel"))
with sv.managed_session() as session:
sv.saver.restore(session, tf.train.latest_checkpoint("mmodel"))
@post('/sample')
def gen():
cands = None
print(request.forms.get("words"))
wl = json.loads(request.forms.get("words"))
if wl[0] not in w2i:
raise HTTPError(status=401)
xi = w2i[wl[0]]
o, s = model.forward(session, {model.step_input: [xi]})
for i in wl[1:]:
if i not in w2i:
raise HTTPError(status=401)
feed = {}
feed[model.step_input] = [w2i[i]]
for i, (c, h) in enumerate(model.initial_state):
feed[c] = s[i].c
feed[h] = s[i].h
o, s = model.forward(session, feed)
cands = ask_(o[0], i2w)
fwl = []
for _ in range(0,200):
xi = get_sample(o[0])
if i2w[xi] == '<eos>':
break
fwl.append(i2w[xi])
feed = {}
feed[model.step_input] = [xi]
for i, (c, h) in enumerate(model.initial_state):
feed[c] = s[i].c
feed[h] = s[i].h
o, s = model.forward(session, feed)
return json.dumps({"fwl":fwl, "cands":cands})
@route('/')
def index():
return static_file("index.html", root=".")
print("restored")
run(host='0.0.0.0', port=os.environ['PORT'])
main(0)
# vim: set sw=2:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment