Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Created May 1, 2022 18:23
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ProGamerGov/e4060b55c702835ac933d95f063a2f6e to your computer and use it in GitHub Desktop.
Save ProGamerGov/e4060b55c702835ac933d95f063a2f6e to your computer and use it in GitHub Desktop.
Remove hooks in PyTorch without using the hook handle
from collections import OrderedDict
from typing import Callable, Dict, Optional
from warnings import warn
import torch
def _remove_all_forward_hooks(
module: torch.nn.Module, hook_fn_name: Optional[str] = None
) -> None:
"""
This function removes all forward hooks in the specified module, without requiring
any hook handles. This lets us clean up & remove any hooks that weren't property
deleted.
Warning: Various PyTorch modules and systems make use of hooks, and thus extreme
caution should be exercised when removing all hooks. Users are recommended to give
their hook function a unique name that can be used to safely identify and remove
the target forward hooks.
Args:
module (nn.Module): The module instance to remove forward hooks from.
hook_fn_name (str, optional): Optionally only remove specific forward hooks
based on their function's __name__ attribute.
Default: None
"""
if hook_fn_name is None:
warn("Removing all active hooks will break some PyTorch modules & systems.")
def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None:
if hasattr(module, "_forward_hooks"):
if m._forward_hooks != OrderedDict():
if name is not None:
dict_items = list(m._forward_hooks.items())
m._forward_hooks = OrderedDict(
[(i, fn) for i, fn in dict_items if fn.__name__ != name]
)
else:
m._forward_hooks: Dict[int, Callable] = OrderedDict()
def _remove_child_hooks(
target_module: torch.nn.Module, hook_name: Optional[str] = None
) -> None:
for name, child in target_module._modules.items():
if child is not None:
_remove_hooks(child, hook_name)
_remove_child_hooks(child, hook_name)
# Remove hooks from target submodules
_remove_child_hooks(module, hook_fn_name)
# Remove hooks from the target module
_remove_hooks(module, hook_fn_name)
from collections import OrderedDict
from typing import List, Optional
import torch
def _count_forward_hooks(
module: torch.nn.Module, hook_fn_name: Optional[str] = None
) -> int:
"""
Count the number of active forward hooks on the specified module instance.
Args:
module (nn.Module): The model module instance to count the number of
forward hooks on.
name (str, optional): Optionally only count specific forward hooks based on
their function's __name__ attribute.
Default: None
Returns:
num_hooks (int): The number of active hooks in the specified module.
"""
num_hooks: List[int] = [0]
def _count_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None:
if hasattr(m, "_forward_hooks"):
if m._forward_hooks != OrderedDict():
dict_items = list(m._forward_hooks.items())
for i, fn in dict_items:
if hook_fn_name is None or fn.__name__ == name:
num_hooks[0] += 1
def _count_child_hooks(
target_module: torch.nn.Module,
hook_name: Optional[str] = None,
) -> None:
for name, child in target_module._modules.items():
if child is not None:
_count_hooks(child, hook_name)
_count_child_hooks(child, hook_name)
_count_child_hooks(module, hook_fn_name)
_count_hooks(module, hook_fn_name)
return num_hooks[0]
@ProGamerGov
Copy link
Author

ProGamerGov commented May 1, 2022

These functions are based on issues with PyTorch's hook management system that I raised here: pytorch/pytorch#70455

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment