Created
February 18, 2020 09:47
-
-
Save agronholm/a2af7dbd4f09c36ae4a936a104796623 to your computer and use it in GitHub Desktop.
WIP PostgreSQL data store
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import platform | |
import threading | |
from datetime import datetime, timezone, timedelta | |
from typing import Optional, Dict, AsyncGenerator, List, Any, Union, AsyncIterable | |
import anyio | |
import sniffio | |
from asyncpg import create_pool, Connection | |
from asyncpg.pool import Pool | |
from apscheduler.abc import ScheduleStore, Schedule | |
# noinspection SqlResolve | |
class PostgresqlDataStore(ScheduleStore): | |
_pool: Optional[Pool] = None | |
_event: anyio.Event | |
_earliest_fire_time: Optional[datetime] = None | |
_identity: str | |
def __init__(self, *, dsn: str, schema: str = 'public', | |
listener_channel: str = 'apscheduler-schedules', | |
poll_delay: Union[float, timedelta] = 5, | |
max_schedules_at_once: Optional[int] = 50): | |
if not isinstance(poll_delay, timedelta): | |
poll_delay = timedelta(seconds=poll_delay) | |
self.dsn = dsn | |
self.schema = schema | |
self.listener_channel = listener_channel | |
self.poll_delay = poll_delay | |
self.max_schedules_at_once = max_schedules_at_once | |
@staticmethod | |
def _schedule_from_record(schedule: Dict[str, Any]) -> Schedule: | |
schedule['trigger'] = pickle.loads(schedule['trigger']) | |
return Schedule(**schedule) | |
@staticmethod | |
def _schedule_to_args(schedule: Schedule) -> tuple: | |
trigger = pickle.dumps(schedule.trigger, 4) | |
return (schedule.id, schedule.task_id, trigger, schedule.misfire_grace_time, | |
schedule.last_fire_time, schedule.next_fire_time) | |
async def _update_earliest_fire_time(self, conn: Connection, pid: int, channel: str, | |
payload: str): | |
try: | |
fire_time = datetime.fromtimestamp(float(payload), timezone.utc) | |
except ValueError: | |
return | |
if self._earliest_fire_time is None or fire_time < self._earliest_fire_time: | |
event, self._event = self._event, anyio.create_event() | |
await event.set() | |
async def _init_connection(self, conn: Connection) -> None: | |
await conn.add_listener(self.listener_channel, self._update_earliest_fire_time) | |
@staticmethod | |
async def _setup_connection(conn: Connection) -> None: | |
await conn.fetchval("SELECT 1", 10) # Validate that the connection still works | |
async def start(self) -> None: | |
asynclib = sniffio.current_async_library() or '(unknown)' | |
if asynclib != 'asyncio': | |
raise RuntimeError(f'This data store requires asyncio; currently running: {asynclib}') | |
self._identity = f'{platform.node()}-{os.getpid()}-{threading.get_ident()}' | |
self._event = anyio.create_event() | |
self._pool = await create_pool(self.dsn, min_size=1, max_size=1, | |
init=self._init_connection) | |
async with self._pool.acquire() as conn, conn.transaction(): | |
await conn.execute(f""" | |
CREATE TABLE IF NOT EXISTS {self.schema}.metadata ( | |
schema_version INTEGER NOT NULL | |
) | |
""") | |
version = await conn.fetchval(f"SELECT schema_version FROM {self.schema}.metadata") | |
if version is None: | |
await conn.execute(f"INSERT INTO {self.schema}.metadata VALUES (1)") | |
await conn.execute(f""" | |
CREATE TABLE {self.schema}.schedules ( | |
id TEXT PRIMARY KEY, | |
task_id TEXT NOT NULL, | |
trigger BYTEA NOT NULL, | |
misfire_grace_time INTERVAL, | |
previous_fire_time TIMESTAMP WITH TIME ZONE, | |
next_fire_time TIMESTAMP WITH TIME ZONE, | |
acquired_by TEXT, | |
acquired_at TIMESTAMP WITH TIME ZONE | |
) WITH fillfactor 80; | |
CREATE INDEX ON {self.schema}.schedules (next_fire_time); | |
""") | |
await conn.execute(f""" | |
CREATE TABLE {self.schema}.jobs ( | |
id UUID PRIMARY KEY, | |
task_id TEXT NOT NULL, | |
func TEXT NOT NULL, | |
scheduled_start_time TIMESTAMP WITH TIME ZONE, | |
deadline TIMESTAMP WITH TIME ZONE, | |
result BYTEA, | |
exception BYTEA | |
) | |
CREATE INDEX ON {self.schema}.jobs (scheduled_start_time); | |
CREATE INDEX ON {self.schema}.jobs (deadline); | |
""") | |
elif version > 1: | |
raise RuntimeError('Unexpected schema version ({version}); ' | |
'only version 1 is supported by this version of APScheduler') | |
else: | |
self._earliest_fire_time = await conn.fetchval(""" | |
SELECT next_fire_time FROM schedules WHERE next_fire_time IS NOT NULL | |
ORDER BY next_fire_time | |
""") | |
async def stop(self) -> None: | |
if self._pool: | |
await self._pool.close() | |
del self._pool | |
async def add_or_update_schedule(self, schedule: Schedule) -> None: | |
await self._pool.execute(""" | |
INSERT INTO schedules | |
(id, trigger, task_id, misfire_grace_time, previous_fire_time, next_fire_time) | |
VALUES ($1, $2, $3, $4, $5, $6) | |
ON CONFLICT (id) DO UPDATE SET trigger = $2, task_id = $3, misfire_grace_time = $4, | |
previous_fire_time = $5, next_fire_time = $6 | |
""", *self._schedule_to_args(schedule)) | |
if schedule.next_fire_time: | |
fire_time = schedule.next_fire_time.timestamp() | |
await self._pool.execute(f"NOTIFY $1 $2", self.listener_channel, fire_time) | |
async def remove_schedule(self, schedule_id: str) -> None: | |
await self._pool.execute(f"DELETE FROM {self.schema}.schedules WHERE id = $1", schedule_id) | |
async def remove_all_schedules(self) -> None: | |
await self._pool.execute(f"DELETE FROM {self.schema}.schedules") | |
async def get_all_schedules(self) -> List[Schedule]: | |
records = await self._pool.fetch(f"SELECT * FROM {self.schema}.schedules ORDER BY id") | |
return [self._schedule_from_record(r) for r in records] | |
async def acquire_due_schedules(self) -> List[Schedule]: | |
async with self._pool.acquire() as conn, conn.transaction(): | |
records = await conn.fetch(f""" | |
SELECT * FROM {self.schema}.schedules | |
WHERE next_fire_time IS NOT NULL AND next_fire_time <= $1 | |
ORDER BY next_fire_time FOR NO KEY UPDATE SKIP LOCKED | |
""", datetime.now(timezone.utc)) | |
await conn.executemany() | |
return [self._schedule_from_record(r) for r in records] | |
async def release_due_schedules(self, schedules: List[Schedule]) -> datetime: | |
async with self._pool.acquire() as conn, conn.transaction(): | |
# Delete schedules that don't have a next fire time | |
args = [s.id for s in schedules if s.next_fire_time is None] | |
if args: | |
await conn.executemany(f"DELETE FROM {self.schema}.schedules " | |
f"WHERE id = $1", args) | |
# Update schedules that do have a next fire time | |
args = [self._schedule_to_args(s) for s in schedules if s is not None] | |
if args: | |
await conn.executemany(f""" | |
UPDATE {self.schema}.schedules SET | |
trigger = $2, task_id = $3, misfire_grace_time = $4, | |
previous_fire_time = $5, next_fire_time = $6 | |
WHERE id = $1 | |
""", args) | |
next_fire_time = await conn.fetchval(""" | |
SELECT next_fire_time FROM schedules WHERE next_fire_time IS NOT NULL | |
ORDER BY next_fire_time FOR KEY SHARE SKIP LOCKED LIMIT 1 | |
""") | |
if next_fire_time: | |
await conn.execute( | |
f"NOTIFY {self.listener_channel} {next_fire_time.timestamp()}") | |
sleep_time = min(self.poll_delay, | |
datetime.now(timezone.utc) - self._earliest_fire_time) | |
async with anyio.move_on_after(sleep_time.total_seconds()): | |
await self._event.wait() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment