Skip to content

Instantly share code, notes, and snippets.

@drj42
Last active January 18, 2022 17:47
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save drj42/a6da7d6d503a97fc3eef to your computer and use it in GitHub Desktop.
Save drj42/a6da7d6d503a97fc3eef to your computer and use it in GitHub Desktop.
sftp remote target for luigi
"""sftp.py - PySftp connections wrapped up in a luigi.Target.
TODO: get rid of the redundant stuff, write some tests, contribute to luigi
upstream.
"""
# -*- coding: utf-8 -*-
import io
import os
import random
import sys
import luigi
import luigi.file
import luigi.target
from luigi.format import FileWrapper, MixedUnicodeBytes, get_default_format
import logging
logger = logging.getLogger('luigi-interface')
try:
import pysftp
except ImportError:
logger.warning('Please install pysftp to use this package')
class RemoteFileSystem(luigi.target.FileSystem):
"""Remote file system backed by pysftp.
"""
def __init__(self, host,
username=None, port=22, private_key=None, **kwargs):
self._conn_args = {
'host': host,
'username': username,
'port': int(port),
'private_key': private_key,
}
# pass additional parameters that the pysftp constructor accepts
# print(self._conn_args)
self._conn_args.update(**kwargs)
def exists(self, path, mtime=None):
"""
Return `True` if file or directory at `path` exist, False otherwise.
"""
exists = True
with pysftp.Connection(**self._conn_args) as sftp:
if not sftp.exists(path):
exists = False
elif mtime:
exists = sftp.stat(path).st_mtime > mtime
return exists
def remove(self, path, recursive=False):
"""
Remove file or directory at location ``path``.
:param path: a path within the FileSystem to remove.
:type path: str
:param recursive: if the path is a directory, recursively remove the
directory and
all of its descendants. Defaults to ``True``.
:type recursive: bool
"""
with pysftp.Connection(**self._conn_args) as sftp:
if sftp.isfile(path):
sftp.unlink(path)
else:
if not recursive:
raise RuntimeError(("Path is not a regular file, "
"and recursive option is not set"))
directories = []
# walk the tree, and execute call backs when files,
# directories and unknown types are encountered
# files must be removed first. then directories can be removed
# after the files are gone.
sftp.walktree(path, sftp.unlink,
directories.append, sftp.unlink)
for directory in reversed(directories):
sftp.rmdir(directory)
sftp.rmdir(path)
def put(self, local_path, remote_path):
with pysftp.Connection(**self._conn_args) as sftp:
normpath = os.path.normpath(remote_path)
directory = os.path.dirname(normpath)
sftp.makedirs(directory)
tmp_path = os.path.join(
directory,
'luigi-tmp-{:09d}'.format(random.randrange(0, 1e10)))
sftp.put(local_path, tmp_path)
sftp.rename(tmp_path, normpath)
def get(self, remote_path, local_path):
normpath = os.path.normpath(local_path)
directory = os.path.dirname(normpath)
if not os.path.exists(directory):
os.makedirs(directory)
tmp_local_path = local_path + '-luigi-tmp-{:9d}'.format(
random.randrange(0, 1e10))
with pysftp.Connection(**self._conn_args) as sftp:
sftp.get(remote_path, tmp_local_path)
os.rename(tmp_local_path, local_path)
def listdir(self, path):
with pysftp.Connection(**self._conn_args) as sftp:
for file_path in sftp.listdir(remotepath=path):
yield file_path
def isdir(self, path):
with pysftp.Connection(**self._conn_args) as sftp:
return sftp.isdir(path)
class AtomicSecureFtpFile(luigi.target.AtomicLocalFile):
"""
Simple class that writes to a temp file and upload to ftp on close().
Also cleans up the temp file if close is not invoked.
"""
def __init__(self, fs, path):
""" Initializes an AtomicFtpfile instance. """
self._fs = fs
super(AtomicSecureFtpFile, self).__init__(path)
def move_to_final_destination(self):
self._fs.put(self.tmp_path, self.path)
@property
def fs(self):
return self._fs
class RemoteTarget(luigi.target.FileSystemTarget):
"""
Target used for reading from remote files.
The target is implemented using ssh commands streaming data over the
network.
"""
def __init__(self, path, host, format=None, username=None, port=22,
private_key=None, mtime=None, **kwargs):
sftp_args = {
'username': username,
'port': int(port),
'private_key': private_key,
}
self.mtime = mtime
# pass additional parameters that the pysftp constructor accepts
sftp_args.update(**kwargs)
# print(sftp_args)
if format is None:
format = get_default_format()
# Allow to write unicode in file for retrocompatibility
if sys.version_info[:2] <= (2, 6):
format = format >> MixedUnicodeBytes
self.path = path
self.format = format
self._fs = RemoteFileSystem(host, **sftp_args)
@property
def fs(self):
return self._fs
def open(self, mode):
"""
Open the FileSystem target.
This method returns a file-like object which can either be read from
or written to depending on the specified mode.
"""
if mode == 'w':
return self.format.pipe_writer(AtomicSecureFtpFile(
self._fs, self.path))
elif mode == 'r':
self.__tmp_path = self.path + '-luigi-tmp-{:09d}'.format(
random.randrange(0, 1e10))
# download file to local
self._fs.get(self.path, self.__tmp_path)
return self.format.pipe_reader(
FileWrapper(
io.BufferedReader(io.FileIO(self.__tmp_path, 'r'))))
else:
raise Exception('mode must be r/w')
def exists(self):
return self.fs.exists(self.path, self.mtime)
def put(self, local_path):
self.fs.put(local_path, self.path)
def get(self, local_path):
self.fs.get(self.path, local_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment