Last active
February 26, 2024 19:37
-
-
Save fzimmermann89/22907277168feb0d98015b0fc34cf3c2 to your computer and use it in GitHub Desktop.
superres matrix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from scipy.spatial.transform import Rotation | |
from typing import Callable, Literal | |
import einops | |
import itertools | |
from torch import Tensor | |
class MatrixMultiplication(torch.autograd.Function): | |
"""Helper to do a matrix multiplication if we know the adjoint of the matrix""" | |
@staticmethod | |
def forward(ctx, x, matrix, matrixH): | |
ctx.save_for_backward(matrixH) | |
return matrix @ x | |
@staticmethod | |
def backward(ctx, grad_output): | |
(matrixH,) = ctx.saved_tensors | |
return matrixH @ grad_output, None, None | |
class SliceGaussian: | |
"""Gaussian Slice Profile""" | |
def __init__(self, fwhm: float | Tensor): | |
self.fwhm = torch.as_tensor(fwhm) | |
def __call__(self, x): | |
return torch.exp(-(x**2) / (0.36 * self.fwhm**2)) | |
class SliceSmoothedRect: | |
"""Rectangular Slice Profile with smooth flanks | |
The smaller n, the smoother it is. For n<1, the FWHM might be wrong | |
""" | |
def __init__(self, fwhm: float | Tensor, n: float | Tensor): | |
self.n = n | |
self.fwhm = fwhm | |
def __call__(self, x): | |
y = x * 2 / self.fwhm | |
return torch.erf(self.n * (1 - y)) + torch.erf(self.n * (1 + y)) | |
class SliceInterpolate: | |
"""Slice Profile based on Interpolation of Meassured Profile""" | |
def __init__(self, xs: Tensor, weights: Tensor): | |
self._xs = xs | |
self._weights = weights | |
def __call__(self, x): | |
return torch.as_tensor(np.interp(x, self._xs.numpy(), self.weights.numpy(), 0, 0)) | |
class SliceProjection(torch.nn.Module): | |
def __init__( | |
self, | |
input_shape: tuple[int, int, int], | |
slice_rotation: None | tuple[Rotation, ...] | Tensor, | |
slice_shift: float | Tensor = 0.0, | |
slice_profile: Callable[[Tensor], Tensor] | tuple[Callable[[Tensor], Tensor], ...] = SliceGaussian(2.0), | |
optimize_for: Literal["forward", "adjoint", "both"] = "both", | |
): | |
"""Create a module that represents the projection of a volume onto a plane | |
The projection will be done by sparse matrix multiplication. | |
Either the slice_fwhm representing the slice thickness of a gaussian slice or | |
the slice_x and slice_weight representing the slice profile must be given. | |
Parameters | |
---------- | |
input_shape: | |
Shape of the volume to sample from | |
slice_rotation | |
Rotation that describes the orientation of the plane as a quaternion. | |
If a tensor, it should be rotationa quaternions | |
slice_shift | |
Offset of the plane in the volume perpendicular plane from the center of the volume | |
slice_profile: | |
A function that called with a distance x from the slice center should return the | |
intensity along the slice thickness at x | |
optimize_for: Literal["forward", "adjoint", "both"] | |
Whether to optimize for forward or adjoint operation or both. | |
Optimizing for both takes more memory but is faster for both operations | |
""" | |
super().__init__() | |
if slice_rotation is None: | |
slice_rotation_quaternions = torch.tensor((0.0, 0.0, 0.0, 1)) | |
elif isinstance(slice_rotation, (tuple, list)) and all([isinstance(s, Rotation) for s in slice_rotation]): | |
slice_rotation_quaternions = torch.stack(R.as_quat for R in slice_rotation) | |
else: | |
slice_rotation_quaternions = torch.as_tensor(slice_rotation) | |
if not slice_rotation_quaternions.shape[-1] == 4: | |
raise ValueError("Rotation quaternions must have 4 components") | |
slice_rotation_quaternions = torch.atleast_2d(slice_rotation_quaternions) | |
slice_shift = torch.atleast_1d(torch.as_tensor(slice_shift)) | |
batch_shapes = torch.broadcast_shapes( | |
slice_rotation_quaternions.shape[:-1], | |
slice_shift.shape, | |
) | |
if not isinstance(slice_profile, (tuple, list)): | |
slice_profile = (slice_profile,) * np.prod(batch_shapes) | |
elif len(slice_profile) == 1 and np.prod(batch_shapes) > 1: | |
slice_profile = slice_profile * np.prod(batch_shapes) | |
elif len(slice_profile) != np.prod(batch_shapes): | |
raise ValueError("length of slice_profile does not match batch shapes") | |
m = max(input_shape) | |
ws = [] | |
for p in slice_profile: | |
# figure out how far along the profile we have to consider values | |
# clip up to 0.01 of intensity on both sides | |
r = torch.arange(-m, m) | |
pr = p(r) | |
cs = torch.cumsum(pr, -1) / pr.sum() | |
left = r[np.argmax(cs > 0.01)] | |
right = r[np.argmax(cs > 0.99)] | |
ws.append(int(max(left.abs(), right.abs()) + 1)) | |
slice_rotation_quaternions = slice_rotation_quaternions.expand(batch_shapes + (4,)).reshape(-1, 4) | |
slice_shift = slice_shift.expand(batch_shapes).reshape(-1, 1) | |
matrices = [ | |
SliceProjection.projection_matrix( | |
input_shape, | |
(m, m), | |
offset=torch.tensor([0.0, 0.0, shift]), | |
slice_function=f, | |
rotation_quaternion=quat, | |
w=int(w), | |
) | |
for quat, shift, f, w in zip(slice_rotation_quaternions, slice_shift, slice_profile, ws) | |
] | |
matrix = SliceProjection.join_matrices(matrices) | |
# in csr format the matmul is faster, but saving one for forward and adjoint takes more memory | |
if optimize_for == "forward": | |
self.matrix = matrix.to_sparse_csr() | |
self.matrixT = self.matrix.H | |
elif optimize_for == "adjoint": | |
self.matrixT = self.matrix.H.to_sparse_csr() | |
self.matrix = self.matrixT.H | |
elif optimize_for == "both": | |
self.matrix = matrix.to_sparse_csr() | |
self.matrixT = self.matrix.H.to_sparse_csr() | |
else: | |
raise ValueError("optimize_for must be one of 'forward', 'adjoint', 'both'") | |
self._range_shape = (*batch_shapes, m, m) | |
self._domain_shape = input_shape | |
def forward(self, x): | |
x = MatrixMultiplication().apply(x.ravel(), self.matrix, self.matrixT) | |
return x.reshape(self._range_shape) | |
def adjoint(self, x): | |
x = MatrixMultiplication().apply(x.ravel(), self.matrixT, self.matrix) | |
return x.reshape(self._domain_shape) | |
@staticmethod | |
def join_matrices(matrices): | |
values = [] | |
target = [] | |
source = [] | |
for i, m in enumerate(matrices): | |
if not m.shape == matrices[0].shape: | |
raise ValueError("all matrices should have the same shape") | |
c = m.coalesce() | |
(ctarget, csource) = c.indices() | |
values.append(c.values()) | |
source.append(csource) | |
ctarget = ctarget + i * m.shape[0] | |
target.append(ctarget) | |
matrix = torch.sparse_coo_tensor( | |
indices=torch.stack([torch.cat(target), torch.cat(source)]), | |
values=torch.cat(values), | |
dtype=torch.float32, | |
size=(len(matrices) * m.shape[0], m.shape[1]), | |
) | |
return matrix | |
@staticmethod | |
def projection_matrix( | |
input_shape: tuple[int, int, int], | |
output_shape: tuple[int, int], | |
rotation_quaternion: Tensor, | |
offset: Tensor, | |
w: int, | |
slice_function: Callable[[Tensor], Tensor], | |
rotation_center=None, | |
): | |
rotmat = torch.tensor(Rotation.from_quat(rotation_quaternion).as_matrix(), dtype=torch.float32) | |
def _rotate(vector, inverse=False): | |
if inverse: | |
return (rotmat.T @ vector.reshape(-1, 3, 1)).reshape(vector.shape) | |
else: | |
return (rotmat @ vector.reshape(-1, 3, 1)).reshape(vector.shape) | |
"""Create a sparse matrix that represents the projection of a volume onto a plane | |
Outside the volume values are approximatly zero padded | |
Parameters | |
---------- | |
input_shape: | |
Shape of the volume to sample from | |
output_shape: | |
Shape of the resulting plane | |
rotation_quaternion | |
Rotation that describes the orientation of the plane as a quaternion | |
offset: Tensor | |
Offset of the plane in the volume in plane coordinates from the center of the volume | |
w: int | |
Factor that determines the number of pixels that are considered in the projection along the slice profile direction. | |
slice_function: Callable | |
Function that describes the slice profile | |
rotation_center: Tensor | |
Center of rotation, if None the center of the volume is used, i.e. for 4 pixels 0 1 2 3 it is between 1 and 2 | |
Returns | |
------- | |
torch.sparse_coo_matrix | |
Sparse matrix that represents the projection of the volume onto the plane | |
""" | |
X, Y, Z = input_shape | |
x, y = output_shape # a xy plane | |
sx, sy = (X - x) // 2, (Y - y) // 2 # coordinates of the 2d output pixels in the coordinate system of the input volume, so shape (x,y,3) | |
pixel = torch.stack( | |
[ | |
*torch.meshgrid(torch.arange(sx, sx + x), torch.arange(sy, sy + y)), # x and y coordinates | |
Z / 2 * torch.ones(x, y), # z coordinates | |
], | |
dim=-1, | |
) | |
if offset is not None: | |
pixel = pixel + offset | |
if rotation_center is None: | |
# default rotation center is the center of the volume, i.e. for 4 pixels | |
# 0 1 2 3 it is between 0 and 1 | |
rotation_center = torch.tensor([X / 2 - 0.5, Y / 2 - 0.5, Z / 2 - 0.5]) | |
pixel_rotated = _rotate(pixel - rotation_center) + rotation_center | |
# We cast a ray from the pixel normal to the plane in both directions | |
# points in the original volume further away then w will not be considered | |
ray = _rotate( | |
torch.stack( | |
[ | |
torch.zeros(2 * w + 1), # X | |
torch.zeros(2 * w + 1), # Y | |
torch.arange(-w, w + 1), # Z | |
], | |
dim=-1, | |
) | |
) | |
# In all possible directions for each point aloing the line we consider the eight neighboring points | |
# by adding all possible combinations of 0 and 1 to the point and flooring | |
# (this is the same as adding -.5, .5 to the point and rounding) | |
offsets = torch.tensor(list(itertools.product([0, 1], repeat=3))) | |
# all points that influence a pixel | |
# x,y,8-neighbours,(2*w+1)-raylength,3-dimensions XYZ) | |
points_influencing_pixel = ( | |
einops.rearrange(pixel_rotated, " x y XYZ -> x y 1 1 XYZ") | |
+ einops.rearrange(ray, " ray XYZ -> 1 1 1 ray XYZ") | |
+ einops.rearrange(offsets, "neighbours XYZ -> 1 1 neighbours 1 XYZ") | |
).floor() | |
# directional distance in source volume coordinate system | |
distance = pixel_rotated[:, :, None, None, :] - points_influencing_pixel | |
# Inverse rotation projects this back to the original coordinate system, i.e | |
# Distance in z is distance along the line, i.e. the slice profile weighted direction | |
# Distance in x and y is the distance of a pixel to the ray and linear interpolation is used to weight the distance | |
distance_x, distance_y, distance_z = _rotate(distance, inverse=True).unbind(-1) | |
weight_xy = (1 - distance_x.abs()).clamp_min(0) * (1 - distance_y.abs()).clamp_min(0) | |
weight_z = slice_function(distance_z) | |
weight = (weight_xy * weight_z).reshape(x * y, -1) | |
source = einops.rearrange( | |
points_influencing_pixel, | |
"x y neighbours raylength XYZdim -> (x y) (neighbours raylength) XYZdim", | |
).int() | |
# mask of only potential source points inside the source volume | |
mask = (source[..., 0] < X) & (source[..., 0] >= 0) & (source[..., 1] < Y) & (source[..., 1] >= 0) & (source[..., 2] < Z) & (source[..., 2] >= 0) | |
# We need this at the edge of the volume to approximate zero padding | |
fraction_in_view = (mask * (weight > 0)).sum(-1) / (weight > 0).sum(-1) | |
source_index = torch.tensor(np.ravel_multi_index(source[mask].unbind(-1), (X, Y, Z))) | |
target_index = torch.repeat_interleave(torch.arange(x * y), mask.sum(-1)) | |
# Count duplicates. Coalesce will sum the values of duplicate indices | |
ones = torch.ones_like(source_index).float() | |
ones = torch.sparse_coo_tensor( | |
indices=torch.stack((target_index, source_index)), | |
values=ones, | |
size=(x * y, X * Y * Z), | |
dtype=torch.float32, | |
) | |
ones = ones.coalesce() | |
matrix = torch.sparse_coo_tensor( | |
indices=torch.stack((target_index, source_index)), | |
values=weight.reshape(x * y, -1)[mask], | |
size=(x * y, X * Y * Z), | |
dtype=torch.float32, | |
).coalesce() | |
# To avoid giving to much weight to points that are duplicated in our logic and summed up by coalesce | |
matrix.values()[:] /= ones.values() | |
# Normalize | |
norm = fraction_in_view / (matrix.sum(1).to_dense() + 1e-6) | |
matrix *= norm[:, None] | |
return matrix |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment