Created
August 4, 2020 17:50
-
-
Save pdxjohnny/7a4dfdfd4944ac33fa4a0848739bc532 to your computer and use it in GitHub Desktop.
DataFlow source accumulator operation partial modifications to source for record() etc.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/dffml/source/df.py b/dffml/source/df.py | |
index 380e6312..af9e44c8 100644 | |
--- a/dffml/source/df.py | |
+++ b/dffml/source/df.py | |
@@ -17,17 +17,70 @@ class DataFlowSourceConfig: | |
source: BaseSource | |
dataflow: DataFlow | |
features: Features | |
+ length: str = field("Definition name to add as source length", default=None) | |
orchestrator: BaseOrchestrator = MemoryOrchestrator.withconfig({}) | |
+example = Operation( | |
+ name="example", | |
+ inputs={"stop_words": Definition("stop_words", "string"), | |
+ "length": Definition("source_length", "string")}, | |
+ outputs={"all": Definition("all_sentences", "List[string]")}, | |
+ conditions=[], | |
+) | |
+ | |
+ | |
+class ExampleContext(OperationImplementationContext): | |
+ async def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
+ | |
+ async with self.parent.lock: | |
+ if self.parent.length is None: | |
+ self.parent.length = inputs["source_length"] | |
+ self.parent.list.append(inputs["stop_words"]) | |
+ | |
+ if len(self.parent.list) == self.parent.length: | |
+ self.parent.event.set() | |
+ | |
+ await self.parent.event.wait() | |
+ | |
+ return {"all_sentences": self.parent.list} | |
+ | |
+ | |
+class Example(OperationImplementation): | |
+ | |
+ op = example | |
+ CONTEXT = ExampleContext | |
+ | |
+ def __init__(self, *args, **kwargs): | |
+ super().__init__(*args, **kwargs) | |
+ self.lock = None | |
+ self.length = None | |
+ self.event = None | |
+ self.list = [] | |
+ | |
+ async def __aenter__(self) -> "OperationImplementationContext": | |
+ self.lock = asyncio.Lock() | |
+ self.event = asyncio.Event() | |
+ return self | |
+ | |
+ async def __aexit__(self, exc_type, exc_value, traceback): | |
+ self.lock = None | |
+ | |
+ | |
class DataFlowSourceContext(BaseSourceContext): | |
async def update(self, record: Record): | |
await self.sctx.update(record) | |
- async def records(self) -> AsyncIterator[Record]: | |
- async for record in self.sctx.records(): | |
+ # TODO Implement this method. We forgot to implement it when we initially | |
+ # added the DataFlowSourceContext | |
+ async def record(self, key: str) -> AsyncIterator[Record]: | |
+ if self.parent.config.all_for_single: | |
+ async for ctx, result in self.records(): | |
+ if (await ctx.handle()).as_string() == key: | |
+ yield record | |
+ else: | |
async for ctx, result in self.octx.run( | |
- [ | |
+ RecordContext(record): [ | |
Input( | |
value=record.feature(feature.name), | |
definition=Definition( | |
@@ -35,12 +88,44 @@ class DataFlowSourceContext(BaseSourceContext): | |
), | |
) | |
for feature in self.parent.config.features | |
- ] | |
+ ] + ([] if not self.config.length else [ | |
+ Input( | |
+ value=await self.sctx.length() | |
+ definition=Definition( | |
+ name=self.config.length, primitive="int" | |
+ ), | |
+ ) | |
+ ]) | |
+ async for record in [self.sctx.record(key)] | |
): | |
if result: | |
record.evaluated(result) | |
yield record | |
+ async def records(self) -> AsyncIterator[Record]: | |
+ async for ctx, result in self.octx.run( | |
+ RecordContext(record): [ | |
+ Input( | |
+ value=record.feature(feature.name), | |
+ definition=Definition( | |
+ name=feature.name, primitive=str(feature.dtype()) | |
+ ), | |
+ ) | |
+ for feature in self.parent.config.features | |
+ ] + ([] if not self.config.length else [ | |
+ Input( | |
+ value=await self.sctx.length() | |
+ definition=Definition( | |
+ name=self.config.length, primitive="int" | |
+ ), | |
+ ) | |
+ ]) | |
+ async for record in self.sctx.records() | |
+ ): | |
+ if result: | |
+ record.evaluated(result) | |
+ yield record | |
+ | |
async def __aenter__(self) -> "DataFlowSourceContext": | |
self.sctx = await self.parent.source().__aenter__() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment