Skip to content

Instantly share code, notes, and snippets.

@ElisonSherton
Created March 10, 2024 11:37
Show Gist options
  • Save ElisonSherton/b4942a7e5d9c705ce975e5f30e1c95f3 to your computer and use it in GitHub Desktop.
Save ElisonSherton/b4942a7e5d9c705ce975e5f30e1c95f3 to your computer and use it in GitHub Desktop.
Demonstrating issue of cosine similarity computation in pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "affca021-dd73-45bd-8e76-319458e0556e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'3.10.13'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import platform\n",
"platform.python_version()"
]
},
{
"cell_type": "markdown",
"id": "6c8d9707-25f7-43a3-8c6d-be245006afd3",
"metadata": {},
"source": [
"# Similarity Computation Issues with torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3bbf005f-8452-45f5-8592-6ee7429fc8cb",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import math"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "376105c3-f034-4932-82c6-d309a89f365d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.13.1'"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9880657d-7a66-4261-ab05-5d6cf7e3ec96",
"metadata": {},
"outputs": [],
"source": [
"e1 = [-3.6967427730560303, -0.08157289028167725, -0.9293916821479797]\n",
"e2 = [-1.2304151058197021, -0.02738991193473339, -0.3091711401939392]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0cd4d7e4-8fb5-455d-9707-b2b26e38e1c5",
"metadata": {},
"outputs": [],
"source": [
"e1_torch = torch.tensor(e1)\n",
"e2_torch = torch.tensor(e2)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "91cdde06-19be-4733-be26-425a96da2551",
"metadata": {},
"outputs": [],
"source": [
"def compute_plain_similarity(e1, e2):\n",
" nr, m1, m2 = 0.0, 0.0, 0.0\n",
" for i, j in zip(e1, e2):\n",
" nr += i * j\n",
" m1 += i ** 2\n",
" m2 += j ** 2\n",
" return nr / math.sqrt(m1 * m2)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f4688ab2-3bdd-47e1-989d-242b8b669879",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9999999740977576"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Plain Similarity -> fp64\n",
"compute_plain_similarity(e1, e2)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cff84a9c-7500-40b1-95af-4bff282157ce",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0000001192092896"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Pytorch Similarity -> fp32\n",
"nr = torch.matmul(e1_torch, e2_torch)\n",
"dr = e1_torch.norm() * e2_torch.norm()\n",
"sim = nr / dr\n",
"sim.item()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1da36dda-db40-4d02-912c-daea22caad5c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9999999740977578"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Pytorch Similarity -> fp64\n",
"e1_torch = torch.tensor(e1, dtype = torch.float64)\n",
"e2_torch = torch.tensor(e2, dtype = torch.float64)\n",
"\n",
"nr = torch.matmul(e1_torch, e2_torch)\n",
"dr = e1_torch.norm() * e2_torch.norm()\n",
"sim = nr / dr\n",
"sim.item()"
]
},
{
"cell_type": "markdown",
"id": "7bdcf27e-48e3-4a79-bf06-e88b827dd743",
"metadata": {},
"source": [
"# Problem with clamp"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "79aad5e2-d7c8-4d06-9b46-cf7d7e29f35f",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b73c0cd4-c201-46b0-b4eb-7bad88b9837a",
"metadata": {},
"outputs": [],
"source": [
"def simulate_scenario(cossim = 1, clip_grad = True):\n",
" MARGIN = 0.5\n",
" # Also depends on torch.float32 or torch.float64\n",
" cossim = nn.Parameter(torch.tensor(cossim, dtype = torch.float32), requires_grad=True)\n",
"\n",
" clamped_cossim = torch.clamp(cossim, max = 1, min = -1)\n",
" clamped_cossim.retain_grad()\n",
"\n",
" new_angle = clamped_cossim.acos() + MARGIN\n",
" new_angle.retain_grad()\n",
" \n",
" cosine_of_new_angle = new_angle.cos()\n",
" cosine_of_new_angle.retain_grad()\n",
" \n",
" cosine_of_new_angle.backward()\n",
" print(f\"{cosine_of_new_angle.grad=}\\n{new_angle.grad=}\\n{clamped_cossim.grad=}\\n{cossim.grad=}\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "78e3eb24-accb-4ae8-a80c-49a8b2427e2b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cosine_of_new_angle.grad=tensor(1.)\n",
"new_angle.grad=tensor(-0.4794)\n",
"clamped_cossim.grad=tensor(inf)\n",
"cossim.grad=tensor(0.)\n"
]
}
],
"source": [
"simulate_scenario(1.0004, clip_grad=False)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f944bc05-2f90-4708-88e8-acd9ed02c14c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cosine_of_new_angle.grad=tensor(1.)\n",
"new_angle.grad=tensor(0.4794)\n",
"clamped_cossim.grad=tensor(-inf)\n",
"cossim.grad=tensor(0.)\n"
]
}
],
"source": [
"simulate_scenario(-1.0004)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "55e369c1-26cd-4235-ac0f-8f644b59ee7c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cosine_of_new_angle.grad=tensor(1.)\n",
"new_angle.grad=tensor(-0.4794)\n",
"clamped_cossim.grad=tensor(inf)\n",
"cossim.grad=tensor(inf)\n"
]
}
],
"source": [
"simulate_scenario(1-1e-8)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5645f7d1-feba-4a04-942d-9e4330e5a6d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cosine_of_new_angle.grad=tensor(1.)\n",
"new_angle.grad=tensor(0.4794)\n",
"clamped_cossim.grad=tensor(-inf)\n",
"cossim.grad=tensor(0.)\n"
]
}
],
"source": [
"simulate_scenario(-1-1e-7)"
]
},
{
"cell_type": "markdown",
"id": "215b183d-9f43-4013-bd93-c501a3f4aca6",
"metadata": {},
"source": [
"# Root of Issue"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1a4146c0-9bba-4826-b963-1dba94447486",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.tensor(1 - 1e-8) == torch.tensor(1)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ce41ae1e-875f-40ec-94fe-88adb2eb1bb9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.tensor(1 + 1e-8) == torch.tensor(1)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f34df21a-44ad-431e-9b02-3a3b971610b6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.tensor(-1 - 1e-8) == torch.tensor(-1)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f194bc17-ad31-4ccc-89b5-736e844e10a2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.tensor(-1 + 1e-8) == torch.tensor(-1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "57699947-3362-46cb-b30c-bd041170e17e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(False)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.tensor(1 + 1e-7) == torch.tensor(1)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "94f8a141-b3f9-4627-ad00-922cb4cd5008",
"metadata": {},
"outputs": [],
"source": [
"# https://github.com/pytorch/pytorch/blob/53fe804322640653d2dddaed394838b868ce9a26/torch/autograd/_functions/pointwise.py#L95\n",
"def get_mask(item, min_ = -1, max_ = 1):\n",
" return item.ge(min_) * item.le(max_)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "204f42ca-af02-41f9-aaed-130a3252c600",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor(True), tensor(True), tensor(True), tensor(True)]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eps = 1e-8\n",
"[get_mask(torch.tensor(x)) for x in [1 - eps, 1 + eps, -1 - eps, -1 + eps]]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "e8f6eda5-dd24-42d6-9533-a9b3bc138213",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor(True), tensor(False), tensor(False), tensor(True)]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eps = 1e-7\n",
"[get_mask(torch.tensor(x)) for x in [1 - eps, 1 + eps, -1 - eps, -1 + eps]]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment