Skip to content

Instantly share code, notes, and snippets.

@lando22
Last active May 25, 2022 06:29
Show Gist options
  • Save lando22/6f0952fa30ad68c1161fc459aeca63d5 to your computer and use it in GitHub Desktop.
Save lando22/6f0952fa30ad68c1161fc459aeca63d5 to your computer and use it in GitHub Desktop.
# import libraries
from flask import request, jsonify, Flask, abort
import transformers
from transformers import pipeline
import humingbird
# start flask app
app = Flask(__name__)
# download GPT-2/load the model
generator = pipeline('text-generation', model='gpt2')
# our endpoint
@app.route("/gpt-generation", methods=["POST"])
def generate_and_filter():
request_json = request.get_json()
# if not working, try:
# request_json = request.get_json(force=True)
# if not POST request, not allowed
if request.method != "POST":
return abort(403)
else:
# we expect one JSON key, which is 'prompt'
prompt = request_json["prompt"]
# generated text from GPT-2
generated_text = (generator(prompt, max_length=20))[0]['generated_text']
# lets filter it!
prediction = humingbird.Text.predict(
text=generation,
labels=['toxic', 'not toxic']
)
return_object = {"generated_text": generated_text}
# Find the toxic key in the prediction object
for toxic_key in prediction[0]:
# if we found the toxic key, get the score
if toxic_key['className'] == 'toxic':
# if the filter is 70% or more confident
if toxic_key['score'] >= 0.7:
return_object["warning"] = "WARNING: This output may contain sensitive or toxic text. Please use with caution."
break
return jsonify(return_object)
if __name__ == '__main__':
app.run(host="0.0.0.0")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment