Skip to content

Instantly share code, notes, and snippets.

@dmigo
Last active June 17, 2022 15:24
Show Gist options
  • Save dmigo/12e31a6d12e140a90a21ee715cf5f994 to your computer and use it in GitHub Desktop.
Save dmigo/12e31a6d12e140a90a21ee715cf5f994 to your computer and use it in GitHub Desktop.
Runs haystack indexing on ray
import ray
import s3fs
import yaml
from ray import serve
from haystack.pipelines.base import Pipeline
RAY_ADDRESS=<HEAD_NODE_ADDRESS>
RAY_PORT='10001'
RAY_SERVE_PORT='8000'
RAY_NAMESPACE='default'
DATA_PATH=<S3_PATH>
PIPELINES_PATH=<S3_PATH>
class RayConnection:
def __init__(self):
address = f"ray://{RAY_ADDRESS}:{RAY_PORT}"
namespace = RAY_NAMESPACE
try:
ray.init(address=address, namespace=namespace)
serve.start(detached=True)
except Exception as error:
raise RayUnavailableException(f'Connection to ray failed due to "{error}".')
def __enter__(self):
return self
def __exit__(self, typ, value, traceback):
ray.shutdown()
serve.shutdown()
def read_config(s3prefix: str, s3file: str):
fs = s3fs.S3FileSystem()
with fs.open(f"{s3prefix}/{s3file}") as f:
return yaml.safe_load(f)
@ray.remote
def index_docs(pipeline: str, data: str):
fs = s3fs.S3FileSystem()
config = read_config(PIPELINES_PATH, pipeline)
Pipeline.load_from_config(config, pipeline_name='indexing')
return {'data':data,'config':config}
def main():
with RayConnection():
index_ref = index_docs.remote('sparse.yaml', 'case_corpus_10.jsonl')
result = ray.get(index_ref)
print(result)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment