Skip to content

Instantly share code, notes, and snippets.

@jihunchoi
Created April 12, 2023 05:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jihunchoi/d85ac30ba9f8ab20bf076d84507fc592 to your computer and use it in GitHub Desktop.
Save jihunchoi/d85ac30ba9f8ab20bf076d84507fc592 to your computer and use it in GitHub Desktop.
import os
from typing import List, Optional
import boto3
from hydra.core.object_type import ObjectType
from hydra.plugins.config_source import ConfigResult, ConfigSource
from omegaconf import OmegaConf
from smart_open import open, parse_uri
class S3ConfigSource(ConfigSource):
"""
Hydra ConfigSource plugin that adds support for loading
config files from S3.
"""
def __init__(self, provider: str, path: str) -> None:
if not path.endswith("/"):
path = path + "/"
super().__init__(provider=provider, path=path)
self._paths = self._list_s3_directory(s3_uri=self.full_path())
@staticmethod
def scheme() -> str:
return "s3"
def load_config(self, config_path: str) -> ConfigResult:
normalized_config_path = self._normalize_file_name(config_path)
s3_uri = os.path.join(self.full_path(), normalized_config_path)
with open(s3_uri, "r", encoding="utf-8") as f:
header_text = f.read(512)
header = self._get_header_dict(header_text)
f.seek(0)
cfg = OmegaConf.load(f)
return ConfigResult(
config=cfg,
path=f"{self.scheme()}://{self.path}",
provider=self.provider,
header=header,
)
@staticmethod
def _list_s3_directory(s3_uri: str) -> list[str]:
"""
List relative object keys whose prefix is ``s3_uri``.
Args:
s3_uri: S3 URI starting with s3://.
Returns:
Relative object keys; suffixes after ``s3_uri``.
"""
if not s3_uri.endswith("/"):
s3_uri = s3_uri + "/"
s3_uri_parsed = parse_uri(s3_uri)
offset = len(s3_uri_parsed.key_id)
paths = []
s3_client = boto3.client("s3")
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(
Bucket=s3_uri_parsed.bucket_id, Prefix=s3_uri_parsed.key_id
)
for page in pages:
for obj in page["Contents"]:
rel_path = obj["Key"][offset:]
paths.append(rel_path)
return paths
def available(self) -> bool:
# If no config file exists in the given S3 path, return False
return bool(self._paths)
def is_group(self, config_path: str) -> bool:
if config_path == "":
print(f"is_group: {config_path}")
return True
for path in self._paths:
if path.rstrip("/") == config_path and path.endswith("/"):
print("is_group: {path}")
return True
print(f"not is_group: {config_path}")
return False
def is_config(self, config_path: str) -> bool:
for path in self._paths:
if path == config_path and not path.endswith("/"):
print(f"is_config: {path}")
return True
print(f"not is_config: {config_path}")
return False
def list(self, config_path: str, results_filter: Optional[ObjectType]) -> List[str]:
files: List[str] = []
s3_uri = os.path.join(self.full_path(), config_path)
for file in self._list_s3_directory(s3_uri):
file_path = os.path.join(self.full_path(), config_path, file)
self._list_add_result(
files=files,
file_path=file_path,
file_name=file,
results_filter=results_filter,
)
return sorted(list(set(files)))
from hydra.core.config_search_path import ConfigSearchPath
from hydra.plugins.search_path_plugin import SearchPathPlugin
class S3SearchPathPlugin(SearchPathPlugin):
def manipulate_search_path(self, search_path: ConfigSearchPath) -> None:
for el in search_path.get_path():
idx = el.path.find("s3:/")
if idx != -1:
suffix = el.path[idx + len("s3:/") :]
s3_path = f"s3://{suffix}"
el.path = s3_path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment