Skip to content

Instantly share code, notes, and snippets.

@kohr-h
Created January 9, 2018 13:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kohr-h/a7626ea18bef88f237ecfa3698e36369 to your computer and use it in GitHub Desktop.
Save kohr-h/a7626ea18bef88f237ecfa3698e36369 to your computer and use it in GitHub Desktop.
ODL conversion between tensor spaces and power spaces
def as_tensor_space(pspace, axis=None):
"""Convert a `ProductSpace` of `TensorSpace`'s to a tensor space.
Parameters
----------
pspace : `ProductSpace`
Power space (with arbitrary shape) whose base is a `TensorSpace`.
axis : int or sequence of int, optional
Indices at which the powers should be inserted as new axes.
For the default ``None``, the power axes are added to the left.
Examples
--------
>>> pspace = odl.rn(3) ** 2
>>> as_tensor_space(pspace)
rn((2, 3))
>>> as_tensor_space(pspace, axis=1)
rn((3, 2))
>>> pspace = odl.rn(4) ** (2, 3)
>>> as_tensor_space(pspace)
rn((2, 3, 4))
>>> as_tensor_space(pspace, axis=(0, 2))
rn((2, 4, 3))
"""
assert isinstance(pspace, odl.ProductSpace) and pspace.is_power_space
power_shape = pspace.shape
power_ndim = len(pspace.shape)
base = pspace[(0,) * power_ndim]
assert isinstance(base, odl.space.base_tensors.TensorSpace)
if axis is None:
axis = list(range(power_ndim))
elif np.isscalar(axis):
axis = [axis]
assert len(axis) == power_ndim
newshape = []
i = j = 0
for nd in range(power_ndim + base.ndim):
if nd in axis:
newshape.append(power_shape[i])
i += 1
else:
newshape.append(base.shape[j])
j += 1
# TODO: This disregards weighting completely, needs fix!
return type(base)(newshape, dtype=base.dtype)
def as_power_space(tspace, axis=None):
"""Convert a `TensorSpace` to a `ProductSpace` smaller tensor spaces.
Parameters
----------
tspace : `TensorSpace`
Tensor space with ``ndim >= 1`` that should be converted to a
`ProductSpace`.
axis : int or sequence of int, optional
Indices of the axes that should be turned into powers.
For the default ``None``, the first axis is taken.
Examples
--------
>>> tspace = odl.rn((2, 3))
>>> as_power_space(tspace)
ProductSpace(rn(3), 2)
>>> as_power_space(tspace, axis=1)
ProductSpace(rn(2), 3)
>>> tspace = odl.rn((2, 3, 4))
>>> as_power_space(tspace)
ProductSpace(rn(3, 4), 2)
>>> as_power_space(tspace, axis=(0, 2))
ProductSpace(ProductSpace(rn(3), 4), 2)
"""
assert isinstance(tspace, odl.space.base_tensors.TensorSpace)
assert tspace.ndim >= 1
if axis is None:
axis = [0]
elif np.isscalar(axis):
axis = [axis]
else:
axis = list(axis)
remaining_axes = [i for i in range(tspace.ndim) if i not in axis]
removed_shape = [n for i, n in enumerate(tspace.shape) if i in axis]
return tspace.byaxis[remaining_axes] ** removed_shape
@MJLagerwerf
Copy link

MJLagerwerf commented Jan 11, 2018

dyn_space = odl.uniform_discr([0, -20, -20], [T, 20, 20], [T, 128, 128])
as_power_space(dyn_space)

Gives:

AttributeError: 'DiscreteLp' object has no attribute 'byaxis'

Changing the last line to:

return tspace.byaxis_in[remaining_axes] ** removed_shape

Fixes the problem for DiscreteLp spaces, but breaks it again for the rn example:

tspace = odl.rn((2, 3))
as_power_space(tspace)

Gives

'NumpyTensorSpace' object has no attribute 'byaxis_in'

Not a problem I cannot fix, but it might be a good thing to look at your naming scheme for spaces

@MJLagerwerf
Copy link

MJLagerwerf commented Jan 11, 2018

Moreover the following code does not seem to work:

obj_space = odl.uniform_discr([ -20, -20], [20, 20], [128, 128])
dyn_space_P = obj_space ** T
dyn_space = as_tensor_space(dyn_obj_space_P)

This gives:

TypeError: __init__() missing 2 required positional arguments: 'partition' and 'tspace'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment