Skip to content

Instantly share code, notes, and snippets.

@ubless607
Last active May 2, 2024 13:01
Show Gist options
  • Save ubless607/ba41f83647d0e0bf4a6c2d996f345dde to your computer and use it in GitHub Desktop.
Save ubless607/ba41f83647d0e0bf4a6c2d996f345dde to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 264,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from loguru import logger\n",
"\n",
"class Add(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" def forward(self, x1):\n",
" output = torch.where(x1==1., torch.nan, x1)\n",
" output = nn.ReLU()(output)\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 265,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1., 2., 3., 4., 5.],\n",
" [ 6., 7., 8., 9., 10.]])"
]
},
"execution_count": 265,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input = torch.arange(1, 11).reshape(2, 5).to(torch.float32)\n",
"input"
]
},
{
"cell_type": "code",
"execution_count": 266,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1., 2., 3., 4., 5.],\n",
" [ 6., 7., 8., inf, 10.]])"
]
},
"execution_count": 266,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input[1, 3] = float('inf')\n",
"input"
]
},
{
"cell_type": "code",
"execution_count": 267,
"metadata": {},
"outputs": [],
"source": [
"model = Add()"
]
},
{
"cell_type": "code",
"execution_count": 268,
"metadata": {},
"outputs": [],
"source": [
"def nan_hook(module, inp, output):\n",
" if not isinstance(output, tuple):\n",
" outputs = [output]\n",
" else:\n",
" outputs = output\n",
"\n",
" for i, out in enumerate(outputs):\n",
" nan_mask = torch.isnan(out)\n",
" if nan_mask.any():\n",
" logger.debug(f\"Found NAN in output {i} at indices: {nan_mask.nonzero()} where: {out[nan_mask.nonzero()[:, 0].unique(sorted=True)]}\")\n",
"\n",
"def inf_hook(module, inp, output):\n",
" if not isinstance(output, tuple):\n",
" outputs = [output]\n",
" else:\n",
" outputs = output\n",
"\n",
" for i, out in enumerate(outputs):\n",
" inf_mask = torch.isinf(out)\n",
" if inf_mask.any():\n",
" logger.debug(f\"Found INF in output {i} at indices: {inf_mask.nonzero()} where: {out[inf_mask.nonzero()[:, 0].unique(sorted=True)]}\")\n",
"\n",
"for submodule in model.modules():\n",
" submodule.register_forward_hook(inf_hook)\n",
" submodule.register_forward_hook(nan_hook)"
]
},
{
"cell_type": "code",
"execution_count": 269,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-11-08 10:31:54.312\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36minf_hook\u001b[0m:\u001b[36m21\u001b[0m - \u001b[34m\u001b[1mFound INF in output 0 at indices: tensor([[1, 3]]) where: tensor([[ 6., 7., 8., inf, 10.]])\u001b[0m\n",
"\u001b[32m2023-11-08 10:31:54.314\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mnan_hook\u001b[0m:\u001b[36m10\u001b[0m - \u001b[34m\u001b[1mFound NAN in output 0 at indices: tensor([[0, 0]]) where: tensor([[nan, 2., 3., 4., 5.]])\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"tensor([[nan, 2., 3., 4., 5.],\n",
" [ 6., 7., 8., inf, 10.]])"
]
},
"execution_count": 269,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model(input)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment