Created
May 18, 2015 04:15
-
-
Save douglas-larocca/dd44ce744a0f3524d3af to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import pandas as pd | |
import sqlalchemy | |
from datetime import datetime | |
from sqlalchemy.schema import MetaData | |
from sqlalchemy.orm import sessionmaker | |
from sqlacodegen.codegen import CodeGenerator | |
from IPython.testing.skipdoctest import skip_doctest | |
from IPython.core import magic_arguments | |
from IPython.core.display import display, Javascript | |
from IPython.core.magic import Magics, magics_class, cell_magic, line_magic | |
@magics_class | |
class SqlMagic(Magics): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.limit = 5 | |
self.session = sessionmaker() | |
self.engines = [] | |
self.engine = None | |
self.schema = 'public' | |
self._query_count = 0 | |
@skip_doctest | |
@line_magic | |
@cell_magic | |
def sql(self, line, cell=''): | |
""" | |
ipython magics %sql, %%sql | |
%%sql [database name] [options] | |
(Databases need to be configured manually) | |
:database name: | |
can be the full database name or the partial name | |
until no ambiguities remain. For example, we have | |
four databases configured, "cars", "bikes", "trains" | |
"buses". Writing | |
``` | |
%%sql t | |
select ... | |
``` | |
is enough to specify "trains" as the target database for | |
the query. | |
Options are | |
:history: | |
:save history: | |
:generate model: | |
:select: | |
:history: shows the series of queries executed in the | |
session. Returns a DataFrame. Use this when accidentally | |
deleting view definitions etc. | |
:save history: saves the history DataFrame to be used across | |
sessions. (not finished) | |
:generate model: takes optional additional argument :tables: | |
This is used only with line magic (%sql not %%sql). The result | |
will be a sqlalchemy class or table declaration for each table | |
given. Example: | |
``` | |
%sql trains generate model tables schedule | |
``` | |
...replaces the cell with... | |
``` | |
from sqlalchemy import BigInteger, Column, DateTime, Float, MetaData, Table, Text | |
metadata = MetaData() | |
t_schedule = Table( | |
'schedule', metadata, | |
Column('index', BigInteger, index=True), | |
Column('depart', Text), | |
[...] | |
schema = 'public' | |
) | |
``` | |
When tables have no primary keys, declarative-style class models | |
aren't generated. | |
""" | |
#args = magic_arguments.parse_argstring | |
if 'save history' in line: | |
line = ''.join(line.split('save history')) | |
self._save_history(line) | |
if 'history' in line: | |
return self.history | |
if 'generate model' in line: | |
line, *args = line.replace(',','').split(' ') | |
if not self._is_sql_engine(self.engine): | |
if not line == 'generate': | |
self._find_engine_starting_with(line) | |
else: | |
raise ValueError('need to set engine') | |
if 'tables' in args: | |
only = args[args.index('tables')+1:] | |
else: | |
only=None | |
return self.sqlalchemy_model(only=only) | |
if 'select' in line: | |
self._select_statement(line) | |
return | |
if cell: | |
self._execute_query(line, cell) | |
def add_engine(self, connection_string): | |
self.engines.append(sqlalchemy.create_engine(connection_string)) | |
def _execute_query(self, line, cell): | |
try: | |
line, *args = line.split(' ') | |
except ValueError: | |
"no arguments were passed" | |
if not self._is_sql_engine(self.engine): | |
self._find_engine_starting_with(line) | |
else: | |
"engine is already set" | |
pass | |
else: | |
if not self._is_sql_engine(self.engine): | |
self._find_engine_starting_with(line) | |
assert self._is_sql_engine(self.engine) | |
if not self.engine.url.database.startswith(line): | |
self._find_engine_starting_with(line) | |
finally: | |
try: | |
limit = int(args[args.index('limit') + 1]) | |
assert ininstance(limit, int) | |
except: | |
"no limit given, or malformed" | |
else: | |
"got a limit argument, update self.limit" | |
self.limit = limit | |
finally: | |
self._record(cell) | |
with self.engine.begin() as tx: | |
queries = [''.join((s,';')) for s in | |
cell.split(';') if s.strip(' \t\r\n')] | |
for query in queries: | |
cur = tx.execute(self._clean_sql(query)) | |
if cur.returns_rows: | |
if 'nolimit' in args: | |
display(pd.DataFrame(cur.fetchall(), | |
columns=cur.keys())) | |
else: | |
display(pd.DataFrame(cur.fetchmany(10), | |
columns=cur.keys())) | |
elif cur.is_insert: | |
print(cur.rowcount) | |
@staticmethod | |
def _is_sql_engine(engine): | |
return isinstance(engine, sqlalchemy.engine.base.Engine) | |
def _record(self, cell): | |
self._history = getattr(self, '_history', []) | |
self._history.append( | |
[self._query_count, | |
self.engine.url.database, | |
cell, | |
datetime.now()] | |
) | |
self._query_count += 1 | |
def _find_engine_starting_with(self, line): | |
choices = [] | |
for x in dir(self.engines): | |
engine = getattr(self.engines, x) | |
if x.startswith(line) and self._is_sql_engine(engine): | |
choices.append(engine) | |
try: | |
engine, *other_engines = choices | |
except ValueError as exc: | |
raise ValueError("no engines found for '{}'".format(line)) | |
else: | |
if other_engines: | |
msg = ("too many engines for '{}':\n\n\t".format(line), | |
('{}\n\t'*len(choices)).format(*choices)) | |
raise ValueError(''.join(msg)) | |
finally: | |
self.engine = engine | |
return self.engine | |
def _save_history(self, line): | |
print('not implemented', line) | |
@property | |
def history(self): | |
return pd.DataFrame(self._history, | |
columns=['query_number', | |
'database', | |
'query', | |
'timestamp']) | |
def _append_to_active_cell(self, value): | |
js = (r"var cell = IPython.notebook.get_selected_cell();" | |
r"var code = cell.code_mirror.getValue();" | |
r"cell.code_mirror.setValue(code+'\n\n'+{});") | |
value = json.dumps(value) | |
return display(Javascript(js.format(value))) | |
def _replace_active_cell(self, value): | |
js = (r'var cell = IPython.notebook.get_selected_cell();' | |
r'cell.code_mirror.setValue({});') | |
value = json.dumps(value) | |
return display(Javascript(js.format(value))) | |
def sqlalchemy_model_gen(self, only=None, views=True, **kwargs): | |
"""yields imports, metadata declaration, then | |
each table | |
""" | |
metadata = MetaData() | |
metadata.reflect(bind=self.engine, | |
schema=self.schema, | |
only=only, | |
views=views) | |
try: | |
generator = CodeGenerator(metadata) | |
except: | |
try: | |
kwargs.pop('noconstraints') | |
except: | |
pass | |
generator = CodeGenerator(metadata, noconstraints=True, **kwargs) | |
yield generator.render_imports() | |
yield generator.render_metadata_declarations() | |
for model in generator.models: | |
if isinstance(model, generator.class_model): | |
yield generator.render_class(model) | |
if isinstance(model, generator.table_model): | |
yield generator.render_table(model) | |
def sqlalchemy_model(self, *args, **kwargs): | |
model_code = (x for x in self.sqlalchemy_model_gen(*args, **kwargs)) | |
return self._replace_active_cell('\n'.join(model_code)) | |
def _select_statement(self, line): | |
""" | |
writing | |
>>> %sql db tables table1 table2 table3 | |
will replace the cell with | |
>>> %%sql db | |
... select | |
... col1::type | |
... ,col2::type | |
... ... | |
... ,colN::type | |
... from table1; | |
... select | |
... col1::type | |
... ,col2::type | |
... ... | |
... ,colN::type | |
... from table2; | |
... select | |
... col1::type | |
... ,col2::type | |
... ... | |
... ,colN::type | |
... from table3; | |
This is motivated by the tedious process | |
of copying a list of columns, quoting each field, | |
separating by commas, etc, after having worked with | |
a table using select *. | |
""" | |
if ' select' in line: | |
line, *args = line.replace(',','').split(' ') | |
if not self._is_sql_engine(self.engine): | |
if line: | |
self._find_engine_starting_with(line) | |
else: | |
raise ValueError('need to set engine') | |
header = '%%sql {}'.format(self.engine.url.database) | |
stmts = ''.join(self._gen_table_select_statement(*args)) | |
return self._replace_active_cell(''.join((header, stmts))) | |
def _gen_table_select_statement(self, *table_names): | |
""" | |
takes vararg table names as strings, then | |
gets the standard full select query with types | |
and yields each one | |
""" | |
for table_name in table_names: | |
try: | |
assert table_name in self.engine.table_names() | |
except: | |
yield '' | |
else: | |
cur = self.engine.execute(r''' | |
select | |
column_name | |
,data_type | |
from information_schema.columns | |
where table_name = '{}'; | |
'''.format(table_name)) | |
columns = [] | |
for column in cur.fetchall(): | |
name, dtype = column | |
name = ''.join(('"',name,'"')) | |
columns.append('::'.join((name, dtype))) | |
columns = '\n ,'.join(columns) | |
cfg = dict(columns=columns, table_name=table_name) | |
tpl = '\nselect {columns}\nfrom "{table_name}";' | |
yield tpl.format(**cfg) | |
@staticmethod | |
def _clean_sql(x): | |
""" | |
Clean up string before sending to database | |
engines. | |
MSSQL sometimes chokes on newlines and tabs, | |
postgres not as much. So we strip these. | |
The adapters don't handle inline comments or | |
comment blocks, so we strip out comments as well. | |
The `strip_comments` method works by streaming the | |
input query string, listening for comments or | |
comment blocks and filtering them out. This approach | |
is used to catch bad input, e.g. if we split on | |
'/*' etc, there may be uneven open-close comment | |
brackets. The control flow-based iter/yield approach | |
does a conditional 'look ahead': e.g. if we find | |
the character '/' in the stream, it's possible we're | |
entering a comment block, so pull another value from | |
the stream to decide. If we are, then don't yield them | |
but set the comment_block flag to True, otherwise yield | |
them both. | |
""" | |
def strip_comments(query): | |
unopened_comment_error_msg_tpl = ( | |
lambda arr_len=30: ( | |
'closed comment without opening\n{}{} ' | |
'<'+('-'*arr_len)+' pos {}\n{}' | |
) | |
) | |
g = iter(query) | |
comment_block = False | |
lc = 0 | |
while True: | |
"take the next character if it exists" | |
"otherwise jump out" | |
try: | |
c = next(g) | |
lc += 1 | |
except StopIteration: | |
break | |
"entering comment block?" | |
if c == '/': | |
try: | |
c_ = next(g) | |
lc += 1 | |
except StopIteration: | |
if not comment_block: | |
yield c | |
break | |
if c_ == '*': | |
comment_block = True | |
continue | |
if not comment_block: | |
yield from (c, c_) | |
continue | |
"leaving comment block?" | |
if c == '*': | |
try: | |
c_ = next(g) | |
lc += 1 | |
except StopIteration: | |
if not comment_block: | |
yield c | |
break | |
if c_ == '/': | |
if not comment_block: | |
msg = unopened_comment_error_msg_tpl(20) | |
msg = msg.format(query[:lc-2], | |
query[lc-2:lc], | |
lc, | |
query[lc+1:]) | |
raise ValueError(msg) | |
comment_block = False | |
continue | |
if not comment_block: | |
yield from (c, c_) | |
continue | |
"stream all chars outside comment blocks" | |
if not comment_block: | |
yield c | |
x = '\n'.join([x.split('--')[0] for x in x.split('\n')]) | |
return ''.join(strip_comments(x)) | |
def load_ipython_extension(ip): | |
"""Load the extension in IPython.""" | |
ip.register_magics(SqlMagic) | |
from IPython import get_ipython | |
load_ipython_extension(get_ipython()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment