Last active
February 21, 2023 21:01
-
-
Save malcolmgreaves/f8603c7b65be836294de98684d3c2e26 to your computer and use it in GitHub Desktop.
descent_script: A function that recursively expands a model's children and TorchScripts everything, making it easy to identify which network parts are and are not TorchScript compatible.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import traceback | |
from abc import ABC, abstractmethod | |
from dataclasses import dataclass | |
from typing import Iterator, Sequence, Union | |
import torch | |
from core_utils.common import type_name | |
__all__: Sequence[str] = ( | |
"descent_script", | |
"Ok", | |
"Fail", | |
"TsDescentResult", | |
) | |
@dataclass(frozen=True) | |
class TsDescentResult(ABC): | |
name: str | |
level: int | |
index: int | |
@property | |
def identifier(self) -> str: | |
return f"name={self.name},level={self.level},index={self.index}" | |
@abstractmethod | |
def inspection(self) -> str: | |
raise NotImplementedError | |
def __str__(self) -> str: | |
return f"{type_name(type(self))}({self.identifier})" | |
def __repr__(self) -> str: | |
return str(self) | |
@dataclass(frozen=True) | |
class Ok(TsDescentResult): | |
scripted: torch.jit.ScriptModule | |
def inspection(self) -> str: | |
return f"TorchScripted '{self.name}' ({self.level=}, {self.index=})" | |
def __repr__(self) -> str: | |
return str(self) | |
@dataclass(frozen=True) | |
class Fail(TsDescentResult): | |
child: torch.nn.Module | |
error: Exception | |
def __str__(self) -> str: | |
return ( | |
f"Fail({self.identifier},error={self.error},type(child)={type_name(type(self.child))})" | |
) | |
def __repr__(self) -> str: | |
return str(self) | |
def inspection(self) -> str: | |
try: | |
raise self.error | |
except: # noqa | |
error_str: str = traceback.format_exc() | |
return ( | |
f"Cannot TorchScript '{self.name}' ({self.level=},{self.index}) " | |
f"due to {self.error}\n{error_str}" | |
) | |
def descent_script( | |
model: torch.nn.Module, | |
name: str = "", | |
level: int = 0, | |
index: int = 0, | |
*, | |
descend_failure_only: bool = True, | |
) -> Iterator[Union[Ok, Fail]]: | |
"""Recursively descend into the model's children, TorchScripting all. | |
This function is very useful on the first attempt at TorchScripting an existing model. | |
It will find all parts of a :param:`model` that are able to be TorchScripted and will identify | |
the network parts that require modification. | |
A variation of depth-first search, the :param:`model` is TorchScripted (TS) and _then_ its | |
children are recursively descent into. TS results are yielded as they are encountered. | |
A result of :class:`Ok` means that TS was successful on the :param:`model`. In contrast, a | |
result of :class:`Fail` means that TS encountered an exception. | |
The name of each child module is preserved in the output, along with the recursion level and | |
its order, according to the iteration order of the `named_children()` function on | |
`torch.nn.Module` instances. | |
The base case for the recursion is supplied by the default values for :param:`name`, | |
:param:`level`, and :param:`index`, which is the empty string and zeros, respectively. | |
At the base case, :param:`name` will become the fully qualified type name of :param:`model`. | |
Debugging note: Call the `.inspection()` method on the returned results to easily identify | |
what field of the original :param:`model` is not TorchScript-able. | |
Note: By default, will only descend into the named_children when TorchScript fails. | |
To descend into children on a successful TorchScript call, set the option | |
:param:`descend_failure_only` to `True`. | |
""" | |
if len(name) == 0: | |
name = type_name(type(model), keep_main=False) | |
is_failure = False | |
try: | |
scripted = torch.jit.script(model) | |
except Exception as error: | |
yield Fail(name, level, index, model, error) | |
is_failure = True | |
else: | |
yield Ok(name, level, index, scripted) | |
if not descend_failure_only or (descend_failure_only and is_failure): | |
for i, (c_name, child) in enumerate(model.named_children()): | |
yield from descent_script( | |
child, | |
c_name, | |
level + 1, | |
i, | |
descend_failure_only=descend_failure_only, | |
) | |
if __name__ == "__main__": | |
from typing import List, Tuple | |
class Simple(torch.nn.Module): | |
def __init__(self, weights: torch.Tensor) -> None: | |
super().__init__() | |
self.weights = weights | |
class TsModule(Simple): | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.weights * x | |
class NoTsModule(Simple): | |
def forward(self, *xs) -> List[torch.Tensor]: | |
return [self.weight * x for x in xs] | |
class Together(torch.nn.Module): | |
def __init__(self, weights: torch.Tensor): | |
super().__init__() | |
self.ok_for_ts = TsModule(weights) | |
self.fail_for_ts = NoTsModule(weights) | |
def forward(self, xs: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: | |
return ( | |
[self.ok_for_ts.forward(x) for x in xs], | |
self.fail_for_ts.forward(xs), | |
) | |
W = torch.randn((5, 5)) | |
model = Together(W) | |
print(f"Attempting TorchScript of {model=}") | |
print("-" * 80) | |
attempts = list(descent_script(model)) | |
for a in attempts: | |
print(a) | |
print("-" * 80) |
Notice that the ok_for_ts
model is successfully TorchScripted while the fail_for_ts
part is responsbile for the Together
model not working with TorchScript.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Running the example in this module produces the following output: