Created
November 9, 2023 09:29
-
-
Save EtaoinWu/556b630b4c9f7095050b3bd015605510 to your computer and use it in GitHub Desktop.
dirty equinox/jaxtyping/beartype hack for typechecking eqx.Module methods
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from abc import abstractmethod as abstractmethod | |
import beartype._decor._decornontype | |
import equinox as eqx | |
import jax as jax | |
from beartype import BeartypeConf, beartype as typechecker | |
from beartype.door import die_if_unbearable, is_bearable | |
from beartype.typing import * # pyright: ignore[reportWildcardImportFromLibrary] | |
from equinox import Module as EqxModule | |
from jaxtyping import ( | |
Array as Array, | |
Float as Float, | |
Integer as Integer, | |
Key as Key, | |
Scalar as Scalar, | |
jaxtyped as jaxtyped, | |
) | |
RNGKey = Key[Scalar, ""] | |
FloatScalar = Float[Scalar, ""] | |
def bypass_jaxtype[T: Callable](fn: T) -> T: | |
"""Decorator that bypasses `jaxtyping` type checking. | |
""" | |
fn.__bypass_jaxtype__ = True | |
return fn | |
def jaxtype_bypassed(fn: Any) -> bool: | |
"""Check if a function is @bypass_jaxtype-ed. | |
""" | |
return bool(getattr(fn, "__bypass_jaxtype__", False)) | |
def jaxtype_class[T](c: type[T]) -> type[T]: | |
"""Decorates a class to add @jaxtyped @beartype on each of its members. | |
Also works for `eqx.Module`s. | |
""" | |
if not inspect.isclass(c): | |
raise TypeError("`jaxtype_class` expected a class.") | |
# We use .__sizeof__ to store our special attributes, | |
# just like beartype, because it's a rarely used method. | |
c_sizeof = c.__sizeof__ | |
if getattr(c_sizeof, '__class_jaxtyped', False): | |
# No-op if already done. | |
return c | |
if getattr(c_sizeof, "__beartyped_cls", None) is not c: | |
raise ValueError(f"{c} is not already `@beartype`-ed") | |
for attr_name, attr_value in c.__dict__.items(): | |
if jaxtype_bypassed(attr_value) \ | |
or jaxtype_bypassed(getattr(attr_value, 'method', None)) \ | |
or jaxtype_bypassed(getattr(attr_value, '__wrapped__', None)): | |
continue | |
if isinstance(attr_value, eqx._module._wrap_method): | |
# We have to use a beartype internal here | |
# for PEP 673 Self annotation. | |
attr_value.method = beartype._decor._decornontype.beartype_func( | |
attr_value.method, | |
conf=BeartypeConf(), | |
cls_stack=(c,), | |
) | |
attr_value.method = jaxtyped(attr_value.method) | |
setattr(attr_value, "__beartype_wrapper", True) | |
continue | |
if not hasattr(attr_value, "__beartype_wrapper"): | |
continue | |
# This attribute is beartyped (is thus a function or method) | |
# and we want it to be @jaxtyped. | |
new_fn = jaxtyped(attr_value) | |
setattr(new_fn, "__beartype_wrapper", True) | |
setattr(c, attr_name, new_fn) | |
setattr(c_sizeof, "__class_jaxtyped", True) | |
return c |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment