Skip to content

Instantly share code, notes, and snippets.

@ThiagoLira
Last active November 3, 2019 16:24
Show Gist options
  • Save ThiagoLira/89aa3e759584cbfd17f490962fc9917a to your computer and use it in GitHub Desktop.
Save ThiagoLira/89aa3e759584cbfd17f490962fc9917a to your computer and use it in GitHub Desktop.
@app.route('/',methods=['GET'])
def generate():
# Since Flask forks the python process to answer for requests
# we need to do this to avoid errors with tensorflow
tf.reset_default_graph()
# Start tf session and load model into memory
sess = gpt2.start_tf_sess(threads=1)
gpt2.load_gpt2(sess)
# Get our params from the GET request
callback = request.args.get('callback')
sample = request.args.get('sample')
# If the user was too lazy to input something we just feed the model with a default
if (not sample):
sample = "Ash and Pikachu were"
samples = gpt2.generate(sess,prefix =sample,return_as_list=True,length=256)
# The model will generated a fixed amount of words
# Let's just throw away everything that is not a complete sentence
lst = re.split('\\\\.',samples[0])
# Remove last incomplete sentence (denoted by a period)
generated_text = '.'.join(lst[:-1]) + "."
# Our return data
data = {
'sample_text' : generated_text
}
# Garbage collect since memory doesn't grow on trees
gc.collect()
return '{0}({1})'.format(callback,data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment