Skip to content

Instantly share code, notes, and snippets.

@andyfaff
Last active April 17, 2024 06:28
Show Gist options
  • Save andyfaff/5880370330da655291271d3d28cf2f32 to your computer and use it in GitHub Desktop.
Save andyfaff/5880370330da655291271d3d28cf2f32 to your computer and use it in GitHub Desktop.
Patch for basic JAX usage with Objective/ReflectModel/Structure
diff --git a/refnx/analysis/objective.py b/refnx/analysis/objective.py
index 9ff56057..c77139ad 100644
--- a/refnx/analysis/objective.py
+++ b/refnx/analysis/objective.py
@@ -674,8 +674,8 @@ class Objective(BaseObjective):
logl += (y - model) ** 2 / var_y
# nans play havoc
- if np.isnan(logl).any():
- raise RuntimeError("Objective.logl encountered a NaN.")
+ # if np.isnan(logl).any():
+ # raise RuntimeError("Objective.logl encountered a NaN.")
# add on extra 'potential' terms from the model.
extra_potential = self.model.logp()
diff --git a/refnx/reflect/reflect_model.py b/refnx/reflect/reflect_model.py
index 5cebba6c..34b06949 100644
--- a/refnx/reflect/reflect_model.py
+++ b/refnx/reflect/reflect_model.py
@@ -514,9 +514,10 @@ class ReflectModel:
# fallback to what this object was constructed with
x_err = float(self.dq)
+ slabs = self.structure.slabs()[..., :4]
return reflectivity(
x,
- self.structure.slabs()[..., :4],
+ slabs,
scale=self.scale.value,
bkg=self.bkg.value,
dq=x_err,
diff --git a/refnx/reflect/structure.py b/refnx/reflect/structure.py
index d3ea2edb..65f78128 100644
--- a/refnx/reflect/structure.py
+++ b/refnx/reflect/structure.py
@@ -320,12 +320,12 @@ class Structure(UserList):
# if all the interfaces are Gaussian, then simply concatenate
# the default slabs property of each component.
sl = [c.slabs(structure=self) for c in self.components]
-
+ import jax.numpy as jnp
try:
- slabs = np.concatenate(sl)
+ slabs = jnp.concatenate(sl)
except ValueError:
# some of slabs may be None. np can't concatenate arr and None
- slabs = np.concatenate([s for s in sl if s is not None])
+ slabs = jnp.concatenate([s for s in sl if s is not None])
else:
# there is a non-default interfacial roughness, create a microslab
# representation
@@ -912,9 +912,12 @@ class SLD(Scatterer):
return f"SLD([{self.real!r}, {self.imag!r}], name={self.name!r})"
def __complex__(self):
- sldc = complex(self.real.value, self.imag.value)
+ sldc = self.real.value + self.imag.value * 1j
return sldc
+ def complex(self):
+ return self.real.value + self.imag.value * 1j
+
@property
def parameters(self):
"""
@@ -1289,22 +1292,22 @@ class Slab(Component):
Slab representation of this component. See :class:`Component.slabs`
"""
# speculative shortcut to prevent a number of attribute retrievals
+ import jax.numpy as jnp
if self.sld.dispersive:
sldc = self.sld.complex(getattr(structure, "wavelength", None))
else:
- sldc = complex(self.sld)
+ sldc = self.sld.complex()
- return np.array(
+ return jnp.array(
[
[
- self.thick.value,
+ self.thick._value,
sldc.real,
sldc.imag,
- self.rough.value,
- self.vfsolv.value,
+ self.rough._value,
+ self.vfsolv._value,
]
- ],
- dtype=float,
+ ]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment