Skip to content

Instantly share code, notes, and snippets.

@likejazz
Created January 8, 2024 04:44
Show Gist options
  • Save likejazz/291c5eb183853b33ae49a3d92af28bcd to your computer and use it in GitHub Desktop.
Save likejazz/291c5eb183853b33ae49a3d92af28bcd to your computer and use it in GitHub Desktop.
trlX + pipeline
$ curl -X PUT http://XXX.XXX.XXX.XXX:XXXX/ \
-H 'Content-Type: application/json' \
-d '{"prompt": "this movie was sucks!"}' | jq
import json
from transformers import pipeline
from flask import Flask, request, jsonify
app = Flask(__name__)
rm = pipeline(task='sentiment-analysis', model="lvwerra/distilbert-imdb", device=0)
model_name = ['gpt', 'sft', 'ppo']
model = {
model_name[0]: pipeline(model="XXXX", device=1),
model_name[1]: pipeline(model="XXXX", device=2),
model_name[2]: pipeline(model="XXXX", device=3)
}
@app.route('/', methods=['GET', 'PUT'])
def index():
data = json.loads(request.data)
output = {}
score = {}
results = [
{
'model': 'megatron-gpt2-345m'
}
]
for m in model_name:
output[m] = model[m](data['prompt'], do_sample=True, top_p=0.95)
score[m] = rm(output[m][0]['generated_text'])
results.append({
m.upper(): output[m][0]['generated_text'],
'score': score[m][0]['score'] if score[m][0]['label'] == 'POSITIVE' else 1 - score[m][0]['score']
})
return jsonify(results)
app.run(host='0.0.0.0', port=XXXX)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment