Created
July 20, 2024 22:17
-
-
Save RGGH/4edeea4d49c369455268705db14dc0bb to your computer and use it in GitHub Desktop.
flask + faiss python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, request, jsonify, render_template_string | |
from transformers import AutoTokenizer, AutoModel | |
import numpy as np | |
import faiss | |
import torch | |
app = Flask(__name__) | |
# Initialize the model and tokenizer | |
model_name = 'bert-base-uncased' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
def get_embedding(text: str) -> np.ndarray: | |
"""Generate embedding for a given text using a pre-trained model.""" | |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state.mean(dim=1) | |
return embeddings.numpy() | |
# Sample texts for demonstration | |
sample_texts = [ | |
"The quick brown fox jumps over the lazy dog.", | |
"Artificial intelligence is transforming industries.", | |
"FAISS is a library for efficient similarity search.", | |
"Natural language processing enables machines to understand text.", | |
"Transformers models are state-of-the-art for NLP tasks." | |
] | |
embeddings = np.vstack([get_embedding(text) for text in sample_texts]) | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings) | |
@app.route('/', methods=['GET', 'POST']) | |
def home(): | |
if request.method == 'POST': | |
query_text = request.form.get('text') | |
if not query_text: | |
return render_template_string(HTML_TEMPLATE, results=[], error="No text provided") | |
query_embedding = get_embedding(query_text) | |
k = 3 # Number of nearest neighbors | |
distances, indices = index.search(query_embedding, k) | |
results = [{"text": sample_texts[idx], "distance": float(dist)} for idx, dist in zip(indices[0], distances[0])] | |
return render_template_string(HTML_TEMPLATE, results=results, error=None) | |
return render_template_string(HTML_TEMPLATE, results=[], error=None) | |
HTML_TEMPLATE = """ | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Text Similarity Search</title> | |
<!-- Bootstrap CSS CDN --> | |
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" rel="stylesheet"> | |
</head> | |
<body> | |
<div class="container mt-4"> | |
<h1 class="mb-4">Text Similarity Search</h1> | |
<form method="post" class="mb-4"> | |
<div class="form-group"> | |
<label for="text">Enter text:</label> | |
<input type="text" id="text" name="text" class="form-control" required> | |
</div> | |
<button type="submit" class="btn btn-primary">Search</button> | |
</form> | |
{% if error %} | |
<div class="alert alert-danger" role="alert"> | |
{{ error }} | |
</div> | |
{% endif %} | |
{% if results %} | |
<h2>Results:</h2> | |
<ul class="list-group"> | |
{% for result in results %} | |
<li class="list-group-item"><strong>{{ result.text }}</strong> - Distance: {{ result.distance }}</li> | |
{% endfor %} | |
</ul> | |
{% endif %} | |
</div> | |
<!-- Bootstrap JS and dependencies --> | |
<script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.10.2/dist/umd/popper.min.js" integrity="sha384-7+zCNj/IqJ95wo16oMtfsKbZ9ccEh31eOz1HGyDuCQ6wgnyJNSYdrPa03rtR1zdB" crossorigin="anonymous"></script> | |
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.min.js" integrity="sha384-QJHtvGhmr9XOIpI6YVutG+2QOK9T+ZnN4kzFN1RtK3zEFEIsxhlmWl5/YESvpZ13" crossorigin="anonymous"></script> | |
</body> | |
</html> | |
""" | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=8000) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
ah