Skip to content

Instantly share code, notes, and snippets.

@malcolmgreaves
Created July 1, 2023 03:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save malcolmgreaves/6dd6058d599d34220c2d0adf9df9e06c to your computer and use it in GitHub Desktop.
Save malcolmgreaves/6dd6058d599d34220c2d0adf9df9e06c to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from typing import Collection, Dict, Iterator, Optional, Set
import torch
def torchscript(model: torch.nn.Module) -> torch.ScriptModule:
"""Runs TorchScript's scripting mode on the input model.
A torch scripted model is able to run in a Python-free execution environment,
ideal for production inference.
NOTE: Puts the model into evaluation, no-gradient mode before scripting.
NOTE: See `torch.jit.script` for details on the `optimize` parameter.
"""
with torch.no_grad():
model = model.eval()
scripted_model = torch.jit.script(model)
return scripted_model
def torchtrace(model: torch.nn.Module, fwd_input, **kwargs) -> torch.ScriptModule:
"""Traces through a forward pass execution of the model.
NOTE: Puts the model into evaluation, no-gradient mode before tracing.
NOTE: Unlike scripting mode, there will be no dynamic logic in the traced model. Additionally,
implementation specific details, such as compute device, may be hard-coded in the resulting
traced ScriptModule.
NOTE: This function passes all extra keyword arguments to the `torch.jit.trace` function.
"""
with torch.no_grad():
model = model.eval()
traced_model = torch.jit.trace(model, fwd_input, **kwargs)
return traced_model
@dataclass(frozen=True)
class Mismatch:
message: str
err: Optional[Exception] = None
def model_compare(
m1: torch.nn.Module, m2: torch.nn.Module, ignore: Optional[Collection[str]] = None
) -> Iterator[Mismatch]:
m1_d: Dict[str, torch.Tensor] = m1.state_dict()
m2_d: Dict[str, torch.Tensor] = m2.state_dict()
def comp(k: str, v1: torch.Tensor, v2: torch.Tensor) -> Iterator[Mismatch]:
try:
if not (v1 == v2).all():
yield Mismatch(f"{k} not equal between models")
except Exception as err:
yield Mismatch(f"ERROR: {k} not equal between models", err)
seen: Set[str] = set()
for k, v in m1_d.items():
if ignore is not None and k in ignore:
continue
seen.add(k)
v2 = m2_d.get(k)
if v2 is not None:
for x in comp(k, v, v2):
yield x
else:
yield Mismatch(f"First model has {k} ({v.shape}) but second does not.")
for k, v in m2_d.items():
if ignore is not None and k in ignore:
continue
if k in seen:
continue
seen.add(k)
v1 = m1_d.get(k)
if v1 is not None:
for x in comp(k, v1, v):
yield x
else:
yield Mismatch(f"Second model has {k} ({v.shape}) but first does not.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment