Skip to content

Instantly share code, notes, and snippets.

@andrewgross
Created July 24, 2020 14:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andrewgross/c9947006826b61301bdf0b1631e48854 to your computer and use it in GitHub Desktop.
Save andrewgross/c9947006826b61301bdf0b1631e48854 to your computer and use it in GitHub Desktop.
import os
from urllib.parse import urlparse
from pyspark.sql.functions import desc, asc
from pyspark.sql.types import (
StructType,
StructField,
StringType,
LongType,
TimestampType,
)
from yipit_databricks_utils.helpers import get_spark_session
from yipit_glue.sessions import get_s3_client
def browse_s3(s3_path, delimiter="/", token=None):
s3_path = _normalize_s3_path(s3_path)
bucket, prefix = parse_location(s3_path)
resp = _get_s3_objects(bucket, prefix)
if resp["IsTruncated"]:
print("Limited to 1000 Results")
df = _convert_resp_to_dataframe(bucket, resp, delimiter="/")
return df
def parse_location(s3_location):
parsed = urlparse(s3_location)
bucket = parsed.netloc
prefix = parsed.path
prefix = prefix[1:] # Trim leading /
return bucket, prefix
def _normalize_s3_path(s3_path):
"""
Normalize by ensuring we start with s3:// and adding a trailing slash if necessary
"""
if not s3_path.endswith("/"):
s3_path += "/"
if not (s3_path.startswith("s3://") or s3_path.startswith("dbfs:/")):
s3_path = "s3://{}".format(s3_path)
return s3_path
def _convert_resp_to_dataframe(bucket, resp, delimiter="/"):
bucket = _normalize_s3_path(bucket)
cleaned_rows = []
for row in resp.get("CommonPrefixes", []):
path = row["Prefix"].split(delimiter)[
-2
] # These paths always end with the delimiter
cleaned_rows.append(
{
"path": path,
"type": "Folder",
"modified": None,
"size": None,
"etag": "",
"full_path": row["Prefix"],
"full_s3_prefix": os.path.join(bucket, row["Prefix"]),
}
)
for row in resp.get("Contents", []):
path = row["Key"].split(delimiter)[-1]
# Exclude empty paths
if not path:
continue
cleaned_rows.append(
{
"path": path,
"type": "File",
"modified": row["LastModified"],
"size": row["Size"],
"etag": row["ETag"].replace('"', ""),
"full_path": row["Key"],
"full_s3_prefix": os.path.join(bucket, row["Key"]),
}
)
schema = _get_s3_preview_schema()
spark = get_spark_session()
df = spark.createDataFrame(cleaned_rows, schema=schema)
return df.sort(desc("type"), asc("path"))
def _get_s3_objects(bucket, prefix, delimiter="/", token=None):
kwargs = {
"Bucket": bucket,
"Prefix": prefix,
"Delimiter": delimiter,
}
if token is not None:
kwargs["ContinuationToken"] = token
client = get_s3_client()
resp = client.list_objects_v2(**kwargs)
return resp
def _get_s3_preview_schema():
columns = [
StructField("path", StringType(), True),
StructField("type", StringType(), True),
StructField("modified", TimestampType(), True),
StructField("size", LongType(), True),
StructField("etag", StringType(), True),
StructField("full_path", StringType(), True),
StructField("full_s3_prefix", StringType(), True),
]
return StructType(columns)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment