Created
September 30, 2023 07:24
-
-
Save eliorc/b1f465ecbd100a09b7253725b879870c to your computer and use it in GitHub Desktop.
Global open file (from local, remote, http etc.)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Based of Kedro's code, from this example https://docs.kedro.org/en/stable/data/how_to_create_a_custom_dataset.html#the-complete-example | |
import re | |
import typing as t | |
from contextlib import contextmanager | |
from pathlib import PurePath, PurePosixPath | |
from urllib.parse import urlsplit | |
import fsspec | |
CLOUD_PROTOCOLS = ("s3", "s3n", "s3a", "gcs", "gs", "adl", "abfs", "abfss", "gdrive") | |
HTTP_PROTOCOLS = ("http", "https") | |
PROTOCOL_DELIMITER = "://" | |
def _parse_filepath(filepath: str) -> dict[str, str]: | |
""" | |
Split filepath on protocol and path. Based on `fsspec.utils.infer_storage_options`. | |
:param filepath: Either local absolute file path or URL (s3://bucket/file.csv) | |
:return: Parsed filepath. | |
""" | |
if re.match(r"^[a-zA-Z]:[\\/]", filepath) or re.match(r"^[a-zA-Z0-9]+://", filepath) is None: | |
return {"protocol": "file", "path": filepath} | |
parsed_path = urlsplit(filepath) | |
protocol = parsed_path.scheme or "file" | |
if protocol in HTTP_PROTOCOLS: | |
return {"protocol": protocol, "path": filepath} | |
path = parsed_path.path | |
if protocol == "file": | |
windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path) | |
if windows_path: | |
path = ":".join(windows_path.groups()) | |
options = {"protocol": protocol, "path": path} | |
if parsed_path.netloc and protocol in CLOUD_PROTOCOLS: | |
host_with_port = parsed_path.netloc.rsplit("@", 1)[-1] | |
host = host_with_port.rsplit(":", 1)[0] | |
options["path"] = host + options["path"] | |
# Azure Data Lake Storage Gen2 URIs can store the container name in the | |
# 'username' field of a URL (@ syntax), so we need to add it to the path | |
if protocol == "abfss" and parsed_path.username: | |
options["path"] = parsed_path.username + "@" + options["path"] | |
return options | |
def get_protocol_and_path(filepath: str) -> tuple[str, str]: | |
""" | |
Parses filepath on protocol and path. | |
:param filepath: raw filepath e.g.: ``gcs://bucket/test.json``. | |
:return: Protocol and path. | |
""" | |
options_dict = _parse_filepath(filepath) | |
path = options_dict["path"] | |
protocol = options_dict["protocol"] | |
if protocol in HTTP_PROTOCOLS: | |
path = path.split(PROTOCOL_DELIMITER, 1)[-1] | |
return protocol, path | |
def get_filepath_str(path: PurePath, protocol: str) -> str: | |
""" | |
Returns filepath. Returns full filepath (with protocol) if protocol is HTTP(s). | |
:param path: filepath without protocol. | |
:param protocol: protocol. | |
:return: Filepath string. | |
""" | |
path = path.as_posix() | |
if protocol in HTTP_PROTOCOLS: | |
path = "".join((protocol, PROTOCOL_DELIMITER, path)) | |
return path | |
@contextmanager | |
def open_any_file(filepath: str, mode: str = "r", **kwargs) -> t.Generator[t.IO, None, None]: | |
""" | |
Open file and close it after use. Works for local, remote, http, https, s3, gcs, etc. | |
:param filepath: Filepath. | |
:param mode: Mode. | |
:param kwargs: Keyword arguments. | |
:return: File object. | |
""" | |
protocol, path = get_protocol_and_path(filepath) | |
filepath = PurePosixPath(path) | |
filesystem = fsspec.filesystem(protocol) | |
load_path = get_filepath_str(filepath, protocol) | |
with filesystem.open(load_path, mode=mode, **kwargs) as f: | |
yield f |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment