Skip to content

Instantly share code, notes, and snippets.

@waleedkadous
Created April 15, 2023 23:17
Show Gist options
  • Save waleedkadous/4c41f3ee66040f57d34c6a40e42b5969 to your computer and use it in GitHub Desktop.
Save waleedkadous/4c41f3ee66040f57d34c6a40e42b5969 to your computer and use it in GitHub Desktop.
from langchain.document_loaders import ReadTheDocsLoader
from langchain.embeddings.base import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from typing import List
import time
import os
import ray
import numpy as np
from embeddings import LocalHuggingFaceEmbeddings
# To download the files locally for processing, here's the command line
# wget -e robots=off --recursive --no-clobber --page-requisites --html-extension \
# --convert-links --restrict-file-names=windows \
# --domains docs.ray.io --no-parent https://docs.ray.io/en/master/
FAISS_INDEX_PATH="faiss_index_fast"
db_shards = 8
loader = ReadTheDocsLoader("docs.ray.io/en/master/")
text_splitter = RecursiveCharacterTextSplitter(
# Set a really small chunk size, just to show.
chunk_size = 300,
chunk_overlap = 20,
length_function = len,
)
@ray.remote(num_gpus=1)
def process_shard(shard):
print(f'Starting process_shard of {len(shard)} chunks.')
st = time.time()
embeddings = LocalHuggingFaceEmbeddings('multi-qa-mpnet-base-dot-v1')
result = FAISS.from_documents(shard, embeddings)
et = time.time() - st
print(f'Shard completed in {et} seconds.')
return result
# Stage one: read all the docs, split them into chunks.
st = time.time()
print('Loading documents ...')
docs = loader.load()
#Theoretically, we could use Ray to accelerate this, but it's fast enough as is.
chunks = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[doc.metadata for doc in docs])
et = time.time() - st
print(f'Time taken: {et} seconds. {len(chunks)} chunks generated')
#Stage two: embed the docs.
print(f'Loading chunks into vector store ... using {db_shards} shards')
st = time.time()
shards = np.array_split(chunks, db_shards)
futures = [process_shard.remote(shards[i]) for i in range(db_shards)]
results = ray.get(futures)
et = time.time() - st
print(f'Shard processing complete. Time taken: {et} seconds.')
st = time.time()
print('Merging shards ...')
# Straight serial merge of others into results[0]
db = results[0]
for i in range(1,db_shards):
db.merge_from(results[i])
et = time.time() - st
print(f'Merged in {et} seconds.')
st = time.time()
print('Saving faiss index')
db.save_local(FAISS_INDEX_PATH)
et = time.time() - st
print(f'Saved in: {et} seconds.')
@waleedkadous
Copy link
Author

If you use the non Ray accelerated version does it use the GPU?

In any case, you're not going to see speedups on a GTX 1060. It may be that the GTX 1060 doesn't support the type of acceleration required, or it could be that you don't have CUDA drivers installed?

@unni-pure
Copy link

I also noticed similar issue when running the script on DGX A100 with 40G mem. Without ray embedding completed in 109 secs but with ray 3 jobs took around 150secs each. I have cuda drivers installed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment