Skip to content

Instantly share code, notes, and snippets.

@cmrfrd
Created March 9, 2020 21:28
Show Gist options
  • Save cmrfrd/603900832fa51db42cba367935458c94 to your computer and use it in GitHub Desktop.
Save cmrfrd/603900832fa51db42cba367935458c94 to your computer and use it in GitHub Desktop.
bloop
import time
import asyncio
import operator
from functools import wraps
from itertools import repeat, count
from aioitertools import iter as aiter
from aioitertools import islice
from peewee import *
from peewee import Expression
from playhouse.fields import PickleField
try:
from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase
except ImportError:
from playhouse.sqlite_ext import SqliteExtDatabase
Sentinel = type('Sentinel', (object,), {})
EMPTY = type('EMPTY', (object,), {})
class KeyValue(object):
"""
Persistent dictionary.
:param Field key_field: field to use for key. Defaults to CharField.
:param Field value_field: field to use for value. Defaults to PickleField.
:param bool ordered: data should be returned in key-sorted order.
:param Database database: database where key/value data is stored.
:param str table_name: table name for data.
"""
def __init__(self, key_field=None, value_field=None, ordered=False,
database=None, table_name='keyvalue', default=Sentinel):
if key_field is None:
key_field = CharField(max_length=255, primary_key=True)
if not key_field.primary_key:
raise ValueError('key_field must have primary_key=True.')
if value_field is None:
value_field = PickleField()
self._key_field = key_field
self._value_field = value_field
self._ordered = ordered
self._database = database or SqliteExtDatabase(':memory:')
self._table_name = table_name
if isinstance(self._database, PostgresqlDatabase):
self.upsert = self._postgres_upsert
self.update = self._postgres_update
else:
self.upsert = self._upsert
self.update = self._update
self.model = self.create_model()
self.key = self.model.key
self.value = self.model.value
# Ensure table exists.
self.model.create_table()
self.default = default
def create_model(self):
class KeyValue(Model):
key = self._key_field
value = self._value_field
class Meta:
database = self._database
table_name = self._table_name
return KeyValue
def query(self, *select):
query = self.model.select(*select).tuples()
if self._ordered:
query = query.order_by(self.key)
return query
def convert_expression(self, expr):
if not isinstance(expr, Expression):
return (self.key == expr), True
return expr, False
def __contains__(self, key):
expr, _ = self.convert_expression(key)
return self.model.select().where(expr).exists()
def __len__(self):
return len(self.model)
def __getitem__(self, expr):
converted, is_single = self.convert_expression(expr)
query = self.query(self.value).where(converted)
item_getter = operator.itemgetter(0)
result = [item_getter(row) for row in query]
if len(result) == 0 and is_single:
if self.default != Sentinel:
return self.default
raise KeyError(expr)
elif is_single:
return result[0]
return result
def _upsert(self, key, value):
(self.model
.insert(key=key, value=value)
.on_conflict('replace')
.execute())
def _postgres_upsert(self, key, value):
(self.model
.insert(key=key, value=value)
.on_conflict(conflict_target=[self.key],
preserve=[self.value])
.execute())
def __setitem__(self, expr, value):
if isinstance(expr, Expression):
self.model.update(value=value).where(expr).execute()
else:
self.upsert(expr, value)
def __delitem__(self, expr):
converted, _ = self.convert_expression(expr)
self.model.delete().where(converted).execute()
def __iter__(self):
return iter(self.query().execute())
def keyspace(self, space):
return map(operator.itemgetter(0), self.query(self.key).where(self.key.startswith(space)))
ks = keyspace
async def ikeyspace(self, space):
async for item in map(operator.itemgetter(0), self.query(self.key).where(self.key.startswith(space))):
yield item
iks = ikeyspace
def keys(self, expr=None):
return list(self.ikeys(expr))
def ikeys(self, expr=None):
if expr:
converted, _ = self.convert_expression(expr)
return map(operator.itemgetter(0), self.query(self.key).where(converted))
return map(operator.itemgetter(0), self.query(self.key))
def values(self):
return list(self.ivalues())
def ivalues(self):
return map(operator.itemgetter(0), self.query(self.value))
def items(self, expr=None):
return list(self.iitems(expr))
def iitems(self, expr=None):
if expr:
converted, _ = self.convert_expression(expr)
return iter(self.query().where(converted))
return iter(self.query().execute())
def _update(self, __data=None, **mapping):
if __data is not None:
mapping.update(__data)
return (self.model
.insert_many(list(mapping.items()),
fields=[self.key, self.value])
.on_conflict('replace')
.execute())
def _postgres_update(self, __data=None, **mapping):
if __data is not None:
mapping.update(__data)
return (self.model
.insert_many(list(mapping.items()),
fields=[self.key, self.value])
.on_conflict(conflict_target=[self.key],
preserve=[self.value])
.execute())
def get(self, key, default=None):
converted, is_single = self.convert_expression(key)
query = self.query(self.value).where(converted)
item_getter = operator.itemgetter(0)
result = [item_getter(row) for row in query]
try:
if len(result) == 0 and is_single:
if self.default != Sentinel:
return self.default
raise KeyError(key)
elif is_single:
return result[0]
return result
except KeyError:
return default
def iget(self, key, default=None):
converted, _ = self.convert_expression(key)
query = self.query(self.key).where(converted)
for key in query:
yield self.get(key, default)
def setdefault(self, key, default=None):
try:
return self[key]
except KeyError:
self[key] = default
return default
def pop(self, key, default=Sentinel):
with self._database.atomic():
try:
result = self[key]
except KeyError:
if self.default != Sentinel:
return self.default
if default is Sentinel:
raise
return default
del self[key]
return result
def ipop(self, key, default=Sentinel):
for k in self.ikeys(key):
yield self.pop(k, default)
async def apop_expect(self, key, default=Sentinel, expect=-1, break_on=Sentinel, timeout=10.0):
"""oh ye
"""
assert timeout > 0, 'Must have non negative timeout'
start = time.time()
print("start while")
while ((time.time() - start) < timeout) and (expect != 0):
# print(t(), not timeout_exceeded(), p(), not expect_exceeded(), "yoooo")
async for item in aiter(self.ipop(key, default)):
if item == break_on:
return
if (item not in [self.default, default]):
expect -= 1
start = time.time()
yield item
await asyncio.sleep(0)
async def aitems_expect(self, expr=None, default=Sentinel, expect=-1, break_on=Sentinel, timeout=10.0):
"""oh ye
"""
assert timeout > 0, 'Must have non negative timeout'
start = time.time()
print("start while")
while ((time.time() - start) < timeout) and (expect != 0):
# print(t(), not timeout_exceeded(), p(), not expect_exceeded(), "yoooo")
async for key, value in aiter(self.iitems(expr)):
if value == break_on:
return
if (value not in [self.default, default]):
expect -= 1
start = time.time()
yield key, value
await asyncio.sleep(0)
def clear(self):
self.model.delete().execute()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment