Last active
November 3, 2019 16:24
-
-
Save ThiagoLira/89aa3e759584cbfd17f490962fc9917a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@app.route('/',methods=['GET']) | |
def generate(): | |
# Since Flask forks the python process to answer for requests | |
# we need to do this to avoid errors with tensorflow | |
tf.reset_default_graph() | |
# Start tf session and load model into memory | |
sess = gpt2.start_tf_sess(threads=1) | |
gpt2.load_gpt2(sess) | |
# Get our params from the GET request | |
callback = request.args.get('callback') | |
sample = request.args.get('sample') | |
# If the user was too lazy to input something we just feed the model with a default | |
if (not sample): | |
sample = "Ash and Pikachu were" | |
samples = gpt2.generate(sess,prefix =sample,return_as_list=True,length=256) | |
# The model will generated a fixed amount of words | |
# Let's just throw away everything that is not a complete sentence | |
lst = re.split('\\\\.',samples[0]) | |
# Remove last incomplete sentence (denoted by a period) | |
generated_text = '.'.join(lst[:-1]) + "." | |
# Our return data | |
data = { | |
'sample_text' : generated_text | |
} | |
# Garbage collect since memory doesn't grow on trees | |
gc.collect() | |
return '{0}({1})'.format(callback,data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment