Skip to content

Instantly share code, notes, and snippets.

@cemoody
Last active June 26, 2023 04:23
Show Gist options
  • Save cemoody/e35b15a7b9e7bfa6ef5d4b5864cfbf08 to your computer and use it in GitHub Desktop.
Save cemoody/e35b15a7b9e7bfa6ef5d4b5864cfbf08 to your computer and use it in GitHub Desktop.
import os
import io
import json
import math
import time
import random
import numpy as np
import cachetools.func
import sqlite3
from loguru import logger
# Modeled after persist-queue
# https://github.com/peter-wangxu/persist-queue
def adapt_array(arr):
"""
http://stackoverflow.com/a/31312102/190597 (SoulNibbler)
"""
out = io.BytesIO()
np.save(out, arr)
out.seek(0)
return sqlite3.Binary(out.read())
def convert_array(text):
out = io.BytesIO(text)
out.seek(0)
return np.load(out)
# Converts np.array to TEXT when inserting
sqlite3.register_adapter(np.ndarray, adapt_array)
# Converts TEXT to np.array when selecting
sqlite3.register_converter("ARRAY", convert_array)
class AckStatus(object):
inited = "0"
ready = "1"
unack = "2" # Message is popped off by receiever has not ack'd
acked = "5" # Message is popped and receiver has ack'd; assumed done
ack_failed = "9" # Reciever has marked message as failed
ack_done = "17" # Reciever has marked explicitly message as done
class DummySerializer:
def loads(self, x):
return x
def dumps(self, x):
return x
class DynamicList(list):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getitem__(self, idx):
self.expand(idx)
return super().__getitem__(idx)
def __setitem__(self, idx, val):
self.expand(idx)
return super().__setitem__(idx, val)
def expand(self, idx):
if isinstance(idx, int):
idx += 1
elif isinstance(idx, slice):
idx = max(idx.start, idx.stop)
if idx > len(self):
self.extend([0] * (idx - len(self)))
dummy_serializer = DummySerializer()
class SQLiteAckQueue:
columns = []
_TABLE_NAME = "ack_unique_queue_default"
_KEY_COLUMN = "_id"
_SQL_CREATE_UNIQUE = (
"CREATE TABLE IF NOT EXISTS {table_name} ("
"{key_column} INTEGER PRIMARY KEY AUTOINCREMENT, "
"timestamp FLOAT, status INTEGER, {unique_column} TEXT, UNIQUE ({unique_column}))"
)
_SQL_CREATE = (
"CREATE TABLE IF NOT EXISTS {table_name} ("
"{key_column} INTEGER PRIMARY KEY AUTOINCREMENT, "
"timestamp FLOAT, status INTEGER)"
)
_SQL_SELECT = (
"SELECT {key_column}, timestamp, status {table_columns} FROM {table_name} "
"WHERE status < %s "
"ORDER BY {key_column} ASC LIMIT {limit} OFFSET {offset}" % AckStatus.unack
)
_SQL_SELECT_ALL = (
"SELECT {key_column}, timestamp, status {table_columns} FROM {table_name} "
"ORDER BY {key_column} ASC LIMIT {limit} OFFSET {offset}"
)
_SQL_MARK_ACK_SELECT = """
SELECT _id, data FROM {table_name}
WHERE {key_column} IN ({indices})
"""
_SQL_MARK_ACK_UPDATE = ("""
UPDATE {table_name} SET status = {status}
WHERE {key_column} IN ({indices})
""")
_SQL_UPDATE_SINGLE_ROW = ("""
UPDATE {table_name} SET {column_name} = {column_value}
WHERE {row_id_col} = {row_id_val}
"""
)
_SQL_DELETE = """
DELETE FROM {table_name}
WHERE {key_column} IN ({indices})
"""
_SQL_INSERT = (
"INSERT OR IGNORE INTO {table_name} (timestamp, status, {table_columns})"
" VALUES (?, %s, {table_values}) " % AckStatus.inited
)
_SQL_LAST_INSERT = "SELECT last_insert_rowid()"
_SQL_COUNT = "SELECT COUNT(*) FROM {table_name}"
_SQL_FREE = "SELECT COUNT(*) FROM {table_name} WHERE status < %s" % AckStatus.unack
_SQL_DONE = "SELECT COUNT(*) FROM {table_name} WHERE status > %s" % AckStatus.unack
_SQL_ACTIVE = "SELECT COUNT(*) FROM {table_name} WHERE status >= %s AND status < %s" % (AckStatus.unack, AckStatus.ack_failed)
_SQL_TIMEOUT = """
UPDATE {table_name}
SET status = %s
WHERE status = %s
AND timestamp < {timeout}
""" % (
AckStatus.ready,
AckStatus.unack,
)
_SQL_CREATE_COLUMN = "ALTER TABLE {table_name} ADD {column_name} {column_type}"
_SQL_READ_COLUMNS = "PRAGMA table_info({table_name})"
_con = None
_last_count_update = -1
last_timeout_application = 0
serializer = json
do_debug = False
created = False
initialized = False
def __init__(
self,
path,
unique_column=None,
timeout=300,
max_size=None,
delete_on_ack=False,
serializer=json,
table_name=None,
verbose=False
):
self.timeout = timeout
self.path = path
self.max_size = max_size
self.delete_on_ack = delete_on_ack
self.serializer = serializer
self.verbose = verbose
self.unique_column = unique_column
if table_name:
self._TABLE_NAME = table_name
def initialize(self):
if self.initialized:
return
unique_column = self.unique_column
self.sql = self._SQL_CREATE_UNIQUE if unique_column else self._SQL_CREATE
if self._con is None:
self._con = sqlite3.connect(self.path, detect_types=sqlite3.PARSE_DECLTYPES)
self._con.execute(
self.sql.format(table_name=self._TABLE_NAME, key_column=self._KEY_COLUMN,
unique_column=unique_column)
)
self.columns = self._read_columns()
if unique_column is not None and unique_column not in self.columns:
self.columns.append(unique_column)
self._con.commit()
self.initialized = True
@property
def con(self):
self._apply_timeout()
if self._con is None:
self._con = sqlite3.connect(self.path, detect_types=sqlite3.PARSE_DECLTYPES)
self.initialize()
return self._con
def execute(self, *args, **kwargs):
if self.verbose:
logger.debug(f"Query: args[0]")
ret = self.con.execute(*args, **kwargs)
if self.verbose:
logger.debug(f"Returned: ", ret)
return ret
def get(self):
return self.gets(1)
def gets(self, n, random_offset=False, ack=True, return_keys=False,
read_all=False):
self.initialize()
offset = 0
if random_offset:
offset = random.randint(0, n * 100)
# Select rows to update
rows = self.select(n, offset, read_all=read_all)
# Skip the id & timestamp & status fields by only
# reading from the 3rd field onward
items = self._process_rows(rows)
items = self.unflatten_array_columns(items)
keys = [row[0] for row in rows]
# Mark them as checked out
if ack:
self.updates(keys, AckStatus.unack)
self.con.commit()
if return_keys:
return keys, items
return items
def _process_rows(self, rows):
items = [{k: v for (k, v) in zip(self.columns, row[3:])} for row in rows]
return items
def select(self, n, offset=0, read_all=False):
self.initialize()
qwhere = self._SQL_SELECT_ALL if read_all else self._SQL_SELECT
qwhere = qwhere.format(
table_name=self._TABLE_NAME,
key_column=self._KEY_COLUMN,
table_columns = "," + ", ".join(self.columns) if len(self.columns) > 0 else "",
limit=n,
offset=offset,
)
cursor = self.execute(qwhere)
rows = list(cursor.fetchall())
return rows
def put(self, item):
key, = self.puts([item])
return key
def puts(self, items):
self.initialize()
if len(items) == 0:
return []
if not all(isinstance(i, dict) for i in items):
raise ValueError("Items must be dicts")
if not all(len(i) > 0 for i in items):
raise ValueError("Dicts cannot be empty")
self._max_size_block()
items = self.flatten_array_columns(items)
self._update_table_schema(items[0])
items = self._reorder_to_match_table_schema(items)
cols_str = ", ".join(self.columns)
vals_str = ", ".join("?" for _ in self.columns)
keys = []
for item in items:
insert = self._SQL_INSERT.format(table_name=self._TABLE_NAME,
table_columns=cols_str,
table_values=vals_str,
key_column=self._KEY_COLUMN)
cursor = self.execute(insert, item)
if(cursor.rowcount > 0):
# Get the key of the inserted row
cursor = self.execute(self._SQL_LAST_INSERT)
key, = cursor.fetchone()
keys.append(key)
self.con.commit()
return keys
def flatten_array_columns(self, items):
new_items = []
for item in items:
new_item = {}
for key, value in item.items():
if isinstance(value, list):
for idim, element in enumerate(value):
new_item[f'{key}_dim_{idim:04d}'] = element
else:
new_item[key] = value
new_items.append(new_item)
return new_items
def unflatten_array_columns(self, items):
new_items = []
for item in items:
new_item = {}
arrays = {}
for key, value in item.items():
if '_dim_' in key:
column_name = key.split('_')[0]
column_idim = int(key.split('_')[2])
arr = arrays.get(column_name, DynamicList())
arr[column_idim] = value
arrays[column_name] = arr
else:
new_item[key] = value
for column, dyn_array in arrays.items():
arr = list(dyn_array)
assert all(x is not None for x in arr)
new_item[column] = arr
new_items.append(new_item)
return new_items
def _update_table_schema(self, row):
""" Update table schema """
for k, v in row.items():
if k not in self.columns:
self._create_column(k, v)
def _create_column(self, name, value):
if isinstance(value, str):
v_type = "TEXT"
elif isinstance(value, float):
v_type = "REAL"
elif isinstance(value, int):
v_type = "INTEGER"
elif isinstance(value, dict):
raise ValueError("Cannot have nested dictionaries")
elif isinstance(value, np.ndarray):
v_type = "ARRAY"
else:
v_type = "TEXT"
query = self._SQL_CREATE_COLUMN.format(table_name=self._TABLE_NAME, column_name=name, column_type=v_type)
self.execute(query)
self.columns.append(name)
def _reorder_to_match_table_schema(self, rows):
new_rows = []
for i, row in enumerate(rows):
new_row = [time.time()]
for column in self.columns:
new_row.append(row.pop(column, None))
assert len(row) == 0, f"Extra columns not present in table found in {i}th row"
new_rows.append(new_row)
return new_rows
def _read_columns(self):
cursor = self.execute(self._SQL_READ_COLUMNS.format(table_name=self._TABLE_NAME))
rows = cursor.fetchall()
column_names = [row[1] for row in rows]
column_names = [n for n in column_names if n not in ("_id", "timestamp", "status")]
return column_names
def _max_size_block(self):
""" Block the main thread until the count in the table
decreases.
"""
if self.max_size:
i = 0
while self.approx_count() > self.max_size:
i += 1
time.sleep(1)
if int(math.log2(i)) == math.log2(i):
logger.info(f"Waited {i} sec so far for queue to deplete")
if i > 1:
logger.info(f"Finished waiting after {i} sec")
def updates(self, keys, status=AckStatus.unack):
self.initialize()
indices = ",".join((str(r) for r in keys))
qupdat = self._SQL_MARK_ACK_UPDATE.format(
table_name=self._TABLE_NAME,
key_column=self._KEY_COLUMN,
status=status,
indices=indices,
)
cursor = self.execute(qupdat)
if cursor.rowcount != len(keys):
raise KeyError("Could not update all keys")
self.con.commit()
def set(self, row_key_dict, **field_dict):
return self.sets([row_key_dict], [field_dict])
def sets(self, row_key_dicts, field_dicts):
self.initialize()
for row_key_dict, field_dict in zip(row_key_dicts, field_dicts):
(row_id_col, row_id_val), = list(row_key_dict.items())
for column_name, column_value in field_dict.items():
if column_name not in self.columns:
self._create_column(column_name, column_value)
qry = self._SQL_UPDATE_SINGLE_ROW.format(table_name=self._TABLE_NAME,
row_id_col=row_id_col,
row_id_val=row_id_val,
column_name=column_name,
column_value=column_value)
cursor = self.execute(qry)
assert cursor.rowcount == 1, f"Did not find row for {row_id_col}={row_id_val}"
def delete(self, keys):
self.initialize()
indices = ",".join((str(r) for r in keys))
qdel = self._SQL_DELETE.format(
table_name=self._TABLE_NAME,
key_column=self._KEY_COLUMN,
indices=indices,
)
self.execute(qdel)
self.con.commit()
def acks(self, keys, status=AckStatus.acked):
self.updates(keys, status)
if self.delete_on_ack:
self.delete(keys)
def _apply_timeout(self):
# Chane unack to ready
# Don't apply time out if connection isnt open yet
if self._con is None:
return
# Make sure we do not apply the timeout logic too frequently
dt = time.time() - self.last_timeout_application
if dt < self.timeout:
return
if self.do_debug:
logger.debug(f"Applying timeout on old unack messages on {self.path}")
logger.debug(f"Last applied timeout {dt:1.1f} sec ago")
time_cutoff = time.time() - self.timeout
qtimeout = self._SQL_TIMEOUT.format(
table_name=self._TABLE_NAME, timeout=time_cutoff
)
self._con.execute(qtimeout)
self._con.commit()
self.last_timeout_application = time.time()
if self.do_debug:
logger.debug(f"Finished recycling messages at {self.last_timeout_application}")
def free(self):
self.initialize()
cursor = self.execute(self._SQL_FREE.format(table_name=self._TABLE_NAME))
(n,) = cursor.fetchone()
self.con.commit()
return n
def done(self):
self.initialize()
cursor = self.execute(self._SQL_DONE.format(table_name=self._TABLE_NAME))
(n,) = cursor.fetchone()
self.con.commit()
return n
def active(self):
self.initialize()
cursor = self.execute(self._SQL_ACTIVE.format(table_name=self._TABLE_NAME))
(n,) = cursor.fetchone()
self.con.commit()
return n
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def approx_count(self):
return self._count()
def _count(self):
cursor = self.execute(self._SQL_COUNT.format(table_name=self._TABLE_NAME))
(n,) = cursor.fetchone()
return n
def count(self):
self.initialize()
return self._count()
def test():
if os.path.exists("temp.db"):
os.remove("temp.db")
# Initialized queue should be zero sized
q = SQLiteAckQueue("temp.db", unique_column="id")
assert q.count() == 0
# Raise an error -- we have zero items
items = q.gets(1)
# Does not raise an error -- key does not exist
try:
q.acks([7])
raise RuntimeError("Expected to raise KeyError")
except KeyError:
pass
# Cannot put in dicts
try:
q.puts([{} for _ in range(10)])
raise RuntimeError("Expected to raise ValueError")
except ValueError:
pass
# Initialize list
q.puts([{'id': i} for i in range(10)])
assert q.count() == 10
# Won't duplicate items
q.puts([{'id': i} for i in range(10)])
assert q.count() == 10
assert q.free() == 10
assert q.done() == 0
# Will add a new column
q.puts([{'id': i, 'color': str(i + 100)} for i in range(10, 21)])
assert q.count() == 21
assert q.free() == 21
assert q.done() == 0
# Get items
keys, items = q.gets(7, return_keys=True)
assert len(keys) == len(items) == 7
assert q.count() == 21
assert q.free() == 14
assert len(keys) == len(items) == 7
all_items = [i for i in items]
# We have finished processing keys; now ack them
q.acks(keys)
assert q.count() == 21
assert q.free() == 14
# Ack them again -- should be idempotent
q.acks(keys)
assert q.count() == 21
assert q.free() == 14
# This should get the remainder of the items
keys, items = q.gets(50, return_keys=True)
assert len(keys) == len(items) == 14
assert q.count() == 21
assert q.free() == 0
# Test that we got back our original items
all_items.extend(items)
items_1 = [i for i in all_items if i['color'] is None]
items_2 = [i for i in all_items if i['color'] is not None]
assert len(all_items) == 21
assert len(items_1) == 10
assert len(items_2) == 11
# Will update item fields in place
assert q.count() == 21
q.sets([{'id': i} for i in range(8)],
[{"id2": i + 500} for i in range(8)])
assert q.count() == 21
items = q.gets(50, read_all=True)
assert len([i for i in items if i['id2'] is not None]) == 8
os.remove('temp.db')
def test_vec():
if os.path.exists("temp.db"):
os.remove("temp.db")
# Initialized queue should be zero sized
q = SQLiteAckQueue("temp.db", unique_column="id")
# Auto expand vectors into many columns
q.puts([{'vec': [1, 2, 3]}])
row, = q.gets(1)
assert sum(row['vec']) == 6
os.remove('temp.db')
def test_np():
if os.path.exists("temp.db"):
os.remove("temp.db")
# Initialized queue should be zero sized
q = SQLiteAckQueue("temp.db", unique_column="id")
# Auto expand vectors into many columns
q.puts([{'vec': np.arange(8)}])
row, = q.gets(1)
assert sum(row['vec']) == 28
os.remove('temp.db')
def test_real_data():
if os.path.exists("temp.db"):
os.remove("temp.db")
items = [
{'shopify_id': 6881964425309,
'domain': 'www.farmrio.com',
'shop_description': "Shop FARM Rio, Brazil’s beloved women's clothing & lifestyle brand. Free shipping on orders above $50 + free returns. Printed dresses, bottoms, tops & more!",
'product_image': 'https://cdn.shopify.com/s/files/1/0077/6673/6963/products/301795_01.jpg?v=1660240205',
'product_title': 'Black Flower Garden Organic Cotton Mini Dress',
'product_type': 'Dresses',
'product_title_umap': np.array([ 6.9717116, -12.974599 ], dtype=np.float32),
'product_type_umap': np.array([-13.280279, 12.3995 ], dtype=np.float32),
'shop_description_umap': np.array([10.05782 , 6.426944], dtype=np.float32),
'mixed_umap': np.array([ 4.6861753, -13.146825 ], dtype=np.float32)},
{'shopify_id': 6881964425310,
'domain': 'www.farmrio.com',
'shop_description': "Shop FARM Rio, Brazil’s beloved women's clothing & lifestyle brand. Free shipping on orders above $50 + free returns. Printed dresses, bottoms, tops & more!",
'product_image': 'https://cdn.shopify.com/s/files/1/0077/6673/6963/products/301795_01.jpg?v=1660240205',
'product_title': 'Black Flower Garden Organic Cotton Mini Dress',
'product_type': 'Dresses',
'product_title_umap': np.array([ 6.9717116, -12.974599 ], dtype=np.float32),
'product_type_umap': np.array([-13.280279, 12.3995 ], dtype=np.float32),
'shop_description_umap': np.array([10.05782 , 6.426944], dtype=np.float32),
'mixed_umap': np.array([ 4.6861753, -13.146825 ], dtype=np.float32)}
]
# Initialized queue should be zero sized
q = SQLiteAckQueue("temp.db", unique_column="id")
# Auto expand vectors into many columns
q.puts(items)
row, = q.gets(1)
assert row['shopify_id'] == 6881964425309
os.remove('temp.db')
if __name__ == "__main__":
test_real_data()
test_np()
test_vec()
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment