Skip to content

Instantly share code, notes, and snippets.

@tori29umai0123
Created July 9, 2024 10:15
Show Gist options
  • Save tori29umai0123/f0eb16bc838a8e59042bc044b7192a30 to your computer and use it in GitHub Desktop.
Save tori29umai0123/f0eb16bc838a8e59042bc044b7192a30 to your computer and use it in GitHub Desktop.
tagGEN
import csv
from pathlib import Path
import re
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.vectorstores import FAISS
from langchain_community.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
from sentence_transformers import SentenceTransformer
from langchain_core.embeddings import Embeddings
# カスタムEmbeddingsクラスを作成
class CustomEmbeddings(Embeddings):
def __init__(self, model):
self.model = model
def embed_documents(self, texts):
return self.model.encode(texts).tolist()
def embed_query(self, text):
return self.model.encode([text])[0].tolist()
# Load and process Danbooru tags
def load_danbooru_tags(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return [line.split(",")[0].strip() for line in f]
# Create vector store
def create_vector_store(tags):
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
embeddings = CustomEmbeddings(model)
documents = [Document(page_content=tag) for tag in tags]
vector_store = FAISS.from_documents(documents, embeddings)
return vector_store
# Set up RAG system
def setup_rag_system(vector_store, num_of_ref_tags=10):
model_path = "gemma-2-9b-it-Q6_K.gguf"
llm = LlamaCpp(model_path=model_path, n_ctx=2048, n_batch=512, verbose=False)
retriever = vector_store.as_retriever(search_kwargs={"k": num_of_ref_tags})
template = """
You are a precise tag matcher for Danbooru tags. Your task is to find exact matches in the provided context for the given input element.
Rules:
1. Only return tags that are exact matches to the input element and present in the context.
2. Do not add any tags that are not exact matches to the input.
3. If an exact match is found in the context, return only that match.
4. If no exact match is found, return the input element as is.
5. Never modify, complete, or expand the input element.
6. Provide a maximum of one tag as output.
Context: {context}
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", template),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, prompt)
chain = create_retrieval_chain(retriever, question_answer_chain)
return chain
def generate_rich_description(scene_description):
model_path = "gemma-2-9b-it-Q6_K.gguf"
llm = LlamaCpp(model_path=model_path, n_ctx=2048, n_batch=512, verbose=False)
template = """
Based on the following brief scene description, generate a comma-separated list of danbooru tags that could be used as a Stable Diffusion prompt in English.
Brief description: {scene_description}
## Stable Diffusion Prompt Elements:
"""
prompt = PromptTemplate(
input_variables=["scene_description"],
template=template,
)
prompt = llm(prompt.format(scene_description=scene_description))
prompt = prompt.strip()
return prompt
def convert_description_to_danbooru_tags(prompt_elements, rag_chain):
elements = [elem.strip() for elem in prompt_elements.split(",")]
danbooru_tags = []
for element in elements:
result = rag_chain.invoke({"input": element})
tag = result["answer"].strip()
if tag: # 空でない場合のみ追加
danbooru_tags.append(tag)
print(f"{element}: {tag}") # デバッグ出力
# 重複を除去しつつ順序を保持
unique_tags = []
seen = set()
for tag in danbooru_tags:
if tag not in seen:
unique_tags.append(tag)
seen.add(tag)
return ", ".join(unique_tags)
if __name__ == "__main__":
vector_store_path = Path("danbooru_tags_vector_store.faiss")
if vector_store_path.exists():
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
embeddings = CustomEmbeddings(model)
vector_store = FAISS.load_local(str(vector_store_path), embeddings)
else:
tags = load_danbooru_tags("danbooru_tags.csv")
vector_store = create_vector_store(tags)
vector_store.save_local(str(vector_store_path))
rag_chain = setup_rag_system(vector_store, num_of_ref_tags=10)
scene_description = "浴衣の少女が花火を見て笑っている"
prompt_elements = generate_rich_description(scene_description)
print(f"Generated Danbooru Tags:\n{prompt_elements}")
danbooru_tags = convert_description_to_danbooru_tags(prompt_elements, rag_chain)
print(f"\nMatched Danbooru Tags:\n{danbooru_tags}")
# リソースの解放
del vector_store
del rag_chain
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment