Created
July 1, 2023 03:01
-
-
Save malcolmgreaves/6dd6058d599d34220c2d0adf9df9e06c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import Collection, Dict, Iterator, Optional, Set | |
import torch | |
def torchscript(model: torch.nn.Module) -> torch.ScriptModule: | |
"""Runs TorchScript's scripting mode on the input model. | |
A torch scripted model is able to run in a Python-free execution environment, | |
ideal for production inference. | |
NOTE: Puts the model into evaluation, no-gradient mode before scripting. | |
NOTE: See `torch.jit.script` for details on the `optimize` parameter. | |
""" | |
with torch.no_grad(): | |
model = model.eval() | |
scripted_model = torch.jit.script(model) | |
return scripted_model | |
def torchtrace(model: torch.nn.Module, fwd_input, **kwargs) -> torch.ScriptModule: | |
"""Traces through a forward pass execution of the model. | |
NOTE: Puts the model into evaluation, no-gradient mode before tracing. | |
NOTE: Unlike scripting mode, there will be no dynamic logic in the traced model. Additionally, | |
implementation specific details, such as compute device, may be hard-coded in the resulting | |
traced ScriptModule. | |
NOTE: This function passes all extra keyword arguments to the `torch.jit.trace` function. | |
""" | |
with torch.no_grad(): | |
model = model.eval() | |
traced_model = torch.jit.trace(model, fwd_input, **kwargs) | |
return traced_model | |
@dataclass(frozen=True) | |
class Mismatch: | |
message: str | |
err: Optional[Exception] = None | |
def model_compare( | |
m1: torch.nn.Module, m2: torch.nn.Module, ignore: Optional[Collection[str]] = None | |
) -> Iterator[Mismatch]: | |
m1_d: Dict[str, torch.Tensor] = m1.state_dict() | |
m2_d: Dict[str, torch.Tensor] = m2.state_dict() | |
def comp(k: str, v1: torch.Tensor, v2: torch.Tensor) -> Iterator[Mismatch]: | |
try: | |
if not (v1 == v2).all(): | |
yield Mismatch(f"{k} not equal between models") | |
except Exception as err: | |
yield Mismatch(f"ERROR: {k} not equal between models", err) | |
seen: Set[str] = set() | |
for k, v in m1_d.items(): | |
if ignore is not None and k in ignore: | |
continue | |
seen.add(k) | |
v2 = m2_d.get(k) | |
if v2 is not None: | |
for x in comp(k, v, v2): | |
yield x | |
else: | |
yield Mismatch(f"First model has {k} ({v.shape}) but second does not.") | |
for k, v in m2_d.items(): | |
if ignore is not None and k in ignore: | |
continue | |
if k in seen: | |
continue | |
seen.add(k) | |
v1 = m1_d.get(k) | |
if v1 is not None: | |
for x in comp(k, v1, v): | |
yield x | |
else: | |
yield Mismatch(f"Second model has {k} ({v.shape}) but first does not.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment