Skip to content

Instantly share code, notes, and snippets.

@PhilipVinc
Created November 3, 2022 18:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PhilipVinc/728a4a35f5d98713c384ab6f91dd6d3c to your computer and use it in GitHub Desktop.
Save PhilipVinc/728a4a35f5d98713c384ab6f91dd6d3c to your computer and use it in GitHub Desktop.
bug in jax
# pip install jax jaxlib netket
"""Module for the common control flow utilities."""
import os
from functools import partial
from typing import Callable, Optional, Sequence, Set
from jax import core
from jax import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src import ad_util
from jax._src import util
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
from jax.tree_util import tree_map, tree_unflatten, tree_structure
map, unsafe_map = safe_map, map
allowed_effects: Set[core.Effect] = set()
allowed_effects.add(lax.InOutFeedEffect.Infeed)
allowed_effects.add(lax.InOutFeedEffect.Outfeed)
def _abstractify(x):
return core.raise_to_shaped(core.get_aval(x))
def _typecheck_param(prim, param, name, msg_required, pred):
if not pred:
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
f'{msg_required} required:')
param_str = str(param)
sep = os.linesep if os.linesep in param_str else ' '
msg = sep.join([msg, param_str])
raise core.JaxprTypeError(msg)
@weakref_lru_cache
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
return jaxpr, consts, out_tree()
@weakref_lru_cache
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
fun, in_tree, in_avals, primitive_name)
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
return closed_jaxpr, consts, out_tree
@cache()
def _initial_style_jaxprs_with_common_consts(
funs: Sequence[Callable], in_tree, in_avals, primitive_name: str):
# When staging the branches of a conditional into jaxprs, constants are
# extracted from each branch and converted to jaxpr arguments. To use the
# staged jaxprs as the branches to a conditional *primitive*, we need for
# their (input) signatures to match. This function "joins" the staged jaxprs:
# for each one, it makes another that accepts *all* constants, but only uses
# those that it needs (dropping the rest).
jaxprs, all_consts, all_out_trees = \
unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
for fun in funs)
newvar = core.gensym(jaxprs, suffix='_')
all_const_avals = [map(_abstractify, consts) for consts in all_consts]
unused_const_vars = [map(newvar, const_avals)
for const_avals in all_const_avals]
def pad_jaxpr_constvars(i, jaxpr):
prefix = util.concatenate(unused_const_vars[:i])
suffix = util.concatenate(unused_const_vars[i + 1:])
constvars = [*prefix, *jaxpr.constvars, *suffix]
return jaxpr.replace(constvars=constvars)
consts = util.concatenate(all_consts)
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
for jaxpr in jaxprs]
return closed_jaxprs, consts, all_out_trees
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
Corresponding `tree` and `avals` must match in the sense that the number of
leaves in `tree` must be equal to the length of `avals`. `what` will be
prepended to details of the mismatch in TypeError.
"""
if tree1 != tree2:
raise TypeError(
f"{what} must have same type structure, got {tree1} and {tree2}.")
if not all(map(core.typematch, avals1, avals2)):
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
tree_unflatten(tree2, avals2))
raise TypeError(f"{what} must have identical types, got\n{diff}.")
def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
if has_aux:
actual_tree_children = actual_tree.children()
if len(actual_tree_children) == 2:
# select first child as result tree
actual_tree = tree_structure(actual_tree_children[0])
else:
raise ValueError(
f"{func_name}() produced a pytree with structure "
f"{actual_tree}, but a pytree tuple with auxiliary "
f"output was expected because has_aux was set to True.")
if actual_tree != expected_tree:
raise TypeError(
f"{func_name}() output pytree structure must match {expected_name}, "
f"got {actual_tree} and {expected_tree}.")
def _prune_zeros(ts):
return [t for t in ts if type(t) is not ad_util.Zero]
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
return core.ClosedJaxpr(jaxpr, consts)
def _show_diff(array1, array2):
if core.typematch(array1, array2):
return f"{array1}"
return f"DIFFERENT {array1} vs. {array2}"
def _avals_short(avals):
to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))()
return ' '.join(map(to_str, avals))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment