Created
January 9, 2018 13:51
-
-
Save kohr-h/a7626ea18bef88f237ecfa3698e36369 to your computer and use it in GitHub Desktop.
ODL conversion between tensor spaces and power spaces
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def as_tensor_space(pspace, axis=None): | |
"""Convert a `ProductSpace` of `TensorSpace`'s to a tensor space. | |
Parameters | |
---------- | |
pspace : `ProductSpace` | |
Power space (with arbitrary shape) whose base is a `TensorSpace`. | |
axis : int or sequence of int, optional | |
Indices at which the powers should be inserted as new axes. | |
For the default ``None``, the power axes are added to the left. | |
Examples | |
-------- | |
>>> pspace = odl.rn(3) ** 2 | |
>>> as_tensor_space(pspace) | |
rn((2, 3)) | |
>>> as_tensor_space(pspace, axis=1) | |
rn((3, 2)) | |
>>> pspace = odl.rn(4) ** (2, 3) | |
>>> as_tensor_space(pspace) | |
rn((2, 3, 4)) | |
>>> as_tensor_space(pspace, axis=(0, 2)) | |
rn((2, 4, 3)) | |
""" | |
assert isinstance(pspace, odl.ProductSpace) and pspace.is_power_space | |
power_shape = pspace.shape | |
power_ndim = len(pspace.shape) | |
base = pspace[(0,) * power_ndim] | |
assert isinstance(base, odl.space.base_tensors.TensorSpace) | |
if axis is None: | |
axis = list(range(power_ndim)) | |
elif np.isscalar(axis): | |
axis = [axis] | |
assert len(axis) == power_ndim | |
newshape = [] | |
i = j = 0 | |
for nd in range(power_ndim + base.ndim): | |
if nd in axis: | |
newshape.append(power_shape[i]) | |
i += 1 | |
else: | |
newshape.append(base.shape[j]) | |
j += 1 | |
# TODO: This disregards weighting completely, needs fix! | |
return type(base)(newshape, dtype=base.dtype) | |
def as_power_space(tspace, axis=None): | |
"""Convert a `TensorSpace` to a `ProductSpace` smaller tensor spaces. | |
Parameters | |
---------- | |
tspace : `TensorSpace` | |
Tensor space with ``ndim >= 1`` that should be converted to a | |
`ProductSpace`. | |
axis : int or sequence of int, optional | |
Indices of the axes that should be turned into powers. | |
For the default ``None``, the first axis is taken. | |
Examples | |
-------- | |
>>> tspace = odl.rn((2, 3)) | |
>>> as_power_space(tspace) | |
ProductSpace(rn(3), 2) | |
>>> as_power_space(tspace, axis=1) | |
ProductSpace(rn(2), 3) | |
>>> tspace = odl.rn((2, 3, 4)) | |
>>> as_power_space(tspace) | |
ProductSpace(rn(3, 4), 2) | |
>>> as_power_space(tspace, axis=(0, 2)) | |
ProductSpace(ProductSpace(rn(3), 4), 2) | |
""" | |
assert isinstance(tspace, odl.space.base_tensors.TensorSpace) | |
assert tspace.ndim >= 1 | |
if axis is None: | |
axis = [0] | |
elif np.isscalar(axis): | |
axis = [axis] | |
else: | |
axis = list(axis) | |
remaining_axes = [i for i in range(tspace.ndim) if i not in axis] | |
removed_shape = [n for i, n in enumerate(tspace.shape) if i in axis] | |
return tspace.byaxis[remaining_axes] ** removed_shape |
Moreover the following code does not seem to work:
obj_space = odl.uniform_discr([ -20, -20], [20, 20], [128, 128])
dyn_space_P = obj_space ** T
dyn_space = as_tensor_space(dyn_obj_space_P)
This gives:
TypeError: __init__() missing 2 required positional arguments: 'partition' and 'tspace'
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Gives:
Changing the last line to:
Fixes the problem for DiscreteLp spaces, but breaks it again for the
rn
example:Gives
Not a problem I cannot fix, but it might be a good thing to look at your naming scheme for spaces