Pyspark script to ingest Hive table to Druid using rovio-ingest library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Copyright 2021 Rovio Entertainment Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import argparse | |
import json | |
import logging | |
from datetime import datetime | |
import pyspark.sql.functions as f | |
import pyspark.sql.types as t | |
from pyspark.sql import DataFrame, SparkSession | |
from rovio_ingest import DRUID_SOURCE | |
from rovio_ingest.extensions.dataframe_extension import ( | |
ConfKeys, | |
add_dataframe_druid_extension, | |
normalize_date, | |
) | |
def is_numeric(column_type): | |
return column_type in ("smallint", "int", "bigint", "float", "double") | |
def string_cols_to_list(str_cols): | |
"""Returns a list from a string of column names | |
:param str_cols: Comma seperated column names as string | |
""" | |
return [col.strip() for col in str_cols.split(",") if col.strip()] | |
def run(spark, args): | |
meta_db_uri = args.meta_db_uri | |
meta_db_username = args.meta_db_username | |
meta_db_password = args.meta_db_password | |
metastore_s3_bucket = args.metastore_s3_bucket | |
metastore_s3_basekey = args.metastore_s3_basekey | |
add_dataframe_druid_extension() | |
df = spark.table(args.source_table_name) | |
excluded_columns = string_cols_to_list(args.excluded_columns) | |
if excluded_columns: | |
logging.info("{} are excluded from ingestion".format(excluded_columns)) | |
df = df.drop(*excluded_columns) | |
df = df.filter(f.col(args.time_column).isNotNull()) | |
forced_dimension_columns = string_cols_to_list(args.forced_dimension_columns) | |
for (column_name, dtype) in df.dtypes: | |
if column_name == args.time_column: | |
if dtype != "timestamp": | |
# Extra logic for time_column casting to timestamp type if it isn't by default | |
if dtype in ["date", "string"]: | |
df = df.withColumn( | |
args.time_column, | |
f.col(args.time_column).cast(t.TimestampType()), | |
) | |
elif dtype == "int": | |
df = df.withColumn( | |
args.time_column, | |
f.to_timestamp( | |
f.col(args.time_column).cast(t.StringType()), "yyyyMMdd" | |
), | |
) | |
else: | |
raise RuntimeError( | |
"{}:{} in table {} not a INT/Date/String/Timestamp type".format( | |
args.time_column, dtype, args.source_table_name | |
) | |
) | |
if (column_name in forced_dimension_columns) and (dtype != "string"): | |
logging.info( | |
"Forcing '{}:{}' column to be a dimension column while casting " | |
"it to string!".format(column_name, dtype) | |
) | |
df = df.withColumn(column_name, f.col(column_name).cast(t.StringType())) | |
if args.start_processdate: | |
start_processdate = datetime.strptime(args.start_processdate, "%Y%m%d") | |
end_processdate = datetime.strptime(args.end_processdate, "%Y%m%d") | |
normalized_start_date = normalize_date( | |
spark, start_processdate, args.segment_granularity | |
) | |
logging.info( | |
f"normalized start date {start_processdate} => {normalized_start_date}" | |
) | |
df = df.filter( | |
f.col(args.time_column).between(normalized_start_date, end_processdate) | |
) | |
# When rollup is enabled original row count is lost during ingestion. | |
# Hence, we add count as a separate field. DataExplorer uses this when count is required. | |
if "_count" in df.columns: | |
raise Exception("Conflicting column _count exists in source df") | |
df = df.withColumn("_count", f.lit(1)) | |
metrics_spec = parse_and_validate_metrics_spec(df, args.metrics) | |
logging.info(f"Metric spec = {metrics_spec}") | |
df_prepared: DataFrame = df.repartition_by_druid_segment_size( | |
args.time_column, | |
segment_granularity=args.segment_granularity, | |
rows_per_segment=args.rows_per_segment, | |
exclude_columns_with_unknown_types=args.exclude_columns_with_unknown_types, | |
) | |
df_prepared.write.mode("overwrite").format(DRUID_SOURCE).option( | |
ConfKeys.DATA_SOURCE, args.data_source_name | |
).option(ConfKeys.TIME_COLUMN, args.time_column).option( | |
ConfKeys.DEEP_STORAGE_S3_BUCKET, metastore_s3_bucket | |
).option( | |
ConfKeys.DEEP_STORAGE_S3_BASE_KEY, metastore_s3_basekey | |
).option( | |
ConfKeys.METADATA_DB_URI, meta_db_uri | |
).option( | |
ConfKeys.METADATA_DB_USERNAME, meta_db_username | |
).option( | |
ConfKeys.METADATA_DB_PASSWORD, meta_db_password | |
).option( | |
ConfKeys.SEGMENT_GRANULARITY, args.segment_granularity | |
).option( | |
ConfKeys.QUERY_GRANULARITY, args.query_granularity | |
).option( | |
ConfKeys.DATASOURCE_INIT, str(args.init) | |
).option( | |
ConfKeys.METRICS_SPEC, json.dumps(metrics_spec) | |
).save() | |
def parse_and_validate_metrics_spec(df, metrics_json_str): | |
""" | |
Uses sum aggregator for all metric-type columns unless specified otherwise. | |
If there is no aggregator defined in metrics_json_str for some column, this function appends sum aggregator. | |
""" | |
metrics_spec = [] | |
if metrics_json_str: | |
metrics_json_str = metrics_json_str.replace("\n", "") | |
parsed_metrics = json.loads(metrics_json_str) | |
else: | |
parsed_metrics = [] | |
if parsed_metrics: | |
if not isinstance(parsed_metrics, list): | |
raise Exception( | |
f"Metrics should be passed as json array, got {metrics_json_str}" | |
) | |
for metric in parsed_metrics: | |
if not isinstance(metric, dict) or any( | |
k not in metric for k in ("type", "fieldName", "name") | |
): | |
raise Exception(f"Missing mandatory args in metric json {metric}") | |
if metric["fieldName"] not in df.columns: | |
logging.warning( | |
f"Ignored metric aggregator {metric}, {metric['fieldName']} not" | |
" found in df" | |
) | |
else: | |
metrics_spec.append(metric) | |
metrics_field_name_to_name = {m["fieldName"]: m["name"] for m in metrics_spec} | |
for (column_name, dtype) in df.dtypes: | |
if ( | |
is_numeric(dtype) | |
and metrics_field_name_to_name.get(column_name, None) is None | |
): | |
# Add default sum aggregation for columns that are not excluded and doesn't have any metricSpec. | |
aggregation = ( | |
"longSum" if dtype in ("short", "int", "long") else "doubleSum" | |
) | |
metrics_spec.append( | |
{"type": aggregation, "name": column_name, "fieldName": column_name} | |
) | |
return metrics_spec | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--source-table-name", required=True, help="Hive table name") | |
parser.add_argument("--data-source-name", required=True, help="Druid dataSource") | |
parser.add_argument("--time-column", required=True, help="Druid time column") | |
parser.add_argument( | |
"--segment-granularity", default="DAY", help="Druid segment granularity" | |
) | |
parser.add_argument( | |
"--query-granularity", default="DAY", help="Druid query granularity" | |
) | |
parser.add_argument("--meta-db-uri", help="Override Druid Metadata db uri") | |
parser.add_argument( | |
"--meta-db-username", help="Override Druid Metadata db username" | |
) | |
parser.add_argument( | |
"--meta-db-password", help="Override Druid Metadata db password" | |
) | |
parser.add_argument( | |
"--metastore-s3-bucket", help="Override Druid deep storage bucket" | |
) | |
parser.add_argument( | |
"--metastore-s3-basekey", help="Override Druid deep storage prefix" | |
) | |
parser.add_argument( | |
"--start-processdate", help="Start processdate filter in YYYYMMDD format" | |
) | |
parser.add_argument( | |
"--end-processdate", help="End processdate filter in YYYYMMDD format" | |
) | |
parser.add_argument( | |
"--excluded-columns", | |
default="", | |
help="Comma separated list of source columns that are excluded from ingestion", | |
) | |
parser.add_argument( | |
"--forced-dimension-columns", | |
default="", | |
help=( | |
"Comma separated list of source columns. rovio-ingest library " | |
"treats numeric columns as metrics by default. This arg can be used " | |
"to force a numeric column as a dimension." | |
), | |
) | |
parser.add_argument("--rows-per-segment", type=int, default=5000000) | |
parser.add_argument( | |
"--exclude-columns-with-unknown-types", | |
type=bool, | |
default=False, | |
help="Exclude source columns with unknown types", | |
) | |
parser.add_argument( | |
"--init", | |
action="store_true", | |
default=False, | |
help="(Re)Init datasource, skips granularity check", | |
) | |
parser.add_argument( | |
"--env", help="Rovio environment", choices=["cloud", "mist", "smoke"] | |
) | |
parser.add_argument( | |
"--metrics", help="Metrics spec passed as json string", type=str | |
) | |
args = parser.parse_args() | |
if (args.start_processdate and not args.end_processdate) or ( | |
not args.start_processdate and args.end_processdate | |
): | |
raise RuntimeError( | |
"Either both start-processdate and end-processdate required or neither" | |
) | |
spark = ( | |
SparkSession.builder.appName( | |
f"druid_ingest {args.source_table_name} -> {args.data_source_name}" | |
) | |
.config("spark.sql.shuffle.partitions", 2000) | |
.config("spark.sql.session.timeZone", "UTC") | |
.enableHiveSupport() | |
.getOrCreate() | |
) | |
try: | |
run(spark, args) | |
finally: | |
spark.stop() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment