Skip to content

Instantly share code, notes, and snippets.

@EtaoinWu
Created November 9, 2023 09:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EtaoinWu/556b630b4c9f7095050b3bd015605510 to your computer and use it in GitHub Desktop.
Save EtaoinWu/556b630b4c9f7095050b3bd015605510 to your computer and use it in GitHub Desktop.
dirty equinox/jaxtyping/beartype hack for typechecking eqx.Module methods
import inspect
from abc import abstractmethod as abstractmethod
import beartype._decor._decornontype
import equinox as eqx
import jax as jax
from beartype import BeartypeConf, beartype as typechecker
from beartype.door import die_if_unbearable, is_bearable
from beartype.typing import * # pyright: ignore[reportWildcardImportFromLibrary]
from equinox import Module as EqxModule
from jaxtyping import (
Array as Array,
Float as Float,
Integer as Integer,
Key as Key,
Scalar as Scalar,
jaxtyped as jaxtyped,
)
RNGKey = Key[Scalar, ""]
FloatScalar = Float[Scalar, ""]
def bypass_jaxtype[T: Callable](fn: T) -> T:
"""Decorator that bypasses `jaxtyping` type checking.
"""
fn.__bypass_jaxtype__ = True
return fn
def jaxtype_bypassed(fn: Any) -> bool:
"""Check if a function is @bypass_jaxtype-ed.
"""
return bool(getattr(fn, "__bypass_jaxtype__", False))
def jaxtype_class[T](c: type[T]) -> type[T]:
"""Decorates a class to add @jaxtyped @beartype on each of its members.
Also works for `eqx.Module`s.
"""
if not inspect.isclass(c):
raise TypeError("`jaxtype_class` expected a class.")
# We use .__sizeof__ to store our special attributes,
# just like beartype, because it's a rarely used method.
c_sizeof = c.__sizeof__
if getattr(c_sizeof, '__class_jaxtyped', False):
# No-op if already done.
return c
if getattr(c_sizeof, "__beartyped_cls", None) is not c:
raise ValueError(f"{c} is not already `@beartype`-ed")
for attr_name, attr_value in c.__dict__.items():
if jaxtype_bypassed(attr_value) \
or jaxtype_bypassed(getattr(attr_value, 'method', None)) \
or jaxtype_bypassed(getattr(attr_value, '__wrapped__', None)):
continue
if isinstance(attr_value, eqx._module._wrap_method):
# We have to use a beartype internal here
# for PEP 673 Self annotation.
attr_value.method = beartype._decor._decornontype.beartype_func(
attr_value.method,
conf=BeartypeConf(),
cls_stack=(c,),
)
attr_value.method = jaxtyped(attr_value.method)
setattr(attr_value, "__beartype_wrapper", True)
continue
if not hasattr(attr_value, "__beartype_wrapper"):
continue
# This attribute is beartyped (is thus a function or method)
# and we want it to be @jaxtyped.
new_fn = jaxtyped(attr_value)
setattr(new_fn, "__beartype_wrapper", True)
setattr(c, attr_name, new_fn)
setattr(c_sizeof, "__class_jaxtyped", True)
return c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment