Skip to content

Instantly share code, notes, and snippets.

@RGGH
Created July 20, 2024 22:17
Show Gist options
  • Save RGGH/4edeea4d49c369455268705db14dc0bb to your computer and use it in GitHub Desktop.
Save RGGH/4edeea4d49c369455268705db14dc0bb to your computer and use it in GitHub Desktop.
flask + faiss python
from flask import Flask, request, jsonify, render_template_string
from transformers import AutoTokenizer, AutoModel
import numpy as np
import faiss
import torch
app = Flask(__name__)
# Initialize the model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
def get_embedding(text: str) -> np.ndarray:
"""Generate embedding for a given text using a pre-trained model."""
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.numpy()
# Sample texts for demonstration
sample_texts = [
"The quick brown fox jumps over the lazy dog.",
"Artificial intelligence is transforming industries.",
"FAISS is a library for efficient similarity search.",
"Natural language processing enables machines to understand text.",
"Transformers models are state-of-the-art for NLP tasks."
]
embeddings = np.vstack([get_embedding(text) for text in sample_texts])
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
@app.route('/', methods=['GET', 'POST'])
def home():
if request.method == 'POST':
query_text = request.form.get('text')
if not query_text:
return render_template_string(HTML_TEMPLATE, results=[], error="No text provided")
query_embedding = get_embedding(query_text)
k = 3 # Number of nearest neighbors
distances, indices = index.search(query_embedding, k)
results = [{"text": sample_texts[idx], "distance": float(dist)} for idx, dist in zip(indices[0], distances[0])]
return render_template_string(HTML_TEMPLATE, results=results, error=None)
return render_template_string(HTML_TEMPLATE, results=[], error=None)
HTML_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Text Similarity Search</title>
<!-- Bootstrap CSS CDN -->
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" rel="stylesheet">
</head>
<body>
<div class="container mt-4">
<h1 class="mb-4">Text Similarity Search</h1>
<form method="post" class="mb-4">
<div class="form-group">
<label for="text">Enter text:</label>
<input type="text" id="text" name="text" class="form-control" required>
</div>
<button type="submit" class="btn btn-primary">Search</button>
</form>
{% if error %}
<div class="alert alert-danger" role="alert">
{{ error }}
</div>
{% endif %}
{% if results %}
<h2>Results:</h2>
<ul class="list-group">
{% for result in results %}
<li class="list-group-item"><strong>{{ result.text }}</strong> - Distance: {{ result.distance }}</li>
{% endfor %}
</ul>
{% endif %}
</div>
<!-- Bootstrap JS and dependencies -->
<script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.10.2/dist/umd/popper.min.js" integrity="sha384-7+zCNj/IqJ95wo16oMtfsKbZ9ccEh31eOz1HGyDuCQ6wgnyJNSYdrPa03rtR1zdB" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.min.js" integrity="sha384-QJHtvGhmr9XOIpI6YVutG+2QOK9T+ZnN4kzFN1RtK3zEFEIsxhlmWl5/YESvpZ13" crossorigin="anonymous"></script>
</body>
</html>
"""
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8000)
@RGGH
Copy link
Author

RGGH commented Jul 21, 2024

ah

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment