Last active
April 5, 2020 16:53
-
-
Save JakeTheCorn/857e822c7e2b475b75b68d8bc1d90c13 to your computer and use it in GitHub Desktop.
This is for traversing arbitrary values, maybe trees, and applying a transformation to each node.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from typing import (Callable, Optional, Any) | |
import unittest | |
def build_traverser(*, visitor: Callable[[Any, Optional[str]], Any]) -> Callable[[Any], Any]: | |
func = visitor | |
# would prefer to insert a map here instead positional args. Don't have time for the moment on applying this here. | |
if len(inspect.signature(func).parameters.values()) < 2: | |
def visitor_wrapper(val, _key=None): | |
return visitor(val) | |
func = visitor_wrapper | |
def traverse(val): | |
if isinstance(val, dict): | |
collector = {} | |
for prop in val: | |
collector[prop] = traverse(func(val[prop], prop)) | |
return collector | |
if isinstance(val, list) and len(val) > 0: | |
items = [] | |
for item in val: | |
items.append(traverse(item)) | |
return items | |
return func(val) | |
return traverse | |
class Test(unittest.TestCase): | |
def test_objects(self): | |
def email_visitor(node, key=None): | |
if key == 'email': | |
return {'active':1} | |
return node | |
email_traverser = build_traverser(visitor=email_visitor) | |
obj = {'key': 1, 'email': {}} | |
result = email_traverser(obj) | |
expectation = {'key': 1, 'email': {'active': 1}} | |
self.assertEqual(result, expectation) | |
def test_lists(self): | |
def email_visitor(node, key=None): | |
if key == 'email': | |
return {'active':1} | |
return node | |
email_traverser = build_traverser(visitor=email_visitor) | |
arg = [{'email': {}}] | |
result = email_traverser(arg) | |
expectation = [{'email': {'active': 1}}] | |
self.assertEqual(result, expectation) | |
def test_primitives(self): | |
def primitive_visitor(node): | |
if node == 1: | |
return 'ONE' | |
if node == 'a': | |
return 'A' | |
return node | |
primitive_traverse = build_traverser(visitor=primitive_visitor) | |
tables = [ | |
(1, 'ONE'), | |
('a', 'A') | |
] | |
for arg, expectation in tables: | |
result = primitive_traverse(arg) | |
self.assertEqual(result, expectation) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment