-
-
Save sravantit25/77700a2ce308134bfbf0c13fcdd31803 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script uses ChatGPT, a conversational language model, | |
to generate a reply based on a user message. | |
It takes a user message as input and utilizes ChatGPT to generate a response. | |
Uses gpt-4-1106-preview model. An OpenAI API key is required. | |
""" | |
import os | |
import openai | |
from openai import OpenAI | |
API_KEY = os.getenv("API_KEY") | |
CLIENT = OpenAI(api_key=API_KEY) | |
def get_response(input_message, system_prompt, max_tokens=None): | |
"Return the ChatGPT response" | |
try: | |
response = CLIENT.chat.completions.create( | |
model="gpt-4-1106-preview", | |
response_format={"type": "json_object"}, | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": input_message}, | |
], | |
max_tokens=max_tokens, | |
) | |
print("Input message: ", input_message) | |
response = response.choices[0].message.content | |
print("GPT response: ", response) | |
except openai.error.Timeout as error: | |
print(f"OpenAI API request timed out: {error}") | |
except openai.error.APIConnectionError as error: | |
print(f"OpenAI API request failed to connect: {error}") | |
except openai.error.InvalidRequestError as error: | |
print(f"OpenAI API request was invalid: {error}") | |
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Data validation test for performing numerical validation using GPT-4. | |
The test validates that all the numbers in the provided dataset are between 0.0 and 1.0 | |
""" | |
import sys | |
import os | |
import json | |
import ast | |
import unittest | |
import pandas as pd | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
from get_gpt_response import get_response | |
class TestGithubScores(unittest.TestCase): | |
def setUp(self) -> None: | |
""" | |
Set the system prompt for the LLM and get the dataset | |
""" | |
self.system_prompt= ''' | |
You are an expert in numerical data validation. You will receive a dataset formatted as {"dataset": [list of numbers]}. The list of numbers are all float values. Your task is to analyze the dataset and determine its validity based on the following criteria: | |
- All values in the dataset must be between 0.0 and 1.0 (inclusive). | |
Return a JSON object with two keys: | |
1. "valid": true if the dataset meets the criteria, false otherwise | |
2. "failed_values": a list containing numbers that do not satisfy the condition that is numbers that are less than 0.0 or greater than 1.0 | |
''' | |
fname = os.path.join("/home/sravanti/project/QElo/tests/data_validation/great_expectations/llm_tests/github_scores.csv") | |
df = pd.read_csv(fname) | |
self.repo_scores = df['repo_score'] | |
def test_repo_score_boundaries(self): | |
""" | |
Test to validate that each repo score should be between 0 and 1 | |
""" | |
input_data = { "dataset": self.repo_scores.tolist() } | |
input_message = json.dumps(input_data) | |
llm_response = get_response(input_message, self.system_prompt) | |
llm_response = json.loads(llm_response) | |
expected_output = {"valid": True, "failed_values": []} | |
self.assertEqual(expected_output, llm_response) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment