Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Pyspark script to ingest Hive table to Druid using rovio-ingest library
#
# Copyright 2021 Rovio Entertainment Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import json
import logging
from datetime import datetime
import pyspark.sql.functions as f
import pyspark.sql.types as t
from pyspark.sql import DataFrame, SparkSession
from rovio_ingest import DRUID_SOURCE
from rovio_ingest.extensions.dataframe_extension import (
ConfKeys,
add_dataframe_druid_extension,
normalize_date,
)
def is_numeric(column_type):
return column_type in ("smallint", "int", "bigint", "float", "double")
def string_cols_to_list(str_cols):
"""Returns a list from a string of column names
:param str_cols: Comma seperated column names as string
"""
return [col.strip() for col in str_cols.split(",") if col.strip()]
def run(spark, args):
meta_db_uri = args.meta_db_uri
meta_db_username = args.meta_db_username
meta_db_password = args.meta_db_password
metastore_s3_bucket = args.metastore_s3_bucket
metastore_s3_basekey = args.metastore_s3_basekey
add_dataframe_druid_extension()
df = spark.table(args.source_table_name)
excluded_columns = string_cols_to_list(args.excluded_columns)
if excluded_columns:
logging.info("{} are excluded from ingestion".format(excluded_columns))
df = df.drop(*excluded_columns)
df = df.filter(f.col(args.time_column).isNotNull())
forced_dimension_columns = string_cols_to_list(args.forced_dimension_columns)
for (column_name, dtype) in df.dtypes:
if column_name == args.time_column:
if dtype != "timestamp":
# Extra logic for time_column casting to timestamp type if it isn't by default
if dtype in ["date", "string"]:
df = df.withColumn(
args.time_column,
f.col(args.time_column).cast(t.TimestampType()),
)
elif dtype == "int":
df = df.withColumn(
args.time_column,
f.to_timestamp(
f.col(args.time_column).cast(t.StringType()), "yyyyMMdd"
),
)
else:
raise RuntimeError(
"{}:{} in table {} not a INT/Date/String/Timestamp type".format(
args.time_column, dtype, args.source_table_name
)
)
if (column_name in forced_dimension_columns) and (dtype != "string"):
logging.info(
"Forcing '{}:{}' column to be a dimension column while casting "
"it to string!".format(column_name, dtype)
)
df = df.withColumn(column_name, f.col(column_name).cast(t.StringType()))
if args.start_processdate:
start_processdate = datetime.strptime(args.start_processdate, "%Y%m%d")
end_processdate = datetime.strptime(args.end_processdate, "%Y%m%d")
normalized_start_date = normalize_date(
spark, start_processdate, args.segment_granularity
)
logging.info(
f"normalized start date {start_processdate} => {normalized_start_date}"
)
df = df.filter(
f.col(args.time_column).between(normalized_start_date, end_processdate)
)
# When rollup is enabled original row count is lost during ingestion.
# Hence, we add count as a separate field. DataExplorer uses this when count is required.
if "_count" in df.columns:
raise Exception("Conflicting column _count exists in source df")
df = df.withColumn("_count", f.lit(1))
metrics_spec = parse_and_validate_metrics_spec(df, args.metrics)
logging.info(f"Metric spec = {metrics_spec}")
df_prepared: DataFrame = df.repartition_by_druid_segment_size(
args.time_column,
segment_granularity=args.segment_granularity,
rows_per_segment=args.rows_per_segment,
exclude_columns_with_unknown_types=args.exclude_columns_with_unknown_types,
)
df_prepared.write.mode("overwrite").format(DRUID_SOURCE).option(
ConfKeys.DATA_SOURCE, args.data_source_name
).option(ConfKeys.TIME_COLUMN, args.time_column).option(
ConfKeys.DEEP_STORAGE_S3_BUCKET, metastore_s3_bucket
).option(
ConfKeys.DEEP_STORAGE_S3_BASE_KEY, metastore_s3_basekey
).option(
ConfKeys.METADATA_DB_URI, meta_db_uri
).option(
ConfKeys.METADATA_DB_USERNAME, meta_db_username
).option(
ConfKeys.METADATA_DB_PASSWORD, meta_db_password
).option(
ConfKeys.SEGMENT_GRANULARITY, args.segment_granularity
).option(
ConfKeys.QUERY_GRANULARITY, args.query_granularity
).option(
ConfKeys.DATASOURCE_INIT, str(args.init)
).option(
ConfKeys.METRICS_SPEC, json.dumps(metrics_spec)
).save()
def parse_and_validate_metrics_spec(df, metrics_json_str):
"""
Uses sum aggregator for all metric-type columns unless specified otherwise.
If there is no aggregator defined in metrics_json_str for some column, this function appends sum aggregator.
"""
metrics_spec = []
if metrics_json_str:
metrics_json_str = metrics_json_str.replace("\n", "")
parsed_metrics = json.loads(metrics_json_str)
else:
parsed_metrics = []
if parsed_metrics:
if not isinstance(parsed_metrics, list):
raise Exception(
f"Metrics should be passed as json array, got {metrics_json_str}"
)
for metric in parsed_metrics:
if not isinstance(metric, dict) or any(
k not in metric for k in ("type", "fieldName", "name")
):
raise Exception(f"Missing mandatory args in metric json {metric}")
if metric["fieldName"] not in df.columns:
logging.warning(
f"Ignored metric aggregator {metric}, {metric['fieldName']} not"
" found in df"
)
else:
metrics_spec.append(metric)
metrics_field_name_to_name = {m["fieldName"]: m["name"] for m in metrics_spec}
for (column_name, dtype) in df.dtypes:
if (
is_numeric(dtype)
and metrics_field_name_to_name.get(column_name, None) is None
):
# Add default sum aggregation for columns that are not excluded and doesn't have any metricSpec.
aggregation = (
"longSum" if dtype in ("short", "int", "long") else "doubleSum"
)
metrics_spec.append(
{"type": aggregation, "name": column_name, "fieldName": column_name}
)
return metrics_spec
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--source-table-name", required=True, help="Hive table name")
parser.add_argument("--data-source-name", required=True, help="Druid dataSource")
parser.add_argument("--time-column", required=True, help="Druid time column")
parser.add_argument(
"--segment-granularity", default="DAY", help="Druid segment granularity"
)
parser.add_argument(
"--query-granularity", default="DAY", help="Druid query granularity"
)
parser.add_argument("--meta-db-uri", help="Override Druid Metadata db uri")
parser.add_argument(
"--meta-db-username", help="Override Druid Metadata db username"
)
parser.add_argument(
"--meta-db-password", help="Override Druid Metadata db password"
)
parser.add_argument(
"--metastore-s3-bucket", help="Override Druid deep storage bucket"
)
parser.add_argument(
"--metastore-s3-basekey", help="Override Druid deep storage prefix"
)
parser.add_argument(
"--start-processdate", help="Start processdate filter in YYYYMMDD format"
)
parser.add_argument(
"--end-processdate", help="End processdate filter in YYYYMMDD format"
)
parser.add_argument(
"--excluded-columns",
default="",
help="Comma separated list of source columns that are excluded from ingestion",
)
parser.add_argument(
"--forced-dimension-columns",
default="",
help=(
"Comma separated list of source columns. rovio-ingest library "
"treats numeric columns as metrics by default. This arg can be used "
"to force a numeric column as a dimension."
),
)
parser.add_argument("--rows-per-segment", type=int, default=5000000)
parser.add_argument(
"--exclude-columns-with-unknown-types",
type=bool,
default=False,
help="Exclude source columns with unknown types",
)
parser.add_argument(
"--init",
action="store_true",
default=False,
help="(Re)Init datasource, skips granularity check",
)
parser.add_argument(
"--env", help="Rovio environment", choices=["cloud", "mist", "smoke"]
)
parser.add_argument(
"--metrics", help="Metrics spec passed as json string", type=str
)
args = parser.parse_args()
if (args.start_processdate and not args.end_processdate) or (
not args.start_processdate and args.end_processdate
):
raise RuntimeError(
"Either both start-processdate and end-processdate required or neither"
)
spark = (
SparkSession.builder.appName(
f"druid_ingest {args.source_table_name} -> {args.data_source_name}"
)
.config("spark.sql.shuffle.partitions", 2000)
.config("spark.sql.session.timeZone", "UTC")
.enableHiveSupport()
.getOrCreate()
)
try:
run(spark, args)
finally:
spark.stop()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment