Skip to content

Instantly share code, notes, and snippets.

@takuma104
Last active May 31, 2023 16:34
Show Gist options
  • Save takuma104/93094f989ee89e4cd61af09f9d909e26 to your computer and use it in GitHub Desktop.
Save takuma104/93094f989ee89e4cd61af09f9d909e26 to your computer and use it in GitHub Desktop.
monkey_patch_minimum_test.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "T4",
"name": "monkey_patch_minimum_test.ipynb",
"authorship_tag": "ABX9TyNjZOFW4I2ayqE1wnNBP8OS",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/takuma104/93094f989ee89e4cd61af09f9d909e26/untitled11.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "dQwQ6YK0BaFf"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn"
]
},
{
"cell_type": "code",
"source": [
"device = 'cuda' # Changing this to 'cpu' still results in the vanilla case failing."
],
"metadata": {
"id": "PesDolhICF0L"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class TargetModule(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.linears = torch.nn.ModuleList([\n",
" torch.nn.Linear(2, 2),\n",
" ])\n",
" def forward(self, x):\n",
" for module in self.linears:\n",
" x = module(x)\n",
" return x\n"
],
"metadata": {
"id": "McbX2HrlBdrp"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def test_monkey_patch_fix_closure():\n",
" def monkey_patch(target):\n",
" for name, module in target.named_modules():\n",
" if isinstance(module, torch.nn.Linear):\n",
" print(f'monkey patching to {name}')\n",
"\n",
" if hasattr(module, 'old_forward'):\n",
" print('undo monkey-patch')\n",
" module.forward = module.old_forward\n",
" delattr(module, 'old_forward')\n",
"\n",
" old_forward = module.old_forward = module.forward\n",
" def make_new_forward(old_forward):\n",
" def new_forward(x):\n",
" return old_forward(x) * 2.0\n",
" return new_forward\n",
" module.forward = make_new_forward(old_forward)\n",
"\n",
" torch.manual_seed(0)\n",
" x = torch.randn((2, 2)).to(device)\n",
" target = TargetModule().to(device)\n",
" with torch.no_grad():\n",
" print('')\n",
" print('*' * 80)\n",
" print('vanilla:')\n",
"\n",
" y = target(x)\n",
" print(y)\n",
" assert y.shape == (2, 2)\n",
"\n",
" monkey_patch(target)\n",
"\n",
" yy = target(x)\n",
" print(yy)\n",
" assert torch.allclose(yy, y*2.0), \"fix closure: monkey patching failed\"\n",
"\n",
" monkey_patch(target)\n",
"\n",
" yyy = target(x)\n",
" print(yyy)\n",
" assert torch.allclose(yyy, y*2.0), \"fix closure: monkey patching failed\"\n",
"\n"
],
"metadata": {
"id": "I3-XZvruQRwD"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test_monkey_patch_fix_closure()"
],
"metadata": {
"id": "y63bOY-CQlQX",
"outputId": "ce428186-1e12-4739-d6f5-6b2b16e0df88",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"********************************************************************************\n",
"vanilla:\n",
"tensor([[-0.8271, -0.7568],\n",
" [-0.4325, -0.0817]], device='cuda:0')\n",
"monkey patching to linears.0\n",
"tensor([[-1.6543, -1.5137],\n",
" [-0.8649, -0.1634]], device='cuda:0')\n",
"monkey patching to linears.0\n",
"undo monkey-patch\n",
"tensor([[-1.6543, -1.5137],\n",
" [-0.8649, -0.1634]], device='cuda:0')\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "aQU2IEcImwta"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment