Skip to content

Instantly share code, notes, and snippets.

@jwickens
Last active March 28, 2022 11:33
Show Gist options
  • Save jwickens/7be655d478f546f8262de0037a70b7ce to your computer and use it in GitHub Desktop.
Save jwickens/7be655d478f546f8262de0037a70b7ce to your computer and use it in GitHub Desktop.
Pytorch-like iterable Dataset example backed by async iterator (postgres)
import asyncpg
async def create_pool():
pool = await asyncpg.create_pool(
database="research",
user="jwickens",
setup=setup_connection,
min_size=32,
max_size=32
)
return pool
async def setup_connection(connection):
await connection.execute("set search_path to endofday")
from datetime import datetime, date
from torch.utils.data import DataLoader
from typing import NamedTuple, List, Optional, AsyncIterator, Iterator
from enum import Enum
from asyncpg.pool import Pool
from asyncio import AbstractEventLoop
import asyncio
from utils import wrap_async_iter
class OHLCV(NamedTuple):
open: float
high: float
low: float
close: float
volume: int
class DatedOHLCV(NamedTuple):
date: date
ohlcv: OHLCV
class OHLCVSequence(NamedTuple):
ticker: str
sequence: List[DatedOHLCV]
class OHLCVSequenceDataset:
class Type(Enum):
training = 'training_sample'
test = 'test_sample'
batch_size: int
sequence_length: int
percent_sample: float
pool: Pool
loop: AbstractEventLoop
type: Type
def __init__(self,
type: Type,
loop: AbstractEventLoop,
pool: Pool,
percent_sample: float = 100,
sequence_length: int = 32,
batch_size: int = 32
):
super().__init__()
self.type = type
self.loop = loop
self.batch_size = batch_size
self.pool = pool
self.percent_sample = percent_sample
self.sequence_length = sequence_length
async def iter_start_point_cursor(self):
async with self.pool.acquire() as connection:
async with connection.transaction():
async for record in connection.cursor(f"""
SELECT ticker, ticker_id, date_id
FROM stock_data, {self.type.value} TABLESAMPLE SYSTEM ({self.percent_sample})
WHERE stock_data.id = {self.type.value}.stock_data_id
"""):
yield record
async def get_sequence(self, start_point) -> OHLCVSequence:
ticker, ticker_id, date_id = start_point
async with self.pool.acquire() as connection:
result = await connection.fetch(f"""
SELECT
date,
open,
high,
low,
close,
volume
FROM stock_data
WHERE stock_data.ticker_id = {ticker_id}
AND stock_data.date_id >= {date_id}
ORDER BY stock_data.date_id
LIMIT {self.sequence_length}
""")
def convert_row(row):
return DatedOHLCV(
date=row[0],
ohlcv=OHLCV(*row[1:])
)
return OHLCVSequence(
ticker=ticker,
sequence=list(map(convert_row, result))
)
async def __aiter__(self) -> AsyncIterator[List[OHLCVSequence]]:
start_points: List = []
async def map_start_points():
tasks = map(self.get_sequence, start_points)
batch = await asyncio.gather(*tasks)
start_points.clear()
return batch
async for start_point in self.iter_start_point_cursor():
start_points.append(start_point)
if len(start_points) == self.batch_size:
batch = await map_start_points()
yield batch
batch = await map_start_points()
if len(batch) > 0:
yield batch
def __iter__(self) -> Iterator[List[OHLCVSequence]]:
return wrap_async_iter(self, self.loop)
from dataset import OHLCVSequenceDataset
from torch.utils.data import DataLoader
from datetime import datetime
from database import create_pool
import threading
import asyncio
loop = asyncio.get_event_loop()
# create an asyncio loop that runs in the background to
# serve our asyncio needs
threading.Thread(target=loop.run_forever, daemon=True).start()
pool = asyncio.run_coroutine_threadsafe(create_pool(), loop=loop).result()
start = datetime.now()
d = OHLCVSequenceDataset(
type=OHLCVSequenceDataset.Type.test,
percent_sample=0.1,
loop=loop,
pool=pool)
i = 0
for x in d:
if i == 0:
print(f"first batch in {datetime.now() - start}")
print(len(x))
print(len(x[0]))
print(x[0].ticker)
print(x[0].sequence[0].date)
print(x[0].sequence[0].ohlcv)
if i == 1:
print(f"second batches in {datetime.now() - start}")
if i == 10:
print(f"10 batches in {datetime.now() - start}")
i += 1
print(f"{i} batches in {datetime.now() - start}")
# https://stackoverflow.com/a/55164899
def wrap_async_iter(ait, loop):
"""Wrap an asynchronous iterator into a synchronous one"""
q = queue.Queue()
_END = object()
def yield_queue_items():
while True:
next_item = q.get()
if next_item is _END:
break
yield next_item
# After observing _END we know the aiter_to_queue coroutine has
# completed. Invoke result() for side effect - if an exception
# was raised by the async iterator, it will be propagated here.
async_result.result()
async def aiter_to_queue():
try:
async for item in ait:
q.put(item)
finally:
q.put(_END)
async_result = asyncio.run_coroutine_threadsafe(aiter_to_queue(), loop)
return yield_queue_items()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment