Created
April 14, 2024 02:08
-
-
Save Wind010/2091213ae2f2b788e0d28de53f195efc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
from typing import Any, List, Tuple | |
from pyspark.sql import DataFrame, SparkSession | |
from jdbc_configuration import JdbcConfiguration | |
# https://www.youtube.com/watch?v=_p73PZIDQuA | |
NOT_PREPARED: str = "The prepared_statement param is None!" | |
SQL_FORMAT: str = '%Y-%m-%d %H:%M:%S' | |
class JdbcClient: | |
""" | |
JDBC CRUD operations. Current supports SQL. Composable | |
https://spark.apache.org/docs/3.5.1/sql-data-sources-jdbc.html | |
""" | |
# Use abstract base class if we support for other databases requires more | |
# than driver change and different methods. | |
def __init__(self, spark_session: SparkSession, jdbc_config: JdbcConfiguration): | |
""" | |
Initialize the JDBC Query Executor with JDBC configuration. | |
:param: jdbc_config - JdbcConfiguration: Contains JDBC connection properties. | |
""" | |
self.session = spark_session | |
self.jdbc_config: JdbcConfiguration = jdbc_config | |
self.jdbc_url = f"jdbc:sqlserver://{self.jdbc_config.hostname}:{self.jdbc_config.port};database={self.jdbc_config.database}" | |
self.properties = { | |
"user": self.jdbc_config.username , | |
"password": self.jdbc_config.password, | |
"driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver" | |
} | |
def execute_query(self, statement: str) -> DataFrame: | |
""" | |
Execute the specified SQL query using the configured JDBC connection without transaction. | |
:param - statement - str: The SQL query to execute. | |
:returns: The DataFrame containing the result of the query. | |
""" | |
return self.session.read \ | |
.format("jdbc") \ | |
.option("url", self.jdbc_url) \ | |
.option("query", statement) \ | |
.option("user", self.jdbc_config.username) \ | |
.option("password", self.jdbc_config.password) \ | |
.option("batchsize", self.jdbc_config.batch_size) \ | |
.load() | |
def execute_upsert(self, statement: str, data: List[Any]) -> bool: | |
""" | |
Execute the specified SQL statement using the configured JDBC connection with transaction. | |
:param - statement - str: The SQL insert/update statement to execute. | |
:param - data - List of Tuples: The data to upsert. | |
:returns: The DataFrame containing the result of the query. | |
:notes: Offers flexibility in running SQL statements, but not optimized for many records. | |
:example: | |
data = [1, "John", 30] | |
statement = "INSERT INTO your_table_name (id, name, age) VALUES (?, ?, ?)" | |
""" | |
try: | |
# Begin a transaction | |
connection = self._get_connection() # type: ignore | |
# Prepare and execute the INSERT query within the transaction | |
prepared_statement = connection.prepareStatement(statement) | |
self._map_parameters_to_data_types(prepared_statement, data) | |
prepared_statement.executeUpdate() | |
# Commit the transaction | |
connection.commit() | |
return True | |
except Exception as e: | |
# Rollback the transaction if an error occurs | |
if 'connection' in locals(): | |
connection.rollback() | |
print("Transaction rolled back due to error:", str(e)) | |
raise | |
finally: | |
# Close the prepared statement and connection | |
if 'prepared_statement' in locals(): | |
prepared_statement.close() | |
if 'connection' in locals(): | |
connection.close() | |
def write(self, df: DataFrame, table_name: str, schema="dbo", mode: str = "Append") -> bool: | |
""" | |
Writes the dataframe to the specified table with given modes wrapped in a transaction. | |
:param - df - DataFrame: The dataframe to write to table. | |
:param - table_name - str: The table name to write to. Can contain schema.tablename. | |
:param - mode - str: The save mode - Overwrite, Append, Error, MERGE? | |
Append: Appends the data to the existing data in the target table. | |
If the table does not exist, it will be created. | |
Overwrite: Replaces the existing data in the target table with the data from the DataFrame. | |
If the table does not exist, it will be created. | |
Ignore: Does nothing if the target table already exists. It skips writing and does not perform any action. | |
Error: Throws an error if the target table already exists. | |
It will not write any data and will raise an exception. | |
Merge: This overwrites the schema also... | |
:returns: Boolean indicating success of execution. | |
""" | |
try: | |
# Begin a transaction | |
connection = self._get_connection() # type: ignore | |
# Write DataFrame to the SQL database table using JDBC | |
df.write.jdbc(url=self.jdbc_url, table=f"{schema}.{table_name}", mode=mode, properties=self.properties) | |
# Or below with more granularity. | |
# df.write \ | |
# .format("jdbc") \ | |
# .option("url", self.jdbc_url) \ | |
# .option("user", self.jdbc_config.username) \ | |
# .option("password", self.jdbc_config.password) \ | |
# .option("batchsize", self.jdbc_config.batch_size) \ | |
# .save() | |
# Commit the transaction | |
connection.commit() | |
return True | |
except Exception as e: | |
# Rollback the transaction if an error occurs | |
if 'connection' in locals(): | |
connection.rollback() | |
print("Transaction rolled back due to error:", str(e)) | |
raise | |
finally: | |
# Close the connection | |
if 'connection' in locals(): | |
connection.close() | |
def _map_parameters_to_data_types(self, prepared_statement: Any, data: List[Any]): | |
""" | |
Sets parameters in a PreparedStatement based on data types. | |
:param prepared_statement: java.sql.PreparedStatement | |
The PreparedStatement object to set parameters on. | |
:param data: list[any] | |
List data to set as parameters in the PreparedStatement. | |
Each tuple represents a set of values to be inserted as a single row in the database table. | |
:return: None | |
:raises TypeError: | |
If the data argument is not a list of tuples. | |
""" | |
if not prepared_statement: | |
raise ValueError(NOT_PREPARED) | |
type_map = { | |
int: prepared_statement.setInt, | |
float: prepared_statement.setFloat, | |
str: prepared_statement.setString, | |
bool: prepared_statement.setBoolean, | |
bytes: prepared_statement.setBytes, | |
datetime: prepared_statement.setString | |
# Add more mappings as needed. ¯\_(ツ)_/¯ | |
} | |
for idx, value in enumerate(data): | |
data_type = type(value) | |
if data_type in type_map: | |
if data_type is datetime: | |
value = value.strftime(SQL_FORMAT) | |
set_method = type_map[data_type] | |
set_method(idx + 1, value) # Note: PreparedStatement index starts from 1 | |
else: | |
raise TypeError("No appropriate mapping found in type_map for passed in data type.") | |
def _get_connection(self): | |
connection = self.session._jvm.java.sql.DriverManager \ | |
.getConnection(self.jdbc_url, self.properties["user"], self.properties["password"]) # type: ignore | |
connection.setAutoCommit(False) | |
return connection | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class JdbcConfiguration: | |
def __init__(self, hostname, port, database, username, password): | |
""" | |
Initialize JDBC configuration. | |
Args: | |
hostname (str): The hostname or IP address of the database server. | |
port (str): The port number for the database server. | |
database (str): The name of the database to connect to. | |
username (str): The username for authentication. | |
password (str): The password for authentication. | |
""" | |
self.hostname = hostname | |
self.port = port | |
self.database = database | |
self.username = username | |
self.password = password | |
self.batch_size = 1000 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment