Skip to content

Instantly share code, notes, and snippets.

@emmanuelnk
Last active May 9, 2024 07:46
Show Gist options
  • Save emmanuelnk/db62507184125ddfe24844bb552fc26d to your computer and use it in GitHub Desktop.
Save emmanuelnk/db62507184125ddfe24844bb552fc26d to your computer and use it in GitHub Desktop.
Python SQLAlchemy Basic Model, Session, DB Connection Classes
from sqlalchemy import event
import os
import logging
import sqlalchemy
import boto3
import base64
import json
from botocore.exceptions import ClientError
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class DB:
__instance = None
def __init__(self):
""" Virtually private constructor. """
if DB.__instance is not None:
raise Exception(
"This class is a singleton, use DB.create()")
else:
DB.__instance = self
self.engine = self.create_engine()
@staticmethod
def create():
if DB.__instance is None:
DB.__instance = DB()
return DB.__instance
@staticmethod
def get_secret(secret_name):
client = boto3.client('secretsmanager')
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
except ClientError as e:
if e.response['Error']['Code'] == 'DecryptionFailureException':
raise e
elif e.response['Error']['Code'] == 'InternalServiceErrorException':
raise e
elif e.response['Error']['Code'] == 'InvalidParameterException':
raise e
elif e.response['Error']['Code'] == 'InvalidRequestException':
raise e
elif e.response['Error']['Code'] == 'ResourceNotFoundException':
raise e
else:
if 'SecretString' in get_secret_value_response:
secret = get_secret_value_response['SecretString']
else:
secret = base64.b64decode(get_secret_value_response['SecretBinary'])
return json.loads(secret)
def get_credentials():
""" Fetch credentials from either environment variables (for testing) or AWS Secret Manager"""
if os.getenv('SECRETSMANAGER_RDS_PG_ID') is None:
return {
'username': os.getenv('POSTGRESQL_USER', 'postgres'),
'password': os.getenv('POSTGRESQL_PASSWORD', 'some_password'),
'host': os.getenv('POSTGRESQL_HOST', 'localhost'),
'port': os.getenv('POSTGRESQL_PORT', 5432),
'database': os.getenv('POSTGRESQL_DATABASE', 'user_database'),
}
# get all access credentials from secrets manager
credentials = DB.get_secret(os.getenv('SECRETSMANAGER_RDS_PG_ID'))
return {
'username': credentials['username'],
'password': credentials['password'],
'host': credentials['host'],
'port': credentials['port'],
'database': credentials['dbname'],
}
def create_engine(self):
credentials = DB.get_credentials()
return sqlalchemy.create_engine('{engine}://{user}:{password}@{host}:{port}/{database}'.format(
engine='postgres+psycopg2',
user=credentials['username'],
password=credentials['password'],
host=credentials['host'],
port=int(credentials['port']),
database=credentials['database']
),
pool_size=200,
max_overflow=0,
echo=bool(os.getenv('POSTGRESQL_DEBUG', False))
)
def connect(self):
return self.engine.connect()
from dataclasses import dataclass
from sqlalchemy import Column, Integer, String, DateTime, UniqueConstraint, func
from sqlalchemy.ext.declarative import declarative_base
from .db import DB
db = DB.create()
engine = db.engine
Base = declarative_base()
@dataclass
class User(Base):
__tablename__ = 'user'
# only one email can be attached to one id card
__table_args__ = (UniqueConstraint('email', 'id_card_no'),)
id = Column(Integer, primary_key=True)
first_name: str = Column(String)
last_name: str = Column(String)
id_card_no: str = Column(String)
email: str = Column(String)
created_at = Column(DateTime(timezone=True), default=func.now())
updated_at = Column(DateTime(timezone=True),
default=func.now(), onupdate=func.now())
# create table if it does not exist, if you change the model,
# you have to drop the table first for this code to alter it in the db
Base.metadata.create_all(engine)
from sqlalchemy.orm import sessionmaker
from .session import SessionHandler
from .models import User
from .db import DB
db = DB.create()
engine = db.engine
Session = sessionmaker(bind=engine)
session = Session()
try:
user_session = SessionHandler.create(session, User)
# add a new record
user_session.add({
"first_name": "john",
"last_name": "doe",
"id_card_no": "1234598765",
"email": "something@business.com"
})
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
import json
import datetime
import time
from dataclasses import asdict
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy import UniqueConstraint
from . import DB
class SchemaEncoder(json.JSONEncoder):
"""Encoder for converting Model objects into JSON."""
def default(self, obj):
if isinstance(obj, datetime.date):
return time.strftime('%Y-%m-%dT%H:%M:%SZ', obj.utctimetuple())
return json.JSONEncoder.default(self, obj)
class SessionHandler():
__instance = None
def __init__(self, session, model):
""" Virtually private constructor. """
SessionHandler.__instance = self
self.model = model
self.session = session
@staticmethod
def create(session, model):
SessionHandler.__instance = SessionHandler(session, model)
return SessionHandler.__instance
def add(self, record_dict):
record_model = self.model(**record_dict)
self.session.add(record_model)
def insert_many(self, record_list):
statements = [pg_insert(self.model).values(record_dict).on_conflict_do_nothing() for record_dict in record_list]
return [self.session.execute(statement) for statement in statements]
def add_many(self, record_list):
return self.session.add_all([self.model(**record_dict) for record_dict in record_list])
def update(self, query_dict, update_dict):
return self.session.query(self.model).filter_by(**query_dict).update(update_dict)
def upsert(self, record_dict, set_dict, constraint):
statement = pg_insert(self.model).values(record_dict).on_conflict_do_update(
constraint=constraint,
set_= set_dict
)
return self.session.execute(statement)
def get(self, id, to_json=None):
result = self.session.query(self.model).get(id)
return asdict(result) if to_json is None else self.to_json(result)
def get_one(self, query_dict, to_json=None):
result = self.session.query(self.model).filter_by(**query_dict).first()
return asdict(result) if to_json is None else self.to_json(result)
def get_latest(self, query_dict, to_json=None):
result = self.session.query(self.model).filter_by(**query_dict).order_by(self.model.updated_at.desc()).first()
return None if result is None else (asdict(result) if to_json is None else self.to_json(result))
def get_count(self, query_dict, to_json=None):
return self.session.query(self.model).filter_by(**query_dict).count()
def get_all(self, query_dict, to_json=None):
results = self.session.query(self.model).filter_by(**query_dict).all()
return [asdict(result) if to_json is None else self.to_json(result) for result in results]
def delete(self, query_dict):
return self.session.query(self.model).filter_by(**query_dict).delete()
def to_json(self, record_obj):
return json.dumps(asdict(record_obj), cls=SchemaEncoder, ensure_ascii=False)
@yilmazali32
Copy link

I think very nice implementation.

I need auto create model from database table like Entity Framework

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment