Skip to content

Instantly share code, notes, and snippets.

@kwatch
Created July 5, 2016 00:22
Show Gist options
  • Save kwatch/53d991c47c42132e6352a6d0520ceb92 to your computer and use it in GitHub Desktop.
Save kwatch/53d991c47c42132e6352a6d0520ceb92 to your computer and use it in GitHub Desktop.
SQLおじさんのサンプルSQLをO/Rマッパーで書いてみた
# -*- coding: utf-8 -*-
"""
SQLAlchemy example code.
Requirements:
* Python3
* PostgreSQL
* SQLAlchemy
* Psychopg2
"""
import sys, os, re
import sqlalchemy
import sqlalchemy.orm
import sqlalchemy.ext.declarative
class Config(object):
## SQLAlchemy
sa_url = 'postgres://user2@localhost/example2' # CHANGE HERE
sa_echo = True
config = Config()
engine = sqlalchemy.create_engine(config.sa_url, echo=config.sa_echo)
Base = sqlalchemy.ext.declarative.declarative_base()
DBSession = sqlalchemy.orm.sessionmaker()
DBSession.configure(bind=engine)
from sqlalchemy import (
Column, ForeignKey, UniqueConstraint,
String, Text, Integer, Date, DateTime, Time, Boolean,
)
from sqlalchemy.orm import relationship, backref
class Invoice(Base):
"""請求書クラス"""
DDL = r"""
CREATE TABLE invoices (
id serial PRIMARY KEY
, customer_id integer NOT NULL --REFERENCES customers(id)
, total_amount integer NOT NULL
, total_tax integer NOT NULL
);
"""
__tablename__ = "invoices"
id = Column(Integer, primary_key=True)
customer_id = Column(Integer, nullable=False) # ForeignKey('customers.id')
total_amount = Column(Integer, nullable=False, default=0)
total_tax = Column(Integer, nullable=False, default=0)
lines = relationship('InvoiceLine', uselist=True)
TAX_RATE = 0.08
def __repr__(self):
return "Invoice(id=%r, customer_id=%r, total_amout=%r, total_tax=%r)" % \
(self.id, self.customer_id, self.total_amount, self.total_tax)
class InvoiceLine(Base):
"""請求書明細クラス"""
DDL = r"""
CREATE TABLE invoice_lines (
id serial NOT NULL PRIMARY KEY
, invoice_id integer NOT NULL REFERENCES invoices(id)
, line_no integer NOT NULL
, item_id integer NOT NULL --REFERENCES items(id)
, item_count integer NOT NULL
, unit_price integer NOT NULL
, UNIQUE (invoi
);
ALTER TABLE invoice_lines ADD CONSTRAINT invoices_lines_compound_uniq UNIQUE(invoice_id, line_no);
"""
__tablename__ = "invoice_lines"
__table_args__ = (UniqueConstraint('invoice_id', 'line_no'), )
id = Column(Integer, primary_key=True)
invoice_id = Column(Integer, ForeignKey(Invoice.id), nullable=False)
line_no = Column(Integer, nullable=False)
item_id = Column(Integer, nullable=False) # ForeginKey('items.id')
item_count = Column(Integer, nullable=False)
unit_price = Column(Integer, nullable=False)
#item = relationship('Item', uselist=False)
header = relationship('Invoice', uselist=False)
def __repr__(self):
return "InvoiceLine(id=%r, invoice_id=%r, line_no=%r, item_i=%r, item_count=%r, unit_price=%r)" % \
(self.id, self.invoice_id, self.line_no, self.item_id, self.item_count, self.unit_price)
class TX(object):
"""Helper class to start transaction.
Usage:
with TX() as db:
for invoice in db.query(Invoice).all():
print(invoice)
"""
def __enter__(self):
self.db = db = DBSession()
return db
def __exit__(self, exclass, ex, extraceback):
if ex:
self.db.rollback()
else:
self.db.commit()
def row_number(xs, partition_by=lambda x: x):
"""Simulates row_number() window function.
ex::
>>> items = [("a", "X"), ("b", "X"), ("c", "X"),
... ("d", "Y"), ("e", "Y"),]
>>> for i, (k, v) in row_number(items, lambda t: t[1]):
... print([i, k, v])
...
[1, 'a', 'X']
[2, 'b', 'X']
[3, 'c', 'X']
[1, 'd', 'Y']
[2, 'e', 'Y']
"""
i = None
prev = None
for x in xs:
curr = partition_by(x)
if i is None or prev != curr:
i = 1
prev = curr
else:
i += 1
yield i, x
class Operation(object):
"""Base class of operations."""
def __init__(self, db):
self.db = db
class InvoiceOp(Operation):
def query_invoice_lines1(self):
"""
window関数を使ったSQLを発行するバージョン (for Oracle or PostgreSQL)。
ref: http://qiita.com/kantomi/items/5e07641016615c073b9f#-%E7%BF%BB%E8%A8%B3%E3%81%97%E3%81%9Fsql
クエリオブジェクトを返す。
なお DB の Optimizer とは何の関係もないことに注意。
発行されるSQL::
SELECT a.invoice_id AS a_invoice_id
, a.line_no AS a_line_no
, a.item_id AS a_item_id
, a.unit_price AS a_unit_price
, a.item_count AS a_item_count
, a.amount AS a_amount
, CASE
WHEN (a.ord <= a.diff) THEN 1
ELSE 0
END AS anon_1
FROM (
SELECT invoice_lines.id AS id
, invoice_lines.invoice_id AS invoice_id
, invoice_lines.line_no AS line_no
, invoice_lines.item_id AS item_id
, invoice_lines.item_count AS item_count
, invoice_lines.unit_price AS unit_price
, invoice_lines.item_count * invoice_lines.unit_price
AS amount
, invoices.total_tax - sum(trunc(invoice_lines.item_count * invoice_lines.unit_price * 0.08))
OVER (partition BY invoice_lines.invoice_id)
AS diff
, row_number()
OVER (PARTITION BY invoice_lines.invoice_id
ORDER BY invoice_lines.item_count * invoice_lines.unit_price DESC, invoice_lines.line_no)
AS ord
FROM invoice_lines
JOIN invoices ON invoices.id = invoice_lines.invoice_id
) AS a
"""
from sqlalchemy import func as fn, desc, case
from sqlalchemy.orm import aliased
h = Invoice # 請求書 (or: h = aliased(Invoice, name='h'))
d = InvoiceLine # 請求書明細 (or: d = aliased(InvoiceLine, name='d'))
amount = d.item_count * d.unit_price # 数量 * 単価 (値ではなく式であることに注意)
tax_rate = Invoice.TAX_RATE # 税率
subcolumns = [
amount.label('amount'), # 数量 * 単価
(
h.total_tax - fn.sum(fn.trunc(amount * tax_rate)).over(partition_by=d.invoice_id)
).label('diff'), # 誤差金額
(
fn.row_number().over(partition_by=d.invoice_id, order_by=(desc(amount), d.line_no))
).label('ord'), # 購入額順位
]
# サブクエリ (FROM 請求書明細 INNER JOIN 請求書 h ON d.請求書NO = h.請求書NO)
subquery = (self.db.query(InvoiceLine, *subcolumns)
.join(InvoiceLine.header)
).subquery()
a = aliased(subquery, name='a')
# メインクエリ
columns = [
a.c.invoice_id, # 請求書No
a.c.line_no, # 行No
a.c.item_id, # 商品CD
a.c.unit_price, # 単価
a.c.item_count, # 数量
a.c.amount, # 購入額
case([(a.c.ord <= a.c.diff, 1)], else_=0), # 配賦するなら1、しないなら0
]
query = (self.db.query(*columns)
#.filter(....) # 通常は何らかの絞り込みがある
)#.all()
return query # クエリオブジェクトを返す
def query_invoice_lines2(self):
"""
window関数のかわりにサブクエリを使うSQLを発行するバージョン (for MySQL or SQLite)。
ジェネレータを返す。
なお DB の Optimizer とは何の関係もないことに注意。
発行されるSQL::
SELECT invoice_lines.id AS invoice_lines_id
, invoice_lines.invoice_id AS invoice_lines_invoice_id
, invoice_lines.line_no AS invoice_lines_line_no
, invoice_lines.item_id AS invoice_lines_item_id
, invoice_lines.item_count AS invoice_lines_item_count
, invoice_lines.unit_price AS invoice_lines_unit_price
, anon_1.tax_diff AS anon_1_tax_diff
FROM invoice_lines
JOIN (
SELECT invoices.id AS invoice_id
, invoices.total_tax - sum(trunc(invoice_lines.item_count * invoice_lines.unit_price * 0.08))
AS tax_diff
FROM invoices
JOIN invoice_lines
ON invoices.id = invoice_lines.invoice_id
GROUP BY invoices.id,
invoices.total_tax
) AS anon_1 ON invoice_lines.invoice_id = anon_1.invoice_id
ORDER BY invoice_lines.invoice_id
, invoice_lines.item_count * invoice_lines.unit_price DESC
, invoice_lines.line_no
"""
from sqlalchemy import func as fn, desc
tax_rate = Invoice.TAX_RATE # 税率
amount = InvoiceLine.item_count * InvoiceLine.unit_price
tax_diff_ = Invoice.total_tax - fn.sum(fn.trunc(amount * tax_rate))
subcolumns = [
Invoice.id .label('invoice_id'),
tax_diff_ .label('tax_diff'),
]
subq = (self.db.query(*subcolumns)
.join(Invoice.lines)
#.filter(....) # 通常は何らかの絞り込みがある
.group_by(Invoice.id, Invoice.total_tax)
).subquery()
##
d = InvoiceLine
qry = (self.db.query(InvoiceLine, subq.c.tax_diff)
.join(subq, d.invoice_id == subq.c.invoice_id)
.order_by(d.invoice_id, desc(amount), d.line_no)
)#.all()
##
for i, (invoice_line, tax_diff) in row_number(qry, lambda x: x[0].invoice_id):
invoice_line.distribute_flag = i <= tax_diff # 配賦するならTrue、しないならFalse
yield invoice_line # ジェネレータが返される
def query_invoice_lines3(self):
"""
Window関数もサブクエリも使わないバーション。
SQL を 2 回発行するのでその分遅くなるが、通常は許容範囲内に収まるはず。
ジェネレータを返す。
なお DB の Optimizer とは何の関係もないことに注意。
発行されるSQL::
SELECT invoices.id AS invoices_id
, invoices.total_tax - sum(trunc(invoice_lines.item_count * invoice_lines.unit_price * 0.08))
AS anon_1
FROM invoices
JOIN invoice_lines ON invoices.id = invoice_lines.invoice_id
GROUP BY invoices.id, invoices.total_tax
;
SELECT invoice_lines.id AS invoice_lines_id
, invoice_lines.invoice_id AS invoice_lines_invoice_id
, invoice_lines.line_no AS invoice_lines_line_no
, invoice_lines.item_id AS invoice_lines_item_id
, invoice_lines.item_count AS invoice_lines_item_count
, invoice_lines.unit_price AS invoice_lines_unit_price
FROM invoice_lines
WHERE invoice_lines.invoice_id IN (101, 102, 103, ....)
ORDER BY invoice_lines.invoice_id
, invoice_lines.item_count * invoice_lines.unit_price DESC
, invoice_lines.line_no
;
"""
from sqlalchemy import func as fn, desc
#
tax_rate = Invoice.TAX_RATE # 税率
d = InvoiceLine
amount = d.item_count * d.unit_price
tax_diff_ = Invoice.total_tax - fn.sum(fn.trunc(amount * tax_rate))
qry1 = (self.db.query(Invoice.id, tax_diff_)
.join(Invoice.lines)
#.filter(....) # 通常は何らかの絞り込みがある
.group_by(Invoice.id, Invoice.total_tax)
)
tax_diffs = { invoice_id: tax_diff for invoice_id, tax_diff in qry1 }
if not tax_diffs: # 空なら終了
return
#
invoice_ids = tuple(tax_diffs.keys())
order_by = [d.invoice_id, desc(d.item_count * d.unit_price), d.line_no]
qry2 = (self.db.query(InvoiceLine)
.filter(d.invoice_id.in_(invoice_ids))
.order_by(*order_by)
)
#
for i, invoice_line in row_number(qry2, lambda x: x.invoice_id):
tax_diff = tax_diffs[invoice_line.invoice_id]
invoice_line.distribute_flag = i < tax_diff # 配賦するならTrue、しないならFalse
yield invoice_line # ジェネレータが返される
def query_invoice_lines4(self, batch_size=100):
"""
query_invoice_lines3() のSQLを、請求書100件ずつ実行するようにしたバージョン。
またSQLを2つ発行するので、それぞれを別のメソッドに分離。
ジェネレータを返す。
なお DB の Optimizer とは何の関係もないことに注意。
発行されるSQLは、quey_invoice_lines3() とほぼ同じで、limit と offset がついているだけ。
"""
offset = 0
while True:
qry1 = self._query_invoices4(batch_size, offset)
tax_diffs = { invoice_id: tax_diff for invoice_id, tax_diff in qry1 }
if not tax_diffs: # 空なら終了
break
invoice_ids = tuple(tax_diffs.keys())
#
qry2 = self._query_invoice_lines4(invoice_ids)
for i, invoice_line in row_number(qry2, lambda x: x.invoice_id):
tax_diff = tax_diffs[invoice_line.invoice_id]
invoice_line.distribute_flag = i < tax_diff # 配賦するならTrue、しないならFalse
yield invoice_line # ジェネレータが返される
#
if len(invoice_ids) < batch_size:
break
offset += batch_size
def _query_invoices4(self, limit, offset):
from sqlalchemy import func as fn
tax_rate = Invoice.TAX_RATE # 税率
d = InvoiceLine
amount = d.item_count * d.unit_price
tax_diff = Invoice.total_tax - fn.sum(fn.trunc(amount * tax_rate))
return (self.db.query(Invoice.id, tax_diff)
.join(Invoice.lines)
#.filter(....) # 通常は何らかの絞り込みがある
.group_by(Invoice.id, Invoice.total_tax)
.limit(limit)
.offset(offset)
)
def _query_invoice_lines4(self, invoice_ids):
from sqlalchemy import func as fn, desc
d = InvoiceLine
order_by = [d.invoice_id, desc(d.item_count * d.unit_price), d.line_no]
return (self.db.query(InvoiceLine)
.filter(d.invoice_id.in_(invoice_ids))
.order_by(*order_by)
)
def _main(args):
if not args:
script_name = os.path.basename(sys.argv[0])
sys.stderr.write("Usage: python %s [1-4]\n" % script_name)
return 1
arg = args[0]
with TX() as db:
if arg == '1':
for row in InvoiceOp(db).query_invoice_lines1():
print(row)
elif arg == '2':
for invoice_line in InvoiceOp(db).query_invoice_lines2():
print(invoice_line)
print(invoice_line.distribute_flag)
elif arg == '3':
for invoice_line in InvoiceOp(db).query_invoice_lines3():
print(invoice_line)
print(invoice_line.distribute_flag)
elif arg == '4':
for invoice_line in InvoiceOp(db).query_invoice_lines4():
print(invoice_line)
print(invoice_line.distribute_flag)
else:
sys.stderr.write("%s: Unexpected argument.\n" % arg)
return 1
return 0
if __name__ == '__main__':
status = _main(sys.argv[1:])
if status != 0:
sys.exit(status)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment