Skip to content

Instantly share code, notes, and snippets.

@itsuncheng
Last active October 20, 2022 04:56
Show Gist options
  • Save itsuncheng/2f3e889b907e77b8dd5c6d454b7d04d1 to your computer and use it in GitHub Desktop.
Save itsuncheng/2f3e889b907e77b8dd5c6d454b7d04d1 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import Library"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"from sentence_transformers import SentenceTransformer, util\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Selection and Initialization"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1.31G/1.31G [01:36<00:00, 13.6MB/s] \n"
]
}
],
"source": [
"# List of models optimized for semantic textual similarity can be found at:\n",
"# https://docs.google.com/spreadsheets/d/14QplCdTCDwEmTqrn1LH4yrbKvdogK4oQvYO1K1aPR5M/edit#gid=0\n",
"model = SentenceTransformer('stsb-roberta-large')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Calculate semantic similarity between two sentences"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sentence 1: I like Python because I can build AI applications\n",
"Sentence 2: I like Python because I can do data analytics\n",
"Similarity score: 0.8188022375106812\n"
]
}
],
"source": [
"sentence1 = \"I like Python because I can build AI applications\"\n",
"sentence2 = \"I like Python because I can do data analytics\"\n",
"\n",
"# encode sentences to get their embeddings\n",
"embedding1 = model.encode(sentence1, convert_to_tensor=True)\n",
"embedding2 = model.encode(sentence2, convert_to_tensor=True)\n",
"\n",
"# compute similarity scores of two embeddings\n",
"cosine_scores = util.pytorch_cos_sim(embedding1, embedding2)\n",
"\n",
"print(\"Sentence 1:\", sentence1)\n",
"print(\"Sentence 2:\", sentence2)\n",
"print(\"Similarity score:\", cosine_scores.item())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Calculate semantic similarity between two lists of sentences"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sentence 1: I like Python because I can build AI applications\n",
"Sentence 2: I like Python because I can do data analytics\n",
"Similarity Score: 0.8188023567199707\n",
"\n",
"Sentence 1: I like Python because I can build AI applications\n",
"Sentence 2: The cat walks on the sidewalk\n",
"Similarity Score: -0.06005367636680603\n",
"\n",
"Sentence 1: The cat sits on the ground\n",
"Sentence 2: I like Python because I can do data analytics\n",
"Similarity Score: 0.12721936404705048\n",
"\n",
"Sentence 1: The cat sits on the ground\n",
"Sentence 2: The cat walks on the sidewalk\n",
"Similarity Score: 0.4131842255592346\n",
"\n"
]
}
],
"source": [
"sentences1 = [\"I like Python because I can build AI applications\", \"The cat sits on the ground\"] \n",
"sentences2 = [\"I like Python because I can do data analytics\", \"The cat walks on the sidewalk\"]\n",
"\n",
"# encode list of sentences to get their embeddings\n",
"embedding1 = model.encode(sentences1, convert_to_tensor=True)\n",
"embedding2 = model.encode(sentences2, convert_to_tensor=True)\n",
"\n",
"# compute similarity scores of two embeddings\n",
"cosine_scores = util.pytorch_cos_sim(embedding1, embedding2)\n",
"\n",
"for i in range(len(sentences1)):\n",
" for j in range(len(sentences2)):\n",
" print(\"Sentence 1:\", sentences1[i])\n",
" print(\"Sentence 2:\", sentences2[j])\n",
" print(\"Similarity Score:\", cosine_scores[i][j].item())\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Retrieve Top K most similar sentences from a corpus given a sentence"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"corpus = [\"I like Python because I can build AI applications\",\n",
" \"I like Python because I can do data analytics\",\n",
" \"The cat sits on the ground\",\n",
" \"The cat walks on the sidewalk\"]\n",
"\n",
"# encode corpus to get corpus embeddings\n",
"corpus_embeddings = model.encode(corpus, convert_to_tensor=True)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"I like Javascript because I can build web applications\"\n",
"\n",
"# encode sentence to get sentence embeddings\n",
"sentence_embedding = model.encode(sentence, convert_to_tensor=True)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sentence: I like Javascript because I can build web applications \n",
"\n",
"Top 2 most similar sentences in corpus:\n",
"I like Python because I can build AI applications (Score: 0.6253)\n",
"I like Python because I can do data analytics (Score: 0.5348)\n"
]
}
],
"source": [
"# top_k results to return\n",
"top_k=2\n",
"\n",
"# compute similarity scores of the sentence with the corpus\n",
"cos_scores = util.pytorch_cos_sim(sentence_embedding, corpus_embeddings)[0]\n",
"\n",
"# Sort the results in decreasing order and get the first top_k\n",
"top_results = np.argpartition(-cos_scores, range(top_k))[0:top_k]\n",
"\n",
"print(\"Sentence:\", sentence, \"\\n\")\n",
"print(\"Top\", top_k, \"most similar sentences in corpus:\")\n",
"for idx in top_results[0:top_k]:\n",
" print(corpus[idx], \"(Score: %.4f)\" % (cos_scores[idx]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment