Skip to content

Instantly share code, notes, and snippets.

@plammens
Last active January 2, 2021 01:49
Show Gist options
  • Save plammens/45c84971bb0132b7f9ddb457906817d4 to your computer and use it in GitHub Desktop.
Save plammens/45c84971bb0132b7f9ddb457906817d4 to your computer and use it in GitHub Desktop.
Recursive apply: apply a function recursively to all sub-objects
"""
Algorithm to recursively apply a function to a composite object.
Dependencies:
- multimethod (https://pypi.org/project/multimethod/)
"""
# MIT License
#
# Copyright (c) 2020 Paolo Lammens
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import copy
import dataclasses
import itertools
import sys
import warnings
from typing import Any, Callable, Dict, List, TypeVar
import multimethod
from multimethod import isa
class RecursionWarning(UserWarning):
pass
T = TypeVar("T")
def recursive_apply(
obj: T, predicate: Callable[[Any], bool], function: Callable[[Any], Any]
) -> T:
"""
Recursively apply a function to all sub-objects that satisfy a predicate.
:return: Shallow copy of the object with replaced sub-objects, or the same object
unmodified if no sub-objects satisfying the predicate were found.
"""
# mapping of sub-object id to the transformed sub-object
processed: Dict[int, Any] = {}
# placeholder for cyclic references
cyclic_reference_placeholder = object()
# map of referred object id to patch function to update a cyclic reference
patches: Dict[int, List[Callable[[Any], None]]] = {}
# LIFO stack of objects being processed (id(o): o); used to detect reference cycles
processing: Dict[int, Any] = {}
def apply(o: Any, patch_func: Callable[[Any], None] = None):
if id(o) in processing:
# there is a reference cycle; we try to deal with it by returning a
# placeholder for now and leaving a memo to remember updating this
# reference when the referred object has been fully processed too
cycle = " --> ".join(map(repr, itertools.chain(processing.values(), [o])))
if patch_func is None:
# no patch function available; can't deal with the cyclic reference
raise RecursionError(
f"Cycle detected and I'm not able to deal with it:\n{cycle}"
)
# determine stacklevel for warning
stacklevel = _compute_stacklevel(public_call_site=recursive_apply)
warnings.warn(
f"Attempting to resolve detected cycle "
f"(this might have undesired side effects):\n{cycle}",
category=RecursionWarning,
stacklevel=stacklevel,
)
patches.setdefault(id(o), []).append(patch_func)
return cyclic_reference_placeholder
processing[id(o)] = o
try:
if id(o) in processed:
maybe_transformed = processed[id(o)]
else:
maybe_transformed = _do_apply(o)
processed[id(o)] = maybe_transformed
# restore the correct value in any cyclic reference placeholders
if (patch_functions := patches.pop(id(o), None)) is not None:
# noinspection PyUnboundLocalVariable
for patch in patch_functions:
patch(maybe_transformed)
return maybe_transformed
finally:
processing.popitem() # relies on Python 3.7+ behaviour
# _do_apply is overloaded based on whether the object satisfies the predicate first,
# and its type. Overloads are checked in reverse order of registration.
@multimethod.overload
def _do_apply(o):
# fallback for other objects; do attribute lookup
def attribute_patch(attr: str):
def patch_func(maybe_transformed):
object.__setattr__(attribute_patch.the_object, attr, maybe_transformed)
return patch_func
filled_values = {}
for name, value in _instance_attributes(o).items():
if name.startswith("__"):
continue
transformed = apply(value, patch_func=attribute_patch(name))
if transformed is not value:
filled_values[name] = transformed
if not filled_values:
return o
if dataclasses.is_dataclass(o):
transformed = dataclasses.replace(o, **filled_values)
else:
obj_copy = copy.copy(o)
for name, value in filled_values.items():
setattr(obj_copy, name, value)
transformed = obj_copy
attribute_patch.the_object = transformed
return transformed
@_do_apply.register
def _do_apply(o: isa(dict)):
def key_patch(key: Any):
def patch_func(maybe_transformed):
d = key_patch.the_dict
# warning: this might change the order
value = d.pop(key)
d[maybe_transformed] = value
return patch_func
def value_patch(key: Any):
def patch_func(maybe_transformed):
value_patch.the_dict[key] = maybe_transformed
return patch_func
# noinspection PyArgumentList
transformed = type(o)(
(apply(k, patch_func=key_patch(k)), apply(v, patch_func=value_patch(v)))
for k, v in o.items()
)
value_patch.the_dict = key_patch.the_dict = transformed
# "maybe" is because the changes might just be cyclic reference placeholders
maybe_modified = not all(
k1 is k2 and v1 is v2
for (k1, v1), (k2, v2) in zip(o.items(), transformed.items())
)
return transformed if maybe_modified else o
@_do_apply.register
def _do_apply(o: isa(tuple, list, set, frozenset)):
if isinstance(o, list):
def item_patch(index: int):
def patch_func(maybe_transformed):
item_patch.the_list[index] = maybe_transformed
return patch_func
else:
item_patch = lambda i: None # noqa
transformed = type(o)(
apply(x, patch_func=item_patch(i)) for i, x in enumerate(o)
)
item_patch.the_list = transformed
# "maybe" is because the changes might just be cyclic reference placeholders
maybe_modified = not all(x is y for x, y in zip(o, transformed))
return transformed if maybe_modified else o
@_do_apply.register
def _do_apply(o: predicate):
return function(o)
result = apply(obj)
assert len(patches) == 0
return result
def _instance_attributes(obj: Any) -> Dict[str, Any]:
"""Get a name-to-value dictionary of instance attributes of an arbitrary object."""
try:
return vars(obj)
except TypeError:
pass
# object doesn't have __dict__, try with __slots__
try:
slots = obj.__slots__
except AttributeError:
# doesn't have __dict__ nor __slots__, probably a builtin like str or int
return {}
# collect all slots attributes (some might not be present)
attrs = {}
for name in slots:
try:
attrs[name] = getattr(obj, name)
except AttributeError:
continue
return attrs
def _compute_stacklevel(public_call_site: callable) -> int:
"""Compute the stacklevel necessary to emit a warning at the "public" call site."""
function = (
public_call_site
if isinstance(public_call_site, types.FunctionType) # noqa
else public_call_site.__call__
)
stacklevel = 2
frame = sys._getframe() # noqa
while frame.f_code is not function.__code__:
frame = frame.f_back
stacklevel += 1
return stacklevel - 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment