Skip to content

Instantly share code, notes, and snippets.

@mmalone
Created April 22, 2010 22:35
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 13 You must be signed in to fork a gist
  • Save mmalone/375916 to your computer and use it in GitHub Desktop.
Save mmalone/375916 to your computer and use it in GitHub Desktop.
# In-memory Cassandra-ish thingy... useful for unit tests. Maybe useful for other
# stuff too? No support for SuperColumns, but that should be easy enough to add.
import bisect
import copy
from cassandra.ttypes import NotFoundException, Column, ColumnPath, ColumnOrSuperColumn
class SSTable(object):
def __init__(self, data=None):
if data is None:
data = {}
self.data = data
self.sorted_keys = sorted(self.data.iterkeys())
def __setitem__(self, key, value):
if key not in self.data:
bisect.insort(self.sorted_keys, key)
self.data[key] = value
def __delitem__(self, key):
del self.data[key]
del self.sorted_keys[bisect.bisect_left(self.sorted_keys, key)]
def __getitem__(self, key):
return self.data[key]
def __contains__(self, key):
return self.data.__contains__(key)
def __len__(self):
return len(self.sorted_keys)
def get(self, *args, **kwargs):
return self.data.get(*args, **kwargs)
def clear(self):
self.data.clear()
del self.sorted_keys[:]
def copy(self):
return self.__class__(self.data.copy())
def key_slice(self, start, finish, count=None, reversed=False):
start = bisect.bisect_left(self.sorted_keys, start) if start else 0
finish = bisect.bisect_left(self.sorted_keys, finish) if finish else len(self.sorted_keys)
keys = self.sorted_keys[start:finish]
if reversed:
keys.reverse()
if count is not None:
keys = keys[:count]
return keys
def slice(self, start, finish, count=None, reversed=False):
return (self[key] for key in self.key_slice(start, finish, count, reversed))
class Cassandra(object):
def __init__(self, keyspaces=None):
self.database = {}
self.keyspaces = {}
if keyspaces is not None:
for keyspace, column_families in keyspaces.iteritems():
self.add_keyspace(keyspace, column_families)
def add_keyspace(self, keyspace, column_families):
self.database[keyspace] = SSTable()
self.keyspaces[keyspace] = column_families
def _slice_range(self, row, predicate):
if predicate.slice_range:
slice_range = predicate.slice_range
columns = list(row.slice(
start=slice_range.start,
finish=slice_range.finish,
count=slice_range.count,
reversed=slice_range.reversed
))
else:
columns = [row[column_name] for column_name in predicate.column_names if column_name in row]
return [ColumnOrSuperColumn(column=copy.deepcopy(column)) for column in columns]
def _get(self, keyspace, key, column_path, consistency_level):
key = self.database[keyspace].get(key)
if key is not None:
column_family = key[column_path.column_family] # Should raise misconfiguration exception here if KeyError.
try:
return column_family[column_path.column]
except KeyError:
pass
raise NotFoundException()
def get(self, keyspace, key, column_path, consistency_level):
return copy.deepcopy(self._get(keyspace, key, column_path, consistency_level))
def get_slice(self, keyspace, key, column_parent, predicate, consistency_level):
key = self.database[keyspace].get(key)
return self._slice_range(key[column_parent.column_family], predicate) if key is not None else []
def multiget_slice(self, keyspace, keys, column_parent, predicate, consistency_level):
return dict((key, self.get_slice(keyspace, key, column_parent, predicate, consistency_level)) for key in keys)
def get_range_slice(self, keyspace, column_parent, predicate, start_key, finish_key, row_count, consistency_level):
keys = self.database[keyspace].key_slice(start_key, finish_key, count=row_count)
return self.multiget_slice(keyspace, keys, column_parent, predicate, consistency_level)
def get_range_slices(self, keyspace, column_parent, predicate, key_range, consistency_level):
return self.get_range_slice(keyspace, column_parent, predicate, key_range.start_key, key_range.end_key, key_range.count, consistency_level)
def multiget_key_range(self, keyspace, column_family, key_ranges, count, consistency_level):
keys = []
for key_range in key_ranges:
key_slice = self.database[keyspace].key_slice(key_range.start_key, key_range.end_key)
keys.extend(key for key in key_slice if len(self.database[keyspace][key][column_family]))
if len(keys) >= count:
break
return keys[:count]
def insert(self, keyspace, key, column_path, value, timestamp, consistency_level):
try:
column = self._get(keyspace, key, column_path, consistency_level)
if timestamp > column.timestamp:
column.value = value
except NotFoundException:
if key not in self.database[keyspace]:
self.database[keyspace][key] = dict((column_family, SSTable()) for column_family in self.keyspaces[keyspace])
self.database[keyspace][key][column_path.column_family][column_path.column] = Column(column_path.column, value, timestamp)
def remove(self, keyspace, key, column_path, timestamp, consistency_level):
key = self.database[keyspace].get(key)
if key is not None:
column_family = key[column_path.column_family] # Should raise misconfiguration exception here if KeyError.
if column_path.column in column_family and timestamp >= column_family[column_path.column].timestamp:
del column_family[column_path.column]
def batch_insert(self, keyspace, key, cfmap, consistency_level):
for column_family, columns in cfmap.iteritems():
for column in columns:
self.insert(
keyspace,
key,
ColumnPath(column_family=column_family, column=column.column.name),
column.column.value,
column.column.timestamp,
consistency_level
)
def batch_mutate(self, keyspace, mutation_map, consistency_level):
for key, cfmap in mutation_map.iteritems():
for column_family, mutations in cfmap.iteritems():
for mutation in mutations:
if mutation.deletion:
if mutation.deletion.predicate.slice_range:
raise Exception('Slice range on batch mutations not supported yet.')
for column in mutation.deletion.predicate.column_names:
self.remove(
keyspace,
key,
ColumnPath(column_family=column_family, column=column),
mutation.deletion.timestamp,
consistency_level
)
else:
column = mutation.column_or_supercolumn.column
self.insert(
keyspace,
key,
ColumnPath(column_family=column_family, column=column.name),
column.value,
column.timestamp,
consistency_level
)
@mager
Copy link

mager commented Apr 22, 2010

whiz (imagine gist whizzing over my head)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment