Created
April 13, 2020 17:13
-
-
Save jazzido/4c34fa2e56c3b403a6851791b8c752dc to your computer and use it in GitHub Desktop.
Prefect ShellTask with STDIN support
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import subprocess | |
import tempfile | |
from typing import Any | |
import prefect | |
from prefect.utilities.tasks import defaults_from_attrs | |
class ShellTaskWithStdin(prefect.Task): | |
""" | |
Task for running arbitrary shell commands. | |
Args: | |
- command (string, optional): shell command to be executed; can also be | |
provided post-initialization by calling this task instance | |
- env (dict, optional): dictionary of environment variables to use for | |
the subprocess; can also be provided at runtime | |
- helper_script (str, optional): a string representing a shell script, which | |
will be executed prior to the `command` in the same process. Can be used to | |
change directories, define helper functions, etc. when re-using this Task | |
for different commands in a Flow | |
- shell (string, optional): shell to run the command with; defaults to "bash" | |
- return_all (bool, optional): boolean specifying whether this task should | |
return all lines of stdout as a list, or just the last line as a string; | |
defaults to `False` | |
- stdin (string, optional): an optional string piped to the command's standard | |
input | |
- **kwargs: additional keyword arguments to pass to the Task constructor | |
Example: | |
```python | |
from prefect import Flow | |
from prefect.tasks.shell import ShellTask | |
task = ShellTask(helper_script="cd ~") | |
with Flow("My Flow") as f: | |
# both tasks will be executed in home directory | |
contents = task(command='ls') | |
mv_file = task(command='mv .vimrc /.vimrc') | |
out = f.run() | |
``` | |
""" | |
def __init__( | |
self, | |
command: str = None, | |
env: dict = None, | |
helper_script: str = None, | |
shell: str = "bash", | |
return_all: bool = False, | |
stdin: str = None, | |
**kwargs: Any | |
): | |
self.command = command | |
self.env = env | |
self.helper_script = helper_script | |
self.shell = shell | |
self.return_all = return_all | |
self.stdin = stdin | |
super().__init__(**kwargs) | |
@defaults_from_attrs("command", "env", "stdin") | |
def run(self, command: str = None, env: dict = None, stdin: str = None) -> str: | |
""" | |
Run the shell command. | |
Args: | |
- command (string): shell command to be executed; can also be | |
provided at task initialization. Any variables / functions defined in | |
`self.helper_script` will be available in the same process this command | |
runs in | |
- env (dict, optional): dictionary of environment variables to use for | |
the subprocess | |
- stdin (string, optional): an optional string piped to the command's standard | |
input | |
Returns: | |
- stdout (string): if `return_all` is `False` (the default), only | |
the last line of stdout is returned, otherwise all lines are returned, | |
which is useful for passing result of shell command to other downstream | |
tasks. If there is no output, `None` is returned. | |
Raises: | |
- prefect.engine.signals.FAIL: if command has an exit code other | |
than 0 | |
""" | |
if command is None: | |
raise TypeError("run() missing required argument: 'command'") | |
current_env = os.environ.copy() | |
current_env.update(env or {}) | |
with tempfile.NamedTemporaryFile(prefix="prefect-") as tmp: | |
if self.helper_script: | |
tmp.write(self.helper_script.encode()) | |
tmp.write("\n".encode()) | |
tmp.write(command.encode()) | |
tmp.flush() | |
try: | |
sub_process = subprocess.run( | |
[self.shell, tmp.name], | |
input=stdin or None, | |
env=current_env, | |
capture_output=True, | |
text=True, | |
check=True, | |
) | |
lines = [] | |
line = None | |
for raw_line in iter(sub_process.stdout.splitlines()): | |
line = raw_line.decode("utf-8").rstrip() | |
if self.return_all: | |
lines.append(line) | |
else: | |
# if we're returning all, we don't log every line | |
self.logger.debug(line) | |
if sub_process.returncode: | |
msg = "Command failed with exit code {0}: {1}".format( | |
sub_process.returncode, line | |
) | |
self.logger.error(msg) | |
raise prefect.engine.signals.FAIL(msg) from None # type: ignore | |
except subprocess.CalledProcessError as e: | |
self.logger.error(e.stderr) | |
raise prefect.engine.signals.FAIL(e.stderr) from None # type: ignore | |
return lines if self.return_all else line |
Why don't you commit this upstream?
@BonaBeavis Yeah, that would be a good idea. Unfortunately I've since switched tasks, and I'm not working with Prefect anymore. IIRC, I didn't get around to doing it because I couldn't find the time to add tests.
I can fork Prefect, add this, you add the tests, and we submit a PR upstream.
Ok, lets do this
Cool. I created a new branch in my fork, to which I've just added you as a collaborator.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Why don't you commit this upstream?