Skip to content

Instantly share code, notes, and snippets.

@pippolino
Last active April 17, 2024 21:43
Show Gist options
  • Save pippolino/de7ef74afe0ade6d0f4682664637038a to your computer and use it in GitHub Desktop.
Save pippolino/de7ef74afe0ade6d0f4682664637038a to your computer and use it in GitHub Desktop.
Redshift Data API Query Executor
import time
import boto3
import pandas as pd
class RedshiftQueryExecutor:
def __init__(self, cluster_id, database, user, region):
self.client = boto3.client('redshift-data', region_name=region)
self.cluster_id = cluster_id
self.database = database
self.user = user
def execute_query(self, sql_query, timeout_seconds=300):
response = self.client.execute_statement(
ClusterIdentifier=self.cluster_id,
Database=self.database,
DbUser=self.user,
Sql=sql_query
)
query_id = response['Id']
return self.__wait_for_query_completion(query_id, timeout_seconds)
def __wait_for_query_completion(self, query_id, timeout_seconds):
start_time = time.time()
while True:
if time.time() - start_time > timeout_seconds:
raise TimeoutError("Query execution exceeded the timeout limit.")
status_response = self.client.describe_statement(Id=query_id)
status = status_response['Status']
if status in ['FINISHED', 'FAILED', 'ABORTED']:
return self.__handle_query_status(status, query_id, status_response)
time.sleep(1)
def __handle_query_status(self, status, query_id, status_response):
if status == 'FINISHED':
return self.__fetch_all_results(query_id)
elif status == 'FAILED':
raise Exception(f"Query failed: {status_response.get('ErrorMessage', 'No error message provided')}")
else:
raise Exception("Query was aborted")
def __fetch_all_results(self, query_id):
column_metadata = None
def result_generator():
nonlocal column_metadata
next_token = None
while True:
kwargs = {'Id': query_id, 'NextToken': next_token} if next_token else {'Id': query_id}
result_response = self.client.get_statement_result(**kwargs)
if column_metadata is None:
column_metadata = {col['name']: col['typeName'] for col in result_response['ColumnMetadata']}
for record in result_response['Records']:
yield {col: RedshiftQueryExecutor.__parse_field_value(field, column_metadata[col])
for col, field in zip(column_metadata.keys(), record)}
next_token = result_response.get('NextToken')
if not next_token:
break
return pd.DataFrame(result_generator())
@staticmethod
def __parse_field_value(field, col_type):
if 'isNull' in field and field['isNull']:
return None
for data_type, value in field.items():
if data_type != 'isNull':
return RedshiftQueryExecutor.__convert_to_type(value, data_type)
return None
@staticmethod
def __convert_to_type(value, data_type):
converters = {
'stringValue': str,
'booleanValue': bool,
'doubleValue': float,
'longValue': int,
'blobValue': bytes
}
return converters[data_type](value) if data_type in converters else value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment