Skip to content

Instantly share code, notes, and snippets.

@rldotai
Last active January 7, 2019 04:31
Show Gist options
  • Save rldotai/de982d90cbce548b027c0704047c8c52 to your computer and use it in GitHub Desktop.
Save rldotai/de982d90cbce548b027c0704047c8c52 to your computer and use it in GitHub Desktop.
Convert objects (including iterables of objects) to dictionaries mapping `sympy` symbols to the objects values.
import itertools
def is_iterable(x) -> bool:
"""Return `True` if `x` is iterable."""
try:
iter(x)
return True
except TypeError:
return False
def nditer(obj, iter_dict=True, use_dict_keys=False):
"""Given an object, flatten the iterables it contains into a single
(generated) sequence, while recording the n-dimensional coordinates.
It is similar to `numpy.ndenumerate` but supports sequences of variable
sizes (even infinite sizes, as it returns a generator) rather than what can
be represented as an array, and it doesn't modify datatypes
One caveat is with regards to string handling, because strings are indeed
iterables even if they are of zero length, or contain a single letter.
For simplicity (and following `numpy.ndenumerate`) they are treated as
terminal elements, and not expanded, although that behavior could be changed
in the future and will be available given the appropriate flag.
Another is with regards to dictionary handling, because while a `dict` is
an iterable, if we don't include the key as part of "coordinates", we lose
information. On the other hand, if we do incorporate the keys into the
coordinates, then we lose the ability to specify elements solely in terms of
integers. Keeping both helps no-one, because it alters the apparent depth
of the terminal elements in the sequence.
Consider:
`{'a': 1, 'b': [23, 32]}`
--> `[((0, 'a'), 1), ((1, 'b', 0), 23), (1, 'b', 1, 33)`
By default the keys are disregarded, although they can be used instead of
integer coordinates by setting a flag, as can the behavior of iterating over
dictionaries at all.
Parameters
----------
obj: A Python object.
Any Python object, although only iterables really make sense as input.
iter_dict: bool, optional
Iterate over dictionaries rather than just returning them as if they
were a non-iterable element. Defaults to `True`.
use_dict_keys: bool, optional
Whether to use dictionary keys in place of coordinates, only relevant if
the object contains dicts. Defaults to `False`
Returns
-------
out: generator[tuple]
A generator containing `(coordinate, element)` pairs, where each
`coordinate` is a tuple of nonnegative integers specifying the position
of the associated `element` from the iteration over the input `obj`.
See Also
--------
ndenum: A similar function that returns a list instead of a generator.
Examples
--------
See the examples for `ndenum`, which essentially wraps this function's
output with a call to `list()`.
"""
def func(elem, prefix):
"""A recursive generator function for unpacking sequences"""
if not is_iterable(elem):
yield (prefix, elem)
else:
# Handle special types of iterables
if isinstance(elem, str):
yield (prefix, elem)
# Handling dictionaries as a special case
elif isinstance(elem, dict):
if iter_dict:
if use_dict_keys:
for ix, x in elem.items():
yield from func(x, prefix+(ix,))
else:
for ix, x in enumerate(elem.values()):
yield from func(x, prefix+(ix,))
else:
yield (prefix, elem)
# Handle all other types of iterables the same way
else:
for ix, x in enumerate(elem):
yield from func(x, prefix+(ix,))
return func(obj, tuple())
def ndenum(obj, iter_dict=True, use_dict_keys=False):
"""Given an object, flatten the iterables it contains into a single
sequence, while recording the n-dimensional coordinates.
Similar to `numpy.ndenumerate` but supports sequences of variable sizes
rather than what can be represented as an array, and doesn't modify
datatypes.
See the documentation for `nditer` for more details.
Parameters
----------
obj: A Python object.
Any Python object, although only iterables really make sense as input.
iter_dict: bool, optional
Iterate over dictionaries rather than just returning them as if they
were a non-iterable element. Defaults to `True`.
use_dict_keys: bool, optional
Whether to use dictionary keys in place of coordinates, only relevant if
the object contains dicts. Defaults to `False`
Returns
-------
out: List[Tuple]
A list containing `(coordinate, element)` pairs, where each
`coordinate` is a tuple of nonnegative integers specifying the position
of the associated `element` from the iteration over the input `obj`.
See Also
--------
nditer: A similar function that returns a list instead of a generator.
Examples
--------
>>> nditer(2)
Enumerate over lists or arrays:
>>> ndenum([[3, 2], [4, 1]])
[((0, 0), 3), ((0, 1), 2), ((1, 0), 4), ((1, 1), 1)]
>>> ndenum([2])
[((0), 2)]
A "deficient" 3-dimensional array:
>> ndenum([[[3, 2], [4, 1]],[[8], [6, 3, 3]]])
[((0, 0, 0), 3),
((0, 0, 1), 2),
((0, 1, 0), 4),
((0, 1, 1), 1),
((1, 0, 0), 8),
((1, 1, 0), 6),
((1, 1, 1), 3),
((1, 1, 2), 3)]
Given non-iterable inputs, it still returns a result, although how useful
this might be is debatable:
>>> ndenum(2)
[((), 2)]
"""
return list(nditer(obj, iter_dict=iter_dict, use_dict_keys=use_dict_keys))
import sympy
import sympy.matrices.matrices
from sympy import sympify
from ndenum import ndenum, nditer
def symbolify(obj, prefix):
"""Map an object (which could be an array or a constant) to symbols.
Useful for substitutions when working with `sympy`.
Examples
--------
>>> symbolify([[3, 2], [4, 1]], 'c')
{'c_0_0': 3, 'c_0_1': 2, 'c_1_0': 4, 'c_1_1': 1}
>>> symbolify([[3, 2], [4]], 'd')
{'d_0_0': 3, 'd_0_1': 2, 'd_1_0': 4}
"""
ret = {}
try:
# Sympy has somewhat weird indexing behavior, but behaves "properly"
# under `numpy.ndenumerate`
if isinstance(obj, sympy.matrices.MatrixBase):
iterable = np.ndenumerate(obj)
else:
iterable = nditer(obj)
# Iterate over the elements and assign them the appropriate suffix
for ixs, x in iterable:
suffix = "_".join(map(str, ixs))
# If the object was iterable, assign it an indexed suffix
if suffix:
name = prefix + "_" + suffix
else:
# Otherwise, we assume it is a constant
name = prefix
ret[name] = sympify(x)
except Exception as e:
raise(e)
return ret
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment