Last active
June 2, 2016 04:18
-
-
Save yusukemurayama/e034a98ff06d7c1db19b83808e2e876a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
from contextlib import contextmanager | |
from datetime import datetime | |
from sqlalchemy import ( | |
create_engine, Column, ForeignKey, Integer, Float, String, | |
Date, DateTime | |
) | |
from sqlalchemy.ext.declarative import declarative_base, declared_attr | |
from sqlalchemy.orm import sessionmaker, reconstructor, relationship | |
Base = declarative_base() | |
engine = create_engine('sqlite://', echo=True) | |
# engine = create_engine('sqlite:///foo.sqlite', echo=True) | |
Session = sessionmaker(bind=engine, autocommit=False) | |
class Stock(Base): | |
__tablename__ = 'stock' | |
id = Column(Integer, primary_key=True) | |
name = Column(String(length=64), nullable=False) | |
symbol = Column(String(length=32), nullable=False, unique=True) | |
created_at = Column(DateTime, default=datetime.now()) | |
updated_at = Column(DateTime, default=datetime.now(), onupdate=datetime.now()) | |
histories = relationship("History", backref='stock', cascade='delete') | |
@reconstructor | |
def initialize(self): | |
pass | |
class History(Base): | |
__tablename__ = 'history' | |
stock_id = Column(Integer, ForeignKey(Stock.id), primary_key=True) | |
date = Column(Date, primary_key=True) | |
price = Column(Float, nullable=False) | |
volume = Column(Integer, nullable=False) | |
Base.metadata.create_all(engine, checkfirst=True) | |
@contextmanager | |
def start_session(commit=False): | |
"""セッションを開始します。 | |
Args: | |
commit: Trueにするとセッション終了時にcommitします。 | |
Yields: | |
SqlAlcyemyのセッション | |
Usage: | |
with start_session() as session: | |
q = session.query(Table)... | |
""" | |
session = None | |
try: | |
# トランザクションを開始します。 | |
# ※autocommit=Falseなので、自動的にトランザクションが開始されます。 | |
session = Session() | |
try: | |
yield session | |
if commit: | |
session.commit() | |
except: | |
# 例外発生時はトランザクションをロールバックして、その例外をそのまま投げます。 | |
session.rollback() | |
raise | |
finally: | |
if session is not None: | |
session.close() | |
class DynamicTableMixin(object): | |
id = Column(Integer, primary_key=True) | |
@declared_attr | |
def stock_id(cls): | |
return Column(Integer, ForeignKey('stock.id')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import unittest | |
from datetime import date, timedelta | |
from sqlalchemy_samples import engine, Base, Session, Stock, History, DynamicTableMixin, start_session | |
class Test1(unittest.TestCase): | |
def setUp(self): | |
self.num_stocks = 10 # テスト前に用意するStockの数を定義します。 | |
self.num_histories = 10 # テスト前に用意するHistoryの数を定義します。 | |
Base.metadata.create_all(engine, checkfirst=True) | |
self.session = Session() | |
for i in range(self.num_stocks): | |
# Stockを追加します。 | |
stock = Stock() | |
stock.name = 'Test_{}'.format(i) | |
stock.symbol = 'test-{}'.format(i) | |
self.session.add(stock) | |
# stock.idを取得するためflush -> refreshを実施します。 | |
self.session.flush() | |
self.session.refresh(stock) | |
for j in range(self.num_histories): | |
# Historyを追加します。 | |
hist = History() | |
hist.stock_id = stock.id | |
hist.date = date.today() - timedelta(days=j) | |
hist.price = (j + 1) * 100.12 | |
hist.volume = (j + 1) * 1000 | |
self.session.add(hist) | |
self.session.commit() | |
def tearDown(self): | |
self.session.close() | |
Base.metadata.drop_all(engine) # テストごとにテーブルを削除します。 | |
def test_select(self): | |
"""SELECTのテストケースです。""" | |
stocks = self.session.query(Stock).all() | |
for stock in stocks: | |
symbol = stock.symbol | |
self.assertTrue(symbol.startswith('test-')) | |
# filterメソッドで絞り込めることを確認します。 | |
stock = self.session.query(Stock).filter(Stock.symbol == symbol) | |
self.assertEqual(stock[0].symbol, symbol) | |
# filter_byメソッドで絞り込めることを確認します。 | |
stock = self.session.query(Stock).filter_by(symbol=symbol) | |
self.assertEqual(stock[0].symbol, symbol) | |
def test_count(self): | |
"""SELECT COUNT ...のテストケースです。""" | |
self.assertEqual(self.num_stocks, self.session.query(Stock).count()) | |
self.assertEqual(self.num_stocks*self.num_histories, self.session.query(History).count()) | |
for stock in self.session.query(Stock).all(): | |
self.assertEqual(self.num_stocks, len(stock.histories)) | |
def test_update(self): | |
"""UPDATEのテストケースです。""" | |
for stock in self.session.query(Stock).all(): | |
self.assertFalse(stock.name.startswith('Mod_')) | |
for stock in self.session.query(Stock).all(): | |
stock.name = 'Mod_' + stock.name | |
for stock in self.session.query(Stock).all(): | |
self.assertTrue(stock.name.startswith('Mod_')) | |
def test_delete(self): | |
"""DELETEのテストケースです。""" | |
# レコードが0以上あることを確認します。 | |
q = self.session.query(Stock) | |
self.assertTrue(self.session.query(q.exists()).scalar()) | |
self.assertGreater(self.session.query(History).count(), 0) | |
# Stockを削除されることを確認します。 | |
for stock in q.all(): | |
self.session.delete(stock) | |
# ON DELETE CASCADEによってHistoryも削除されていることを確認します。 | |
self.assertFalse(self.session.query(q.exists()).scalar()) | |
self.assertEqual(self.session.query(History).count(), 0) | |
def test_start_session(self): | |
"""start_sessionのテストケースです。""" | |
def get_stock(idx): | |
stock = Stock() | |
stock.name = 'Test {}'.format(idx) | |
stock.symbol = 'test-{}'.format(idx) | |
return stock | |
# 初期状態を確認します。 | |
self.assertEqual(self.num_stocks, self.session.query(Stock).count()) | |
# commitをoffにして、Stockが追加されていないことを確認します。 | |
with start_session() as session: | |
stock = get_stock(idx=self.num_stocks+1) | |
session.add(stock) | |
self.assertEqual(self.num_stocks, self.session.query(Stock).count()) | |
# commitをONにして、Stockが追加されたことを確認します。 | |
with start_session(commit=True) as session: | |
stock = get_stock(idx=self.num_stocks+2) | |
session.add(stock) | |
self.assertEqual(self.num_stocks+1, self.session.query(Stock).count()) | |
def test_dynamictable1(self): | |
"""テーブル名を動的に変更するテストケースです。""" | |
from random import randint | |
tablename = 'new_table_{}'.format(randint(0, 10000)) # テーブル名を決定します。 | |
self.assertFalse(engine.has_table(tablename)) # テーブルが無いことを確認します。 | |
# テーブルクラスを定義します。 | |
class NewTable(DynamicTableMixin, Base): | |
__tablename__ = tablename | |
Base.metadata.create_all(engine) # new_tableを作成します。 | |
self.assertTrue(engine.has_table(tablename)) # テーブルが作成されたことを確認します。 | |
# レコードを追加します。 | |
stock = Stock() | |
stock.name = 'Test' | |
stock.symbol = 'test' | |
self.session.add(stock) | |
self.session.flush() | |
self.session.refresh(stock) | |
# 定義したテーブルにデータが登録されていないことを確認します。 | |
self.assertEqual(0, self.session.query(NewTable).count()) | |
# レコードを挿入し、登録されることを確認します。 | |
t = NewTable() | |
t.stock_id = stock.id | |
self.session.add(t) | |
self.assertEqual(1, self.session.query(NewTable).count()) | |
def test_dynamictable2(self): | |
"""テーブル名を動的に変更するテストケースです。""" | |
from random import randint | |
tablename = 'new_table_{}'.format(randint(0, 10000)) # テーブル名を決定します。 | |
self.assertFalse(engine.has_table(tablename)) # テーブルが無いことを確認します。 | |
klass = type('NewTable', (DynamicTableMixin, Base), {'__tablename__': tablename}) | |
Base.metadata.create_all(engine) # new_tableを作成します。 | |
self.assertTrue(engine.has_table(tablename)) # テーブルが作成されたことを確認します。 | |
# レコードを追加します。 | |
stock = Stock() | |
stock.name = 'Test' | |
stock.symbol = 'test' | |
self.session.add(stock) | |
self.session.flush() | |
self.session.refresh(stock) | |
# 定義したテーブルにデータが登録されていないことを確認します。 | |
self.assertEqual(0, self.session.query(klass).count()) | |
# レコードを挿入し、登録されることを確認します。 | |
t = klass() | |
t.stock_id = stock.id | |
self.session.add(t) | |
self.assertEqual(1, self.session.query(klass).count()) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment