Skip to content

Instantly share code, notes, and snippets.

@JBrVJxsc
Created December 10, 2020 06:20
Show Gist options
  • Save JBrVJxsc/e7991665d14478cfe99311c5e966b3e0 to your computer and use it in GitHub Desktop.
Save JBrVJxsc/e7991665d14478cfe99311c5e966b3e0 to your computer and use it in GitHub Desktop.
Multithread Executor for Dagster
import queue
import threading
from dagster import executor, Field, Int, check, DagsterEvent
from dagster.core.definitions.executor import check_cross_process_constraints
from dagster.core.events import EngineEventData
from dagster.core.execution.retries import get_retries_config, Retries
from dagster.core.executor.child_process_executor import ChildProcessEvent, ChildProcessSystemErrorEvent, ChildProcessCrashException, ChildProcessCommand, \
ChildProcessDoneEvent, PROCESS_DEAD_AND_QUEUE_EMPTY, _poll_for_event, _execute_command_in_child_process
from dagster.core.executor.init import InitExecutorContext
from dagster.core.executor.multiprocess import InProcessExecutorChildProcessCommand, MultiprocessExecutor
DELEGATE_MARKER = "multithread_subprocess_init"
@executor(
name="multithread",
config_schema={
"max_concurrent": Field(Int, is_required=False, default_value=0),
"retries": get_retries_config(),
},
)
def multithread_executor(init_context):
"""The default multithread executor."""
check.inst_param(init_context, "init_context", InitExecutorContext)
check_cross_process_constraints(init_context)
return MultithreadExecutor(
pipeline=init_context.pipeline,
max_concurrent=init_context.executor_config["max_concurrent"],
retries=Retries.from_config(init_context.executor_config["retries"]),
)
class MultithreadExecutor(MultiprocessExecutor):
"""Multithread Executor."""
def execute_step_out_of_process(self, step_context, step, errors, term_events):
command = InProcessExecutorChildProcessCommand(
run_config=step_context.run_config,
pipeline_run=step_context.pipeline_run,
step_key=step.key,
instance_ref=step_context.instance.get_ref(),
term_event=term_events[step.key],
recon_pipeline=self.pipeline,
retries=self.retries,
)
yield DagsterEvent.engine_event(
step_context,
"Launching subprocess for {}".format(step.key),
EngineEventData(marker_start=DELEGATE_MARKER),
step_key=step.key,
)
for ret in execute_command_in_thread(command):
if ret is None or isinstance(ret, DagsterEvent):
yield ret
elif isinstance(ret, ChildProcessEvent):
if isinstance(ret, ChildProcessSystemErrorEvent):
errors[ret.pid] = ret.error_info
else:
check.failed("Unexpected return value from child process {}".format(type(ret)))
def execute_command_in_thread(command):
"""Executes command in a thread."""
check.inst_param(command, "command", ChildProcessCommand)
event_queue = queue.Queue()
try:
thread = threading.Thread(
target=_execute_command_in_child_process, args=(event_queue, command)
)
thread.start()
completed_properly = False
while not completed_properly:
event = _poll_for_event(thread, event_queue)
if event == PROCESS_DEAD_AND_QUEUE_EMPTY:
break
yield event
if isinstance(event, (ChildProcessDoneEvent, ChildProcessSystemErrorEvent)):
completed_properly = True
if not completed_properly:
# TODO Figure out what to do about stderr/stdout
raise ChildProcessCrashException(exit_code=1)
thread.join()
finally:
event_queue.task_done()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment