Skip to content

Instantly share code, notes, and snippets.

@chicago-joe
Last active October 4, 2021 21:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chicago-joe/98c165a596c6eb12337a07073f8e9b4d to your computer and use it in GitHub Desktop.
Save chicago-joe/98c165a596c6eb12337a07073f8e9b4d to your computer and use it in GitHub Desktop.
MySQL connector class
#!/usr/bin/env python
# coding=utf-8
import os, logging, pendulum
import pandas as pd
from prefect.tasks.secrets import PrefectSecret
from prefect.tasks.aws import AWSSecretsManager
from pymysql import Connect, MySQLError
import time
# --------------------------------------------------------------------------------------------------
# mysql class
class MySQLController(object):
"""
Python Class for connecting with MySQL servers and querying,updating, or deleting using MySQL
"""
__instance = None
__session = None
__connection = None
__run_type = None
__schema = None
__config = None
def __init__(self, run_type = 'dev', schema = None):
if not schema:
print('\n PLEASE SELECT A SCHEMA!\nTry SRA or SRSE\n')
else:
self.__run_type = run_type
self.__schema = schema.upper()
self.__auth()
## End def __init__
def __auth(self):
cred = PrefectSecret("AWS_CREDENTIALS")
auth = AWSSecretsManager(boto_kwargs = { "region_name":"us-east-2" }).run(secret = "MYSQL_DB_AUTH", credentials = cred.run())
db_config = auth[self.__run_type]
self.__config = db_config[self.__schema]
## End def __auth
def __open(self):
try:
cnx = Connect(**self.__config)
self.__connection = cnx
self.__session = cnx.cursor()
except MySQLError as e:
print("Error %d: %s" % (e.args[0], e.args[1]))
## End def __open
def __close(self):
self.__session.close()
self.__connection.close()
## End def __close
def __check_table_locks(self, db='sradb', table=None, maxRetries=5, waitTime=10):
"""
check if table is locked
example usage:
if fnCheckTableLocks(connSRA, db='sradb', table='testtable'):
print('Connected')
do_other_stuff()
"""
if not table:
print('Please choose a table name')
return
db = db.lower()
table = table.lower()
maxRetries = maxRetries + 1
self.__open()
# for beta sheets, just check if table exists
if table.startswith('beta_sheet'):
query = """
SELECT COUNT(*) as existingTbl
FROM information_schema.tables
WHERE table_schema = '%s'
AND table_name = '%s'
;
""" % (db, table)
for i in range(maxRetries):
checkExists = pd.read_sql(query, self.__connection)['existingTbl']
if checkExists.values == 1:
return True
if i == maxRetries-1:
logging.error('===== MAXIMUM RETRIES EXCEEDED =====')
logging.error('FAILURE CONNECTING TO: %s.%s' % (db, table))
break
else:
logging.warning('Missing table %s.%s. Waiting %s seconds to reattempt %d of %s.' %
(db, table, waitTime, i+1, maxRetries-1))
time.sleep(waitTime)
## for all other tables, check In_use status
else:
query = """SHOW OPEN TABLES in %s WHERE In_use = 0""" % db
# loop through checking if table is locked
for i in range(maxRetries):
checkLock = pd.read_sql(query, self.__connection)['Table']
if table in checkLock.values:
return True
if i == maxRetries-1:
logging.error('===== MAXIMUM RETRIES EXCEEDED =====')
logging.error('FAILURE CONNECTING TO: %s.%s' % (db, table))
# self.__close()
break
# sys.exit(1)
else:
logging.warning('Missing table %s.%s. Waiting %s seconds to reattempt %d of %s.' %
(db, table, waitTime, i+1, maxRetries-1))
time.sleep(waitTime)
self.__close()
return True
## end def check_table_locks
def query(self, select="*", db=None, table=None, where=None, *args, **kwargs):
strSELECT = "SELECT "
strWHERE = " "
lstVals=[]
if not type(select)==list:
for val in select.split(","):
lstVals.append(val.strip(" "))
select = lstVals
strSELECT += ", ".join(select)
strFROM = f" FROM {db}.{table} "
if where:
if not type(where)==list:
where = [where]
for constraint in tuple(where):
strWHERE += f" AND {constraint}"
query = f"""
{strSELECT}
{strFROM}
WHERE 1=1
{strWHERE}
"""
self.__open()
result = pd.read_sql_query(query, self.__connection)
self.__close()
return result
## End def simple_query
def advanced_query(self, query = None, *args, **kwargs):
self.__open()
result = pd.read_sql_query(query, self.__connection)
self.__close()
return result
## End def advanced_query
def upload(self, df = None, db = None, table = None, mode = 'REPLACE', colNames = None, unlinkFile = True):
dttm = pendulum.now().to_datetime_string().replace('-', '_').replace(':', '_').replace(' ', '-')
from sraPy.common import setLogging
from sraPy.common import setOutputFilePath
setLogging()
tmpFile = setOutputFilePath(OUTPUT_SUBDIRECTORY = "upload", OUTPUT_FILE_NAME = f"{table} {dttm}-{os.getpid()}.txt")
logging.info(f"Creating temp file: {tmpFile}")
colsSQL = pd.read_sql(f"""SELECT * FROM {db}.{table} LIMIT 0;""", self.__connection).columns.tolist()
if colNames:
# check columns in db table vs dataframe
colsDF = df[colNames].columns.tolist()
colsDiff = set(colsSQL).symmetric_difference(set(colsDF))
if len(colsDiff) > 0:
logging.warning(f'----- COLUMN MISMATCH WHEN ATTEMPTING UPLOAD TO {table} -----')
if len(set(colsDF) - set(colsSQL)) > 0:
logging.warning(f'Columns in dataframe not found in {db}.{table}: \n{list((set(colsDF) - set(colsSQL)))}')
else:
df[colsDF].to_csv(tmpFile, sep = "\t", na_rep = "\\N", float_format = "%.8g", header = True, index = False, doublequote = False)
query = """LOAD DATA LOCAL INFILE '%s' %s INTO TABLE %s.%s LINES TERMINATED BY '\r\n' IGNORE 1 LINES (%s)""" % \
(tmpFile.replace('\\', '/'), mode, db, table, colsDF)
logging.debug(query)
self.__open()
rv = self.__session.execute(query)
self.__close()
logging.info("Number of rows affected: %s" % len(df))
return
# check columns in db table vs dataframe
colsDF = df.columns.tolist()
colsDiff = set(colsSQL).symmetric_difference(set(colsDF))
if len(colsDiff) > 0:
logging.warning('----- COLUMN MISMATCH WHEN ATTEMPTING TO UPLOAD %s -----' % table)
if len(set(colsSQL) - set(colsDF)) > 0:
logging.warning('Columns in %s.%s not found in dataframe: %s' % (db, table, list((set(colsSQL) - set(colsDF)))))
if len(set(colsDF) - set(colsSQL)) > 0:
logging.warning('Columns in dataframe not found in %s.%s: %s' % (db, table, list((set(colsDF) - set(colsSQL)))))
df[colsSQL].to_csv(tmpFile, sep = "\t", na_rep = "\\N", float_format = "%.8g", header = True, index = False, doublequote = False)
query = """LOAD DATA LOCAL INFILE '%s' %s INTO TABLE %s.%s LINES TERMINATED BY '\r\n' IGNORE 1 LINES""" % \
(tmpFile.replace('\\', '/'), mode, db, table)
logging.debug(query)
self.__open()
rv = self.__session.execute(query)
self.__close()
logging.info("Number of rows affected: %s" % len(df))
if unlinkFile:
os.unlink(tmpFile)
logging.info("Deleting temporary file: {}".format(tmpFile))
logging.info("DONE")
return
## End def upload
def delete(self, db, table, where = None, *args):
query = f"DELETE FROM {db}.{table}"
if where:
query += ' WHERE %s' % where
values = tuple(args)
self.__open()
self.__session.execute(query, values)
self.__connection.commit()
# Obtain rows affected
delete_rows = self.__session.rowcount
self.__close()
return delete_rows
## End def delete
## End class
# --------------------------------------------------------------------------------------------------
# sradb controller class
# noinspection SqlNoDataSourceInspection
class sraController(MySQLController):
# def __new__(cls, *args, **kwargs):
# if not cls.__instance: # or not cls.__database:
# kwargs = { }
# cls.__instance = super(MySQLController, cls).__new__(cls, *args, **kwargs)
# return cls.__instance
## End def __new__
def __init__(self, run_type = 'dev'):
super().__init__(run_type, schema = 'SRA').__init__()
def get_config(self, configType = 'account'):
db = "sradb"
msgName = f"msgSRA_{configType}config"
query = f"""
SELECT *
FROM
{db}.{msgName}
"""
self._MySQLController__open()
result = pd.read_sql_query(query, self._MySQLController__connection)
self._MySQLController__close()
return result
def get_risk_table(self, table = 'beta_sheet_strategy_risk', days = None, accnts = None, tickers = None, noTest = True):
"""
'beta_sheet_strategy_risk'
'beta_sheet_daily_strategy_risk'
'beta_sheet_daily_strategy_risk_bf'
'beta_sheet_intra_strategy_risk'
"""
db = "sradb"
msgName = table
strWHERE = ' '
if (noTest):
strTEST = " AND substr(accnt,1,1) != 'T'"
strWHERE = strWHERE + strTEST
if days:
strDAYS = " AND DATE >= DATE_SUB(CURDATE(), INTERVAL %s DAY)" % days
strWHERE = strWHERE + strDAYS
if accnts:
strACCNT = " AND accnt IN ('%s')" % "', '".join(accnts)
strWHERE = strWHERE + strACCNT
if tickers:
strSYM = " AND ticker_tk IN ('%s')" % "', '".join(tickers)
strWHERE = strWHERE + strSYM
table = "%s.%s" % (db, msgName)
checkTbl = self._MySQLController__check_table_locks(db='sradb', table=msgName)
self._MySQLController__open()
if checkTbl:
try:
query = f"""
SELECT *
FROM {table}
WHERE 1=1 {strWHERE}
;
"""
result = pd.read_sql_query(query, self._MySQLController__connection)
if (pd.isnull(len(result)) or len(result) == 0):
print("%s %s: " % (self.__connection, table))
print("SQL SELECT ERROR")
print("statement: %s" % query)
print("result: ")
print(result)
print("")
return
else:
return result
finally:
self._MySQLController__close()
# end load_risk_table def
# end class def
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment