Skip to content

Instantly share code, notes, and snippets.

@vivek-balakrishnan-rovio
Created February 18, 2022 10:54
Show Gist options
  • Save vivek-balakrishnan-rovio/3af746050743f26fcb9333c549b546a4 to your computer and use it in GitHub Desktop.
Save vivek-balakrishnan-rovio/3af746050743f26fcb9333c549b546a4 to your computer and use it in GitHub Desktop.
Script to clean overshadowed segments from Metadata and Deep Storage
#
# Copyright 2021 Rovio Entertainment Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import functools
import logging
import sys
from xml.sax import SAXParseException
import requests
from boto.exception import BotoServerError
from boto.s3.connection import S3Connection
from mysql.connector import connect
# Internal util for retrying with backoff.
from utils.retry_util import retry_backoff
# Internal util for accessing secrets from AWS parameter store.
from utils.ssm_utils import SSMClient
DELETE_UNUSED_SEGMENTS_SQL = """
DELETE
FROM druid.druid_segments
WHERE used=false AND datasource='{datasource}' {VERSION}
ORDER BY id
{LIMIT}
"""
SELECT_USED_SEGMENT_SQL = """
SELECT id
FROM druid.druid_segments
WHERE used=true AND datasource='{datasource}'
"""
USED_SEGMENT_MAX_VERSION_PER_DATASOURCES_SQL = """
SELECT datasource, max(version)
FROM druid.druid_segments
WHERE used=true
GROUP BY 1
"""
def get_bucket(bucket_name,):
s3_host = "s3.amazonaws.com"
conn = S3Connection(host=s3_host)
return conn.get_bucket(bucket_name)
def get_datasources(include_unused=False):
hostname = f"druid.{get_base_domain()}"
url = f"http://{hostname}:8081/druid/coordinator/v1/metadata/datasources"
if include_unused:
url += "?includeUnused"
resp = requests.get(url, timeout=20)
resp.raise_for_status()
return resp.json()
def query_used_segment_max_version_per_datasource(mysql_conn):
segment_max_version = {}
with mysql_conn.cursor() as cursor:
cursor.execute(USED_SEGMENT_MAX_VERSION_PER_DATASOURCES_SQL)
for item in cursor:
segment_max_version[item[0]] = item[1]
return segment_max_version
def query_used_segments(conn, datasource):
select_sql = SELECT_USED_SEGMENT_SQL.format(datasource=datasource)
with conn.cursor() as cursor:
cursor.execute(select_sql)
return [item[0] for item in cursor]
def delete_unused_segment_rows(conn, datasource, max_version):
version_placeholder = f" AND version < '{max_version}'" if max_version else ""
delete_sql = DELETE_UNUSED_SEGMENTS_SQL.format(
datasource=datasource, VERSION=version_placeholder, LIMIT="LIMIT 10000"
)
logging.info("Deleting unused segments rows for %s", datasource)
logging.info("sql: %s", delete_sql)
deleted_count = 0
with conn.cursor() as cursor:
while True:
cursor.execute(delete_sql)
conn.commit()
if cursor.rowcount > 0:
logging.debug("Deleted %s rows from %s", cursor.rowcount, datasource)
deleted_count += cursor.rowcount
else:
break
logging.info(
"Deleted all unused segment rows for %s, total rows deleted = %s",
datasource,
deleted_count,
)
def parse_segment_id_and_version(s3_root, key_name):
"""
Parse segment_id and version from s3 key_name
segment path is of the form <s3_root_prefix>/<datasource>/<start_interval>_<end_interval>/<version>/<partition>/file
segment_id is datasource_<start_interval>_<end_interval>_<version>_<partition>
eg:
druid/segments/test/2018-10-01T00:00:00.000Z_2018-10-02T00:00:00.000Z/2020-05-11T07:09:28.722Z/1/descriptor.json
then segment_id is test_2018-10-01T00:00:00.000Z_2018-10-02T00:00:00.000Z_2020-05-11T07:09:28.722Z_1
Returns: Tuple of segment_id, version
"""
segment_id = key_name.split(f"{s3_root}/", 1)[1].rsplit("/", 1)[0]
segment_id_split = segment_id.split("/")
version = (
segment_id_split[2]
if len(segment_id_split) >= 3
else "1970-01-01T00:00:00.000Z"
)
# Partition 0 is not used in segment_id.
segment_id = segment_id.replace("/0", "")
# Segment_id delimiter is "_"
segment_id = segment_id.replace("/", "_")
return segment_id, version
def is_key_candidate_for_deletion(s3_root, key_name, used_segment_ids, max_version):
"""
Check if the key_name is candidate for deletion.
Key_name is candidate for deletion if key is not in used_segment_ids and version is less than max_version.
"""
segment_id, version = parse_segment_id_and_version(s3_root, key_name)
should_delete = True
if used_segment_ids:
should_delete = segment_id not in used_segment_ids
if should_delete and max_version:
should_delete = version < max_version
return should_delete
def delete_s3_files(datasource, max_version, used_segment_ids):
"""
List s3 files by datasource prefix and delete it in batches based on is_key_candidate_for_deletion() check.
"""
bucket_path = get_druid_s3_bucket_params()
s3_root = f"{bucket_path[1]}/{datasource}/"
bucket = get_bucket(bucket_path[0])
logging.info("Delete s3 path for prefix %s", s3_root)
logging.info(
"%s has used %s segment_id with max version : %s",
datasource,
len(used_segment_ids),
max_version,
)
keys_to_delete = [
key
for key in list(bucket.list(prefix=s3_root))
if is_key_candidate_for_deletion(
bucket_path[1], key.name, used_segment_ids, max_version
)
]
logging.info("%s s3 objects to be deleted for %s", len(keys_to_delete), datasource)
def _should_retry(e):
# AWS sometime returns empty string which fails to parse the xml and throws SAXParseException.
# AWS can also return S3 slowDown.
# Usually these works on retry, so backoff and retry until attempts are exhausted.
return e and (
isinstance(e, SAXParseException)
or (isinstance(e, BotoServerError) and str(e.status) == "503")
)
@retry_backoff(should_retry=_should_retry, max_retries=20)
def _delete(key_list):
bucket.delete_keys(key_list, quiet=True)
return len(key_list)
deleted_count = 0
if keys_to_delete:
chunk_size = 1000
chunks = [
keys_to_delete[x : chunk_size + x]
for x in range(0, len(keys_to_delete), chunk_size)
]
for chunk in chunks:
deleted_count += _delete(chunk)
logging.info(
"Deleted %s s3 objects in total with prefix %s", deleted_count, s3_root
)
def validate(datasource, used_segment_ids):
"""
Checks that there is some s3 object for all used_segment_ids from metastore db.
Raise error if there are no s3 object for any used_segment_ids from metastore db.
Also, warns if there are extra s3 objects that are missing from metastore db but that is not thrown
as an exception considering s3 eventual consistency.
"""
bucket_params = get_druid_s3_bucket_params()
s3_root = f"{bucket_params[1]}/{datasource}/"
bucket = get_bucket(bucket_params[0])
logging.info("Delete s3 path for prefix %s", s3_root)
keys = bucket.list(prefix=s3_root)
s3_segment_ids = [
parse_segment_id_and_version(bucket_params[1], key.name)[0]
for key in list(keys)
]
diff_s3_db = list(set(s3_segment_ids) - set(used_segment_ids))
diff_db_s3 = list(set(used_segment_ids) - set(s3_segment_ids))
logging.debug("diff s3->db : %s", diff_s3_db)
logging.debug("diff db->s3 : %s", diff_db_s3)
if diff_s3_db:
logging.warning(
"For %s, diff segment_ids s3->db : %s,"
"Note: this may be due to s3 eventual consistency for delete",
datasource,
diff_s3_db,
)
if diff_db_s3:
logging.error("For %s, diff db->s3 : %s", datasource, diff_db_s3)
raise Exception(
f"Mismatch in known segments between db and s3 path for {datasource}:\n"
f" diff segment_ids db->s3 : {len(diff_db_s3)}"
)
logging.info("Validated used segments in DB & s3 for %s", datasource)
def clean_unused_segment_for_datasource(mysql_conn, datasource, max_version):
# delete unused segment rows.
delete_unused_segment_rows(mysql_conn, datasource, max_version)
used_segment_ids = query_used_segments(mysql_conn, datasource)
# Delete s3 files.
delete_s3_files(datasource, max_version, used_segment_ids)
logging.info("Metadata DB and s3 files cleanup for %s", datasource)
validate(datasource, used_segment_ids)
def init_logging():
logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
def get_rds_connection_params():
# profile_name not needed because this is executed on EMR
params = SSMClient(profile_name=None).get_params("druid/")
jdbc_uri = params["metadata_db/uri"]
# jdbc uri is of the form jdbc:mysql://host:port/dbname
host_and_port = jdbc_uri.replace("jdbc:mysql://", "").split("/")[0].split(":")
return {
"user": params["metadata_db/username"],
"password": params["metadata_db/password"],
"host": host_and_port[0],
"port": int(host_and_port[1]),
}
@functools.lru_cache(maxsize=None)
def get_druid_s3_bucket_params():
# profile_name not needed because this is executed on EMR
params = SSMClient(profile_name=None).get_params("druid/")
bucket_name = params["deep_storage/bucket"]
if bucket_name[-1] == "/":
bucket_name = bucket_name[:-1]
base_key = params["deep_storage/basekey"]
if base_key[-1] == "/":
base_key = base_key[:-1]
return bucket_name, base_key
def get_base_domain():
# profile_name not needed because this is executed on EMR
params = SSMClient(profile_name=None).get_params("dns/")
return params["base_domain"]
def main():
init_logging()
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--datasources",
nargs="*",
type=str,
help=(
"list of datasources to clean, when not provided all datasources are"
" considered"
),
)
args = parser.parse_args()
connection_params = get_rds_connection_params()
datasources_override = args.datasources if args.datasources else []
logging.info("datasources override : %s", datasources_override)
all_datasources = get_datasources(include_unused=True)
if datasources_override:
all_datasources = [d for d in all_datasources if d in datasources_override]
logging.info("datasources to cleanup: %s", all_datasources)
with connect(**connection_params) as mysql_conn:
segment_max_version = query_used_segment_max_version_per_datasource(mysql_conn)
for datasource in all_datasources:
clean_unused_segment_for_datasource(
mysql_conn,
datasource,
segment_max_version.get(datasource, None),
)
logging.info("*" * 80)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment