Skip to content

Instantly share code, notes, and snippets.

@JakeTheCorn
Last active April 5, 2020 16:53
Show Gist options
  • Save JakeTheCorn/857e822c7e2b475b75b68d8bc1d90c13 to your computer and use it in GitHub Desktop.
Save JakeTheCorn/857e822c7e2b475b75b68d8bc1d90c13 to your computer and use it in GitHub Desktop.
This is for traversing arbitrary values, maybe trees, and applying a transformation to each node.
import inspect
from typing import (Callable, Optional, Any)
import unittest
def build_traverser(*, visitor: Callable[[Any, Optional[str]], Any]) -> Callable[[Any], Any]:
func = visitor
# would prefer to insert a map here instead positional args. Don't have time for the moment on applying this here.
if len(inspect.signature(func).parameters.values()) < 2:
def visitor_wrapper(val, _key=None):
return visitor(val)
func = visitor_wrapper
def traverse(val):
if isinstance(val, dict):
collector = {}
for prop in val:
collector[prop] = traverse(func(val[prop], prop))
return collector
if isinstance(val, list) and len(val) > 0:
items = []
for item in val:
items.append(traverse(item))
return items
return func(val)
return traverse
class Test(unittest.TestCase):
def test_objects(self):
def email_visitor(node, key=None):
if key == 'email':
return {'active':1}
return node
email_traverser = build_traverser(visitor=email_visitor)
obj = {'key': 1, 'email': {}}
result = email_traverser(obj)
expectation = {'key': 1, 'email': {'active': 1}}
self.assertEqual(result, expectation)
def test_lists(self):
def email_visitor(node, key=None):
if key == 'email':
return {'active':1}
return node
email_traverser = build_traverser(visitor=email_visitor)
arg = [{'email': {}}]
result = email_traverser(arg)
expectation = [{'email': {'active': 1}}]
self.assertEqual(result, expectation)
def test_primitives(self):
def primitive_visitor(node):
if node == 1:
return 'ONE'
if node == 'a':
return 'A'
return node
primitive_traverse = build_traverser(visitor=primitive_visitor)
tables = [
(1, 'ONE'),
('a', 'A')
]
for arg, expectation in tables:
result = primitive_traverse(arg)
self.assertEqual(result, expectation)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment