Skip to content

Instantly share code, notes, and snippets.

@douglas-larocca
Created May 18, 2015 04:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save douglas-larocca/dd44ce744a0f3524d3af to your computer and use it in GitHub Desktop.
Save douglas-larocca/dd44ce744a0f3524d3af to your computer and use it in GitHub Desktop.
import os
import json
import pandas as pd
import sqlalchemy
from datetime import datetime
from sqlalchemy.schema import MetaData
from sqlalchemy.orm import sessionmaker
from sqlacodegen.codegen import CodeGenerator
from IPython.testing.skipdoctest import skip_doctest
from IPython.core import magic_arguments
from IPython.core.display import display, Javascript
from IPython.core.magic import Magics, magics_class, cell_magic, line_magic
@magics_class
class SqlMagic(Magics):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.limit = 5
self.session = sessionmaker()
self.engines = []
self.engine = None
self.schema = 'public'
self._query_count = 0
@skip_doctest
@line_magic
@cell_magic
def sql(self, line, cell=''):
"""
ipython magics %sql, %%sql
%%sql [database name] [options]
(Databases need to be configured manually)
:database name:
can be the full database name or the partial name
until no ambiguities remain. For example, we have
four databases configured, "cars", "bikes", "trains"
"buses". Writing
```
%%sql t
select ...
```
is enough to specify "trains" as the target database for
the query.
Options are
:history:
:save history:
:generate model:
:select:
:history: shows the series of queries executed in the
session. Returns a DataFrame. Use this when accidentally
deleting view definitions etc.
:save history: saves the history DataFrame to be used across
sessions. (not finished)
:generate model: takes optional additional argument :tables:
This is used only with line magic (%sql not %%sql). The result
will be a sqlalchemy class or table declaration for each table
given. Example:
```
%sql trains generate model tables schedule
```
...replaces the cell with...
```
from sqlalchemy import BigInteger, Column, DateTime, Float, MetaData, Table, Text
metadata = MetaData()
t_schedule = Table(
'schedule', metadata,
Column('index', BigInteger, index=True),
Column('depart', Text),
[...]
schema = 'public'
)
```
When tables have no primary keys, declarative-style class models
aren't generated.
"""
#args = magic_arguments.parse_argstring
if 'save history' in line:
line = ''.join(line.split('save history'))
self._save_history(line)
if 'history' in line:
return self.history
if 'generate model' in line:
line, *args = line.replace(',','').split(' ')
if not self._is_sql_engine(self.engine):
if not line == 'generate':
self._find_engine_starting_with(line)
else:
raise ValueError('need to set engine')
if 'tables' in args:
only = args[args.index('tables')+1:]
else:
only=None
return self.sqlalchemy_model(only=only)
if 'select' in line:
self._select_statement(line)
return
if cell:
self._execute_query(line, cell)
def add_engine(self, connection_string):
self.engines.append(sqlalchemy.create_engine(connection_string))
def _execute_query(self, line, cell):
try:
line, *args = line.split(' ')
except ValueError:
"no arguments were passed"
if not self._is_sql_engine(self.engine):
self._find_engine_starting_with(line)
else:
"engine is already set"
pass
else:
if not self._is_sql_engine(self.engine):
self._find_engine_starting_with(line)
assert self._is_sql_engine(self.engine)
if not self.engine.url.database.startswith(line):
self._find_engine_starting_with(line)
finally:
try:
limit = int(args[args.index('limit') + 1])
assert ininstance(limit, int)
except:
"no limit given, or malformed"
else:
"got a limit argument, update self.limit"
self.limit = limit
finally:
self._record(cell)
with self.engine.begin() as tx:
queries = [''.join((s,';')) for s in
cell.split(';') if s.strip(' \t\r\n')]
for query in queries:
cur = tx.execute(self._clean_sql(query))
if cur.returns_rows:
if 'nolimit' in args:
display(pd.DataFrame(cur.fetchall(),
columns=cur.keys()))
else:
display(pd.DataFrame(cur.fetchmany(10),
columns=cur.keys()))
elif cur.is_insert:
print(cur.rowcount)
@staticmethod
def _is_sql_engine(engine):
return isinstance(engine, sqlalchemy.engine.base.Engine)
def _record(self, cell):
self._history = getattr(self, '_history', [])
self._history.append(
[self._query_count,
self.engine.url.database,
cell,
datetime.now()]
)
self._query_count += 1
def _find_engine_starting_with(self, line):
choices = []
for x in dir(self.engines):
engine = getattr(self.engines, x)
if x.startswith(line) and self._is_sql_engine(engine):
choices.append(engine)
try:
engine, *other_engines = choices
except ValueError as exc:
raise ValueError("no engines found for '{}'".format(line))
else:
if other_engines:
msg = ("too many engines for '{}':\n\n\t".format(line),
('{}\n\t'*len(choices)).format(*choices))
raise ValueError(''.join(msg))
finally:
self.engine = engine
return self.engine
def _save_history(self, line):
print('not implemented', line)
@property
def history(self):
return pd.DataFrame(self._history,
columns=['query_number',
'database',
'query',
'timestamp'])
def _append_to_active_cell(self, value):
js = (r"var cell = IPython.notebook.get_selected_cell();"
r"var code = cell.code_mirror.getValue();"
r"cell.code_mirror.setValue(code+'\n\n'+{});")
value = json.dumps(value)
return display(Javascript(js.format(value)))
def _replace_active_cell(self, value):
js = (r'var cell = IPython.notebook.get_selected_cell();'
r'cell.code_mirror.setValue({});')
value = json.dumps(value)
return display(Javascript(js.format(value)))
def sqlalchemy_model_gen(self, only=None, views=True, **kwargs):
"""yields imports, metadata declaration, then
each table
"""
metadata = MetaData()
metadata.reflect(bind=self.engine,
schema=self.schema,
only=only,
views=views)
try:
generator = CodeGenerator(metadata)
except:
try:
kwargs.pop('noconstraints')
except:
pass
generator = CodeGenerator(metadata, noconstraints=True, **kwargs)
yield generator.render_imports()
yield generator.render_metadata_declarations()
for model in generator.models:
if isinstance(model, generator.class_model):
yield generator.render_class(model)
if isinstance(model, generator.table_model):
yield generator.render_table(model)
def sqlalchemy_model(self, *args, **kwargs):
model_code = (x for x in self.sqlalchemy_model_gen(*args, **kwargs))
return self._replace_active_cell('\n'.join(model_code))
def _select_statement(self, line):
"""
writing
>>> %sql db tables table1 table2 table3
will replace the cell with
>>> %%sql db
... select
... col1::type
... ,col2::type
... ...
... ,colN::type
... from table1;
... select
... col1::type
... ,col2::type
... ...
... ,colN::type
... from table2;
... select
... col1::type
... ,col2::type
... ...
... ,colN::type
... from table3;
This is motivated by the tedious process
of copying a list of columns, quoting each field,
separating by commas, etc, after having worked with
a table using select *.
"""
if ' select' in line:
line, *args = line.replace(',','').split(' ')
if not self._is_sql_engine(self.engine):
if line:
self._find_engine_starting_with(line)
else:
raise ValueError('need to set engine')
header = '%%sql {}'.format(self.engine.url.database)
stmts = ''.join(self._gen_table_select_statement(*args))
return self._replace_active_cell(''.join((header, stmts)))
def _gen_table_select_statement(self, *table_names):
"""
takes vararg table names as strings, then
gets the standard full select query with types
and yields each one
"""
for table_name in table_names:
try:
assert table_name in self.engine.table_names()
except:
yield ''
else:
cur = self.engine.execute(r'''
select
column_name
,data_type
from information_schema.columns
where table_name = '{}';
'''.format(table_name))
columns = []
for column in cur.fetchall():
name, dtype = column
name = ''.join(('"',name,'"'))
columns.append('::'.join((name, dtype)))
columns = '\n ,'.join(columns)
cfg = dict(columns=columns, table_name=table_name)
tpl = '\nselect {columns}\nfrom "{table_name}";'
yield tpl.format(**cfg)
@staticmethod
def _clean_sql(x):
"""
Clean up string before sending to database
engines.
MSSQL sometimes chokes on newlines and tabs,
postgres not as much. So we strip these.
The adapters don't handle inline comments or
comment blocks, so we strip out comments as well.
The `strip_comments` method works by streaming the
input query string, listening for comments or
comment blocks and filtering them out. This approach
is used to catch bad input, e.g. if we split on
'/*' etc, there may be uneven open-close comment
brackets. The control flow-based iter/yield approach
does a conditional 'look ahead': e.g. if we find
the character '/' in the stream, it's possible we're
entering a comment block, so pull another value from
the stream to decide. If we are, then don't yield them
but set the comment_block flag to True, otherwise yield
them both.
"""
def strip_comments(query):
unopened_comment_error_msg_tpl = (
lambda arr_len=30: (
'closed comment without opening\n{}{} '
'<'+('-'*arr_len)+' pos {}\n{}'
)
)
g = iter(query)
comment_block = False
lc = 0
while True:
"take the next character if it exists"
"otherwise jump out"
try:
c = next(g)
lc += 1
except StopIteration:
break
"entering comment block?"
if c == '/':
try:
c_ = next(g)
lc += 1
except StopIteration:
if not comment_block:
yield c
break
if c_ == '*':
comment_block = True
continue
if not comment_block:
yield from (c, c_)
continue
"leaving comment block?"
if c == '*':
try:
c_ = next(g)
lc += 1
except StopIteration:
if not comment_block:
yield c
break
if c_ == '/':
if not comment_block:
msg = unopened_comment_error_msg_tpl(20)
msg = msg.format(query[:lc-2],
query[lc-2:lc],
lc,
query[lc+1:])
raise ValueError(msg)
comment_block = False
continue
if not comment_block:
yield from (c, c_)
continue
"stream all chars outside comment blocks"
if not comment_block:
yield c
x = '\n'.join([x.split('--')[0] for x in x.split('\n')])
return ''.join(strip_comments(x))
def load_ipython_extension(ip):
"""Load the extension in IPython."""
ip.register_magics(SqlMagic)
from IPython import get_ipython
load_ipython_extension(get_ipython())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment