Skip to content

Instantly share code, notes, and snippets.

@spitz-dan-l
Created January 21, 2018 20:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save spitz-dan-l/26402f348433b665332f373a9dda6ba5 to your computer and use it in GitHub Desktop.
Save spitz-dan-l/26402f348433b665332f373a9dda6ba5 to your computer and use it in GitHub Desktop.
from distributed import Client, Queue, Variable, get_client, secede, rejoin, wait, fire_and_forget
from time import sleep
import gc
from functools import partial
from collections import Counter, defaultdict, namedtuple
from tornado import gen
import pandas as pd
def filter_dict_by_key_range(dct, key_range):
key_start, key_end = key_range
result = {}
for (k, v) in dct.items():
if (key_start is None or k >= key_start) and (key_end is None or k < key_end):
result[k] = v
return result
def compute_key_range_splits(keys, key_range, n_splits):
key_start, key_end = key_range
new_key_ranges = []
split_start = key_start
for i in range(1, n_splits):
split_end_index = i * len(keys) // n_splits
split_end = keys[split_end_index]
new_key_ranges.append((split_start, split_end))
split_start = split_end
new_key_ranges.append((split_start, key_end))
return new_key_ranges
@gen.coroutine
def queue_put_batch(queue, batch, timeout=None):
for x in batch:
yield queue._put(x, timeout)
class CountAccumulator:
data_class = NotImplemented
def __init__(self, initial_data=None):
if self.data_class is NotImplemented:
raise NotImplementedError("you must set the class' data_class attribute to something")
if initial_data is None:
self.data = self.data_class()
else:
self.data = self.data_class(initial_data)
def update_from_data(self, data):
raise NotImplementedError()
def update_from_batch(self, batch):
for data in batch:
self.update_from_data(data)
@classmethod
def _prepare_data(cls, data):
return data
def get_prepared_data(self):
return self._prepare_data(self.data)
def __len__(self):
return len(self.data)
@classmethod
def filter_data_by_key_range(cls, data, key_range):
raise NotImplementedError()
@classmethod
def filter_batch_by_key_range(cls, batch, key_range):
result_batch = []
for data in batch:
result = cls.filter_data_by_key_range(data, key_range)
result_batch.append(result)
return result_batch
def filter_by_key_range(self, key_range):
result1 = self.filter_data_by_key_range(self.data, key_range)
return self._prepare_data(result1)
def compute_key_range_splits(self, key_range, n_splits):
raise NotImplementedError()
class CounterAccumulator(CountAccumulator):
data_class = Counter
def update_from_data(self, data):
self.data.update(data)
@classmethod
def _prepare_data(cls, data):
return dict(data)
@classmethod
def filter_data_by_key_range(cls, data, key_range):
return filter_dict_by_key_range(data, key_range)
def compute_key_range_splits(self, key_range, n_splits):
keys = sorted(self.data.keys())
return compute_key_range_splits(keys, key_range, n_splits)
TermAndDocCounts = namedtuple('TermAndDocCounts', ['term_count', 'doc_count'])
class TermAndDocAccumulator(CountAccumulator):
data_class = partial(defaultdict, (lambda: TermAndDocCounts(0, 0)))
def __init__(self, initial_data=None):
if initial_data is not None:
for k, v in initial_data.items():
initial_data[k] = TermAndDocCounts._make(v)
super().__init__(initial_data)
def update_from_data(self, other):
for k, v in other.items():
current_counts = self.data[k]
self.data[k] = TermAndDocCounts(current_counts.term_count + v[0], current_counts.doc_count + v[1])
@classmethod
def _prepare_data(cls, data):
return {k: tuple(v) for k, v in data.items()}
@classmethod
def filter_data_by_key_range(cls, data, key_range):
return filter_dict_by_key_range(data, key_range)
def compute_key_range_splits(self, key_range, n_splits):
keys = sorted(self.data.keys())
return compute_key_range_splits(keys, key_range, n_splits)
class DataFrameAccumulator(CountAccumulator):
data_class = pd.DataFrame
def update_from_data(self, other_df):
if other_df.size == 0:
print('skipping update of empty input')
return
self.data = self.data.add(other_df, fill_value=0)
@classmethod
def _consolidate_batch(cls, batch):
if isinstance(batch, list):
if len(batch) == 1:
batch_df = batch[0]
else:
batch_df = pd.concat(batch)
if batch_df.size > 0 and batch_df.index.has_duplicates:
batch_df = batch_df.groupby(batch_df.index).sum()
else:
batch_df = batch
return batch_df
def update_from_batch(self, batch):
batch_df = self._consolidate_batch(batch)
self.update_from_data(batch_df)
@classmethod
def filter_data_by_key_range(cls, data, key_range):
if not data.index.is_monotonic_increasing:
data.sort_index(inplace=True)
start, stop = key_range
# check if this range is gonna filter nothing
if (start is None or start <= data.index.min()) and (stop is None or stop > data.index.max()):
print('skipped a .loc')
return data
filtered = data.loc[start:stop]
if stop is not None and filtered.size > 0 and filtered.index[-1] == stop:
filtered = filtered.iloc[:-1]
return filtered
@classmethod
def filter_batch_by_key_range(cls, batch, key_range):
batch_df = cls._consolidate_batch(batch)
return cls.filter_data_by_key_range(batch_df, key_range)
def compute_key_range_splits(self, key_range, n_splits):
if not self.data.index.is_monotonic_increasing:
self.data.sort_index(inplace=True)
return compute_key_range_splits(self.data.index, key_range, n_splits)
class DistributedCountAccumulator:
MAGIC_SLEEP_NUMBER = 1
counter_class = CounterAccumulator
def __init__(self, key_range=(None, None), split_at=500000, n_splits=4, initial_counts=None):
self.key_range = key_range
self.split_at = split_at
self.n_splits = n_splits
self.initial_counts = initial_counts
def ideal_input_queue_maxsize(self):
return self.n_splits
def iter_count_batches(self, input_q, stop_v, batches_of_batches=False):
while self.should_continue(stop_v, input_q):
qs = input_q.qsize()
if qs == 0:
sleep(self.MAGIC_SLEEP_NUMBER)
continue
if not batches_of_batches:
count_batch_f = input_q.get()
yield count_batch_f
else:
count_batch_fs = input_q.get(batch=qs)
yield count_batch_fs
def should_continue(self, stop_v, input_q):
return not stop_v.get() or input_q.qsize() > 0
def run(self, input_q, output_q, stop_v):
if self.initial_counts is None:
accum_counts = self.counter_class()
else:
accum_counts = self.counter_class(self.initial_counts)
client = get_client()
secede() # We will be spending most of our time seceded from the dask worker pool.
for count_batch_f in self.iter_count_batches(input_q, stop_v):
count_batch = count_batch_f.result()
rejoin()
filtered_count_batch = self.counter_class.filter_batch_by_key_range(count_batch, key_range=self.key_range)
accum_counts.update_from_batch(filtered_count_batch)
# Check if key size exceeds split threshold.
if len(accum_counts) > self.split_at:
break
secede()
else:
# If we are here, we didn't have to split.
result_f, = client.scatter([accum_counts.get_prepared_data()]) # Convert to dict since it serializes way faster.
output_q.put(result_f)
# Returning the number of separate accumulators.
# Since we're here, we didn't split, there's only one of us.
return 1
# If we are here, it means we need to split.
# We are also currently in the worker pool, which we want right now.
child_key_ranges = accum_counts.compute_key_range_splits(self.key_range, self.n_splits)
child_stop_v = Variable(client=client)
child_stop_v.set(False)
child_accumulators = []
child_input_qs = []
for child_key_range in child_key_ranges:
child_acc_f, child_input_q = self.submit_child_accumulator(client, accum_counts, child_key_range, output_q, child_stop_v)
child_accumulators.append(child_acc_f)
child_input_qs.append(child_input_q)
#free up memory now that we've passed our counts down to our children
del accum_counts
gc.collect()
secede() # Now that we have constructed the child accumulators, we go into multiplexer-mode and sit in the background.
for count_batch_fs in self.iter_count_batches(input_q, stop_v, batches_of_batches=True):
for child_input_q in child_input_qs:
client.sync(queue_put_batch, child_input_q, count_batch_fs)
child_stop_v.set(True)
return sum(client.gather(child_accumulators))
def submit_child_accumulator(self, client, accum_counts, child_key_range, output_q, child_stop_v):
child_initial_count = accum_counts.filter_by_key_range(child_key_range)
child_input_q = Queue(client=client, maxsize=self.ideal_input_queue_maxsize())
child_acc = self.__class__(
child_key_range,
self.split_at,
self.n_splits,
child_initial_count
)
child_acc_f = client.submit(
child_acc.run,
child_input_q,
output_q,
child_stop_v
)
return child_acc_f, child_input_q
class DistributedTermAndDocCountAccumulator(DistributedCountAccumulator):
counter_class = TermAndDocAccumulator
class DistributedDataFrameAccumulator(DistributedCountAccumulator):
counter_class = DataFrameAccumulator
if __name__ == '__main__':
import uuid
import random
from time import time
from nltk.corpus import wordnet
words = list(wordnet.words())
def random_key():
return random.choice(words) + '_' + random.choice(words) + '_' + random.choice(words)
def random_count():
return random.randint(1, 20)
def random_word_counts(n_keys_min=5, n_keys_max=50):
n_keys = random.randint(n_keys_min, n_keys_max)
result = {}
for i in range(n_keys):
result[random_key()] = random_count()
return result
def random_word_counts_chunk(chunksize=100):
result = []
for i in range(chunksize):
result.append(random_word_counts())
return result
def random_term_and_doc_counts(n_keys_min=5, n_keys_max=50):
n_keys = random.randint(n_keys_min, n_keys_max)
result = {}
for i in range(n_keys):
result[random_key()] = tuple(TermAndDocCounts(term_count=random_count(), doc_count=1))
return result
def random_term_and_doc_counts_chunk(chunksize=1000):
result = []
for i in range(chunksize):
result.append(random_term_and_doc_counts())
return result
def random_term_doc_counts_df(n_keys_min=5, n_keys_max=50):
dct = random_term_and_doc_counts(n_keys_min, n_keys_max)
df = pd.DataFrame.from_dict(dct, orient='index')
df.columns = ['term_count', 'doc_count']
return df
def random_term_doc_counts_df_chunk(chunksize=1000):
result = []
for i in range(chunksize):
result.append(random_term_doc_counts_df())
result_df = pd.concat(result)
return [result_df.groupby(result_df.index).sum()]
print('starting the cluster')
client = Client('127.0.0.1:8786')
accumulator = DistributedDataFrameAccumulator(split_at=100000)
input_q = Queue(client=client, maxsize=accumulator.ideal_input_queue_maxsize())
output_q = Queue(client=client)
stop_v = Variable(client=client)
stop_v.set(False)
accumulator_f = client.submit(accumulator.run, input_q, output_q, stop_v)
start = time()
orig_inputs = []
print('populating the queue')
for i in range(100):
f = client.submit(random_term_doc_counts_df_chunk, pure=False)
input_q.put(f)
print('added one to the queue', i)
stop_v.set(True)
print('waiting for accumulators')
total_accums = accumulator_f.result()
end = time()
print('done')
results = []
while output_q.qsize() > 0:
results.append(output_q.get())
result_key_set = set()
for res in results:
result_key_set.update(res.result().index)
print('total keys:', len(result_key_set))
print('total accumulators:', total_accums)
print('accumulation took:', end - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment