-
-
Save chicago-joe/98c165a596c6eb12337a07073f8e9b4d to your computer and use it in GitHub Desktop.
MySQL connector class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding=utf-8 | |
import os, logging, pendulum | |
import pandas as pd | |
from prefect.tasks.secrets import PrefectSecret | |
from prefect.tasks.aws import AWSSecretsManager | |
from pymysql import Connect, MySQLError | |
import time | |
# -------------------------------------------------------------------------------------------------- | |
# mysql class | |
class MySQLController(object): | |
""" | |
Python Class for connecting with MySQL servers and querying,updating, or deleting using MySQL | |
""" | |
__instance = None | |
__session = None | |
__connection = None | |
__run_type = None | |
__schema = None | |
__config = None | |
def __init__(self, run_type = 'dev', schema = None): | |
if not schema: | |
print('\n PLEASE SELECT A SCHEMA!\nTry SRA or SRSE\n') | |
else: | |
self.__run_type = run_type | |
self.__schema = schema.upper() | |
self.__auth() | |
## End def __init__ | |
def __auth(self): | |
cred = PrefectSecret("AWS_CREDENTIALS") | |
auth = AWSSecretsManager(boto_kwargs = { "region_name":"us-east-2" }).run(secret = "MYSQL_DB_AUTH", credentials = cred.run()) | |
db_config = auth[self.__run_type] | |
self.__config = db_config[self.__schema] | |
## End def __auth | |
def __open(self): | |
try: | |
cnx = Connect(**self.__config) | |
self.__connection = cnx | |
self.__session = cnx.cursor() | |
except MySQLError as e: | |
print("Error %d: %s" % (e.args[0], e.args[1])) | |
## End def __open | |
def __close(self): | |
self.__session.close() | |
self.__connection.close() | |
## End def __close | |
def __check_table_locks(self, db='sradb', table=None, maxRetries=5, waitTime=10): | |
""" | |
check if table is locked | |
example usage: | |
if fnCheckTableLocks(connSRA, db='sradb', table='testtable'): | |
print('Connected') | |
do_other_stuff() | |
""" | |
if not table: | |
print('Please choose a table name') | |
return | |
db = db.lower() | |
table = table.lower() | |
maxRetries = maxRetries + 1 | |
self.__open() | |
# for beta sheets, just check if table exists | |
if table.startswith('beta_sheet'): | |
query = """ | |
SELECT COUNT(*) as existingTbl | |
FROM information_schema.tables | |
WHERE table_schema = '%s' | |
AND table_name = '%s' | |
; | |
""" % (db, table) | |
for i in range(maxRetries): | |
checkExists = pd.read_sql(query, self.__connection)['existingTbl'] | |
if checkExists.values == 1: | |
return True | |
if i == maxRetries-1: | |
logging.error('===== MAXIMUM RETRIES EXCEEDED =====') | |
logging.error('FAILURE CONNECTING TO: %s.%s' % (db, table)) | |
break | |
else: | |
logging.warning('Missing table %s.%s. Waiting %s seconds to reattempt %d of %s.' % | |
(db, table, waitTime, i+1, maxRetries-1)) | |
time.sleep(waitTime) | |
## for all other tables, check In_use status | |
else: | |
query = """SHOW OPEN TABLES in %s WHERE In_use = 0""" % db | |
# loop through checking if table is locked | |
for i in range(maxRetries): | |
checkLock = pd.read_sql(query, self.__connection)['Table'] | |
if table in checkLock.values: | |
return True | |
if i == maxRetries-1: | |
logging.error('===== MAXIMUM RETRIES EXCEEDED =====') | |
logging.error('FAILURE CONNECTING TO: %s.%s' % (db, table)) | |
# self.__close() | |
break | |
# sys.exit(1) | |
else: | |
logging.warning('Missing table %s.%s. Waiting %s seconds to reattempt %d of %s.' % | |
(db, table, waitTime, i+1, maxRetries-1)) | |
time.sleep(waitTime) | |
self.__close() | |
return True | |
## end def check_table_locks | |
def query(self, select="*", db=None, table=None, where=None, *args, **kwargs): | |
strSELECT = "SELECT " | |
strWHERE = " " | |
lstVals=[] | |
if not type(select)==list: | |
for val in select.split(","): | |
lstVals.append(val.strip(" ")) | |
select = lstVals | |
strSELECT += ", ".join(select) | |
strFROM = f" FROM {db}.{table} " | |
if where: | |
if not type(where)==list: | |
where = [where] | |
for constraint in tuple(where): | |
strWHERE += f" AND {constraint}" | |
query = f""" | |
{strSELECT} | |
{strFROM} | |
WHERE 1=1 | |
{strWHERE} | |
""" | |
self.__open() | |
result = pd.read_sql_query(query, self.__connection) | |
self.__close() | |
return result | |
## End def simple_query | |
def advanced_query(self, query = None, *args, **kwargs): | |
self.__open() | |
result = pd.read_sql_query(query, self.__connection) | |
self.__close() | |
return result | |
## End def advanced_query | |
def upload(self, df = None, db = None, table = None, mode = 'REPLACE', colNames = None, unlinkFile = True): | |
dttm = pendulum.now().to_datetime_string().replace('-', '_').replace(':', '_').replace(' ', '-') | |
from sraPy.common import setLogging | |
from sraPy.common import setOutputFilePath | |
setLogging() | |
tmpFile = setOutputFilePath(OUTPUT_SUBDIRECTORY = "upload", OUTPUT_FILE_NAME = f"{table} {dttm}-{os.getpid()}.txt") | |
logging.info(f"Creating temp file: {tmpFile}") | |
colsSQL = pd.read_sql(f"""SELECT * FROM {db}.{table} LIMIT 0;""", self.__connection).columns.tolist() | |
if colNames: | |
# check columns in db table vs dataframe | |
colsDF = df[colNames].columns.tolist() | |
colsDiff = set(colsSQL).symmetric_difference(set(colsDF)) | |
if len(colsDiff) > 0: | |
logging.warning(f'----- COLUMN MISMATCH WHEN ATTEMPTING UPLOAD TO {table} -----') | |
if len(set(colsDF) - set(colsSQL)) > 0: | |
logging.warning(f'Columns in dataframe not found in {db}.{table}: \n{list((set(colsDF) - set(colsSQL)))}') | |
else: | |
df[colsDF].to_csv(tmpFile, sep = "\t", na_rep = "\\N", float_format = "%.8g", header = True, index = False, doublequote = False) | |
query = """LOAD DATA LOCAL INFILE '%s' %s INTO TABLE %s.%s LINES TERMINATED BY '\r\n' IGNORE 1 LINES (%s)""" % \ | |
(tmpFile.replace('\\', '/'), mode, db, table, colsDF) | |
logging.debug(query) | |
self.__open() | |
rv = self.__session.execute(query) | |
self.__close() | |
logging.info("Number of rows affected: %s" % len(df)) | |
return | |
# check columns in db table vs dataframe | |
colsDF = df.columns.tolist() | |
colsDiff = set(colsSQL).symmetric_difference(set(colsDF)) | |
if len(colsDiff) > 0: | |
logging.warning('----- COLUMN MISMATCH WHEN ATTEMPTING TO UPLOAD %s -----' % table) | |
if len(set(colsSQL) - set(colsDF)) > 0: | |
logging.warning('Columns in %s.%s not found in dataframe: %s' % (db, table, list((set(colsSQL) - set(colsDF))))) | |
if len(set(colsDF) - set(colsSQL)) > 0: | |
logging.warning('Columns in dataframe not found in %s.%s: %s' % (db, table, list((set(colsDF) - set(colsSQL))))) | |
df[colsSQL].to_csv(tmpFile, sep = "\t", na_rep = "\\N", float_format = "%.8g", header = True, index = False, doublequote = False) | |
query = """LOAD DATA LOCAL INFILE '%s' %s INTO TABLE %s.%s LINES TERMINATED BY '\r\n' IGNORE 1 LINES""" % \ | |
(tmpFile.replace('\\', '/'), mode, db, table) | |
logging.debug(query) | |
self.__open() | |
rv = self.__session.execute(query) | |
self.__close() | |
logging.info("Number of rows affected: %s" % len(df)) | |
if unlinkFile: | |
os.unlink(tmpFile) | |
logging.info("Deleting temporary file: {}".format(tmpFile)) | |
logging.info("DONE") | |
return | |
## End def upload | |
def delete(self, db, table, where = None, *args): | |
query = f"DELETE FROM {db}.{table}" | |
if where: | |
query += ' WHERE %s' % where | |
values = tuple(args) | |
self.__open() | |
self.__session.execute(query, values) | |
self.__connection.commit() | |
# Obtain rows affected | |
delete_rows = self.__session.rowcount | |
self.__close() | |
return delete_rows | |
## End def delete | |
## End class | |
# -------------------------------------------------------------------------------------------------- | |
# sradb controller class | |
# noinspection SqlNoDataSourceInspection | |
class sraController(MySQLController): | |
# def __new__(cls, *args, **kwargs): | |
# if not cls.__instance: # or not cls.__database: | |
# kwargs = { } | |
# cls.__instance = super(MySQLController, cls).__new__(cls, *args, **kwargs) | |
# return cls.__instance | |
## End def __new__ | |
def __init__(self, run_type = 'dev'): | |
super().__init__(run_type, schema = 'SRA').__init__() | |
def get_config(self, configType = 'account'): | |
db = "sradb" | |
msgName = f"msgSRA_{configType}config" | |
query = f""" | |
SELECT * | |
FROM | |
{db}.{msgName} | |
""" | |
self._MySQLController__open() | |
result = pd.read_sql_query(query, self._MySQLController__connection) | |
self._MySQLController__close() | |
return result | |
def get_risk_table(self, table = 'beta_sheet_strategy_risk', days = None, accnts = None, tickers = None, noTest = True): | |
""" | |
'beta_sheet_strategy_risk' | |
'beta_sheet_daily_strategy_risk' | |
'beta_sheet_daily_strategy_risk_bf' | |
'beta_sheet_intra_strategy_risk' | |
""" | |
db = "sradb" | |
msgName = table | |
strWHERE = ' ' | |
if (noTest): | |
strTEST = " AND substr(accnt,1,1) != 'T'" | |
strWHERE = strWHERE + strTEST | |
if days: | |
strDAYS = " AND DATE >= DATE_SUB(CURDATE(), INTERVAL %s DAY)" % days | |
strWHERE = strWHERE + strDAYS | |
if accnts: | |
strACCNT = " AND accnt IN ('%s')" % "', '".join(accnts) | |
strWHERE = strWHERE + strACCNT | |
if tickers: | |
strSYM = " AND ticker_tk IN ('%s')" % "', '".join(tickers) | |
strWHERE = strWHERE + strSYM | |
table = "%s.%s" % (db, msgName) | |
checkTbl = self._MySQLController__check_table_locks(db='sradb', table=msgName) | |
self._MySQLController__open() | |
if checkTbl: | |
try: | |
query = f""" | |
SELECT * | |
FROM {table} | |
WHERE 1=1 {strWHERE} | |
; | |
""" | |
result = pd.read_sql_query(query, self._MySQLController__connection) | |
if (pd.isnull(len(result)) or len(result) == 0): | |
print("%s %s: " % (self.__connection, table)) | |
print("SQL SELECT ERROR") | |
print("statement: %s" % query) | |
print("result: ") | |
print(result) | |
print("") | |
return | |
else: | |
return result | |
finally: | |
self._MySQLController__close() | |
# end load_risk_table def | |
# end class def |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment