Skip to content

Instantly share code, notes, and snippets.

@skrawcz
Last active August 6, 2022 22:22
Show Gist options
  • Save skrawcz/6fc5df9a5fd538e9b7070453658eeb6b to your computer and use it in GitHub Desktop.
Save skrawcz/6fc5df9a5fd538e9b7070453658eeb6b to your computer and use it in GitHub Desktop.
Code to get Hamilton to run asynchronously -- no parallelization
from typing import Any
import pandas as pd
"""
Notes:
1. This file is used for all the [ray|dask|spark]/hello_world examples.
2. It therefore show cases how you can write something once and not only scale it, but port it
to different frameworks with ease!
"""
import asyncio
import time
import logging
logger = logging.getLogger(__name__)
async def _log_result(result: Any) -> None:
"""simulates logging asynchronously somewhere"""
logger.warning(f'started {result}')
await asyncio.sleep(2)
logger.warning(f'finished {result}')
async def log_result(avg_3wk_spend: pd.Series) -> dict:
await _log_result(avg_3wk_spend)
return {'result': True}
def avg_3wk_spend(spend: pd.Series) -> pd.Series:
"""Rolling 3 week average spend."""
# print("here")
# result = event_loop.create_task(_log_result('foo'))
# this will schedule the task, be we cannot await it because we are in a non-async function.
# print("done", result)
return spend.rolling(3).mean()
# loop = asyncio.get_running_loop()
# result = loop.create_task(_internal_function('foo'))
# result = await _internal_function('foo')
# loop = asyncio.get_event_loop()
# loop.(asyncio.sleep(10))
# future = event_loop.create_task(asyncio.sleep(10))
# result = event_loop.run_until_complete(future)
# result = asyncio.run(asyncio.sleep(5))
# while not result.done():
# time.sleep(1)
# print('sleeping')
def spend_per_signup(spend: pd.Series, signups: pd.Series) -> pd.Series:
"""The cost per signup in relation to spend."""
return spend / signups
def spend_mean(spend: pd.Series) -> float:
"""Shows function creating a scalar. In this case it computes the mean of the entire column."""
return spend.mean()
def spend_zero_mean(spend: pd.Series, spend_mean: float) -> pd.Series:
"""Shows function that takes a scalar. In this case to zero mean spend."""
return spend - spend_mean
def spend_std_dev(spend: pd.Series) -> float:
"""Function that computes the standard deviation of the spend column."""
return spend.std()
def spend_zero_mean_unit_variance(spend_zero_mean: pd.Series, spend_std_dev: float) -> pd.Series:
"""Function showing one way to make spend have zero mean and unit variance."""
return spend_zero_mean / spend_std_dev
import asyncio
import typing
from types import ModuleType
from typing import Union, Collection, Dict, Any, List
import pandas as pd
from fastapi import FastAPI
from hamilton import driver, base, graph, node
import my_functions
import logging
logger = logging.getLogger(__name__)
app = FastAPI()
class AsyncDictResultBuilder(base.ResultMixin):
"""Simple function that returns the dict of column -> value results."""
@staticmethod
async def build_result(**outputs: typing.Dict[str, typing.Any]) -> typing.Dict:
"""This function builds a simple dict of output -> computed values."""
return outputs
class AsyncSimplePythonGraphAdapter(base.SimplePythonDataFrameGraphAdapter):
"""This class allows you to swap out the build_result very easily."""
def __init__(self, result_builder: base.ResultMixin):
self.result_builder = result_builder
if self.result_builder is None:
raise ValueError('You must provide a ResultMixin object for `result_builder`.')
async def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any:
if asyncio.iscoroutinefunction(node.callable):
return await node.callable(**kwargs)
return node.callable(**kwargs)
async def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
"""Delegates to the result builder function supplied."""
return await self.result_builder.build_result(**outputs)
class AsyncFunctionGraph(graph.FunctionGraph):
async def async_execute_static(nodes: Collection[node.Node],
inputs: Dict[str, Any],
adapter: base.HamiltonGraphAdapter,
computed: Dict[str, Any] = None,
overrides: Dict[str, Any] = None):
"""Executes computation on the given graph, inputs, and memoized computation.
Effectively this is a "private" function and should be viewed as such.
To override a value, utilize `overrides`.
To pass in a value to ensure we don't compute data twice, use `computed`.
Don't use `computed` to override a value, you will not get the results you expect.
:param nodes: the graph to traverse for execution.
:param inputs: the inputs provided. These will only be called if a node is "user-defined"
:param adapter: object that adapts execution based on context it knows about.
:param computed: memoized storage to speed up computation. Usually an empty dict.
:param overrides: any inputs we want to user to override actual computation
:return: the passed in dict for memoized storage.
"""
if overrides is None:
overrides = {}
if computed is None:
computed = {}
async def dfs_traverse(node: node.Node, dependency_type: node.DependencyType = node.DependencyType.REQUIRED):
if node.name in computed:
return
if node.name in overrides:
computed[node.name] = overrides[node.name]
return
for n in node.dependencies:
if n.name not in computed:
_, node_dependency_type = node.input_types[n.name]
await dfs_traverse(n, node_dependency_type)
logger.debug(f'Computing {node.name}.')
if node.user_defined:
if node.name not in inputs:
if dependency_type != node.DependencyType.OPTIONAL:
raise NotImplementedError(f'{node.name} was expected to be passed in but was not.')
return
value = inputs[node.name]
else:
kwargs = {} # construct signature
for dependency in node.dependencies:
if dependency.name in computed:
kwargs[dependency.name] = computed[dependency.name]
try:
value = await adapter.execute_node(node, kwargs)
except Exception as e:
logger.exception(f'Node {node.name} encountered an error')
raise
computed[node.name] = value
for final_var_node in nodes:
dep_type = node.DependencyType.REQUIRED
if final_var_node.user_defined:
# from the top level, we don't know if this UserInput is required. So mark as optional.
dep_type = node.DependencyType.OPTIONAL
await dfs_traverse(final_var_node, dep_type)
return computed
async def execute(self,
nodes: Collection[node.Node] = None,
computed: Dict[str, Any] = None,
overrides: Dict[str, Any] = None,
inputs: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Executes the DAG, given potential inputs/previously computed components.
:param nodes: Nodes to compute
:param computed: Nodes that have already been computed
:param overrides: Overrides for nodes in the DAG
:param inputs: Inputs to the DAG -- have to be disjoint from config.
:return: The result of executing the DAG (a dict of node name to node result)
"""
if nodes is None:
nodes = self.get_nodes()
if inputs is None:
inputs = {}
return await AsyncFunctionGraph.async_execute_static(
nodes=nodes,
inputs=AsyncFunctionGraph.combine_config_and_inputs(self.config, inputs),
adapter=self.adapter,
computed=computed,
overrides=overrides,
)
class AsyncDriver(driver.Driver):
def __init__(self, config: Dict[str, Any], *modules: ModuleType, adapter: base.HamiltonGraphAdapter = None):
"""Constructor: creates a DAG given the configuration & modules to crawl.
:param config: This is a dictionary of initial data & configuration.
The contents are used to help create the DAG.
:param modules: Python module objects you want to inspect for Hamilton Functions.
:param adapter: Optional. A way to wire in another way of "executing" a hamilton graph.
Defaults to using original Hamilton adapter which is single threaded in memory python.
"""
super().__init__(config, *modules, adapter=adapter)
try:
self.graph = AsyncFunctionGraph(*modules, config=config, adapter=self.adapter)
except Exception as e:
logger.error(driver.SLACK_ERROR_MESSAGE)
raise e
async def execute(self,
final_vars: List[str],
overrides: Dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None,
) -> Any:
"""Executes computation.
:param final_vars: the final list of variables we want to compute.
:param overrides: values that will override "nodes" in the DAG.
:param display_graph: DEPRECATED. Whether we want to display the graph being computed.
:param inputs: Runtime inputs to the DAG.
:return: an object consisting of the variables requested, matching the type returned by the GraphAdapter.
See constructor for how the GraphAdapter is initialized. The default one right now returns a pandas
dataframe.
"""
if display_graph:
logger.warning('display_graph=True is deprecated. It will be removed in the 2.0.0 release. '
'Please use visualize_execution().')
try:
outputs = await self.raw_execute(final_vars, overrides, display_graph, inputs=inputs)
return await self.adapter.build_result(**outputs)
except Exception as e:
logger.error(driver.SLACK_ERROR_MESSAGE)
raise e
async def raw_execute(self,
final_vars: List[str],
overrides: Dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None) -> Dict[str, Any]:
"""Raw execute function that does the meat of execute.
It does not try to stitch anything together. Thus allowing wrapper executes around this to shape the output
of the data.
:param final_vars: Final variables to compute
:param overrides: Overrides to run.
:param display_graph: DEPRECATED. DO NOT USE. Whether or not to display the graph when running it
:param inputs: Runtime inputs to the DAG
:return:
"""
nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs)
self.validate_inputs(user_nodes, inputs) # TODO -- validate within the function graph itself
if display_graph: # deprecated flow.
logger.warning('display_graph=True is deprecated. It will be removed in the 2.0.0 release. '
'Please use visualize_execution().')
self.visualize_execution(final_vars, 'test-output/execute.gv', {'view': True})
if self.has_cycles(final_vars): # here for backwards compatible driver behavior.
raise ValueError('Error: cycles detected in you graph.')
memoized_computation = dict() # memoized storage
computed = await self.graph.execute(nodes, memoized_computation, overrides, inputs)
outputs = {c: computed[c] for c in final_vars} # only want request variables in df.
del memoized_computation # trying to cleanup some memory
return outputs
async_adapter = AsyncSimplePythonGraphAdapter(AsyncDictResultBuilder())
dr = AsyncDriver({}, my_functions, adapter=async_adapter)
@app.get("/")
async def read_root():
result = await dr.execute(['log_result'], inputs={'spend': pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])})
return result
# @app.get("/items/{item_id}")
# async def read_item(item_id: int, q: Union[str, None] = None):
# return {"item_id": item_id, "q": q}
if __name__ == '__main__':
import uvicorn
uvicorn.run(app)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment