Skip to content

Instantly share code, notes, and snippets.

@joshlk
Created June 21, 2024 09:59
Show Gist options
  • Save joshlk/20f1f51900e7f299feb618aa2fc55921 to your computer and use it in GitHub Desktop.
Save joshlk/20f1f51900e7f299feb618aa2fc55921 to your computer and use it in GitHub Desktop.
Print PyTorch backwards ops
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "9ba437fc",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b5cb8743",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Tracing back tensors:\n",
"<MulBackward0 object at 0x11c9ffaf0>\n",
"<SumBackward0 object at 0x12e9c3a60>\n",
"<MulBackward0 object at 0x12e9c3400>\n",
"<AccumulateGrad object at 0x12e9c3070>\n",
"Tensor with grad found: tensor([0.3643, 0.6264, 0.1329, 0.5581, 0.3163], requires_grad=True)\n",
" - gradient: tensor([3., 3., 3., 3., 3.])\n",
"\n",
"<AccumulateGrad object at 0x12e9c3ee0>\n",
"Tensor with grad found: tensor([1., 1., 1., 1., 1.], requires_grad=True)\n",
" - gradient: tensor([1.0928, 1.8791, 0.3986, 1.6743, 0.9489])\n",
"\n"
]
}
],
"source": [
"input1 = torch.randn(100, 128, requires_grad=True)\n",
"input2 = torch.randn(100, 128, requires_grad=True)\n",
"cos = nn.CosineSimilarity(dim=1, eps=1e-6)\n",
"output = cos(input1, input2)\n",
"\n",
"print()\n",
"print('Tracing back tensors:')\n",
"def getBack(var_grad_fn):\n",
" print(var_grad_fn)\n",
" for n in var_grad_fn.next_functions:\n",
" if n[0]:\n",
" try:\n",
" tensor = getattr(n[0], 'variable')\n",
" print(n[0])\n",
" print('Tensor with grad found:', tensor)\n",
" print(' - gradient:', tensor.grad)\n",
" print()\n",
" except AttributeError as e:\n",
" getBack(n[0])\n",
"\n",
"output.sum().backward()\n",
"getBack(loss.grad_fn)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3e320c04",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.8534, 0.8103, -0.8892, ..., -0.2915, -0.8183, -0.6481],\n",
" [-0.9195, -0.0834, 1.6122, ..., 0.0945, -0.0291, -0.7190],\n",
" [ 0.3594, 1.0440, -0.5852, ..., -0.2921, -0.4885, 0.1041],\n",
" ...,\n",
" [-0.3927, 0.2467, 0.3223, ..., 0.1250, -0.3101, 0.2410],\n",
" [ 1.3586, -1.4949, 0.3142, ..., -0.1608, 0.8276, 0.8251],\n",
" [-0.2738, 0.9730, -0.6034, ..., -0.8690, 0.0268, -0.8985]],\n",
" grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad4fd55f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment