Created
January 16, 2023 00:34
-
-
Save Wind010/dd124dacefefff1e201843f3a7d48593 to your computer and use it in GitHub Desktop.
Standardizing Input Values with Pyspark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Databricks notebook source | |
from pyspark.sql import SparkSession, DataFrame, DataFrameWriter | |
from pyspark.sql.functions import udf, to_timestamp | |
import pyspark.sql.functions as f | |
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, BooleanType, TimestampType | |
from typing import List, Set, Tuple, Dict | |
from datetime import datetime | |
from delta.tables import * | |
from pyspark.sql import Row | |
spark_session: SparkSession = SparkSession.builder \ | |
.config("spark.sql.execution.arrow.enabled", "true") \ | |
.config("spark.cores.max", "16") \ | |
.config("spark.executor.heartbeatInterval", "3600s") \ | |
.config("spark.network.timeout", "4200s") \ | |
.config("spark.driver.memory", "20g") \ | |
.config("spark.executor.memory", "12g") \ | |
.config("spark.memory.offHeap.enabled", False) \ | |
.config("spark.memory.offHeap.size", "12g") \ | |
.config("spark.ui.showConsoleProgress", "false") \ | |
.config("spark.sql.legacy.timeParserPolicy", "LEGACY") \ | |
.config("spark.databricks.delta.retentionDurationCheck.enabled", "false") \ | |
.getOrCreate() | |
spark_session.sparkContext.setCheckpointDir('/mnt/.chkpt') | |
# COMMAND ---------- | |
# MAGIC %md # Setup | |
# COMMAND ---------- | |
df = spark.createDataFrame([(1, 2, 'Visa', 4), (1, 2, 'VISA', 5), (1, 2, 'MasterCard', 6), (1, 2, 'Amex', 7), (1, 2, 'Cash', 100), (1, 2, 'Weird - CAD $', 101)], ['a', 'b', 'tender_type', 'tender_amount']) | |
df.show() | |
TENDER_TYPE = 'tender_type' | |
CREDIT_AND_DEBIT_CARD_TYPES = ['AMEX', 'Mastercard', 'VISA', 'Discover', 'Debit', 'Debit/Credit'] | |
PAYMENT_TYPES = ['Cash', 'NA', 'Weird - CAD $', 'Weird - USD $', *CREDIT_AND_DEBIT_CARD_TYPES] | |
# COMMAND ---------- | |
# MAGIC %md # Option 1 - Loop with Regex Replace | |
# COMMAND ---------- | |
df1 = df | |
for c in PAYMENT_TYPES: | |
#print(fr"(?i)\b{c}\b") | |
df1 = df1.withColumn(TENDER_TYPE, f.ltrim(f.col(TENDER_TYPE))) \ | |
.withColumn(TENDER_TYPE, f.rtrim(f.col(TENDER_TYPE))) \ | |
.withColumn(TENDER_TYPE, f.regexp_replace(TENDER_TYPE, fr"(?i)\b{c}\b", c)) | |
df1.display() | |
df1.explain(mode="cost") | |
# COMMAND ---------- | |
# MAGIC %md # Option 2 - Lowercase and map | |
# COMMAND ---------- | |
# This solution offers better performance at the cost of setting everything to lower and expecting everything to be values within PAYMENT_TYPES. | |
d_payment_types = { pt.lower(): pt for pt in PAYMENT_TYPES } | |
df3 = df.withColumn(TENDER_TYPE, f.ltrim(f.col(TENDER_TYPE))) \ | |
.withColumn(TENDER_TYPE, f.rtrim(f.col(TENDER_TYPE))) \ | |
.withColumn(TENDER_TYPE, f.lower(f.col(TENDER_TYPE))) | |
z = [*zip([f.lit(pt.lower()) for pt in PAYMENT_TYPES], [f.lit(pt) for pt in PAYMENT_TYPES])] | |
#print(z) | |
tup = sum(z, ()) | |
print(tup, type(tup)) | |
map_col = f.create_map(*tup) | |
#map_col = f.create_map(*sum([*zip([f.lit(pt.lower()) for pt in PAYMENT_TYPES], [f.lit(pt) for pt in PAYMENT_TYPES])], ())) | |
#map_col = f.create_map([f.lit(x) for item in d_payment_types.items() for x in item]) | |
df4 = df3.withColumn(TENDER_TYPE, map_col[f.col(TENDER_TYPE)]) | |
#print([x for i in d_payment_types.items() for x in i]) | |
#print('l') | |
#print([pt.lower() for pt in PAYMENT_TYPES] + [pt for pt in PAYMENT_TYPES]) | |
df4.show() | |
df4.explain(mode="cost") | |
# COMMAND ---------- | |
# MAGIC %md # Options 3 - Programmatically create case mapping | |
# COMMAND ---------- | |
from functools import reduce | |
df5 = df.withColumn(TENDER_TYPE, f.ltrim(f.col(TENDER_TYPE))) \ | |
.withColumn(TENDER_TYPE, f.rtrim(f.col(TENDER_TYPE))) | |
def add_column_safely(df, payment_type): | |
return df.withColumn(TENDER_TYPE, f.when(f.lower(f.col(TENDER_TYPE)) == payment_type.lower(), f.lit(payment_type)) \ | |
.otherwise(f.col(TENDER_TYPE))) | |
df6 = reduce(add_column_safely, PAYMENT_TYPES, df5) | |
df6.show() | |
df6.explain(mode="cost") | |
# df7 = reduce( | |
# lambda df, payment_type: df.withColumn(TENDER_TYPE, f.when(f.lower(f.col(TENDER_TYPE)) == payment_type.lower(), f.lit(payment_type)).otherwise(f.col(TENDER_TYPE))), | |
# PAYMENT_TYPES, | |
# df5 | |
# ) | |
# df7.display() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment