Skip to content

Instantly share code, notes, and snippets.

@tori29umai0123
Last active July 9, 2024 12:36
Show Gist options
  • Save tori29umai0123/c8e478925cc6524401dbb6b28f5a32c6 to your computer and use it in GitHub Desktop.
Save tori29umai0123/c8e478925cc6524401dbb6b28f5a32c6 to your computer and use it in GitHub Desktop.
rag_stable_diffusion_prompt.py
import csv
from pathlib import Path
import re
from sentence_transformers import SentenceTransformer
import faiss
from llama_cpp import Llama
import numpy as np
class CustomEmbeddings:
def __init__(self, model):
self.model = model
def embed_documents(self, texts):
return np.array(self.model.encode(texts), dtype=np.float32)
def embed_query(self, text):
return np.array(self.model.encode([text]), dtype=np.float32)
# Danbooruタグの読み込み
def load_danbooru_tags(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return [line.split(",")[0].strip() for line in f]
# ベクトルストアの作成
def create_vector_store(tags):
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
embeddings = CustomEmbeddings(model)
documents = [tag for tag in tags]
vectors = embeddings.embed_documents(documents)
index = faiss.IndexFlatL2(len(vectors[0]))
index.add(vectors)
return index, documents, embeddings
# RAGシステムのセットアップ
def setup_rag_system(index, documents, embeddings, num_of_ref_tags=10):
model_path = "gemma-2-9b-it-Q8_0.gguf"
llm = Llama(model_path=model_path, n_ctx=2048, n_batch=512)
def retrieve(query, k=num_of_ref_tags):
query_vector = embeddings.embed_query(query)
D, I = index.search(query_vector, k)
return [documents[i] for i in I[0]]
def generate_response(query, context):
prompt = f"""
You are a precise tag matcher for Danbooru tags. Your task is to find exact matches in the provided context for the given input element.
Rules:
1. Only return tags that are exact matches to the input element and present in the context.
2. Do not add any tags that are not exact matches to the input.
3. If an exact match is found in the context, return only that match.
4. If no exact match is found, return the input element as is.
5. Never modify, complete, or expand the input element.
6. Provide a maximum of one tag as output.
Context: {', '.join(context)}
Input: {query}
Output:
"""
response = llm(prompt, max_tokens=50, stop=["\n"])
return response["choices"][0]["text"].strip()
return retrieve, generate_response
def generate_rich_description(scene_description):
model_path = "gemma-2-9b-it-Q8_0.gguf"
llm = Llama(model_path=model_path, n_ctx=2048, n_batch=512, tensor_split=[48, 0, 0])
prompt = f"""
Based on the following brief scene description, generate a comma-separated list of danbooru tags that could be used as a Stable Diffusion prompt in English.
Brief description: {scene_description}
Stable Diffusion Prompt Elements:
"""
response = llm(prompt, max_tokens=200, stop=["\n\n"])
return response["choices"][0]["text"].strip()
def convert_description_to_danbooru_tags(prompt_elements, retrieve, generate_response):
elements = [elem.strip() for elem in prompt_elements.split(",")]
danbooru_tags = []
for element in elements:
context = retrieve(element)
tag = generate_response(element, context)
if tag:
danbooru_tags.append(tag)
print(f"{element}: {tag}")
unique_tags = []
seen = set()
for tag in danbooru_tags:
if tag not in seen:
unique_tags.append(tag)
seen.add(tag)
return ", ".join(unique_tags)
if __name__ == "__main__":
vector_store_path = Path("danbooru_tags_vector_store.faiss")
if vector_store_path.exists():
index = faiss.read_index(str(vector_store_path))
tags = load_danbooru_tags("danbooru_tags.csv")
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
embeddings = CustomEmbeddings(model)
else:
tags = load_danbooru_tags("danbooru_tags.csv")
index, documents, embeddings = create_vector_store(tags)
faiss.write_index(index, str(vector_store_path))
retrieve, generate_response = setup_rag_system(index, tags, embeddings, num_of_ref_tags=10)
scene_description = "浴衣の少女が花火を見て笑っている"
prompt_elements = generate_rich_description(scene_description)
print(f"Generated Danbooru Tags:\n{prompt_elements}")
danbooru_tags = convert_description_to_danbooru_tags(prompt_elements, retrieve, generate_response)
print(f"\nMatched Danbooru Tags:\n{danbooru_tags}")
# リソースの解放
del index
del retrieve
del generate_response
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment