Skip to content

Instantly share code, notes, and snippets.

@malcolmgreaves
Last active February 21, 2023 21:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save malcolmgreaves/f8603c7b65be836294de98684d3c2e26 to your computer and use it in GitHub Desktop.
Save malcolmgreaves/f8603c7b65be836294de98684d3c2e26 to your computer and use it in GitHub Desktop.
descent_script: A function that recursively expands a model's children and TorchScripts everything, making it easy to identify which network parts are and are not TorchScript compatible.
import traceback
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Iterator, Sequence, Union
import torch
from core_utils.common import type_name
__all__: Sequence[str] = (
"descent_script",
"Ok",
"Fail",
"TsDescentResult",
)
@dataclass(frozen=True)
class TsDescentResult(ABC):
name: str
level: int
index: int
@property
def identifier(self) -> str:
return f"name={self.name},level={self.level},index={self.index}"
@abstractmethod
def inspection(self) -> str:
raise NotImplementedError
def __str__(self) -> str:
return f"{type_name(type(self))}({self.identifier})"
def __repr__(self) -> str:
return str(self)
@dataclass(frozen=True)
class Ok(TsDescentResult):
scripted: torch.jit.ScriptModule
def inspection(self) -> str:
return f"TorchScripted '{self.name}' ({self.level=}, {self.index=})"
def __repr__(self) -> str:
return str(self)
@dataclass(frozen=True)
class Fail(TsDescentResult):
child: torch.nn.Module
error: Exception
def __str__(self) -> str:
return (
f"Fail({self.identifier},error={self.error},type(child)={type_name(type(self.child))})"
)
def __repr__(self) -> str:
return str(self)
def inspection(self) -> str:
try:
raise self.error
except: # noqa
error_str: str = traceback.format_exc()
return (
f"Cannot TorchScript '{self.name}' ({self.level=},{self.index}) "
f"due to {self.error}\n{error_str}"
)
def descent_script(
model: torch.nn.Module,
name: str = "",
level: int = 0,
index: int = 0,
*,
descend_failure_only: bool = True,
) -> Iterator[Union[Ok, Fail]]:
"""Recursively descend into the model's children, TorchScripting all.
This function is very useful on the first attempt at TorchScripting an existing model.
It will find all parts of a :param:`model` that are able to be TorchScripted and will identify
the network parts that require modification.
A variation of depth-first search, the :param:`model` is TorchScripted (TS) and _then_ its
children are recursively descent into. TS results are yielded as they are encountered.
A result of :class:`Ok` means that TS was successful on the :param:`model`. In contrast, a
result of :class:`Fail` means that TS encountered an exception.
The name of each child module is preserved in the output, along with the recursion level and
its order, according to the iteration order of the `named_children()` function on
`torch.nn.Module` instances.
The base case for the recursion is supplied by the default values for :param:`name`,
:param:`level`, and :param:`index`, which is the empty string and zeros, respectively.
At the base case, :param:`name` will become the fully qualified type name of :param:`model`.
Debugging note: Call the `.inspection()` method on the returned results to easily identify
what field of the original :param:`model` is not TorchScript-able.
Note: By default, will only descend into the named_children when TorchScript fails.
To descend into children on a successful TorchScript call, set the option
:param:`descend_failure_only` to `True`.
"""
if len(name) == 0:
name = type_name(type(model), keep_main=False)
is_failure = False
try:
scripted = torch.jit.script(model)
except Exception as error:
yield Fail(name, level, index, model, error)
is_failure = True
else:
yield Ok(name, level, index, scripted)
if not descend_failure_only or (descend_failure_only and is_failure):
for i, (c_name, child) in enumerate(model.named_children()):
yield from descent_script(
child,
c_name,
level + 1,
i,
descend_failure_only=descend_failure_only,
)
if __name__ == "__main__":
from typing import List, Tuple
class Simple(torch.nn.Module):
def __init__(self, weights: torch.Tensor) -> None:
super().__init__()
self.weights = weights
class TsModule(Simple):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.weights * x
class NoTsModule(Simple):
def forward(self, *xs) -> List[torch.Tensor]:
return [self.weight * x for x in xs]
class Together(torch.nn.Module):
def __init__(self, weights: torch.Tensor):
super().__init__()
self.ok_for_ts = TsModule(weights)
self.fail_for_ts = NoTsModule(weights)
def forward(self, xs: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
return (
[self.ok_for_ts.forward(x) for x in xs],
self.fail_for_ts.forward(xs),
)
W = torch.randn((5, 5))
model = Together(W)
print(f"Attempting TorchScript of {model=}")
print("-" * 80)
attempts = list(descent_script(model))
for a in attempts:
print(a)
print("-" * 80)
@malcolmgreaves
Copy link
Author

malcolmgreaves commented Feb 21, 2023

Running the example in this module produces the following output:

Attempting TorchScript of model=Together(
  (ok_for_ts): TsModule()
  (fail_for_ts): NoTsModule()
)
--------------------------------------------------------------------------------
Fail(name=Together,level=0,index=0,error=Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "...", line 127
        def forward(self, *xs) -> List[torch.Tensor]:
                          ~~~ <--- HERE
            return [self.weight * x for x in xs]
,type(child)=__main__.Together)
--------------------------------------------------------------------------------
__main__.Ok(name=ok_for_ts,level=1,index=0)
--------------------------------------------------------------------------------
Fail(name=fail_for_ts,level=1,index=1,error=Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "...", line 127
        def forward(self, *xs) -> List[torch.Tensor]:
                          ~~~ <--- HERE
            return [self.weight * x for x in xs]
,type(child)=__main__.NoTsModule)
--------------------------------------------------------------------------------

@malcolmgreaves
Copy link
Author

Notice that the ok_for_ts model is successfully TorchScripted while the fail_for_ts part is responsbile for the Together model not working with TorchScript.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment