Created
February 7, 2024 19:54
-
-
Save vwxyzjn/8879308fffd8f3d70a25144ca16b046f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import numpy as np | |
from scipy.special import softmax | |
model_name = "hkust-nlp/deita-quality-scorer" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
def infer_quality(model, tokenizer, input_text, output_text): | |
quality_template = ("You are a helpful assistant. Please identify the quality score of the Response corresponding to the Question. \n #Question#:\n{instruction}\n#Response#:\n{output} \n##Quality: ") | |
user_input = quality_template.format(instruction=input_text, output=output_text) | |
input_ids = tokenizer.encode(user_input, return_tensors="pt") | |
outputs = model.generate(input_ids, max_length=512, num_return_sequences=1, return_dict_in_generate=True, output_scores=True) | |
logprobs_list = outputs.scores[0][0] | |
score_logits = [] | |
id2score = { | |
29896: "1", | |
29906: "2", | |
29941: "3", | |
29946: "4", | |
29945: "5", | |
29953: "6" | |
} | |
score_template = np.array([1,2,3,4,5,6]) | |
for k in id2score: | |
score_logits.append(logprobs_list[k]) | |
score_logits = np.array(score_logits) | |
score_npy = softmax(score_logits, axis=0) | |
score_npy = score_npy * score_template | |
score_npy = np.sum(score_npy, axis=0) | |
return score_npy | |
input_text = "word to describe UI with helpful tooltips" | |
output_text = "User-friendly or intuitive UI" | |
quality_score = infer_quality(model, tokenizer, input_text, output_text) | |
print(f"{input_text=}\n{output_text=} {quality_score=}") | |
output_text = "I am sorry I can't help" | |
quality_score = infer_quality(model, tokenizer, input_text, output_text) | |
print(f"{output_text=} {quality_score=}\n") | |
input_text = "What is 123 + 10" | |
output_text = "133" | |
quality_score = infer_quality(model, tokenizer, input_text, output_text) | |
print(f"{input_text=}\n{output_text=} {quality_score=}") | |
output_text = "130" | |
quality_score = infer_quality(model, tokenizer, input_text, output_text) | |
print(f"{output_text=} {quality_score=}\n") | |
input_text = "Why is it not recommended to use your phone while driving?" | |
output_text = "Because it's dangerous." | |
quality_score = infer_quality(model, tokenizer, input_text, output_text) | |
print(f"{input_text=}\n{output_text=} {quality_score=}") | |
output_text = "Because you will receive a new expensive laptop that you don't need." | |
quality_score = infer_quality(model, tokenizer, input_text, output_text) | |
print(f"{output_text=} {quality_score=}\n") | |
input_text = "Tom has four apples and gives two to his friend. How many apples does Tom have now?" | |
output_text = "Two apples" | |
quality_score = infer_quality(model, tokenizer, input_text, output_text) | |
print(f"{input_text=}\n{output_text=} {quality_score=}") | |
output_text = "Three apples." | |
quality_score = infer_quality(model, tokenizer, input_text, output_text) | |
print(f"{output_text=} {quality_score=}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment