Created
November 20, 2023 06:44
-
-
Save karpanGit/77230f7b3e66a6fd0c46793e80324206 to your computer and use it in GitHub Desktop.
embeddings through sentencepiece or with PyTorch directly
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# for more details see | |
# https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1 | |
# compute embeddings with sentencepiece | |
from sentence_transformers import SentenceTransformer, util | |
docs = ["Around 9 Million people live in London", "This is nice"] | |
#Load the model | |
model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1') | |
doc_emb = model.encode(docs) | |
print(doc_emb[:2,:2]) | |
# tensor([[ 0.1678, -0.0616], | |
# [ 0.1457, -0.0301]]) | |
# compute embeddings without sentencepiece | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import torch.nn.functional as F | |
# Load model from HuggingFace Hub | |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1") | |
model = AutoModel.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1") | |
#Mean Pooling - Take average of all tokens | |
def mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output.last_hidden_state | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
def encode(texts): | |
# Tokenize sentences | |
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') | |
# Compute token embeddings | |
with torch.no_grad(): | |
model_output = model(**encoded_input, return_dict=True) | |
# Perform pooling | |
embeddings = mean_pooling(model_output, encoded_input['attention_mask']) | |
# Normalize embeddings | |
embeddings = F.normalize(embeddings, p=2, dim=1) | |
return embeddings | |
doc_emb = encode(docs) | |
print(doc_emb[:2,:2]) | |
# [[ 0.16776945 -0.06160375] | |
# [ 0.14571878 -0.03010267]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment