Skip to content

Instantly share code, notes, and snippets.

View fzimmermann89's full-sized avatar

Felix F Zimmermann fzimmermann89

View GitHub Profile
@fzimmermann89
fzimmermann89 / grid_sample_adj.py
Last active April 2, 2024 11:49
adjoint linear operator for torch.grid_sample
class AdjointGridSample(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
y: torch.Tensor,
grid: torch.Tensor,
xshape: Sequence[int],
interpolation_mode: Literal['bilinear', 'nearest', 'bicubic'] = 'bilinear',
padding_mode: Literal['zeros', 'border', 'reflection'] = 'zeros',
align_corners: bool = True,
@fzimmermann89
fzimmermann89 / projection.py
Last active February 26, 2024 19:37
superres matrix
import torch
from scipy.spatial.transform import Rotation
from typing import Callable, Literal
import einops
import itertools
from torch import Tensor
class MatrixMultiplication(torch.autograd.Function):
"""Helper to do a matrix multiplication if we know the adjoint of the matrix"""
import dataclasses
from typing import Type, Callable, Any
def _rapply(obj:Any, functions_per_types: dict[Type | list[Type], Callable]):
""" Apply callables to all fields of a dataclass and keys of a dictionary, recursively.
Which callable is used is determined by the type. By default, the function will recurse into dataclasses.fields and dictionaries, respectively, returning new dataclass instances or dictionaries, respectively. This can be disabled by, for example, adding {dict:lambda x:x} to functions_per_types, as these take precedence.
Example:
>>> _rapply({'a':1, 'string':'not changed'}, {int:lambda x:x+1})
@fzimmermann89
fzimmermann89 / slice_profile.py
Created February 6, 2024 16:19
pulseq slice profile
import math
from pathlib import Path
from datetime import datetime
import pypulseq as pp
### SETTINGS ###
timestamp = datetime.now().strftime("%y%m%d-%H%M%S")
filename = Path(f"profile_{timestamp}.seq")
@fzimmermann89
fzimmermann89 / inati.py
Created January 9, 2024 19:30
inati_iterative
import warnings
import torch
import numpy as np
def _filter_separable(x, kernels, axis):
"""Apply the separable filter kernels to the tensor x along the axes axis.
Does zero-padding to keep the output the same size as the input.
Parameters
@fzimmermann89
fzimmermann89 / sliding_window.py
Created January 8, 2024 09:45
pytorch sliding window
import warnings
from typing import Sequence
import torch
import numpy as np
# fzimmermann89, felix.zimmermann@ptb.de, 2024
def sliding_window(x:torch.Tensor, window_shape:int|Sequence[int], axis:None|int|Sequence[int]|=None):
"""Sliding window into the tensor x.
Returns a view into the tensor x that represents a sliding window.
@fzimmermann89
fzimmermann89 / mri_demo.ipynb
Last active November 14, 2022 09:36
MRI Demo for M4AIM Presentation @ PTB Berlin. (c) Felix Zimmermann, MIT License
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@fzimmermann89
fzimmermann89 / lut.py
Last active September 7, 2021 09:47
lut
import torch
from torch import nn
from typing import Tuple, Callable
class LUT(nn.Module):
def __init__(self, f: Callable, dx: float, xrange: Tuple[float, float], mode: str = "linear"):
"""
LUT of values of a function
f: function to use, does not need to be differentiable
@fzimmermann89
fzimmermann89 / utilities_lv65.py
Last active July 28, 2021 16:54
utilities_lv65
import numpy as np
### Requested pixel2q and q2pixel ####
def pixel2q(pixel, E_ev, detz_m, pixelsize_m=75e-6):
"""
returns q in reciprocal nm
E_ev: photon Energy in eV
detz_m: detector distance in m
pixelsize_m: detector pixelsize in m
@fzimmermann89
fzimmermann89 / simplecorrelator.ipynb
Last active July 28, 2021 14:06
simplecorrelator
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.