Skip to content

Instantly share code, notes, and snippets.

@yusukemurayama
Last active June 2, 2016 04:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yusukemurayama/e034a98ff06d7c1db19b83808e2e876a to your computer and use it in GitHub Desktop.
Save yusukemurayama/e034a98ff06d7c1db19b83808e2e876a to your computer and use it in GitHub Desktop.
# coding: utf-8
from contextlib import contextmanager
from datetime import datetime
from sqlalchemy import (
create_engine, Column, ForeignKey, Integer, Float, String,
Date, DateTime
)
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import sessionmaker, reconstructor, relationship
Base = declarative_base()
engine = create_engine('sqlite://', echo=True)
# engine = create_engine('sqlite:///foo.sqlite', echo=True)
Session = sessionmaker(bind=engine, autocommit=False)
class Stock(Base):
__tablename__ = 'stock'
id = Column(Integer, primary_key=True)
name = Column(String(length=64), nullable=False)
symbol = Column(String(length=32), nullable=False, unique=True)
created_at = Column(DateTime, default=datetime.now())
updated_at = Column(DateTime, default=datetime.now(), onupdate=datetime.now())
histories = relationship("History", backref='stock', cascade='delete')
@reconstructor
def initialize(self):
pass
class History(Base):
__tablename__ = 'history'
stock_id = Column(Integer, ForeignKey(Stock.id), primary_key=True)
date = Column(Date, primary_key=True)
price = Column(Float, nullable=False)
volume = Column(Integer, nullable=False)
Base.metadata.create_all(engine, checkfirst=True)
@contextmanager
def start_session(commit=False):
"""セッションを開始します。
Args:
commit: Trueにするとセッション終了時にcommitします。
Yields:
SqlAlcyemyのセッション
Usage:
with start_session() as session:
q = session.query(Table)...
"""
session = None
try:
# トランザクションを開始します。
# ※autocommit=Falseなので、自動的にトランザクションが開始されます。
session = Session()
try:
yield session
if commit:
session.commit()
except:
# 例外発生時はトランザクションをロールバックして、その例外をそのまま投げます。
session.rollback()
raise
finally:
if session is not None:
session.close()
class DynamicTableMixin(object):
id = Column(Integer, primary_key=True)
@declared_attr
def stock_id(cls):
return Column(Integer, ForeignKey('stock.id'))
# coding: utf-8
import unittest
from datetime import date, timedelta
from sqlalchemy_samples import engine, Base, Session, Stock, History, DynamicTableMixin, start_session
class Test1(unittest.TestCase):
def setUp(self):
self.num_stocks = 10 # テスト前に用意するStockの数を定義します。
self.num_histories = 10 # テスト前に用意するHistoryの数を定義します。
Base.metadata.create_all(engine, checkfirst=True)
self.session = Session()
for i in range(self.num_stocks):
# Stockを追加します。
stock = Stock()
stock.name = 'Test_{}'.format(i)
stock.symbol = 'test-{}'.format(i)
self.session.add(stock)
# stock.idを取得するためflush -> refreshを実施します。
self.session.flush()
self.session.refresh(stock)
for j in range(self.num_histories):
# Historyを追加します。
hist = History()
hist.stock_id = stock.id
hist.date = date.today() - timedelta(days=j)
hist.price = (j + 1) * 100.12
hist.volume = (j + 1) * 1000
self.session.add(hist)
self.session.commit()
def tearDown(self):
self.session.close()
Base.metadata.drop_all(engine) # テストごとにテーブルを削除します。
def test_select(self):
"""SELECTのテストケースです。"""
stocks = self.session.query(Stock).all()
for stock in stocks:
symbol = stock.symbol
self.assertTrue(symbol.startswith('test-'))
# filterメソッドで絞り込めることを確認します。
stock = self.session.query(Stock).filter(Stock.symbol == symbol)
self.assertEqual(stock[0].symbol, symbol)
# filter_byメソッドで絞り込めることを確認します。
stock = self.session.query(Stock).filter_by(symbol=symbol)
self.assertEqual(stock[0].symbol, symbol)
def test_count(self):
"""SELECT COUNT ...のテストケースです。"""
self.assertEqual(self.num_stocks, self.session.query(Stock).count())
self.assertEqual(self.num_stocks*self.num_histories, self.session.query(History).count())
for stock in self.session.query(Stock).all():
self.assertEqual(self.num_stocks, len(stock.histories))
def test_update(self):
"""UPDATEのテストケースです。"""
for stock in self.session.query(Stock).all():
self.assertFalse(stock.name.startswith('Mod_'))
for stock in self.session.query(Stock).all():
stock.name = 'Mod_' + stock.name
for stock in self.session.query(Stock).all():
self.assertTrue(stock.name.startswith('Mod_'))
def test_delete(self):
"""DELETEのテストケースです。"""
# レコードが0以上あることを確認します。
q = self.session.query(Stock)
self.assertTrue(self.session.query(q.exists()).scalar())
self.assertGreater(self.session.query(History).count(), 0)
# Stockを削除されることを確認します。
for stock in q.all():
self.session.delete(stock)
# ON DELETE CASCADEによってHistoryも削除されていることを確認します。
self.assertFalse(self.session.query(q.exists()).scalar())
self.assertEqual(self.session.query(History).count(), 0)
def test_start_session(self):
"""start_sessionのテストケースです。"""
def get_stock(idx):
stock = Stock()
stock.name = 'Test {}'.format(idx)
stock.symbol = 'test-{}'.format(idx)
return stock
# 初期状態を確認します。
self.assertEqual(self.num_stocks, self.session.query(Stock).count())
# commitをoffにして、Stockが追加されていないことを確認します。
with start_session() as session:
stock = get_stock(idx=self.num_stocks+1)
session.add(stock)
self.assertEqual(self.num_stocks, self.session.query(Stock).count())
# commitをONにして、Stockが追加されたことを確認します。
with start_session(commit=True) as session:
stock = get_stock(idx=self.num_stocks+2)
session.add(stock)
self.assertEqual(self.num_stocks+1, self.session.query(Stock).count())
def test_dynamictable1(self):
"""テーブル名を動的に変更するテストケースです。"""
from random import randint
tablename = 'new_table_{}'.format(randint(0, 10000)) # テーブル名を決定します。
self.assertFalse(engine.has_table(tablename)) # テーブルが無いことを確認します。
# テーブルクラスを定義します。
class NewTable(DynamicTableMixin, Base):
__tablename__ = tablename
Base.metadata.create_all(engine) # new_tableを作成します。
self.assertTrue(engine.has_table(tablename)) # テーブルが作成されたことを確認します。
# レコードを追加します。
stock = Stock()
stock.name = 'Test'
stock.symbol = 'test'
self.session.add(stock)
self.session.flush()
self.session.refresh(stock)
# 定義したテーブルにデータが登録されていないことを確認します。
self.assertEqual(0, self.session.query(NewTable).count())
# レコードを挿入し、登録されることを確認します。
t = NewTable()
t.stock_id = stock.id
self.session.add(t)
self.assertEqual(1, self.session.query(NewTable).count())
def test_dynamictable2(self):
"""テーブル名を動的に変更するテストケースです。"""
from random import randint
tablename = 'new_table_{}'.format(randint(0, 10000)) # テーブル名を決定します。
self.assertFalse(engine.has_table(tablename)) # テーブルが無いことを確認します。
klass = type('NewTable', (DynamicTableMixin, Base), {'__tablename__': tablename})
Base.metadata.create_all(engine) # new_tableを作成します。
self.assertTrue(engine.has_table(tablename)) # テーブルが作成されたことを確認します。
# レコードを追加します。
stock = Stock()
stock.name = 'Test'
stock.symbol = 'test'
self.session.add(stock)
self.session.flush()
self.session.refresh(stock)
# 定義したテーブルにデータが登録されていないことを確認します。
self.assertEqual(0, self.session.query(klass).count())
# レコードを挿入し、登録されることを確認します。
t = klass()
t.stock_id = stock.id
self.session.add(t)
self.assertEqual(1, self.session.query(klass).count())
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment