Skip to content

Instantly share code, notes, and snippets.

@phinate
Created July 1, 2022 14:39
Show Gist options
  • Save phinate/c3628daec617c1990d75d7d9fc2543be to your computer and use it in GitHub Desktop.
Save phinate/c3628daec617c1990d75d7d9fc2543be to your computer and use it in GitHub Desktop.
convert a pyhf spec in-place to contain jax arrays
from __future__ import annotations
from jax import numpy as jnp
from typing import Any
def convert_jax(spec: dict[str, Any]) -> None:
for key, value in spec.items():
value_type = type(value)
if value_type == dict:
convert_jax(value)
elif value_type == list:
try:
v = jnp.asarray(value)
spec[key] = v
except TypeError as e:
assert isinstance(value[0], dict)
[convert_jax(a) for a in value]
else:
continue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment