Skip to content

Instantly share code, notes, and snippets.

@galleon
Created February 11, 2024 07:25
Show Gist options
  • Save galleon/9c7e4f42e58e4ab686c461b514f60080 to your computer and use it in GitHub Desktop.
Save galleon/9c7e4f42e58e4ab686c461b514f60080 to your computer and use it in GitHub Desktop.
import argparse
import json
import os
import sys
import time
import uuid
import pandas as pd
from dotenv import load_dotenv
from tqdm import tqdm
from loguru import logger
# Load environment variables from .env file
load_dotenv()
_task = [
"Trouver des articles qui comparent l'efficacité de différents régimes alimentaires pour la perte de poids.",
"Identifier des documents juridiques pertinents pour un cas spécifique de droit des brevets.",
"Extraire des études de cas sur l'utilisation de l'intelligence artificielle dans le diagnostic médical.",
"Rassembler des informations sur les dernières tendances en matière de développement durable dans l'industrie de la mode.",
"Collecter des preuves historiques soutenant ou contredisant une théorie spécifique sur l'origine des civilisations anciennes.",
"Rechercher des documents sur les impacts économiques de l'immigration dans différents pays.",
"Trouver des études sur l'efficacité des différentes méthodes d'enseignement en ligne pour les élèves du primaire.",
"Compiler des recherches sur les effets à long terme de l'exposition aux écrans sur la santé mentale des adolescents.",
"Identifier des sources qui discutent des avancées récentes dans le traitement de maladies rares.",
"Rassembler des articles examinant les conséquences sociales de l'automatisation sur l'emploi.",
"Trouver des documents qui analysent l'impact des réseaux sociaux sur la politique.",
"Extraire des informations sur les innovations en matière de recyclage des plastiques et leur efficacité.",
"Rechercher des études comparant les avantages de l'éducation à domicile par rapport à l'éducation traditionnelle.",
"Compiler des preuves sur les changements dans les modèles de migration des oiseaux en réponse au changement climatique.",
"Rassembler des analyses sur les stratégies de gestion des crises sanitaires mondiales.",
"Trouver des articles détaillant l'évolution des marchés financiers au cours de la dernière décennie.",
"Extraire des études sur l'influence de la musique sur la performance sportive.",
"Identifier des documents sur les meilleures pratiques pour la réhabilitation des écosystèmes dégradés.",
"Collecter des données sur l'utilisation de l'énergie renouvelable dans les zones urbaines.",
"Compiler des informations sur les avancées dans la recherche sur l'intelligence artificielle et l'éthique.",
]
_query_type = ["extremely long-tail", "long-tail", "common"]
_query_length = ["less than 5 words", "5 to 15 words", "at least 10 words"]
_difficulty = ["high school", "college", "PhD"]
_clarity = ["clear", "understandable with some effort", "ambiguous"]
_num_words = [50, 100, 200, 300, 400, 500]
_language = ["English", "French"]
def get_prompt(
task, query_type, query_length, difficulty, clarity, num_words, language
):
return f"""You have been assigned a retrieval task: {task}
Your mission is to write one text retrieval example for this task in JSON format. The JSON object must
contain the following keys:
- 'user_query': a string, a random user search query specified by the retrieval task.
- 'positive_document': a string, a relevant document for the user query.
- 'hard_negative_document': a string, a hard negative document that only appears relevant to the query.
Please adhere to the following guidelines:
- The 'user_query' should be {query_type}, {query_length}, {clarity}, and diverse in topic.
- All documents must be created independent of the query. Avoid copying the query verbatim. It’s acceptable
if some parts of the 'positive_document' are not topically related to the query.
- All documents should be at least {num_words} words long.
- The 'hard_negative_document' contains some useful information, but it should be less useful or comprehensive compared to the 'positive_document'.
- Both the query and documents should be in {language}.
- Do not provide any explanation in any document on why it is relevant or not relevant to the query.
- Both the query and documents require {difficulty} level education to understand.
Your output must always be a valid JSON object only, do not explain yourself or output anything else. Be creative!"""
def process_text_with_ollama(client, prompt, model="mistral", max_retries=5):
"""Process text using the Ollama client with error handling and retries."""
retries = 0
while retries < max_retries:
try:
response = client.generate(model=model, prompt=prompt)
try:
response_json = json.loads(response["response"].replace("\\_", "_"))
return response_json
except Exception as e:
logger.error(f"JSON parsing error: {e}")
logger.debug(response["response"])
logger.debug(f"Retrying (attempt {retries + 1}/{max_retries})...")
retries += 1
time.sleep(1)
except Exception as e:
logger.error(f"Mixtral processing error: {e}")
logger.debug(f"Retrying (attempt {retries + 1}/{max_retries})...")
retries += 1
time.sleep(1)
logger.warning("Max retries reached. Moving to the next chunk.")
return {"rank": None, "id": None, "text": ""}
def build_e5_dataset(client, selected_model):
"""Build the E5 dataset."""
dataset = []
t1 = time.time()
for task_id, task in tqdm(enumerate(_task)):
t2 = time.time()
for query_type in tqdm(_query_type):
for query_length in tqdm(_query_length):
for difficulty in tqdm(_difficulty):
for clarity in tqdm(_clarity):
for num_words in tqdm(_num_words):
for language in tqdm(_language):
prompt = get_prompt(
task,
query_type,
query_length,
difficulty,
clarity,
num_words,
language,
)
response = process_text_with_ollama(client, prompt)
dataset.append(
{
"task": task,
"query_type": query_type,
"query_length": query_length,
"difficulty": difficulty,
"clarity": clarity,
"num_words": num_words,
"language": language,
"prompt": prompt,
"response": response,
}
)
print(f"Time taken for task {task_id:2d}: {time.time() - t2:.2f} seconds")
df = pd.DataFrame(dataset)
df.to_csv("e5_dataset.csv", index=False)
logger.info("E5 dataset saved to e5_dataset.csv")
print(f"Total Time taken: {time.time() - t1:.2f} seconds")
df = pd.DataFrame(dataset)
df.to_csv("e5_dataset.csv", index=False)
logger.info("E5 dataset saved to e5_dataset.csv")
if __name__ == "__main__":
# Configure Loguru to log errors to a file
logger.add("dataset_error.log", level="ERROR", rotation="10 MB")
# Configure Loguru to log both errors and debug logs to the screen
logger.add(sys.stdout, level="ERROR")
logger.add("dataset_debug.log", level="DEBUG", rotation="10 MB")
parser = argparse.ArgumentParser(description="Run a language model")
model_group = parser.add_mutually_exclusive_group(required=True)
model_group.add_argument(
"--text-davinci-003",
action="store_const",
const="text_davinci_003",
help="Use the text-davinci-003 model",
)
model_group.add_argument(
"--mixtral", action="store_const", const="mixtral", help="Use the Mixtral model"
)
model_group.add_argument(
"--mistral", action="store_const", const="mistral", help="Use the Mistral model"
)
args = parser.parse_args()
selected_model = args.text_davinci_003 or args.mixtral or args.mistral
if not selected_model:
parser.error(
"Please select a model using --text-davinci-003, --mixtral, or --mistral"
)
logger.info(f"Using the {selected_model} model")
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key and selected_model == "text-davinci-003":
logger.error(
"No OpenAI API key found. Please set the OPENAI_API_KEY environment variable. Exiting..."
)
exit(1)
client = None
try:
if selected_model == "text-davinci-003":
from openai import OpenAI
client = OpenAI(openai_api_key)
logger.info("Using OpenAI client for text processing.")
else:
from ollama import Client
client = Client(host="http://localhost:11434")
logger.info("Using Ollama client for text processing.")
build_e5_dataset(client, selected_model)
except Exception as e:
logger.error(f"An error occurred: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment