Last active
May 25, 2022 06:29
-
-
Save lando22/6f0952fa30ad68c1161fc459aeca63d5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import libraries | |
from flask import request, jsonify, Flask, abort | |
import transformers | |
from transformers import pipeline | |
import humingbird | |
# start flask app | |
app = Flask(__name__) | |
# download GPT-2/load the model | |
generator = pipeline('text-generation', model='gpt2') | |
# our endpoint | |
@app.route("/gpt-generation", methods=["POST"]) | |
def generate_and_filter(): | |
request_json = request.get_json() | |
# if not working, try: | |
# request_json = request.get_json(force=True) | |
# if not POST request, not allowed | |
if request.method != "POST": | |
return abort(403) | |
else: | |
# we expect one JSON key, which is 'prompt' | |
prompt = request_json["prompt"] | |
# generated text from GPT-2 | |
generated_text = (generator(prompt, max_length=20))[0]['generated_text'] | |
# lets filter it! | |
prediction = humingbird.Text.predict( | |
text=generation, | |
labels=['toxic', 'not toxic'] | |
) | |
return_object = {"generated_text": generated_text} | |
# Find the toxic key in the prediction object | |
for toxic_key in prediction[0]: | |
# if we found the toxic key, get the score | |
if toxic_key['className'] == 'toxic': | |
# if the filter is 70% or more confident | |
if toxic_key['score'] >= 0.7: | |
return_object["warning"] = "WARNING: This output may contain sensitive or toxic text. Please use with caution." | |
break | |
return jsonify(return_object) | |
if __name__ == '__main__': | |
app.run(host="0.0.0.0") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment