Skip to content

Instantly share code, notes, and snippets.

@takuma104
Last active May 19, 2023 18:26
Show Gist options
  • Save takuma104/894dff4e48a7e1dbebedcff136da5956 to your computer and use it in GitHub Desktop.
Save takuma104/894dff4e48a7e1dbebedcff136da5956 to your computer and use it in GitHub Desktop.
monkey_patch_minimum_test.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "T4",
"name": "monkey_patch_minimum_test.ipynb",
"authorship_tag": "ABX9TyMFttedlOM6CLYig1i5WIyI",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/takuma104/894dff4e48a7e1dbebedcff136da5956/untitled11.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "dQwQ6YK0BaFf"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn"
]
},
{
"cell_type": "code",
"source": [
"device = 'cuda' # Changing this to 'cpu' still results in the vanilla case failing."
],
"metadata": {
"id": "PesDolhICF0L"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class TargetModule(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.linears = torch.nn.ModuleList([\n",
" torch.nn.Linear(2, 2),\n",
" torch.nn.Linear(2, 2), # If you comment out this line, the vanilla case will succeed.\n",
" ])\n",
" def forward(self, x):\n",
" for module in self.linears:\n",
" x = module(x)\n",
" return x\n"
],
"metadata": {
"id": "McbX2HrlBdrp"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def test_monkey_patch_instance_method():\n",
" def monkey_patch(target):\n",
" for name, module in target.named_modules():\n",
" if isinstance(module, torch.nn.Linear):\n",
" print(f'monkey patching to {name}')\n",
" module.old_forward = module.forward\n",
" def new_forward(self, x):\n",
" return self.old_forward(x)\n",
" module.forward = new_forward.__get__(module)\n",
"\n",
" torch.manual_seed(0)\n",
" x = torch.randn((2, 2)).to(device)\n",
" target = TargetModule().to(device)\n",
" with torch.no_grad():\n",
" print('')\n",
" print('*' * 80)\n",
" print('instance_method:')\n",
"\n",
" y = target(x)\n",
" print(y)\n",
" assert y.shape == (2, 2)\n",
"\n",
" monkey_patch(target)\n",
"\n",
" yy = target(x)\n",
" print(yy)\n",
" assert torch.allclose(yy, y), \"instance_method: monkey patching failed\"\n"
],
"metadata": {
"id": "RENlHKFNBh6x"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def test_monkey_patch_fix_closure():\n",
" def monkey_patch(target):\n",
" for name, module in target.named_modules():\n",
" if isinstance(module, torch.nn.Linear):\n",
" print(f'monkey patching to {name}')\n",
" old_forward = module.forward\n",
" def make_new_forward(old_forward):\n",
" def new_forward(x):\n",
" return old_forward(x)\n",
" return new_forward\n",
" module.forward = make_new_forward(old_forward)\n",
"\n",
" torch.manual_seed(0)\n",
" x = torch.randn((2, 2)).to(device)\n",
" target = TargetModule().to(device)\n",
" with torch.no_grad():\n",
" print('')\n",
" print('*' * 80)\n",
" print('vanilla:')\n",
"\n",
" y = target(x)\n",
" print(y)\n",
" assert y.shape == (2, 2)\n",
"\n",
" monkey_patch(target)\n",
"\n",
" yy = target(x)\n",
" print(yy)\n",
" assert torch.allclose(yy, y), \"fix closure: monkey patching failed\""
],
"metadata": {
"id": "I3-XZvruQRwD"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def test_monkey_patch_vanilla():\n",
" def monkey_patch(target):\n",
" for name, module in target.named_modules():\n",
" if isinstance(module, torch.nn.Linear):\n",
" print(f'monkey patching to {name}')\n",
" old_forward = module.forward\n",
" def new_forward(x):\n",
" return old_forward(x)\n",
" module.forward = new_forward\n",
"\n",
" torch.manual_seed(0)\n",
" x = torch.randn((2, 2)).to(device)\n",
" target = TargetModule().to(device)\n",
" with torch.no_grad():\n",
" print('')\n",
" print('*' * 80)\n",
" print('vanilla:')\n",
"\n",
" y = target(x)\n",
" print(y)\n",
" assert y.shape == (2, 2)\n",
"\n",
" monkey_patch(target)\n",
"\n",
" yy = target(x)\n",
" print(yy)\n",
" assert torch.allclose(yy, y), \"vanilla: monkey patching failed\"\n"
],
"metadata": {
"id": "DGjU6E2jBgsi"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test_monkey_patch_instance_method()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0AG9kwGoBsQ5",
"outputId": "70185998-3a8c-4bde-8bb1-cc02d425d6e8"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"********************************************************************************\n",
"instance_method:\n",
"tensor([[-0.2581, -0.8602],\n",
" [-0.3555, -0.4635]], device='cuda:0')\n",
"monkey patching to linears.0\n",
"monkey patching to linears.1\n",
"tensor([[-0.2581, -0.8602],\n",
" [-0.3555, -0.4635]], device='cuda:0')\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"test_monkey_patch_fix_closure()"
],
"metadata": {
"id": "y63bOY-CQlQX",
"outputId": "d54b493e-9620-4d82-a352-4197c0cd81d4",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"********************************************************************************\n",
"vanilla:\n",
"tensor([[-0.2581, -0.8602],\n",
" [-0.3555, -0.4635]], device='cuda:0')\n",
"monkey patching to linears.0\n",
"monkey patching to linears.1\n",
"tensor([[-0.2581, -0.8602],\n",
" [-0.3555, -0.4635]], device='cuda:0')\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"test_monkey_patch_vanilla()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "Y0vUq9ZxBu6R",
"outputId": "8c2f6af9-0717-4be1-9bfb-f72010d23fb7"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"********************************************************************************\n",
"vanilla:\n",
"tensor([[-0.2581, -0.8602],\n",
" [-0.3555, -0.4635]], device='cuda:0')\n",
"monkey patching to linears.0\n",
"monkey patching to linears.1\n",
"tensor([[-0.2065, -0.5703],\n",
" [-0.5468, -0.5470]], device='cuda:0')\n"
]
},
{
"output_type": "error",
"ename": "AssertionError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-05af1f1bc236>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest_monkey_patch_vanilla\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-6-95b097a22274>\u001b[0m in \u001b[0;36mtest_monkey_patch_vanilla\u001b[0;34m()\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0myy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mallclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"vanilla: monkey patching failed\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m: vanilla: monkey patching failed"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment