Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Compare two `torch.nn.module` instances together, yielding a list of mismatched layers.
from dataclasses import dataclass
from typing import Optional, Iterator, Collection, Dict
import torch
@dataclass(frozen=True)
class Mismatch:
message: str
err: Optional[Exception] = None
def model_compare(
m1: torch.nn.Module, m2: torch.nn.Module, ignore: Optional[Collection[str]] = None
) -> Iterator[Mismatch]:
m1_d: Dict[str, torch.Tensor] = m1.state_dict()
m2_d: Dict[str, torch.Tensor] = m2.state_dict()
def comp(k: str, v1: torch.Tensor, v2: torch.Tensor) -> None:
try:
if not (v1 == v2).all():
yield Mismatch(f"{k} not equal between models")
except Exception as err:
yield Mismatch(f"ERROR: {k} not equal between models", err)
seen: Set[str] = set()
for k, v in m1_d.items():
if k in ignore:
continue
seen.add(k)
v2 = m2_d.get(k)
if v2 is not None:
for x in comp(k, v, v2):
yield x
else:
yield Mismatch(f"First model has {k} ({v.shape}) but second does not.")
for k, v in m2_d.items():
if k in ignore:
continue
if k in seen:
continue
seen.add(k)
v1 = m1_d.get()
if v1 is not None:
for x in comp(k, v1, v):
yield x
else:
yield Mismatch(f"Second model has {k} ({v.shape}) but second does not.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment