Skip to content

Instantly share code, notes, and snippets.

@acidsound
Created August 28, 2023 18:04
Show Gist options
  • Save acidsound/187d8f917027bcfbcfb2434bce529bce to your computer and use it in GitHub Desktop.
Save acidsound/187d8f917027bcfbcfb2434bce529bce to your computer and use it in GitHub Desktop.
exllama 사용법 +qdrant 색인&검색
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
import numpy as np
from sentence_transformers import SentenceTransformer
from qdrant_client.models import PointStruct
# Initialize the client
client = QdrantClient("localhost:6333")
COLLECTION_NAME = "docs"
client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
)
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2", device="cuda")
sentences = ['This framework generates embeddings for each input sentence',
'Sentences are passed as a list of string.',
'The quick brown fox jumps over the lazy dog.']
sentence_embeddings = model.encode(sentences)
client.upsert(
collection_name="docs",
points=[
PointStruct(
id=idx,
vector=vector.tolist(),
payload={"color": "red", "rand_number": idx % 10}
)
for idx, vector in enumerate(sentence_embeddings)
]
)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"g:\\llm\\exllama1\\venv\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from qdrant_client import QdrantClient\n",
"from qdrant_client.models import Distance, VectorParams\n",
"import numpy as np\n",
"from sentence_transformers import SentenceTransformer\n",
"from qdrant_client.models import PointStruct\n",
"\n",
"# Initialize the client\n",
"client = QdrantClient(\"localhost:6333\")\n",
"COLLECTION_NAME = \"docs\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"\n",
"model = SentenceTransformer(\"paraphrase-multilingual-mpnet-base-v2\", device=\"cuda\")\n",
"\n",
"sentences = ['This framework generates embeddings for each input sentence',\n",
" 'Sentences are passed as a list of string.', \n",
" 'The quick brown fox jumps over the lazy dog.']\n",
"sentence_embeddings = model.encode(sentences)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.00925649, -0.25025985, -0.01117302, ..., 0.13195273,\n",
" -0.16931531, -0.17440248],\n",
" [ 0.15494238, -0.06809117, -0.01351546, ..., 0.0124484 ,\n",
" -0.11615191, -0.12480731],\n",
" [-0.09213714, 0.08014818, -0.00743247, ..., 0.15812056,\n",
" 0.2550481 , 0.03574588]], dtype=float32)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence_embeddings"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client.recreate_collection(\n",
" collection_name=COLLECTION_NAME,\n",
" vectors_config=VectorParams(size=768, distance=Distance.COSINE),\n",
")\n",
"\n",
"client.upsert(\n",
" collection_name=\"docs\",\n",
" points=[\n",
" PointStruct(\n",
" id=idx,\n",
" vector=vector.tolist(), \n",
" payload={\"color\": \"red\", \"rand_number\": idx % 10, \"text\": sentences[idx]}\n",
" )\n",
" for idx, vector in enumerate(sentence_embeddings)\n",
" ]\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"docs = client.search(\n",
" collection_name=COLLECTION_NAME, \n",
" query_vector=model.encode(\"목록\"),\n",
" limit=1\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"references_text = \"\"\n",
"for i, reference in enumerate(docs, start=1):\n",
" text = reference.payload[\"text\"].strip()\n",
" references_text += f\"\\n[{i}]: {text}\""
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\n[1]: Sentences are passed as a list of string.'"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"references_text"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from exllama.lora import ExLlamaLora
from exllama.tokenizer import ExLlamaTokenizer
from exllama.generator import ExLlamaGenerator
import torch
import datetime
def log(message):
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] {message}")
torch.set_grad_enabled(False)
torch.cuda._lazy_init()
config_path = "./models/Scarlett-13B-GPTQ"
config = ExLlamaConfig(f"{config_path}/config.json")
config.model_path = f"{config_path}/model.safetensors"
config.model_path = f"{config_path}/model.safetensors"
model = ExLlama(config)
log(">>> model loading...")
model = ExLlama(config)
log(">>> model caching...")
cache = ExLlamaCache(model)
log(">>> model tokenizing...")
tokenizer = ExLlamaTokenizer(f"{config_path}/tokenizer.model")
log(">>> model generating...")
generator = ExLlamaGenerator(model, tokenizer, cache)
# generator.disallow_tokens([tokenizer.eos_token_id])
# generator.settings.token_repetition_penalty_max = 1.2
# generator.settings.temperature = 0.95
# generator.settings.top_p = 0.65
# generator.settings.top_k = 100
# generator.settings.typical = 0.5
# Produce a simple generation
log(">>> make simple autocomplete..")
prompt = "Once upon a time,"
log (prompt)
output = generator.generate_simple(prompt, max_new_tokens = 200)
log(output[len(prompt):])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment