Skip to content

Instantly share code, notes, and snippets.

@erichiller
Created January 9, 2019 20:17
Show Gist options
  • Save erichiller/a42bc93578eaf1869e8c1f293d0085bf to your computer and use it in GitHub Desktop.
Save erichiller/a42bc93578eaf1869e8c1f293d0085bf to your computer and use it in GitHub Desktop.
from logging import getLogger
from typing import List, Any, Dict
import random
from json import JSONDecoder, JSONEncoder
from collections import UserList
info = getLogger(__name__).info
debug = getLogger(__name__).debug
info = getLogger(__name__).info
warn = getLogger(__name__).warn
import sqlalchemy
from sqlalchemy.ext.declarative import declared_attr, as_declarative, AbstractConcreteBase, declarative_base, has_inherited_table
from sqlalchemy import Column, Integer, String, Float, JSON
print(sqlalchemy.__version__)
# engine = sqlalchemy.create_engine('sqlite:///:memory:', echo=True)
engine = sqlalchemy.create_engine('sqlite:///test.db', echo=True)
# create a configured "Session" class
Session = sqlalchemy.orm.sessionmaker(bind=engine)
# create a Session
_sql_session = Session()
class SourceChain(sqlalchemy.types.TypeDecorator):
""" Encode / Decode JSON input from ``source_chain``
`default()` accepts complex objects returns basic types in their place that can be json encoded
`object_hook()` accepts Yaml config returning objects
"""
def __init__(self, *args) -> None:
self.chain = list(args)
# self.data = list(args)
super().__init__()
impl = sqlalchemy.types.JSON
def process_bind_param(self, value, dialect):
""" Incoming from SQL """
from pprint import pprint
print("process_bind_param (from SQL)".center(80, '*'))
pprint(value)
print(type(value))
pprint(dialect)
print('*' * 80)
return value
def process_result_value(self, value, dialect):
""" Outgoing to SQL """
from pprint import pprint
print("process_result_value (to SQL)".center(80, '*'))
pprint(value)
print(type(value))
pprint(dialect)
print('*' * 80)
return value
# def copy(self, **kw):
# return JsonExt(self.impl.length)
def coerce_compared_value(self, op, value):
return self.impl.coerce_compared_value(op, value)
class JsonExt(sqlalchemy.types.TypeDecorator):
'''Prefixes Unicode values with "PREFIX:" on the way in and
strips it off on the way out.
'''
impl = sqlalchemy.JSON
def process_bind_param(self, value, dialect):
""" Incoming from SQL """
from pprint import pprint
print("process_bind_param (from SQL)".center(80, '*'))
pprint(value)
print(type(value))
pprint(dialect)
print('*' * 80)
return value
def process_result_value(self, value, dialect):
""" Outgoing to SQL """
from pprint import pprint
print("process_result_value (to SQL)".center(80, '*'))
pprint(value)
print(type(value))
pprint(dialect)
print('*' * 80)
return value
# def copy(self, **kw):
# return JsonExt(self.impl.length)
def coerce_compared_value(self, op, value):
return self.impl.coerce_compared_value(op, value)
@as_declarative()
class BaseExtended:
type_map = {
int: Integer,
float: Float,
str: String,
dict: JSON,
SourceChain: JSON
}
__mapper_args__ = {
'concrete': True
}
@declared_attr.cascading
def uid(cls):
if hasattr(cls, "__annotations__"):
for attr_name, attr_value in cls.__annotations__.items():
print(f"{cls.__name__}.uid :: name={attr_name} value={attr_value} value_type={type(attr_value)}")
# setattr(cls, attr_name, Column(type_map[attr_value], primary_key=( True if attr_name == "uid" else False )))
else:
print(f"!!!! no __annotations__ on {cls} found")
r_type = BaseExtended.type_map[cls.__annotations__["uid"]]
print(f"{cls.__name__}.uid :: uid return type={r_type}")
return Column("uid", r_type, primary_key=True)
@declared_attr.cascading
def data(cls):
if hasattr(cls, "__annotations__"):
for attr_name, attr_value in cls.__annotations__.items():
print(f"{cls.__name__}.data :: name={attr_name} value={attr_value} value_type={type(attr_value)}")
else:
#### MEGA ERROR HERE ####
# User MUST configure annotations on classes
print(f"!!!! no __annotations__ on {cls} found")
r_type = BaseExtended.type_map[cls.__annotations__["data"]]
print(f"{cls.__name__}.data :: data return type={r_type}")
return Column("data", r_type)
@declared_attr
def __tablename__(cls):
""" Create tablename from classname by underscore + lowercase , strip and leading underscore """
import string
print(f"{cls.__name__:20} in __tablename__")
# if has_inherited_table(cls):
# print(f"{cls.__name__:20} __tablename__ has_inherited_table")
# return None
table_name = ''.join( [char if char not in string.ascii_uppercase else "_" + char.lower() for char in cls.__name__] ).lstrip("_")
print(f"{cls.__name__:20} has {table_name}")
return table_name
class BaseExtendedChild(BaseExtended):
# __tablename__ = None
uid: int
data: int
class BaseExtendedChild2(BaseExtended):
# __tablename__ = None
uid: int
data: int
class BaseExtendedGrandchild(BaseExtendedChild):
# __tablename__ = None
# __mapper_args__ = {'polymorphic_identity': 'base_extended_grandchild'}
uid: int
data: str
def __init__(self, uid: int, data: str) -> None:
super().__init__(uid=uid, data=data)
class BaseExtendedJson(BaseExtendedChild):
# __tablename__ = None
# __mapper_args__ = {'polymorphic_identity': 'base_extended_grandchild'}
uid: int
data: dict
def __init__(self, uid: int, data: dict) -> None:
super().__init__(uid=uid, data=data)
class BaseExtendedSourceChain(BaseExtendedChild):
# __tablename__ = None
# __mapper_args__ = {'polymorphic_identity': 'base_extended_grandchild'}
uid: int
data: SourceChain
def __init__(self, uid: int, data: SourceChain) -> None:
super().__init__(uid=uid, data=data)
class BaseExtendedSourceChainParents(BaseExtendedChild):
# __tablename__ = None
# __mapper_args__ = {'polymorphic_identity': 'base_extended_grandchild'}
uid: int
data: SourceChain
# source_features: List[str]
def __init__(self, uid: int, data: SourceChain) -> None:
super().__init__(uid=uid, data=data)
BaseExtended.metadata.create_all(engine) # create all MUST come AFTER all class definitions
# 'extend_existing=True'
def query():
for u, a in _sql_session.query(BaseExtendedGrandchild, BaseExtendedChild2).\
filter(BaseExtendedGrandchild.uid == BaseExtendedChild2.uid).\
all():
print(f"({a.uid:12}) {a.data} | ({u.uid:12}) {u.data}")
print(" json query ALL ".center(80, '*'))
print(
_sql_session.query(BaseExtendedJson).all()
)
print(" json query ".center(80, '*'))
print(
_sql_session.query(BaseExtendedJson).filter( sqlalchemy.cast(BaseExtendedJson.data["key2"], String) == 'value2').all()
)
# filter(BaseExtendedSourceChain.data['key2'])
if __name__ == "__main__":
print('*' * 20 + "beginning adds" + '*' * 20)
uid = random.randint(2, 2000000)
bex_item = BaseExtendedChild(uid=uid, data=0.555)
_sql_session.add(bex_item)
bex_item = BaseExtendedChild2(uid=uid, data=0.555)
_sql_session.add(bex_item)
bex_item = BaseExtendedGrandchild(uid=uid, data="eric")
_sql_session.add(bex_item)
nt = BaseExtendedGrandchild(uid + random.randint(200000, 4000000), "Bob")
_sql_session.add(nt)
json_obj = BaseExtendedJson(uid, {"key1": "val1", "key2": 'val2'})
_sql_session.add(json_obj)
# json_obj = BaseExtendedSourceChain(uid, {"key1": "val1", "key2": 'val2'})
json_obj = BaseExtendedSourceChain(uid, [{"key1": "val1", "key2": 'val2'}])
_sql_session.add(json_obj)
_sql_session.commit()
print(" Query test ".center(80, "*"))
query()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment