Skip to content

Instantly share code, notes, and snippets.

@DhyanRathore
Last active July 7, 2021 08:13
Show Gist options
  • Save DhyanRathore/e361e73acf497d320a3bd2b53e78cd9c to your computer and use it in GitHub Desktop.
Save DhyanRathore/e361e73acf497d320a3bd2b53e78cd9c to your computer and use it in GitHub Desktop.
Cleansing and transforming schema drifted csv files into relational data with incremental loads in Azure Databricks
# Python/PySpark code for cleansing and transforming schema drifted csv files into relational data with incremental loads in Azure Databricks
# Author: Dhyanendra Singh Rathore
# Define the variables used for creating connection strings
adlsAccountName = "dlscsvdataproject"
adlsContainerName = "csv-data-store"
adlsFolderName = "covid19-data"
mountPoint = "/mnt/csvFiles"
# Application (Client) ID
applicationId = dbutils.secrets.get(scope="CSVProjectKeyVault",key="ClientId")
# Application (Client) Secret Key
authenticationKey = dbutils.secrets.get(scope="CSVProjectKeyVault",key="ClientSecret")
# Directory (Tenant) ID
tenandId = dbutils.secrets.get(scope="CSVProjectKeyVault",key="TenantId")
endpoint = "https://login.microsoftonline.com/" + tenandId + "/oauth2/token"
source = "abfss://" + adlsContainerName + "@" + adlsAccountName + ".dfs.core.windows.net/" + adlsFolderName
# Connecting using Service Principal secrets and OAuth
configs = {"fs.azure.account.auth.type": "OAuth",
"fs.azure.account.oauth.provider.type": "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider",
"fs.azure.account.oauth2.client.id": applicationId,
"fs.azure.account.oauth2.client.secret": authenticationKey,
"fs.azure.account.oauth2.client.endpoint": endpoint}
# Mounting ADLS Storage to DBFS
# Mount only if it's not already mounted
if not any(mount.mountPoint == mountPoint for mount in dbutils.fs.mounts()):
dbutils.fs.mount(
source = source,
mount_point = mountPoint,
extra_configs = configs)
# Declare variables for creating JDBC URL
jdbcHostname = "sql-csv-data-server.database.windows.net"
jdbcPort = 1433
jdbcDatabase = "syn-csv-data-dw"
jdbcTable = "csvData.covidcsvdata"
# Connection secrets from vault
jdbcUsername = dbutils.secrets.get(scope="CSVProjectKeyVault",key="SQLAdmin")
jdbcPassword = dbutils.secrets.get(scope="CSVProjectKeyVault",key="SQLAdminPwd")
# Create JDBC URL
jdbcUrl = "jdbc:sqlserver://{0}:{1};database={2}".format(jdbcHostname, jdbcPort, jdbcDatabase)
connectionProperties = {
"user" : jdbcUsername,
"password" : jdbcPassword,
"driver" : "com.microsoft.sqlserver.jdbc.SQLServerDriver"
}
# import lit() to create new columns in our dataframe
from pyspark.sql.functions import lit
# Function to flatten the column names by removing (' ', '/', '_') and converting them to lowercase letters
def rename_columns(rename_df):
for column in rename_df.columns:
new_column = column.replace(' ','').replace('/','').replace('_','')
rename_df = rename_df.withColumnRenamed(column, new_column.lower())
return rename_df
# List all the files we have in our store to iterate through them one by one
file_list = [file.name for file in dbutils.fs.ls("dbfs:{}".format(mountPoint))]
# Find out the last loaded file to use as cut-off for our incremental load
lastLoadedFileQuery = "(SELECT MAX(sourcefile) as sourcefile FROM csvData.covidcsvdata) t"
lastFileDf = spark.read.jdbc(url=jdbcUrl, table=lastLoadedFileQuery, properties=connectionProperties)
lastFile = lastFileDf.collect()[0][0]
# Find the index of the file from the list
loadFrom = file_list.index('{}.csv'.format(lastFile)) + 1 if lastFile else 0
# Trim the list keeping only the files that should be processed
file_list = file_list[loadFrom:]
# Iterate through the files
for file in file_list:
loadFile = "{0}/{1}".format(mountPoint, file)
# Read the csv files with first line as header, comma (,) as separator, and detect schema from the file
csvDf = spark.read.format("csv") \
.option("inferSchema", "true") \
.option("header", "true") \
.option("sep", ",") \
.load(loadFile)
csvDf = rename_columns(csvDf)
# Check dataframe and add/rename columns to fit our database table structure
if 'lat' in csvDf.columns:
csvDf = csvDf.withColumnRenamed('lat', 'latitude')
if 'long' in csvDf.columns:
csvDf = csvDf.withColumnRenamed('long', 'longitude')
if 'active' not in csvDf.columns:
csvDf = csvDf.withColumn('active', lit(None).cast("int"))
if 'latitude' not in csvDf.columns:
csvDf = csvDf.withColumn('latitude', lit(None).cast("decimal"))
if 'longitude' not in csvDf.columns:
csvDf = csvDf.withColumn('longitude', lit(None).cast("decimal"))
# Add the source file name (without the extension) as an additional column to help us keep track of data source
csvDf = csvDf.withColumn("sourcefile", lit(file.split('.')[0]))
csvDf = csvDf.select("provincestate", "countryregion", "lastupdate", "confirmed", "deaths", "recovered", "active", "latitude", "longitude", "sourcefile")
# Write the cleansed data to the database
csvDf.createOrReplaceTempView("finalCsvData")
spark.table("finalCsvData").write.mode("append").jdbc(url=jdbcUrl, table=jdbcTable, properties=connectionProperties)
# Unmount only if directory is mounted
if any(mount.mountPoint == mountPoint for mount in dbutils.fs.mounts()):
dbutils.fs.unmount(mountPoint)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment