Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Custom input set contexts with dict notation for dataflow run
diff --git a/dffml/df/memory.py b/dffml/df/memory.py
index ee6bcf4e..c5b74f62 100644
--- a/dffml/df/memory.py
+++ b/dffml/df/memory.py
@@ -1288,6 +1288,10 @@ class MemoryOrchestratorContext(BaseOrchestratorContext):
ctx: Optional[BaseInputSetContext] = None,
input_set: Optional[Union[List[Input], BaseInputSet]] = None,
) -> BaseInputSetContext:
+ if ctx is not None and not isinstance(ctx, BaseInputSetContext):
+ raise TypeError(
+ f"ctx {ctx} is of type {type(ctx)}, should be BaseInputSetContext"
+ )
self.logger.debug("Seeding dataflow with input_set: %s", input_set)
if input_set is None:
# Create a list if extra inputs were not given
@@ -1378,7 +1382,9 @@ class MemoryOrchestratorContext(BaseOrchestratorContext):
await self.forward_inputs_to_subflow(input_set)
ctxs.append(
await self.seed_inputs(
- ctx=StringInputSetContext(ctx_string),
+ ctx=StringInputSetContext(ctx_string)
+ if isinstance(ctx_string, str)
+ else ctx_string,
input_set=input_set,
)
)
diff --git a/tests/test_df.py b/tests/test_df.py
index d496453f..012d01d6 100644
--- a/tests/test_df.py
+++ b/tests/test_df.py
@@ -145,6 +145,10 @@ class TestMemoryOperationImplementationNetwork(AsyncTestCase):
await ctx.run(None, None, add.op, {"numbers": [40, 2]})
+class CustomInputSetContext(StringInputSetContext):
+ pass
+
+
class TestOrchestrator(AsyncTestCase):
"""
create_octx and run exist so that we can subclass from them in
@@ -163,7 +167,11 @@ class TestOrchestrator(AsyncTestCase):
async def test_run(self):
calc_strings_check = {"add 40 and 2": 42, "multiply 42 and 10": 420}
# TODO(p0) Implement and test asyncgenerator
- callstyles_no_expand = ["asyncgenerator", "dict"]
+ callstyles_no_expand = [
+ "asyncgenerator",
+ "dict",
+ "dict_custom_input_set_context",
+ ]
callstyles = {
"dict": {
to_calc: [
@@ -177,6 +185,18 @@ class TestOrchestrator(AsyncTestCase):
]
for to_calc in calc_strings_check.keys()
},
+ "dict_custom_input_set_context": {
+ CustomInputSetContext(to_calc): [
+ Input(
+ value=to_calc, definition=parse_line.op.inputs["line"]
+ ),
+ Input(
+ value=[add.op.outputs["sum"].name],
+ definition=GetSingle.op.inputs["spec"],
+ ),
+ ]
+ for to_calc in calc_strings_check.keys()
+ },
"list_input_sets": [
MemoryInputSet(
MemoryInputSetConfig(
@@ -228,6 +248,10 @@ class TestOrchestrator(AsyncTestCase):
),
)
else:
+ if callstyle == "dict_custom_input_set_context":
+ self.assertTrue(
+ isinstance(ctx, CustomInputSetContext)
+ )
self.assertEqual(
calc_strings_check[ctx_str],
results[add.op.outputs["sum"].name],
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment