Skip to content

Instantly share code, notes, and snippets.

@DNGros
Created January 4, 2019 07:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DNGros/ca66dfa9e896d89c8cb6911e73c70d99 to your computer and use it in GitHub Desktop.
Save DNGros/ca66dfa9e896d89c8cb6911e73c70d99 to your computer and use it in GitHub Desktop.
Pytorch nn.Module with multiple forward methods
"""This code tries to address the desire to have different forwards with
different signatures and supporting static analysis / IDE hinting.
Example:
class MyModule(MultiforwardTorchModule):
@add_hooks
def forward_train(
hidden_state: torch.Tensor,
teacher_force_seq: List[str]
# training specific args...
) -> torch.Tensor:
# ....
return loss
@add_hooks
def forward_inference(
hidden_state: torch.Tensor,
beam_size: int
# inference specific args...
) -> List[str]:
# ....
return result
# mod = MyModule()
# instead of this: mod(foo, bar)
# we can do this: mod.forward_train(foo, bar)
# and still have the forward/backwards hooks called
Copyright 2019 David Gros. Freely available under MIT license
(https://opensource.org/licenses/MIT)
"""
from typing import Callable
import torch.nn
def add_hooks(new_forward_func):
"""A decorator for for methods inside a MultifowardTorchModule to make a
forward act like a forward call (still calling the forwards/backwards
hooks)"""
def wrapper(self: MultiforwardTorchModule, *args, **kwargs):
return self(new_forward_func, self, *args, **kwargs)
return wrapper
class MultiforwardTorchModule(torch.nn.Module):
"""Wraps nn.Module to work with add_forward hooks. Instead of overriding
forward and calling this module with __call__, you can just use the
add_hooks on methods that act like a forward"""
def forward(self, actual_forward: Callable, *args, **kwargs):
"""Calls the value passed in from the annotation. This should not be
overridden (unless you want to create something that happens on all
your forwards somewhat like a forward hook.)"""
return actual_forward(*args, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment