Skip to content

Instantly share code, notes, and snippets.

@Wind010
Created April 14, 2024 02:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Wind010/2091213ae2f2b788e0d28de53f195efc to your computer and use it in GitHub Desktop.
Save Wind010/2091213ae2f2b788e0d28de53f195efc to your computer and use it in GitHub Desktop.
from datetime import datetime
from typing import Any, List, Tuple
from pyspark.sql import DataFrame, SparkSession
from jdbc_configuration import JdbcConfiguration
# https://www.youtube.com/watch?v=_p73PZIDQuA
NOT_PREPARED: str = "The prepared_statement param is None!"
SQL_FORMAT: str = '%Y-%m-%d %H:%M:%S'
class JdbcClient:
"""
JDBC CRUD operations. Current supports SQL. Composable
https://spark.apache.org/docs/3.5.1/sql-data-sources-jdbc.html
"""
# Use abstract base class if we support for other databases requires more
# than driver change and different methods.
def __init__(self, spark_session: SparkSession, jdbc_config: JdbcConfiguration):
"""
Initialize the JDBC Query Executor with JDBC configuration.
:param: jdbc_config - JdbcConfiguration: Contains JDBC connection properties.
"""
self.session = spark_session
self.jdbc_config: JdbcConfiguration = jdbc_config
self.jdbc_url = f"jdbc:sqlserver://{self.jdbc_config.hostname}:{self.jdbc_config.port};database={self.jdbc_config.database}"
self.properties = {
"user": self.jdbc_config.username ,
"password": self.jdbc_config.password,
"driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver"
}
def execute_query(self, statement: str) -> DataFrame:
"""
Execute the specified SQL query using the configured JDBC connection without transaction.
:param - statement - str: The SQL query to execute.
:returns: The DataFrame containing the result of the query.
"""
return self.session.read \
.format("jdbc") \
.option("url", self.jdbc_url) \
.option("query", statement) \
.option("user", self.jdbc_config.username) \
.option("password", self.jdbc_config.password) \
.option("batchsize", self.jdbc_config.batch_size) \
.load()
def execute_upsert(self, statement: str, data: List[Any]) -> bool:
"""
Execute the specified SQL statement using the configured JDBC connection with transaction.
:param - statement - str: The SQL insert/update statement to execute.
:param - data - List of Tuples: The data to upsert.
:returns: The DataFrame containing the result of the query.
:notes: Offers flexibility in running SQL statements, but not optimized for many records.
:example:
data = [1, "John", 30]
statement = "INSERT INTO your_table_name (id, name, age) VALUES (?, ?, ?)"
"""
try:
# Begin a transaction
connection = self._get_connection() # type: ignore
# Prepare and execute the INSERT query within the transaction
prepared_statement = connection.prepareStatement(statement)
self._map_parameters_to_data_types(prepared_statement, data)
prepared_statement.executeUpdate()
# Commit the transaction
connection.commit()
return True
except Exception as e:
# Rollback the transaction if an error occurs
if 'connection' in locals():
connection.rollback()
print("Transaction rolled back due to error:", str(e))
raise
finally:
# Close the prepared statement and connection
if 'prepared_statement' in locals():
prepared_statement.close()
if 'connection' in locals():
connection.close()
def write(self, df: DataFrame, table_name: str, schema="dbo", mode: str = "Append") -> bool:
"""
Writes the dataframe to the specified table with given modes wrapped in a transaction.
:param - df - DataFrame: The dataframe to write to table.
:param - table_name - str: The table name to write to. Can contain schema.tablename.
:param - mode - str: The save mode - Overwrite, Append, Error, MERGE?
Append: Appends the data to the existing data in the target table.
If the table does not exist, it will be created.
Overwrite: Replaces the existing data in the target table with the data from the DataFrame.
If the table does not exist, it will be created.
Ignore: Does nothing if the target table already exists. It skips writing and does not perform any action.
Error: Throws an error if the target table already exists.
It will not write any data and will raise an exception.
Merge: This overwrites the schema also...
:returns: Boolean indicating success of execution.
"""
try:
# Begin a transaction
connection = self._get_connection() # type: ignore
# Write DataFrame to the SQL database table using JDBC
df.write.jdbc(url=self.jdbc_url, table=f"{schema}.{table_name}", mode=mode, properties=self.properties)
# Or below with more granularity.
# df.write \
# .format("jdbc") \
# .option("url", self.jdbc_url) \
# .option("user", self.jdbc_config.username) \
# .option("password", self.jdbc_config.password) \
# .option("batchsize", self.jdbc_config.batch_size) \
# .save()
# Commit the transaction
connection.commit()
return True
except Exception as e:
# Rollback the transaction if an error occurs
if 'connection' in locals():
connection.rollback()
print("Transaction rolled back due to error:", str(e))
raise
finally:
# Close the connection
if 'connection' in locals():
connection.close()
def _map_parameters_to_data_types(self, prepared_statement: Any, data: List[Any]):
"""
Sets parameters in a PreparedStatement based on data types.
:param prepared_statement: java.sql.PreparedStatement
The PreparedStatement object to set parameters on.
:param data: list[any]
List data to set as parameters in the PreparedStatement.
Each tuple represents a set of values to be inserted as a single row in the database table.
:return: None
:raises TypeError:
If the data argument is not a list of tuples.
"""
if not prepared_statement:
raise ValueError(NOT_PREPARED)
type_map = {
int: prepared_statement.setInt,
float: prepared_statement.setFloat,
str: prepared_statement.setString,
bool: prepared_statement.setBoolean,
bytes: prepared_statement.setBytes,
datetime: prepared_statement.setString
# Add more mappings as needed. ¯\_(ツ)_/¯
}
for idx, value in enumerate(data):
data_type = type(value)
if data_type in type_map:
if data_type is datetime:
value = value.strftime(SQL_FORMAT)
set_method = type_map[data_type]
set_method(idx + 1, value) # Note: PreparedStatement index starts from 1
else:
raise TypeError("No appropriate mapping found in type_map for passed in data type.")
def _get_connection(self):
connection = self.session._jvm.java.sql.DriverManager \
.getConnection(self.jdbc_url, self.properties["user"], self.properties["password"]) # type: ignore
connection.setAutoCommit(False)
return connection
class JdbcConfiguration:
def __init__(self, hostname, port, database, username, password):
"""
Initialize JDBC configuration.
Args:
hostname (str): The hostname or IP address of the database server.
port (str): The port number for the database server.
database (str): The name of the database to connect to.
username (str): The username for authentication.
password (str): The password for authentication.
"""
self.hostname = hostname
self.port = port
self.database = database
self.username = username
self.password = password
self.batch_size = 1000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment