Created
March 10, 2024 11:37
-
-
Save ElisonSherton/b4942a7e5d9c705ce975e5f30e1c95f3 to your computer and use it in GitHub Desktop.
Demonstrating issue of cosine similarity computation in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "affca021-dd73-45bd-8e76-319458e0556e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'3.10.13'" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import platform\n", | |
"platform.python_version()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6c8d9707-25f7-43a3-8c6d-be245006afd3", | |
"metadata": {}, | |
"source": [ | |
"# Similarity Computation Issues with torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "3bbf005f-8452-45f5-8592-6ee7429fc8cb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import math" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "376105c3-f034-4932-82c6-d309a89f365d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'1.13.1'" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "9880657d-7a66-4261-ab05-5d6cf7e3ec96", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"e1 = [-3.6967427730560303, -0.08157289028167725, -0.9293916821479797]\n", | |
"e2 = [-1.2304151058197021, -0.02738991193473339, -0.3091711401939392]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "0cd4d7e4-8fb5-455d-9707-b2b26e38e1c5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"e1_torch = torch.tensor(e1)\n", | |
"e2_torch = torch.tensor(e2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "91cdde06-19be-4733-be26-425a96da2551", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def compute_plain_similarity(e1, e2):\n", | |
" nr, m1, m2 = 0.0, 0.0, 0.0\n", | |
" for i, j in zip(e1, e2):\n", | |
" nr += i * j\n", | |
" m1 += i ** 2\n", | |
" m2 += j ** 2\n", | |
" return nr / math.sqrt(m1 * m2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "f4688ab2-3bdd-47e1-989d-242b8b669879", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.9999999740977576" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Plain Similarity -> fp64\n", | |
"compute_plain_similarity(e1, e2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "cff84a9c-7500-40b1-95af-4bff282157ce", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.0000001192092896" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Pytorch Similarity -> fp32\n", | |
"nr = torch.matmul(e1_torch, e2_torch)\n", | |
"dr = e1_torch.norm() * e2_torch.norm()\n", | |
"sim = nr / dr\n", | |
"sim.item()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "1da36dda-db40-4d02-912c-daea22caad5c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.9999999740977578" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Pytorch Similarity -> fp64\n", | |
"e1_torch = torch.tensor(e1, dtype = torch.float64)\n", | |
"e2_torch = torch.tensor(e2, dtype = torch.float64)\n", | |
"\n", | |
"nr = torch.matmul(e1_torch, e2_torch)\n", | |
"dr = e1_torch.norm() * e2_torch.norm()\n", | |
"sim = nr / dr\n", | |
"sim.item()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7bdcf27e-48e3-4a79-bf06-e88b827dd743", | |
"metadata": {}, | |
"source": [ | |
"# Problem with clamp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "79aad5e2-d7c8-4d06-9b46-cf7d7e29f35f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.nn as nn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "b73c0cd4-c201-46b0-b4eb-7bad88b9837a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def simulate_scenario(cossim = 1, clip_grad = True):\n", | |
" MARGIN = 0.5\n", | |
" # Also depends on torch.float32 or torch.float64\n", | |
" cossim = nn.Parameter(torch.tensor(cossim, dtype = torch.float32), requires_grad=True)\n", | |
"\n", | |
" clamped_cossim = torch.clamp(cossim, max = 1, min = -1)\n", | |
" clamped_cossim.retain_grad()\n", | |
"\n", | |
" new_angle = clamped_cossim.acos() + MARGIN\n", | |
" new_angle.retain_grad()\n", | |
" \n", | |
" cosine_of_new_angle = new_angle.cos()\n", | |
" cosine_of_new_angle.retain_grad()\n", | |
" \n", | |
" cosine_of_new_angle.backward()\n", | |
" print(f\"{cosine_of_new_angle.grad=}\\n{new_angle.grad=}\\n{clamped_cossim.grad=}\\n{cossim.grad=}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "78e3eb24-accb-4ae8-a80c-49a8b2427e2b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"cosine_of_new_angle.grad=tensor(1.)\n", | |
"new_angle.grad=tensor(-0.4794)\n", | |
"clamped_cossim.grad=tensor(inf)\n", | |
"cossim.grad=tensor(0.)\n" | |
] | |
} | |
], | |
"source": [ | |
"simulate_scenario(1.0004, clip_grad=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "f944bc05-2f90-4708-88e8-acd9ed02c14c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"cosine_of_new_angle.grad=tensor(1.)\n", | |
"new_angle.grad=tensor(0.4794)\n", | |
"clamped_cossim.grad=tensor(-inf)\n", | |
"cossim.grad=tensor(0.)\n" | |
] | |
} | |
], | |
"source": [ | |
"simulate_scenario(-1.0004)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "55e369c1-26cd-4235-ac0f-8f644b59ee7c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"cosine_of_new_angle.grad=tensor(1.)\n", | |
"new_angle.grad=tensor(-0.4794)\n", | |
"clamped_cossim.grad=tensor(inf)\n", | |
"cossim.grad=tensor(inf)\n" | |
] | |
} | |
], | |
"source": [ | |
"simulate_scenario(1-1e-8)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "5645f7d1-feba-4a04-942d-9e4330e5a6d9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"cosine_of_new_angle.grad=tensor(1.)\n", | |
"new_angle.grad=tensor(0.4794)\n", | |
"clamped_cossim.grad=tensor(-inf)\n", | |
"cossim.grad=tensor(0.)\n" | |
] | |
} | |
], | |
"source": [ | |
"simulate_scenario(-1-1e-7)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "215b183d-9f43-4013-bd93-c501a3f4aca6", | |
"metadata": {}, | |
"source": [ | |
"# Root of Issue" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "1a4146c0-9bba-4826-b963-1dba94447486", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.tensor(1 - 1e-8) == torch.tensor(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "ce41ae1e-875f-40ec-94fe-88adb2eb1bb9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.tensor(1 + 1e-8) == torch.tensor(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "f34df21a-44ad-431e-9b02-3a3b971610b6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.tensor(-1 - 1e-8) == torch.tensor(-1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "f194bc17-ad31-4ccc-89b5-736e844e10a2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.tensor(-1 + 1e-8) == torch.tensor(-1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "57699947-3362-46cb-b30c-bd041170e17e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(False)" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.tensor(1 + 1e-7) == torch.tensor(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "94f8a141-b3f9-4627-ad00-922cb4cd5008", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# https://github.com/pytorch/pytorch/blob/53fe804322640653d2dddaed394838b868ce9a26/torch/autograd/_functions/pointwise.py#L95\n", | |
"def get_mask(item, min_ = -1, max_ = 1):\n", | |
" return item.ge(min_) * item.le(max_)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "204f42ca-af02-41f9-aaed-130a3252c600", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[tensor(True), tensor(True), tensor(True), tensor(True)]" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"eps = 1e-8\n", | |
"[get_mask(torch.tensor(x)) for x in [1 - eps, 1 + eps, -1 - eps, -1 + eps]]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "e8f6eda5-dd24-42d6-9533-a9b3bc138213", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[tensor(True), tensor(False), tensor(False), tensor(True)]" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"eps = 1e-7\n", | |
"[get_mask(torch.tensor(x)) for x in [1 - eps, 1 + eps, -1 - eps, -1 + eps]]" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment