Created
May 15, 2025 02:04
-
-
Save shshjhjh4455/f78c405e72b95e35a340a34880dee7d1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
HyDE 라이브러리를 사용하여 agriculture 도메인만 처리하는 단순화된 스크립트 | |
""" | |
import os | |
import json | |
import sys | |
import numpy as np | |
import asyncio | |
from tqdm import tqdm | |
import tiktoken | |
from typing import List, Dict, Tuple, Any | |
from openai import AsyncOpenAI, OpenAI | |
# HyDE 레포지토리 경로 추가 | |
hyde_repo_path = os.path.join(os.getcwd(), "hyde_repo/src") | |
sys.path.insert(0, hyde_repo_path) | |
print(f"추가된 HyDE 레포지토리 경로: {hyde_repo_path}") | |
# HyDE 클래스 임포트 | |
from hyde.hyde import HyDE | |
from hyde.promptor import Promptor | |
from hyde.generator import OpenAIGenerator | |
# 출력 디렉토리 | |
OUTPUT_DIR = "./evaluation_results_ksh" | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# 고유 문서 컨텍스트 JSON 파일 | |
DATASET_FILE = "./datasets/unique_contexts/agriculture_unique_contexts.json" | |
# 질문 파일 (JSON) | |
QUESTIONS_FILE = "./Answer_ksh/agriculture_questions.json" | |
# 직접 API 키 설정 | |
os.environ["OPENAI_API_KEY"] = "" | |
# 초기화 | |
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
async_client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
# 문서 청킹 관련 상수 및 유틸리티 함수 | |
CHUNK_SIZE = 128 # 각 청크의 최대 토큰 수 | |
CHUNK_OVERLAP = 0 # 청크 간 겹치는 토큰 수 | |
DEFAULT_ENCODING_MODEL = "text-embedding-3-small" | |
K_RETRIEVAL = 10 | |
def get_encoding(model: str = DEFAULT_ENCODING_MODEL): | |
"""지정된 모델에 맞는 토큰 인코딩 객체 반환""" | |
try: | |
encoding = tiktoken.encoding_for_model(model) | |
except KeyError: | |
print( | |
f"경고: {model}에 대한 토큰화 정보를 찾을 수 없습니다. cl100k_base를 사용합니다." | |
) | |
encoding = tiktoken.get_encoding("cl100k_base") | |
return encoding | |
def chunk_text_by_tokens( | |
text: str, | |
chunk_size: int = CHUNK_SIZE, | |
chunk_overlap: int = CHUNK_OVERLAP, | |
model: str = DEFAULT_ENCODING_MODEL, | |
) -> List[Dict[str, Any]]: | |
"""텍스트를 지정된 토큰 수에 맞게 청크로 나눔""" | |
if not text: | |
return [] | |
encoding = get_encoding(model) | |
tokens = encoding.encode(text) | |
# 청크 나누기 | |
chunks = [] | |
i = 0 | |
while i < len(tokens): | |
# 현재 청크의 토큰 범위 결정 | |
chunk_end = min(i + chunk_size, len(tokens)) | |
# 현재 청크 토큰 추출 | |
chunk_tokens = tokens[i:chunk_end] | |
chunk_text = encoding.decode(chunk_tokens) | |
# 청크 메타데이터 저장 | |
chunks.append( | |
{ | |
"text": chunk_text, | |
"tokens": len(chunk_tokens), | |
"start_idx": i, | |
"end_idx": chunk_end - 1, | |
} | |
) | |
# 다음 시작점 계산 (청크 겹침 고려) | |
if chunk_end == len(tokens): | |
break | |
i = chunk_end - chunk_overlap | |
print( | |
f"텍스트를 {len(chunks)}개의 청크로 분할 (평균 {sum(c['tokens'] for c in chunks) / len(chunks):.1f} 토큰/청크)" | |
) | |
return chunks | |
def chunk_documents( | |
documents: List[str], | |
chunk_size: int = CHUNK_SIZE, | |
chunk_overlap: int = CHUNK_OVERLAP, | |
model: str = DEFAULT_ENCODING_MODEL, | |
) -> List[Dict[str, Any]]: | |
"""문서 리스트를 청크로 나눔""" | |
all_chunks = [] | |
for doc_idx, doc in enumerate(documents): | |
doc_chunks = chunk_text_by_tokens(doc, chunk_size, chunk_overlap, model) | |
# 문서 ID 및 원본 문서 인덱스 추가 | |
for i, chunk in enumerate(doc_chunks): | |
chunk["doc_id"] = f"doc_{doc_idx}" | |
chunk["doc_idx"] = doc_idx | |
chunk["chunk_idx"] = i | |
chunk["doc_chunks_total"] = len(doc_chunks) | |
all_chunks.extend(doc_chunks) | |
print(f"총 {len(documents)}개 문서를 {len(all_chunks)}개 청크로 분할") | |
return all_chunks | |
# OpenAI 임베딩을 위한 클래스 | |
class OpenAIEncoder: | |
def __init__(self, model_name="text-embedding-3-small"): | |
self.model_name = model_name | |
self.client = OpenAI( | |
api_key=os.environ["OPENAI_API_KEY"] | |
) # 직접 지정한 API 키 사용 | |
# 모델 차원 설정 - 모델에 따라 차원 다름 | |
self.dimensions = 1536 # text-embedding-3-small은 1536 차원 | |
if model_name == "text-embedding-3-large": | |
self.dimensions = 3072 # text-embedding-3-large는 3072 차원 | |
# 임베딩 모델 최대 토큰 수 | |
self.max_tokens = 8000 # 여유 있게 8,000으로 설정 (8,192 제한) | |
# 토큰화 엔진 | |
self.encoding = get_encoding(model_name) | |
def encode(self, text): | |
"""텍스트를 임베딩 벡터로 변환 (단일 텍스트)""" | |
# 토큰 수 계산 | |
tokens = self.encoding.encode(text) | |
# 토큰 수가 제한을 초과하면 자르기 | |
if len(tokens) > self.max_tokens: | |
print( | |
f"임베딩을 위한 토큰 수가 제한을 초과하여 자름: {len(tokens)} -> {self.max_tokens}" | |
) | |
text = self.encoding.decode(tokens[: self.max_tokens]) | |
response = self.client.embeddings.create(model=self.model_name, input=text) | |
return response.data[0].embedding | |
def encode_batch(self, texts, batch_size=16): | |
"""텍스트 배치를 임베딩 벡터로 변환 (API 요청 최적화)""" | |
# 각 텍스트의 토큰 수 제한 적용 | |
processed_texts = [] | |
for text in texts: | |
tokens = self.encoding.encode(text) | |
if len(tokens) > self.max_tokens: | |
print( | |
f"임베딩을 위한 토큰 수가 제한을 초과하여 자름: {len(tokens)} -> {self.max_tokens}" | |
) | |
text = self.encoding.decode(tokens[: self.max_tokens]) | |
processed_texts.append(text) | |
# 배치 처리 (API 할당량 고려) | |
all_embeddings = [] | |
for i in range(0, len(processed_texts), batch_size): | |
batch = processed_texts[i : i + batch_size] | |
response = self.client.embeddings.create(model=self.model_name, input=batch) | |
batch_embeddings = [data.embedding for data in response.data] | |
all_embeddings.extend(batch_embeddings) | |
return all_embeddings | |
def encode_chunks(self, chunks): | |
"""청크 목록에 대한 임베딩 생성""" | |
# 청크 텍스트만 추출 | |
chunk_texts = [chunk["text"] for chunk in chunks] | |
# 배치 임베딩 생성 | |
embeddings = self.encode_batch(chunk_texts) | |
# 임베딩 결과를 청크 객체에 추가 | |
for i, embedding in enumerate(embeddings): | |
chunks[i]["embedding"] = embedding | |
return chunks | |
# FAISS 벡터 검색을 위한 클래스 | |
class FAISSSearcher: | |
def __init__(self, dim=1536): | |
import faiss | |
self.chunks = [] # 모든 청크 저장 (메타데이터 포함) | |
self.index = faiss.IndexFlatL2(dim) # 벡터 차원 크기 | |
self.expected_dim = dim # 예상되는 임베딩 차원 | |
def add(self, vector, document): | |
"""단일 벡터와 문서를 인덱스에 추가 (하위 호환성 유지)""" | |
chunk = { | |
"text": document, | |
"doc_id": f"doc_{len(self.chunks)}", | |
"is_chunk": False, | |
} | |
self.add_chunk(vector, chunk) | |
def add_chunk(self, vector, chunk): | |
"""벡터와 청크를 인덱스에 추가""" | |
if isinstance(vector, list): | |
vector = np.array(vector) | |
# 차원 검증 | |
if len(vector.shape) == 1: | |
if vector.shape[0] != self.expected_dim: | |
print( | |
f"경고: 예상 차원({self.expected_dim})과 다른 차원({vector.shape[0]})의 벡터입니다. 건너뜁니다." | |
) | |
return | |
vector = vector.reshape(1, -1) | |
elif vector.shape[1] != self.expected_dim: | |
print( | |
f"경고: 예상 차원({self.expected_dim})과 다른 차원({vector.shape[1]})의 벡터입니다. 건너뜁니다." | |
) | |
return | |
# 유효한 벡터인지 확인 (NaN 또는 무한대 값 체크) | |
if not np.all(np.isfinite(vector)): | |
print( | |
"경고: 유효하지 않은 벡터입니다 (NaN 또는 무한대 값 포함). 건너뜁니다." | |
) | |
return | |
try: | |
# 청크 인덱스 추가 | |
chunk["chunk_idx_in_searcher"] = len(self.chunks) | |
self.chunks.append(chunk) | |
self.index.add(vector) | |
except Exception as e: | |
print(f"벡터 추가 중 오류 발생: {e}") | |
print(f"벡터 shape: {vector.shape}, 타입: {type(vector)}") | |
# 오류 발생 시 마지막에 추가된 청크 제거 | |
if len(self.chunks) > 0: | |
self.chunks.pop() | |
def add_chunks(self, chunks_with_embeddings): | |
"""임베딩을 포함한 청크 배열을 인덱스에 추가""" | |
added_count = 0 | |
for chunk in chunks_with_embeddings: | |
if "embedding" not in chunk: | |
print( | |
f"경고: 청크에 임베딩이 없습니다. 청크 ID: {chunk.get('doc_id', 'unknown')}" | |
) | |
continue | |
embedding = chunk["embedding"] | |
self.add_chunk(embedding, chunk) | |
added_count += 1 | |
print( | |
f"총 {added_count}/{len(chunks_with_embeddings)}개 청크를 인덱스에 추가했습니다." | |
) | |
return added_count | |
def search(self, query_vector, k=10): | |
"""쿼리 벡터와 가장 가까운 k개의 청크 검색""" | |
if isinstance(query_vector, list): | |
query_vector = np.array(query_vector) | |
# 벡터 차원 검증 | |
if len(query_vector.shape) == 1: | |
if query_vector.shape[0] != self.expected_dim: | |
print( | |
f"검색 오류: 예상 차원({self.expected_dim})과 다른 차원({query_vector.shape[0]})의 쿼리 벡터입니다." | |
) | |
return [] # 빈 결과 반환 | |
query_vector = query_vector.reshape(1, -1) | |
elif query_vector.shape[1] != self.expected_dim: | |
print( | |
f"검색 오류: 예상 차원({self.expected_dim})과 다른 차원({query_vector.shape[1]})의 쿼리 벡터입니다." | |
) | |
return [] # 빈 결과 반환 | |
# 유효한 벡터인지 확인 | |
if not np.all(np.isfinite(query_vector)): | |
print("검색 오류: 유효하지 않은 쿼리 벡터입니다 (NaN 또는 무한대 값 포함).") | |
return [] # 빈 결과 반환 | |
# 인덱스가 비어있는지 확인 | |
if self.index.ntotal == 0: | |
print("검색 오류: 인덱스가 비어있습니다.") | |
return [] # 빈 결과 반환 | |
try: | |
# 검색 실행 | |
D, I = self.index.search(query_vector, min(k, self.index.ntotal)) | |
# 결과 반환 (원본 확장 포맷 유지 + 청크 메타데이터 추가) | |
results = [] | |
for i, d in zip(I[0], D[0]): | |
if i >= 0 and i < len(self.chunks): | |
chunk = self.chunks[int(i)] | |
results.append((int(i), float(d), chunk["text"], chunk)) | |
return results | |
except Exception as e: | |
print(f"검색 중 오류 발생: {e}") | |
print( | |
f"쿼리 벡터 shape: {query_vector.shape}, 인덱스 크기: {self.index.ntotal}" | |
) | |
return [] # 오류 발생 시 빈 결과 반환 | |
def load_documents_from_dataset(): | |
"""unique_contexts/agriculture_unique_contexts.json 파일에서 고유 문서 로드""" | |
print( | |
"unique_contexts/agriculture_unique_contexts.json 파일에서 고유 문서 로드 중..." | |
) | |
# JSON 파일에서 전체 문서 리스트 로드 | |
with open(DATASET_FILE, "r", encoding="utf-8") as f: | |
documents = json.load(f) | |
print( | |
f"{len(documents)}개의 고유 문서를 로드했습니다 (agriculture_unique_contexts.json에서 직접 로드)" | |
) | |
return documents | |
def load_questions_from_file(): | |
"""Answer_ksh/agriculture_questions.json 파일에서 질문 로드""" | |
print("Answer_ksh/agriculture_questions.json 파일에서 질문 로드 중...") | |
with open(QUESTIONS_FILE, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
# 각 아이템의 'query' 필드에서 질문 추출 | |
questions = [item["query"] for item in data if "query" in item] | |
print(f"총 {len(questions)}개의 질문을 로드했습니다.") | |
return questions | |
async def generate_hyde_response(query, hyde): | |
"""HyDE를 사용하여 응답 생성 (비동기 처리) - 프롬프트+컨텍스트 정확히 10만 토큰으로 제한""" | |
try: | |
# HyDE 검색 수행 (전체 데이터베이스에서 가장 관련성 높은 청크 k개 검색) | |
k_retrieval = K_RETRIEVAL # 검색할 청크 수 | |
hits = hyde.e2e_search(query, k=k_retrieval) | |
# 검색 결과가 없는 경우 처리 | |
if not hits or len(hits) == 0: | |
print("검색된 청크가 없습니다. 빈 응답을 반환합니다.") | |
return {"query": query, "result": "관련 정보를 찾을 수 없습니다."} | |
# 모델과 토큰화 설정 | |
import tiktoken | |
encoding = tiktoken.encoding_for_model("gpt-4o-mini") | |
# 최대 토큰 수 설정 | |
max_tokens = ( | |
100000 # 프롬프트+컨텍스트 합쳐서 최대 10만 토큰으로 제한 (사용자 지정) | |
) | |
model_max_tokens = 125000 # 실제 모델 제한 (GPT-4o-mini는 125k 토큰) | |
max_tokens = min(max_tokens, model_max_tokens) # 더 작은 값 사용 | |
# 프롬프트 템플릿 (Answer: 부분 포함) | |
prompt_template = f""" | |
---Role--- | |
You are a helpful assistant responding to user query | |
---Goal--- | |
Generate a concise response based on the following information and follow Response Rules. Do not include information not provided by following Information | |
---User Query--- | |
{query} | |
---Information--- | |
{{contexts}} | |
---Response Rules--- | |
- Use markdown formatting with appropriate section headings | |
- Please respond in the same language as the user's question. | |
- Ensure the response maintains continuity with the conversation history. | |
- If you don't know the answer, just say so. | |
- Do not make anything up. Do not include information not provided by the Infromation. | |
Answer: | |
""" | |
# 컨텍스트 없는 프롬프트의 토큰 수 계산 | |
prompt_without_context = prompt_template.format(contexts="") | |
prompt_tokens = len(encoding.encode(prompt_without_context)) | |
# 컨텍스트에 사용 가능한 토큰 수 계산 | |
available_context_tokens = max_tokens - prompt_tokens | |
print( | |
f"프롬프트 토큰 수: {prompt_tokens}, 컨텍스트에 사용 가능한 토큰 수: {available_context_tokens}" | |
) | |
# 청크 정보 추출 및 구조화 | |
context_chunks = [] | |
for hit in hits: | |
if len(hit) >= 4 and isinstance( | |
hit[3], dict | |
): # 청크 메타데이터가 있는 경우 | |
chunk_data = hit[3] | |
chunk_text = hit[2] | |
chunk_score = hit[1] # 유사도 점수 (낮을수록 관련성 높음) | |
chunk_tokens = len(encoding.encode(chunk_text)) | |
context_chunks.append( | |
{ | |
"text": chunk_text, | |
"tokens": chunk_tokens, | |
"score": chunk_score, | |
"metadata": chunk_data, | |
} | |
) | |
else: # 기존 방식으로 반환된 경우 | |
chunk_text = hit[2] | |
chunk_score = hit[1] | |
chunk_tokens = len(encoding.encode(chunk_text)) | |
context_chunks.append( | |
{ | |
"text": chunk_text, | |
"tokens": chunk_tokens, | |
"score": chunk_score, | |
"metadata": {}, | |
} | |
) | |
# 최적의 컨텍스트 선택 (점수 기반 정렬 + 토큰 제한 적용) | |
selected_chunks = [] | |
total_tokens = 0 | |
# 점수 기반 정렬 (낮을수록 관련성 높음) | |
sorted_chunks = sorted(context_chunks, key=lambda x: x["score"]) | |
# HyDE 라이브러리 로직을 따라 단순히 상위 k개의 결과만 사용 | |
# 각 청크의 토큰 수를 계산하면서 토큰 제한 내에서 가능한 한 많이 포함 | |
for chunk in sorted_chunks: | |
# 청크 추가 시 토큰 제한을 초과하는지 확인 | |
if total_tokens + chunk["tokens"] <= available_context_tokens: | |
selected_chunks.append(chunk) | |
total_tokens += chunk["tokens"] | |
else: | |
# 토큰 제한을 초과하면 더 이상 추가하지 않음 | |
break | |
print(f"검색된 {len(context_chunks)}개 청크 중 {len(selected_chunks)}개 선택됨") | |
# 선택된 청크를 문서 ID 기준으로 그룹화하고 정렬 | |
from collections import defaultdict | |
chunks_by_doc = defaultdict(list) | |
for chunk in selected_chunks: | |
doc_id = chunk["metadata"].get("doc_id", "unknown") | |
chunks_by_doc[doc_id].append(chunk) | |
# 선택된 청크의 문서별 통계 출력 | |
print("\n선택된 청크의 문서별 분포:") | |
for doc_id, chunks in chunks_by_doc.items(): | |
print(f" - {doc_id}: {len(chunks)}개 청크 선택") | |
# 각 문서 내의 청크를 순서대로 정렬 (chunk_idx 기준) | |
for doc_id in chunks_by_doc: | |
chunks_by_doc[doc_id].sort(key=lambda x: x["metadata"].get("chunk_idx", 0)) | |
# 컨텍스트 구성 (문서별로 그룹화) | |
context_sections = [] | |
for doc_id, doc_chunks in chunks_by_doc.items(): | |
doc_texts = [chunk["text"] for chunk in doc_chunks] | |
doc_section = f"--- Document: {doc_id} ---\n" + "\n".join(doc_texts) | |
context_sections.append(doc_section) | |
# 최종 컨텍스트 결합 | |
contexts = "\n\n".join(context_sections) | |
# 문자 길이가 아니라 토큰으로 정확히 계산 | |
contexts_tokens = len(encoding.encode(contexts)) | |
while contexts_tokens > available_context_tokens: | |
# 컨텍스트가 여전히 너무 길면 마지막 섹션 제거 | |
if len(context_sections) > 1: | |
context_sections.pop() | |
contexts = "\n\n".join(context_sections) | |
contexts_tokens = len(encoding.encode(contexts)) | |
else: | |
# 마지막 섹션만 남은 경우 부분 자르기 | |
last_section = context_sections[0] | |
reduction = int(len(last_section) * 0.1) # 10% 줄이기 | |
context_sections[0] = last_section[:-reduction] | |
contexts = context_sections[0] | |
contexts_tokens = len(encoding.encode(contexts)) | |
# 최종 프롬프트 구성 | |
prompt = prompt_template.format(contexts=contexts) | |
total_tokens = len(encoding.encode(prompt)) | |
# 응답 생성 | |
response = await async_client.chat.completions.create( | |
model="gpt-4o-mini", messages=[{"role": "user", "content": prompt}] | |
) | |
result = response.choices[0].message.content | |
return {"query": query, "result": result} | |
except Exception as e: | |
print(f"HyDE 응답 생성 중 오류 발생: {e}") | |
print(f"오류 상세 정보: {str(e)}") | |
return {"query": query, "result": f"오류: {str(e)}"} | |
async def process_queries_sequential(queries, hyde): | |
"""모든 쿼리를 순차적으로 처리하며 진행 상황과 예상 시간을 표시""" | |
print(f"총 {len(queries)}개 쿼리 순차 처리 시작") | |
total = len(queries) | |
results = [] | |
# 시간 측정용 변수 | |
import time | |
start_time = time.time() | |
processing_times = [] | |
# 각 쿼리를 순서대로 하나씩 처리 | |
for idx, query in enumerate(queries): | |
query_start_time = time.time() | |
print(f"쿼리 {idx+1}/{total} 처리 중...") | |
try: | |
# 순차적으로 응답 생성 | |
result = await generate_hyde_response(query, hyde) | |
results.append(result) | |
# 처리 시간 계산 | |
query_time = time.time() - query_start_time | |
processing_times.append(query_time) | |
# 평균 처리 시간 계산 | |
avg_time = sum(processing_times) / len(processing_times) | |
# 남은 쿼리 수 | |
remaining_queries = total - (idx + 1) | |
# 예상 남은 시간 (초 단위) | |
remaining_time = avg_time * remaining_queries | |
# 시간 형식으로 변환 | |
hours, remainder = divmod(remaining_time, 3600) | |
minutes, seconds = divmod(remainder, 60) | |
if hours > 0: | |
time_format = f"{int(hours)}시간 {int(minutes)}분 {int(seconds)}초" | |
elif minutes > 0: | |
time_format = f"{int(minutes)}분 {int(seconds)}초" | |
else: | |
time_format = f"{int(seconds)}초" | |
print( | |
f"쿼리 {idx+1}/{total} 완료 ({(idx+1)/total*100:.1f}%) - 처리 시간: {query_time:.1f}초, 남은 예상 시간: {time_format}" | |
) | |
except Exception as e: | |
print(f"쿼리 {idx+1}/{total} 처리 중 오류: {e}") | |
results.append({"query": query, "result": f"오류: {str(e)}"}) | |
# 오류가 발생해도 시간 계산을 위해 처리 시간 추가 | |
processing_times.append(time.time() - query_start_time) | |
total_time = time.time() - start_time | |
hours, remainder = divmod(total_time, 3600) | |
minutes, seconds = divmod(remainder, 60) | |
if hours > 0: | |
time_format = f"{int(hours)}시간 {int(minutes)}분 {int(seconds)}초" | |
elif minutes > 0: | |
time_format = f"{int(minutes)}분 {int(seconds)}초" | |
else: | |
time_format = f"{int(seconds)}초" | |
print(f"총 {len(results)}개 쿼리 순차 처리 완료, 총 소요 시간: {time_format}") | |
return results | |
async def initialize_hyde(): | |
"""HyDE 초기화 및 벡터 데이터베이스 구축, 이미 임베딩된 결과가 있으면 재활용""" | |
import pickle | |
import os.path | |
domain = "agriculture" # 현재 도메인 | |
# 임베딩 캐시 파일 경로 (청크 기반) | |
chunks_cache_file = os.path.join(OUTPUT_DIR, f"{domain}_chunks.pkl") | |
# HyDE 초기화 | |
print("HyDE 초기화...") | |
promptor = Promptor(task="web search") | |
generator = OpenAIGenerator( | |
model_name="gpt-4o-mini", | |
api_key=os.environ["OPENAI_API_KEY"], # 직접 지정한 API 키 사용 | |
n=8, # 라이브러리 기본값: 8개의 가설 문서 생성 | |
max_tokens=512, # 라이브러리 기본값: 512 토큰 | |
temperature=0.7, | |
stop=["\n\n\n"], # 라이브러리 기본값 | |
) | |
# text-embedding-3-small 모델 사용(1536 차원) | |
encoder = OpenAIEncoder(model_name="text-embedding-3-small") | |
# 동일한 차원의 FAISS 인덱스 초기화 | |
searcher = FAISSSearcher(dim=encoder.dimensions) | |
hyde = HyDE(promptor, generator, encoder, searcher) | |
# 새로운 임베딩을 계산할지 여부 | |
force_new_embeddings = False | |
# 청크 캐시 파일이 존재하면 로드하여 사용 | |
if os.path.exists(chunks_cache_file) and not force_new_embeddings: | |
print(f"기존 청크 캐시 파일 발견: {chunks_cache_file}") | |
try: | |
with open(chunks_cache_file, "rb") as f: | |
cached_chunks = pickle.load(f) | |
# 캐시된 청크 수와 포맷 확인 | |
if not isinstance(cached_chunks, list) or len(cached_chunks) == 0: | |
print("캐시 파일 형식이 올바르지 않습니다. 새로 처리합니다.") | |
else: | |
# 임베딩 차원 확인 | |
if any( | |
"embedding" not in chunk | |
or len(np.array(chunk["embedding"])) != encoder.dimensions | |
for chunk in cached_chunks | |
): | |
print( | |
f"캐시된 임베딩 차원이 현재 모델({encoder.dimensions})과 일치하지 않습니다. 새로 계산합니다." | |
) | |
else: | |
# 벡터 데이터베이스에 청크 추가 | |
print( | |
f"캐시에서 {len(cached_chunks)}개의 청크를 로드하여 벡터 DB에 추가 중..." | |
) | |
valid_count = searcher.add_chunks(cached_chunks) | |
print( | |
f"청크 및 임베딩 추가 완료: {valid_count}/{len(cached_chunks)} (캐시 사용)" | |
) | |
return hyde | |
except Exception as e: | |
print(f"캐시 파일 로드 중 오류 발생: {e}") | |
print("새로 청크 및 임베딩을 계산합니다.") | |
else: | |
if force_new_embeddings: | |
print("기존 캐시를 무시하고 새로운 청크 및 임베딩을 계산합니다.") | |
# 문서 로드 및 청크 분할 | |
print("문서 로드 및 청크 분할 시작...") | |
documents = load_documents_from_dataset() | |
print(f"총 {len(documents)}개의 문서를 로드했습니다.") | |
# 문서를 청크로 분할 | |
chunks = chunk_documents(documents) | |
# 문서별 청크 수 통계 출력 | |
from collections import defaultdict | |
chunks_per_doc = defaultdict(int) | |
for chunk in chunks: | |
doc_id = chunk["doc_id"] | |
chunks_per_doc[doc_id] += 1 | |
print(f"\n문서별 청크 수:") | |
for doc_id, count in chunks_per_doc.items(): | |
print(f" - {doc_id}: {count}개 청크") | |
print( | |
f"총 {len(documents)}개 문서를 {len(chunks)}개의 청크로 분할했습니다 (평균 {len(chunks)/len(documents):.1f}개/문서)" | |
) | |
# 각 청크에 대한 임베딩 생성 | |
chunk_count = len(chunks) | |
print(f"총 {chunk_count}개 청크에 대한 임베딩 계산 중...") | |
chunks_with_embeddings = encoder.encode_chunks(chunks) | |
# 벡터 데이터베이스에 추가 | |
print(f"총 {len(chunks_with_embeddings)}개 청크 및 임베딩을 벡터 DB에 추가 중...") | |
valid_count = searcher.add_chunks(chunks_with_embeddings) | |
print( | |
f"벡터 데이터베이스 구축 완료: {valid_count}/{len(chunks_with_embeddings)} 청크 추가됨" | |
) | |
print( | |
f"FAISS 인덱스 크기: {searcher.index.ntotal}, 저장된 청크 수: {len(searcher.chunks)}" | |
) | |
# 청크 및 임베딩 결과 캐시에 저장 | |
try: | |
print(f"청크 및 임베딩 결과를 캐시 파일에 저장 중: {chunks_cache_file}") | |
with open(chunks_cache_file, "wb") as f: | |
pickle.dump(chunks_with_embeddings, f) | |
print("청크 캐시 저장 완료") | |
except Exception as e: | |
print(f"청크 캐시 저장 중 오류 발생: {e}") | |
return hyde | |
async def run_process(): | |
"""agriculture.jsonl 파일 처리""" | |
# 결과 파일 경로 확인 | |
output_file = os.path.join(OUTPUT_DIR, "hyde_agriculture_results.json") | |
# 이미 처리된 결과 파일이 있는지 확인 | |
if os.path.exists(output_file): | |
with open(output_file, "r", encoding="utf-8") as f: | |
try: | |
existing_results = json.load(f) | |
print( | |
f"agriculture 도메인 결과가 이미 존재합니다: {len(existing_results)}개 결과" | |
) | |
return f"agriculture: {len(existing_results)}개 처리 완료 (기존 파일)" | |
except json.JSONDecodeError: | |
print(f"경고: {output_file} 파일이 손상되었습니다. 다시 생성합니다.") | |
# 질문 로드 - datasets/agriculture.jsonl 파일에서 직접 로드 | |
queries = load_questions_from_file() | |
print(f"{len(queries)}개의 질문을 로드했습니다.") | |
# HyDE 초기화 (병렬 처리) - 문서도 datasets/agriculture.jsonl에서 로드 | |
hyde = await initialize_hyde() | |
# 순차 처리로 HyDE 응답 생성 | |
print("HyDE 응답 순차 생성 중...") | |
hyde_results = await process_queries_sequential(queries, hyde) | |
# 결과 저장 | |
with open(output_file, "w", encoding="utf-8") as f: | |
json.dump(hyde_results, f, ensure_ascii=False, indent=2) | |
print(f"agriculture 결과 저장 완료: {output_file}") | |
return f"agriculture: {len(queries)}개 처리 완료" | |
if __name__ == "__main__": | |
result = asyncio.run(run_process()) | |
print("\n===== 처리 완료 =====") | |
print(result) | |
print(f"\n모든 결과는 {OUTPUT_DIR} 디렉토리에 저장되었습니다.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment