Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save karpanGit/77230f7b3e66a6fd0c46793e80324206 to your computer and use it in GitHub Desktop.
Save karpanGit/77230f7b3e66a6fd0c46793e80324206 to your computer and use it in GitHub Desktop.
embeddings through sentencepiece or with PyTorch directly
# for more details see
# https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1
# compute embeddings with sentencepiece
from sentence_transformers import SentenceTransformer, util
docs = ["Around 9 Million people live in London", "This is nice"]
#Load the model
model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
doc_emb = model.encode(docs)
print(doc_emb[:2,:2])
# tensor([[ 0.1678, -0.0616],
# [ 0.1457, -0.0301]])
# compute embeddings without sentencepiece
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
model = AutoModel.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
#Mean Pooling - Take average of all tokens
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def encode(texts):
# Tokenize sentences
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input, return_dict=True)
# Perform pooling
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
doc_emb = encode(docs)
print(doc_emb[:2,:2])
# [[ 0.16776945 -0.06160375]
# [ 0.14571878 -0.03010267]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment