Created
February 20, 2024 13:18
-
-
Save eliorc/4edcd45cd20a513aea7682e5142a0824 to your computer and use it in GitHub Desktop.
Open any file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CLOUD_PROTOCOLS = ("s3", "s3n", "s3a", "gcs", "gs", "adl", "abfs", "abfss", "gdrive") | |
HTTP_PROTOCOLS = ("http", "https") | |
S3_PROTOCOLS = ("s3", "s3a", "s3n") | |
PROTOCOL_DELIMITER = "://" | |
def _parse_filepath(filepath: str) -> dict[str, str]: | |
""" | |
Split filepath on protocol and path. Based on `fsspec.utils.infer_storage_options`. | |
:param filepath: Either local absolute file path or URL (s3://bucket/file.csv) | |
:return: Parsed filepath. | |
""" | |
if re.match(r"^[a-zA-Z]:[\\/]", filepath) or re.match(r"^[a-zA-Z0-9]+://", filepath) is None: | |
return {"protocol": "file", "path": filepath} | |
parsed_path = urlsplit(filepath) | |
protocol = parsed_path.scheme or "file" | |
if protocol in HTTP_PROTOCOLS: | |
return {"protocol": protocol, "path": filepath} | |
path = parsed_path.path | |
if protocol == "file": | |
windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path) | |
if windows_path: | |
path = ":".join(windows_path.groups()) | |
options = {"protocol": protocol, "path": path} | |
if parsed_path.netloc and protocol in CLOUD_PROTOCOLS: | |
host_with_port = parsed_path.netloc.rsplit("@", 1)[-1] | |
host = host_with_port.rsplit(":", 1)[0] | |
options["path"] = host + options["path"] | |
# Azure Data Lake Storage Gen2 URIs can store the container name in the | |
# 'username' field of a URL (@ syntax), so we need to add it to the path | |
if protocol == "abfss" and parsed_path.username: | |
options["path"] = parsed_path.username + "@" + options["path"] | |
return options | |
def get_protocol_and_path(filepath: str) -> tuple[str, str]: | |
""" | |
Parses filepath on protocol and path. | |
:param filepath: raw filepath e.g.: ``gcs://bucket/test.json``. | |
:return: Protocol and path. | |
""" | |
options_dict = _parse_filepath(filepath) | |
path = options_dict["path"] | |
protocol = options_dict["protocol"] | |
if protocol in HTTP_PROTOCOLS: | |
path = path.split(PROTOCOL_DELIMITER, 1)[-1] | |
return protocol, path | |
def get_filepath_str(path: PurePath, protocol: str) -> str: | |
""" | |
Returns filepath. Returns full filepath (with protocol) if protocol is HTTP(s). | |
:param path: filepath without protocol. | |
:param protocol: protocol. | |
:return: Filepath string. | |
""" | |
path = path.as_posix() | |
if protocol in HTTP_PROTOCOLS: | |
path = "".join((protocol, PROTOCOL_DELIMITER, path)) | |
return path | |
@contextmanager | |
def open_any_file(filepath: str, mode: str = "r", **kwargs) -> t.Generator[t.IO, None, None]: | |
""" | |
Open file and close it after use. Works for local, remote, http, https, s3, gcs, etc. | |
:param filepath: Filepath. | |
:param mode: Mode. | |
:param kwargs: Keyword arguments. | |
:return: File object. | |
""" | |
protocol, path = get_protocol_and_path(filepath) | |
filepath = PurePosixPath(path) | |
filesystem = fsspec.filesystem(protocol) | |
load_path = get_filepath_str(filepath, protocol) | |
# Figure out content type | |
if "content_type" not in kwargs and filepath.suffix == ".json": | |
kwargs["content_type"] = "application/json" | |
with filesystem.open(load_path, mode=mode, **kwargs) as f: | |
yield f | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment