Created
February 11, 2024 23:36
-
-
Save fzimmermann89/df2672e988f0512dca61b2178fc928c8 to your computer and use it in GitHub Desktop.
reshape
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses | |
from typing import Type, Callable, Any | |
def _rapply(obj:Any, functions_per_types: dict[Type | list[Type], Callable]): | |
""" Apply callables to all fields of a dataclass and keys of a dictionary, recursively. | |
Which callable is used is determined by the type. By default, the function will recurse into dataclasses.fields and dictionaries, respectively, returning new dataclass instances or dictionaries, respectively. This can be disabled by, for example, adding {dict:lambda x:x} to functions_per_types, as these take precedence. | |
Example: | |
>>> _rapply({'a':1, 'string':'not changed'}, {int:lambda x:x+1}) | |
{'a': 2, 'string': 'not changed'} | |
Here, the function +1 is applied to all ints in the dict, | |
everything else remains unchanged. | |
Parameters | |
---------- | |
obj | |
Target object. | |
functions_per_types | |
Dictionary mapping types to the functions to apply. | |
""" | |
for types, function in functions_per_types.items(): | |
# User supplied rules | |
if isinstance(obj, types): | |
return function(obj) | |
if dataclasses.is_dataclass(obj): | |
# Recurse into dataclasses | |
new = {field.name: _rapply(getattr(obj, field.name), functions_per_types) for field in dataclasses.fields(obj)} | |
return dataclasses.replace(obj, **new) | |
if isinstance(obj, dict): | |
# Recurse into dictionaries | |
return {key: _rapply(value, functions_per_types) for key, value in obj.items()} | |
else: | |
# No rule: Do nothing | |
return obj | |
def _reshape_kdata(kdata: KData, pattern: str, **axes): | |
""" | |
Parameters | |
---------- | |
kdata | |
data to reshape | |
pattern | |
einops repeat pattern. must match kdata.data shape (i.e. include coil dim, k2, k1, k0) | |
Example: "... other coil k2 k1 k0 -> ... (other 2) coil k2 k1 k0" | |
would stack a copy in the last of the 'other' dimensions | |
**axes | |
any additional specifications for dimensions | |
See https://einops.rocks/api/repeat/ for more information | |
""" | |
def rearrange(x): | |
"""Rearrange the tensors""" | |
if x.ndim == kdata.data.ndim - 2: | |
# missing coil and k0 dim | |
xs = x.unsqueeze(-1).unsqueeze(-4) | |
elif x.ndim == kdata.data.ndim - 1: | |
# missing coil dim | |
xs = x.unsqueeze(-4) | |
else: | |
xs = x | |
xs = einops.repeat(xs, pattern, **axes) | |
if x.ndim == kdata.data.ndim - 2: | |
return xs.squeeze(-4).squeeze(-1) | |
elif x.ndim == kdata.data.ndim - 1: | |
return xs.squeeze(-4) | |
else: | |
return xs | |
data = rearrange(kdata.data.clone()) | |
trajectory = KTrajectory(*rearrange(kdata.traj.as_tensor(-4)).unbind(-4)) | |
header = deepcopy(kdata.header) | |
header.acq_info = _rapply(header.acq_info, {torch.Tensor: rearrange}) | |
return KData(data=data, header=header, traj=trajectory) | |
def split_dyn(kdata, ndyn): | |
pattern = "... c k2 (ndyn k1) k0 -> ... ndyn c k2 k1 k0" | |
return _reshape_kdata(kdata, pattern, ndyn=ndyn) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment