Created
January 18, 2019 19:31
-
-
Save tanayv/a655bc5b2d1000c318c66c767c878382 to your computer and use it in GitHub Desktop.
Clickbait Repel Flask App
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask | |
import tensorflow as tf | |
import numpy as np | |
import os | |
import time | |
import datetime | |
import data_helpers | |
from text_cnn import TextCNN | |
from tensorflow.contrib import learn | |
app = Flask(__name__) | |
@app.route('/check/<estring>') | |
def hello_world(estring): | |
return check(estring) | |
tf.flags.DEFINE_string("checkpoint_dir", "runs/1483816176/checkpoints/", "Checkpoint directory from training run") | |
tf.flags.DEFINE_boolean("eval_train", True, "Evaluate on all training data") | |
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") | |
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") | |
FLAGS = tf.flags.FLAGS | |
FLAGS._parse_flags() | |
vocab_path = os.path.join(FLAGS.checkpoint_dir, "..", "vocab") | |
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path) | |
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) | |
graph = tf.Graph() | |
def check(estring): | |
x_raw = [estring] | |
x_test = np.array(list(vocab_processor.transform(x_raw))) | |
with graph.as_default(): | |
session_conf = tf.ConfigProto( | |
allow_soft_placement=FLAGS.allow_soft_placement, | |
log_device_placement=FLAGS.log_device_placement) | |
sess = tf.Session(config=session_conf) | |
with sess.as_default(): | |
# Load the saved meta graph and restore variables | |
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) | |
saver.restore(sess, checkpoint_file) | |
# Get the placeholders from the graph by name | |
input_x = graph.get_operation_by_name("input_x").outputs[0] | |
# input_y = graph.get_operation_by_name("input_y").outputs[0] | |
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] | |
# Tensors we want to evaluate | |
predictions = graph.get_operation_by_name("output/predictions").outputs[0] | |
# Generate batches for one epoch | |
batches = data_helpers.batch_iter(list(x_test), 1, 1, shuffle=False) | |
# Collect the predictions here | |
all_predictions = [] | |
for x_test_batch in batches: | |
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0}) | |
all_predictions = np.concatenate([all_predictions, batch_predictions]) | |
if(all_predictions[0]==1): | |
return "L" | |
else: | |
return "C" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment