Skip to content

Instantly share code, notes, and snippets.

@fzimmermann89
Created February 11, 2024 23:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fzimmermann89/df2672e988f0512dca61b2178fc928c8 to your computer and use it in GitHub Desktop.
Save fzimmermann89/df2672e988f0512dca61b2178fc928c8 to your computer and use it in GitHub Desktop.
reshape
import dataclasses
from typing import Type, Callable, Any
def _rapply(obj:Any, functions_per_types: dict[Type | list[Type], Callable]):
""" Apply callables to all fields of a dataclass and keys of a dictionary, recursively.
Which callable is used is determined by the type. By default, the function will recurse into dataclasses.fields and dictionaries, respectively, returning new dataclass instances or dictionaries, respectively. This can be disabled by, for example, adding {dict:lambda x:x} to functions_per_types, as these take precedence.
Example:
>>> _rapply({'a':1, 'string':'not changed'}, {int:lambda x:x+1})
{'a': 2, 'string': 'not changed'}
Here, the function +1 is applied to all ints in the dict,
everything else remains unchanged.
Parameters
----------
obj
Target object.
functions_per_types
Dictionary mapping types to the functions to apply.
"""
for types, function in functions_per_types.items():
# User supplied rules
if isinstance(obj, types):
return function(obj)
if dataclasses.is_dataclass(obj):
# Recurse into dataclasses
new = {field.name: _rapply(getattr(obj, field.name), functions_per_types) for field in dataclasses.fields(obj)}
return dataclasses.replace(obj, **new)
if isinstance(obj, dict):
# Recurse into dictionaries
return {key: _rapply(value, functions_per_types) for key, value in obj.items()}
else:
# No rule: Do nothing
return obj
def _reshape_kdata(kdata: KData, pattern: str, **axes):
"""
Parameters
----------
kdata
data to reshape
pattern
einops repeat pattern. must match kdata.data shape (i.e. include coil dim, k2, k1, k0)
Example: "... other coil k2 k1 k0 -> ... (other 2) coil k2 k1 k0"
would stack a copy in the last of the 'other' dimensions
**axes
any additional specifications for dimensions
See https://einops.rocks/api/repeat/ for more information
"""
def rearrange(x):
"""Rearrange the tensors"""
if x.ndim == kdata.data.ndim - 2:
# missing coil and k0 dim
xs = x.unsqueeze(-1).unsqueeze(-4)
elif x.ndim == kdata.data.ndim - 1:
# missing coil dim
xs = x.unsqueeze(-4)
else:
xs = x
xs = einops.repeat(xs, pattern, **axes)
if x.ndim == kdata.data.ndim - 2:
return xs.squeeze(-4).squeeze(-1)
elif x.ndim == kdata.data.ndim - 1:
return xs.squeeze(-4)
else:
return xs
data = rearrange(kdata.data.clone())
trajectory = KTrajectory(*rearrange(kdata.traj.as_tensor(-4)).unbind(-4))
header = deepcopy(kdata.header)
header.acq_info = _rapply(header.acq_info, {torch.Tensor: rearrange})
return KData(data=data, header=header, traj=trajectory)
def split_dyn(kdata, ndyn):
pattern = "... c k2 (ndyn k1) k0 -> ... ndyn c k2 k1 k0"
return _reshape_kdata(kdata, pattern, ndyn=ndyn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment