-
-
Save JBrVJxsc/e7991665d14478cfe99311c5e966b3e0 to your computer and use it in GitHub Desktop.
Multithread Executor for Dagster
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import queue | |
import threading | |
from dagster import executor, Field, Int, check, DagsterEvent | |
from dagster.core.definitions.executor import check_cross_process_constraints | |
from dagster.core.events import EngineEventData | |
from dagster.core.execution.retries import get_retries_config, Retries | |
from dagster.core.executor.child_process_executor import ChildProcessEvent, ChildProcessSystemErrorEvent, ChildProcessCrashException, ChildProcessCommand, \ | |
ChildProcessDoneEvent, PROCESS_DEAD_AND_QUEUE_EMPTY, _poll_for_event, _execute_command_in_child_process | |
from dagster.core.executor.init import InitExecutorContext | |
from dagster.core.executor.multiprocess import InProcessExecutorChildProcessCommand, MultiprocessExecutor | |
DELEGATE_MARKER = "multithread_subprocess_init" | |
@executor( | |
name="multithread", | |
config_schema={ | |
"max_concurrent": Field(Int, is_required=False, default_value=0), | |
"retries": get_retries_config(), | |
}, | |
) | |
def multithread_executor(init_context): | |
"""The default multithread executor.""" | |
check.inst_param(init_context, "init_context", InitExecutorContext) | |
check_cross_process_constraints(init_context) | |
return MultithreadExecutor( | |
pipeline=init_context.pipeline, | |
max_concurrent=init_context.executor_config["max_concurrent"], | |
retries=Retries.from_config(init_context.executor_config["retries"]), | |
) | |
class MultithreadExecutor(MultiprocessExecutor): | |
"""Multithread Executor.""" | |
def execute_step_out_of_process(self, step_context, step, errors, term_events): | |
command = InProcessExecutorChildProcessCommand( | |
run_config=step_context.run_config, | |
pipeline_run=step_context.pipeline_run, | |
step_key=step.key, | |
instance_ref=step_context.instance.get_ref(), | |
term_event=term_events[step.key], | |
recon_pipeline=self.pipeline, | |
retries=self.retries, | |
) | |
yield DagsterEvent.engine_event( | |
step_context, | |
"Launching subprocess for {}".format(step.key), | |
EngineEventData(marker_start=DELEGATE_MARKER), | |
step_key=step.key, | |
) | |
for ret in execute_command_in_thread(command): | |
if ret is None or isinstance(ret, DagsterEvent): | |
yield ret | |
elif isinstance(ret, ChildProcessEvent): | |
if isinstance(ret, ChildProcessSystemErrorEvent): | |
errors[ret.pid] = ret.error_info | |
else: | |
check.failed("Unexpected return value from child process {}".format(type(ret))) | |
def execute_command_in_thread(command): | |
"""Executes command in a thread.""" | |
check.inst_param(command, "command", ChildProcessCommand) | |
event_queue = queue.Queue() | |
try: | |
thread = threading.Thread( | |
target=_execute_command_in_child_process, args=(event_queue, command) | |
) | |
thread.start() | |
completed_properly = False | |
while not completed_properly: | |
event = _poll_for_event(thread, event_queue) | |
if event == PROCESS_DEAD_AND_QUEUE_EMPTY: | |
break | |
yield event | |
if isinstance(event, (ChildProcessDoneEvent, ChildProcessSystemErrorEvent)): | |
completed_properly = True | |
if not completed_properly: | |
# TODO Figure out what to do about stderr/stdout | |
raise ChildProcessCrashException(exit_code=1) | |
thread.join() | |
finally: | |
event_queue.task_done() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment