Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
DataFlow source accumulator operation partial modifications to source for record() etc.
diff --git a/dffml/source/df.py b/dffml/source/df.py
index 380e6312..af9e44c8 100644
--- a/dffml/source/df.py
+++ b/dffml/source/df.py
@@ -17,17 +17,70 @@ class DataFlowSourceConfig:
source: BaseSource
dataflow: DataFlow
features: Features
+ length: str = field("Definition name to add as source length", default=None)
orchestrator: BaseOrchestrator = MemoryOrchestrator.withconfig({})
+example = Operation(
+ name="example",
+ inputs={"stop_words": Definition("stop_words", "string"),
+ "length": Definition("source_length", "string")},
+ outputs={"all": Definition("all_sentences", "List[string]")},
+ conditions=[],
+)
+
+
+class ExampleContext(OperationImplementationContext):
+ async def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+
+ async with self.parent.lock:
+ if self.parent.length is None:
+ self.parent.length = inputs["source_length"]
+ self.parent.list.append(inputs["stop_words"])
+
+ if len(self.parent.list) == self.parent.length:
+ self.parent.event.set()
+
+ await self.parent.event.wait()
+
+ return {"all_sentences": self.parent.list}
+
+
+class Example(OperationImplementation):
+
+ op = example
+ CONTEXT = ExampleContext
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lock = None
+ self.length = None
+ self.event = None
+ self.list = []
+
+ async def __aenter__(self) -> "OperationImplementationContext":
+ self.lock = asyncio.Lock()
+ self.event = asyncio.Event()
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ self.lock = None
+
+
class DataFlowSourceContext(BaseSourceContext):
async def update(self, record: Record):
await self.sctx.update(record)
- async def records(self) -> AsyncIterator[Record]:
- async for record in self.sctx.records():
+ # TODO Implement this method. We forgot to implement it when we initially
+ # added the DataFlowSourceContext
+ async def record(self, key: str) -> AsyncIterator[Record]:
+ if self.parent.config.all_for_single:
+ async for ctx, result in self.records():
+ if (await ctx.handle()).as_string() == key:
+ yield record
+ else:
async for ctx, result in self.octx.run(
- [
+ RecordContext(record): [
Input(
value=record.feature(feature.name),
definition=Definition(
@@ -35,12 +88,44 @@ class DataFlowSourceContext(BaseSourceContext):
),
)
for feature in self.parent.config.features
- ]
+ ] + ([] if not self.config.length else [
+ Input(
+ value=await self.sctx.length()
+ definition=Definition(
+ name=self.config.length, primitive="int"
+ ),
+ )
+ ])
+ async for record in [self.sctx.record(key)]
):
if result:
record.evaluated(result)
yield record
+ async def records(self) -> AsyncIterator[Record]:
+ async for ctx, result in self.octx.run(
+ RecordContext(record): [
+ Input(
+ value=record.feature(feature.name),
+ definition=Definition(
+ name=feature.name, primitive=str(feature.dtype())
+ ),
+ )
+ for feature in self.parent.config.features
+ ] + ([] if not self.config.length else [
+ Input(
+ value=await self.sctx.length()
+ definition=Definition(
+ name=self.config.length, primitive="int"
+ ),
+ )
+ ])
+ async for record in self.sctx.records()
+ ):
+ if result:
+ record.evaluated(result)
+ yield record
+
async def __aenter__(self) -> "DataFlowSourceContext":
self.sctx = await self.parent.source().__aenter__()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment