Skip to content

Instantly share code, notes, and snippets.

@duhaime
Created May 29, 2020 21:19
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save duhaime/24ac8daa5bd060ebb2221411d18661ea to your computer and use it in GitHub Desktop.
Flask and Starlette Servers for GPT-2
from flask import Flask, jsonify, request
import gpt_2_simple as gpt2
import tensorflow as tf
import os
if not os.path.exists('checkpoint'):
os.makedirs(os.path.join('checkpoint', 'run1'))
gpt2.download_gpt2(model_name='117M')
os.system('mv models/117M/* checkpoint/run1/')
app = Flask(__name__)
sess = gpt2.start_tf_sess(threads=1)
gpt2.load_gpt2(sess)
@app.route('/', methods=['GET'])
def index():
global sess
params = request.args
text = gpt2.generate(sess,
length=int(params.get('length', 128)),
temperature=float(params.get('temperature', 0.7)),
top_k=int(params.get('top_k', 0)),
top_p=float(params.get('top_p', 0)),
prefix=params.get('prefix', '')[:500],
truncate=params.get('truncate', None),
include_prefix=str(params.get('include_prefix', True)).lower() == 'true',
return_as_list=True)[0]
return text
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8081)
from starlette.applications import Starlette
from starlette.responses import UJSONResponse
import gpt_2_simple as gpt2
import tensorflow as tf
import uvicorn
import os
if not os.path.exists('checkpoint'):
os.makedirs(os.path.join('checkpoint', 'run1'))
gpt2.download_gpt2(model_name='117M')
os.system('mv models/117M/* checkpoint/run1/')
app = Starlette(debug=False)
sess = gpt2.start_tf_sess(threads=1)
gpt2.load_gpt2(sess)
@app.route('/', methods=['GET'])
async def index(request):
global sess
params = request.query_params
text = gpt2.generate(sess,
length=int(params.get('length', 128)),
temperature=float(params.get('temperature', 0.7)),
top_k=int(params.get('top_k', 0)),
top_p=float(params.get('top_p', 0)),
prefix=params.get('prefix', '')[:500],
truncate=params.get('truncate', None),
include_prefix=str(params.get('include_prefix', True)).lower() == 'true',
return_as_list=True)[0]
return text
if __name__ == '__main__':
uvicorn.run(app, port=8080)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment