Created
August 18, 2020 17:05
-
-
Save pdxjohnny/8da8f0c88d9fadeb8b9af000081fa1c9 to your computer and use it in GitHub Desktop.
Custom input set contexts with dict notation for dataflow run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/dffml/df/memory.py b/dffml/df/memory.py | |
index ee6bcf4e..c5b74f62 100644 | |
--- a/dffml/df/memory.py | |
+++ b/dffml/df/memory.py | |
@@ -1288,6 +1288,10 @@ class MemoryOrchestratorContext(BaseOrchestratorContext): | |
ctx: Optional[BaseInputSetContext] = None, | |
input_set: Optional[Union[List[Input], BaseInputSet]] = None, | |
) -> BaseInputSetContext: | |
+ if ctx is not None and not isinstance(ctx, BaseInputSetContext): | |
+ raise TypeError( | |
+ f"ctx {ctx} is of type {type(ctx)}, should be BaseInputSetContext" | |
+ ) | |
self.logger.debug("Seeding dataflow with input_set: %s", input_set) | |
if input_set is None: | |
# Create a list if extra inputs were not given | |
@@ -1378,7 +1382,9 @@ class MemoryOrchestratorContext(BaseOrchestratorContext): | |
await self.forward_inputs_to_subflow(input_set) | |
ctxs.append( | |
await self.seed_inputs( | |
- ctx=StringInputSetContext(ctx_string), | |
+ ctx=StringInputSetContext(ctx_string) | |
+ if isinstance(ctx_string, str) | |
+ else ctx_string, | |
input_set=input_set, | |
) | |
) | |
diff --git a/tests/test_df.py b/tests/test_df.py | |
index d496453f..012d01d6 100644 | |
--- a/tests/test_df.py | |
+++ b/tests/test_df.py | |
@@ -145,6 +145,10 @@ class TestMemoryOperationImplementationNetwork(AsyncTestCase): | |
await ctx.run(None, None, add.op, {"numbers": [40, 2]}) | |
+class CustomInputSetContext(StringInputSetContext): | |
+ pass | |
+ | |
+ | |
class TestOrchestrator(AsyncTestCase): | |
""" | |
create_octx and run exist so that we can subclass from them in | |
@@ -163,7 +167,11 @@ class TestOrchestrator(AsyncTestCase): | |
async def test_run(self): | |
calc_strings_check = {"add 40 and 2": 42, "multiply 42 and 10": 420} | |
# TODO(p0) Implement and test asyncgenerator | |
- callstyles_no_expand = ["asyncgenerator", "dict"] | |
+ callstyles_no_expand = [ | |
+ "asyncgenerator", | |
+ "dict", | |
+ "dict_custom_input_set_context", | |
+ ] | |
callstyles = { | |
"dict": { | |
to_calc: [ | |
@@ -177,6 +185,18 @@ class TestOrchestrator(AsyncTestCase): | |
] | |
for to_calc in calc_strings_check.keys() | |
}, | |
+ "dict_custom_input_set_context": { | |
+ CustomInputSetContext(to_calc): [ | |
+ Input( | |
+ value=to_calc, definition=parse_line.op.inputs["line"] | |
+ ), | |
+ Input( | |
+ value=[add.op.outputs["sum"].name], | |
+ definition=GetSingle.op.inputs["spec"], | |
+ ), | |
+ ] | |
+ for to_calc in calc_strings_check.keys() | |
+ }, | |
"list_input_sets": [ | |
MemoryInputSet( | |
MemoryInputSetConfig( | |
@@ -228,6 +248,10 @@ class TestOrchestrator(AsyncTestCase): | |
), | |
) | |
else: | |
+ if callstyle == "dict_custom_input_set_context": | |
+ self.assertTrue( | |
+ isinstance(ctx, CustomInputSetContext) | |
+ ) | |
self.assertEqual( | |
calc_strings_check[ctx_str], | |
results[add.op.outputs["sum"].name], |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment