Last active
November 23, 2018 16:02
-
-
Save justinfay/243bf00f3e83e84b23fcc968689832d6 to your computer and use it in GitHub Desktop.
A simple graph query using generators
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def id_generator(): | |
i = 0 | |
def _generator(): | |
while True: | |
nonlocal i | |
yield i | |
i += 1 | |
return _generator() | |
class Graph: | |
def __init__(self, ids=None): | |
if ids is None: | |
ids = id_generator() | |
self.id_generator = ids | |
self.vertices = {} | |
self.edges = [] | |
def add_edge(self, in_, out): | |
""" | |
Add an edge to the graph. | |
""" | |
in_ = self.get_vertex_by_id(in_['_id']) | |
out = self.get_vertex_by_id(out['_id']) | |
in_['_out'].append(out) | |
out['_in'].append(in_) | |
def add_vertex(self, vertex): | |
""" | |
Add a vertex to the graph. | |
""" | |
vertex['_in'] = [] | |
vertex['_out'] = [] | |
id_ = next(self.id_generator) | |
vertex['_id'] = id_ | |
self.vertices[id_] = vertex | |
def add_vertices(self, *vertices): | |
for vertex in vertices: | |
self.add_vertex(vertex) | |
def get_vertex_by_id(self, id_): | |
""" | |
Get the vertex with the given ID. | |
""" | |
return self.vertices[id_] | |
def query(self): | |
""" | |
Return a query object. | |
""" | |
return Query(self) | |
class Query: | |
""" | |
A query on the graph. | |
""" | |
def __init__(self, graph): | |
self.steps = [iter(graph.vertices.values())] | |
def __getattr__(self, attr): | |
pipetype = PIPETYPES[attr] | |
def add_stage(*args, **kwargs): | |
self.steps.append(pipetype(self.steps[-1], *args, **kwargs)) | |
return self | |
return add_stage | |
def __iter__(self): | |
return self | |
def __next__(self): | |
return next(self.steps[-1]) | |
def all_(pipeline): | |
""" | |
Return the next elem in the pipeline. | |
""" | |
for obj in pipeline: | |
yield obj | |
def filter_attrs(pipeline, **kwargs): | |
""" | |
Graph.query().attrs(name='jim') | |
""" | |
for obj in pipeline: | |
for k, v in kwargs.items(): | |
if obj[k] != v: | |
continue | |
yield obj | |
def parents(pipeline): | |
for obj in pipeline: | |
yield from obj['_out'] | |
def children(pipeline): | |
for obj in pipeline: | |
yield from obj['_in'] | |
def select_attrs(pipeline, *fields): | |
for obj in pipeline: | |
yield {field: obj[field] for field in fields} | |
PIPETYPES = { | |
'all': all_, | |
'filter_attrs': filter_attrs, | |
'parents': parents, | |
'children': children, | |
'select_attrs': select_attrs, | |
} | |
def test(): | |
g = Graph() | |
jim = {'name': 'jim'} | |
bob = {'name': 'bob'} | |
tim = {'name': 'tim'} | |
pat = {'name': 'pat'} | |
g.add_vertices(jim, bob, tim, pat) | |
g.add_edge(jim, bob) | |
g.add_edge(bob, tim) | |
g.add_edge(bob, pat) | |
print(list( | |
g.query().filter_attrs(name='bob').parents().select_attrs('name'))) | |
if __name__ == "__main__": | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment