Skip to content

Instantly share code, notes, and snippets.

@jkbjh
Last active March 10, 2021 08:57
Show Gist options
  • Save jkbjh/fd807bd4eb4afab28727efbb82c1f0d6 to your computer and use it in GitHub Desktop.
Save jkbjh/fd807bd4eb4afab28727efbb82c1f0d6 to your computer and use it in GitHub Desktop.
# a selection of small numpy helper functions
import joblib
def apply_along_axes(func, data, axes):
"""
apply function along axes and use joblib parallel to execute the function on the selections.
might result in a large allocation for reconstructing the result.
only makes sense if the sub-executions are sufficiently costly.
"""
sizes = np.array(data.shape)[list(axes)]
it = np.nditer(data, flags=["multi_index"], op_axes=[axes])
indices = []
delayed = []
for i in it:
indices.append(it.multi_index)
delayed.append(joblib.delayed(func)(data[it.multi_index]))
results = joblib.Parallel()(delayed)
retdat = np.empty(sizes, dtype=np.dtype(results[0]))
for i, val in zip(indices, results):
retdat[i] = val
return retdat
def unravel(data):
"""
merge arrays of lists of arrays (of same lengths and shapes) into a larger array.
Might need to be applied repeatedly.
Might be slow.
"""
cell = next(np.nditer(data, flags=["refs_ok"])).item()
shape = list(data.shape)
if hasattr(cell, "shape") and len(cell.shape) > 0:
shape += list(cell.shape)
else:
return data
e = np.empty(shape)
it = np.nditer(data, flags=["multi_index", "refs_ok"], op_axes=[list(range(len(data.shape)))])
for i in it:
e[it.multi_index] = data[it.multi_index]
return e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment