Created
May 29, 2020 21:19
-
-
Save duhaime/24ac8daa5bd060ebb2221411d18661ea to your computer and use it in GitHub Desktop.
Flask and Starlette Servers for GPT-2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, jsonify, request | |
import gpt_2_simple as gpt2 | |
import tensorflow as tf | |
import os | |
if not os.path.exists('checkpoint'): | |
os.makedirs(os.path.join('checkpoint', 'run1')) | |
gpt2.download_gpt2(model_name='117M') | |
os.system('mv models/117M/* checkpoint/run1/') | |
app = Flask(__name__) | |
sess = gpt2.start_tf_sess(threads=1) | |
gpt2.load_gpt2(sess) | |
@app.route('/', methods=['GET']) | |
def index(): | |
global sess | |
params = request.args | |
text = gpt2.generate(sess, | |
length=int(params.get('length', 128)), | |
temperature=float(params.get('temperature', 0.7)), | |
top_k=int(params.get('top_k', 0)), | |
top_p=float(params.get('top_p', 0)), | |
prefix=params.get('prefix', '')[:500], | |
truncate=params.get('truncate', None), | |
include_prefix=str(params.get('include_prefix', True)).lower() == 'true', | |
return_as_list=True)[0] | |
return text | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=8081) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from starlette.applications import Starlette | |
from starlette.responses import UJSONResponse | |
import gpt_2_simple as gpt2 | |
import tensorflow as tf | |
import uvicorn | |
import os | |
if not os.path.exists('checkpoint'): | |
os.makedirs(os.path.join('checkpoint', 'run1')) | |
gpt2.download_gpt2(model_name='117M') | |
os.system('mv models/117M/* checkpoint/run1/') | |
app = Starlette(debug=False) | |
sess = gpt2.start_tf_sess(threads=1) | |
gpt2.load_gpt2(sess) | |
@app.route('/', methods=['GET']) | |
async def index(request): | |
global sess | |
params = request.query_params | |
text = gpt2.generate(sess, | |
length=int(params.get('length', 128)), | |
temperature=float(params.get('temperature', 0.7)), | |
top_k=int(params.get('top_k', 0)), | |
top_p=float(params.get('top_p', 0)), | |
prefix=params.get('prefix', '')[:500], | |
truncate=params.get('truncate', None), | |
include_prefix=str(params.get('include_prefix', True)).lower() == 'true', | |
return_as_list=True)[0] | |
return text | |
if __name__ == '__main__': | |
uvicorn.run(app, port=8080) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment