Created
January 21, 2018 20:26
-
-
Save spitz-dan-l/26402f348433b665332f373a9dda6ba5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from distributed import Client, Queue, Variable, get_client, secede, rejoin, wait, fire_and_forget | |
from time import sleep | |
import gc | |
from functools import partial | |
from collections import Counter, defaultdict, namedtuple | |
from tornado import gen | |
import pandas as pd | |
def filter_dict_by_key_range(dct, key_range): | |
key_start, key_end = key_range | |
result = {} | |
for (k, v) in dct.items(): | |
if (key_start is None or k >= key_start) and (key_end is None or k < key_end): | |
result[k] = v | |
return result | |
def compute_key_range_splits(keys, key_range, n_splits): | |
key_start, key_end = key_range | |
new_key_ranges = [] | |
split_start = key_start | |
for i in range(1, n_splits): | |
split_end_index = i * len(keys) // n_splits | |
split_end = keys[split_end_index] | |
new_key_ranges.append((split_start, split_end)) | |
split_start = split_end | |
new_key_ranges.append((split_start, key_end)) | |
return new_key_ranges | |
@gen.coroutine | |
def queue_put_batch(queue, batch, timeout=None): | |
for x in batch: | |
yield queue._put(x, timeout) | |
class CountAccumulator: | |
data_class = NotImplemented | |
def __init__(self, initial_data=None): | |
if self.data_class is NotImplemented: | |
raise NotImplementedError("you must set the class' data_class attribute to something") | |
if initial_data is None: | |
self.data = self.data_class() | |
else: | |
self.data = self.data_class(initial_data) | |
def update_from_data(self, data): | |
raise NotImplementedError() | |
def update_from_batch(self, batch): | |
for data in batch: | |
self.update_from_data(data) | |
@classmethod | |
def _prepare_data(cls, data): | |
return data | |
def get_prepared_data(self): | |
return self._prepare_data(self.data) | |
def __len__(self): | |
return len(self.data) | |
@classmethod | |
def filter_data_by_key_range(cls, data, key_range): | |
raise NotImplementedError() | |
@classmethod | |
def filter_batch_by_key_range(cls, batch, key_range): | |
result_batch = [] | |
for data in batch: | |
result = cls.filter_data_by_key_range(data, key_range) | |
result_batch.append(result) | |
return result_batch | |
def filter_by_key_range(self, key_range): | |
result1 = self.filter_data_by_key_range(self.data, key_range) | |
return self._prepare_data(result1) | |
def compute_key_range_splits(self, key_range, n_splits): | |
raise NotImplementedError() | |
class CounterAccumulator(CountAccumulator): | |
data_class = Counter | |
def update_from_data(self, data): | |
self.data.update(data) | |
@classmethod | |
def _prepare_data(cls, data): | |
return dict(data) | |
@classmethod | |
def filter_data_by_key_range(cls, data, key_range): | |
return filter_dict_by_key_range(data, key_range) | |
def compute_key_range_splits(self, key_range, n_splits): | |
keys = sorted(self.data.keys()) | |
return compute_key_range_splits(keys, key_range, n_splits) | |
TermAndDocCounts = namedtuple('TermAndDocCounts', ['term_count', 'doc_count']) | |
class TermAndDocAccumulator(CountAccumulator): | |
data_class = partial(defaultdict, (lambda: TermAndDocCounts(0, 0))) | |
def __init__(self, initial_data=None): | |
if initial_data is not None: | |
for k, v in initial_data.items(): | |
initial_data[k] = TermAndDocCounts._make(v) | |
super().__init__(initial_data) | |
def update_from_data(self, other): | |
for k, v in other.items(): | |
current_counts = self.data[k] | |
self.data[k] = TermAndDocCounts(current_counts.term_count + v[0], current_counts.doc_count + v[1]) | |
@classmethod | |
def _prepare_data(cls, data): | |
return {k: tuple(v) for k, v in data.items()} | |
@classmethod | |
def filter_data_by_key_range(cls, data, key_range): | |
return filter_dict_by_key_range(data, key_range) | |
def compute_key_range_splits(self, key_range, n_splits): | |
keys = sorted(self.data.keys()) | |
return compute_key_range_splits(keys, key_range, n_splits) | |
class DataFrameAccumulator(CountAccumulator): | |
data_class = pd.DataFrame | |
def update_from_data(self, other_df): | |
if other_df.size == 0: | |
print('skipping update of empty input') | |
return | |
self.data = self.data.add(other_df, fill_value=0) | |
@classmethod | |
def _consolidate_batch(cls, batch): | |
if isinstance(batch, list): | |
if len(batch) == 1: | |
batch_df = batch[0] | |
else: | |
batch_df = pd.concat(batch) | |
if batch_df.size > 0 and batch_df.index.has_duplicates: | |
batch_df = batch_df.groupby(batch_df.index).sum() | |
else: | |
batch_df = batch | |
return batch_df | |
def update_from_batch(self, batch): | |
batch_df = self._consolidate_batch(batch) | |
self.update_from_data(batch_df) | |
@classmethod | |
def filter_data_by_key_range(cls, data, key_range): | |
if not data.index.is_monotonic_increasing: | |
data.sort_index(inplace=True) | |
start, stop = key_range | |
# check if this range is gonna filter nothing | |
if (start is None or start <= data.index.min()) and (stop is None or stop > data.index.max()): | |
print('skipped a .loc') | |
return data | |
filtered = data.loc[start:stop] | |
if stop is not None and filtered.size > 0 and filtered.index[-1] == stop: | |
filtered = filtered.iloc[:-1] | |
return filtered | |
@classmethod | |
def filter_batch_by_key_range(cls, batch, key_range): | |
batch_df = cls._consolidate_batch(batch) | |
return cls.filter_data_by_key_range(batch_df, key_range) | |
def compute_key_range_splits(self, key_range, n_splits): | |
if not self.data.index.is_monotonic_increasing: | |
self.data.sort_index(inplace=True) | |
return compute_key_range_splits(self.data.index, key_range, n_splits) | |
class DistributedCountAccumulator: | |
MAGIC_SLEEP_NUMBER = 1 | |
counter_class = CounterAccumulator | |
def __init__(self, key_range=(None, None), split_at=500000, n_splits=4, initial_counts=None): | |
self.key_range = key_range | |
self.split_at = split_at | |
self.n_splits = n_splits | |
self.initial_counts = initial_counts | |
def ideal_input_queue_maxsize(self): | |
return self.n_splits | |
def iter_count_batches(self, input_q, stop_v, batches_of_batches=False): | |
while self.should_continue(stop_v, input_q): | |
qs = input_q.qsize() | |
if qs == 0: | |
sleep(self.MAGIC_SLEEP_NUMBER) | |
continue | |
if not batches_of_batches: | |
count_batch_f = input_q.get() | |
yield count_batch_f | |
else: | |
count_batch_fs = input_q.get(batch=qs) | |
yield count_batch_fs | |
def should_continue(self, stop_v, input_q): | |
return not stop_v.get() or input_q.qsize() > 0 | |
def run(self, input_q, output_q, stop_v): | |
if self.initial_counts is None: | |
accum_counts = self.counter_class() | |
else: | |
accum_counts = self.counter_class(self.initial_counts) | |
client = get_client() | |
secede() # We will be spending most of our time seceded from the dask worker pool. | |
for count_batch_f in self.iter_count_batches(input_q, stop_v): | |
count_batch = count_batch_f.result() | |
rejoin() | |
filtered_count_batch = self.counter_class.filter_batch_by_key_range(count_batch, key_range=self.key_range) | |
accum_counts.update_from_batch(filtered_count_batch) | |
# Check if key size exceeds split threshold. | |
if len(accum_counts) > self.split_at: | |
break | |
secede() | |
else: | |
# If we are here, we didn't have to split. | |
result_f, = client.scatter([accum_counts.get_prepared_data()]) # Convert to dict since it serializes way faster. | |
output_q.put(result_f) | |
# Returning the number of separate accumulators. | |
# Since we're here, we didn't split, there's only one of us. | |
return 1 | |
# If we are here, it means we need to split. | |
# We are also currently in the worker pool, which we want right now. | |
child_key_ranges = accum_counts.compute_key_range_splits(self.key_range, self.n_splits) | |
child_stop_v = Variable(client=client) | |
child_stop_v.set(False) | |
child_accumulators = [] | |
child_input_qs = [] | |
for child_key_range in child_key_ranges: | |
child_acc_f, child_input_q = self.submit_child_accumulator(client, accum_counts, child_key_range, output_q, child_stop_v) | |
child_accumulators.append(child_acc_f) | |
child_input_qs.append(child_input_q) | |
#free up memory now that we've passed our counts down to our children | |
del accum_counts | |
gc.collect() | |
secede() # Now that we have constructed the child accumulators, we go into multiplexer-mode and sit in the background. | |
for count_batch_fs in self.iter_count_batches(input_q, stop_v, batches_of_batches=True): | |
for child_input_q in child_input_qs: | |
client.sync(queue_put_batch, child_input_q, count_batch_fs) | |
child_stop_v.set(True) | |
return sum(client.gather(child_accumulators)) | |
def submit_child_accumulator(self, client, accum_counts, child_key_range, output_q, child_stop_v): | |
child_initial_count = accum_counts.filter_by_key_range(child_key_range) | |
child_input_q = Queue(client=client, maxsize=self.ideal_input_queue_maxsize()) | |
child_acc = self.__class__( | |
child_key_range, | |
self.split_at, | |
self.n_splits, | |
child_initial_count | |
) | |
child_acc_f = client.submit( | |
child_acc.run, | |
child_input_q, | |
output_q, | |
child_stop_v | |
) | |
return child_acc_f, child_input_q | |
class DistributedTermAndDocCountAccumulator(DistributedCountAccumulator): | |
counter_class = TermAndDocAccumulator | |
class DistributedDataFrameAccumulator(DistributedCountAccumulator): | |
counter_class = DataFrameAccumulator | |
if __name__ == '__main__': | |
import uuid | |
import random | |
from time import time | |
from nltk.corpus import wordnet | |
words = list(wordnet.words()) | |
def random_key(): | |
return random.choice(words) + '_' + random.choice(words) + '_' + random.choice(words) | |
def random_count(): | |
return random.randint(1, 20) | |
def random_word_counts(n_keys_min=5, n_keys_max=50): | |
n_keys = random.randint(n_keys_min, n_keys_max) | |
result = {} | |
for i in range(n_keys): | |
result[random_key()] = random_count() | |
return result | |
def random_word_counts_chunk(chunksize=100): | |
result = [] | |
for i in range(chunksize): | |
result.append(random_word_counts()) | |
return result | |
def random_term_and_doc_counts(n_keys_min=5, n_keys_max=50): | |
n_keys = random.randint(n_keys_min, n_keys_max) | |
result = {} | |
for i in range(n_keys): | |
result[random_key()] = tuple(TermAndDocCounts(term_count=random_count(), doc_count=1)) | |
return result | |
def random_term_and_doc_counts_chunk(chunksize=1000): | |
result = [] | |
for i in range(chunksize): | |
result.append(random_term_and_doc_counts()) | |
return result | |
def random_term_doc_counts_df(n_keys_min=5, n_keys_max=50): | |
dct = random_term_and_doc_counts(n_keys_min, n_keys_max) | |
df = pd.DataFrame.from_dict(dct, orient='index') | |
df.columns = ['term_count', 'doc_count'] | |
return df | |
def random_term_doc_counts_df_chunk(chunksize=1000): | |
result = [] | |
for i in range(chunksize): | |
result.append(random_term_doc_counts_df()) | |
result_df = pd.concat(result) | |
return [result_df.groupby(result_df.index).sum()] | |
print('starting the cluster') | |
client = Client('127.0.0.1:8786') | |
accumulator = DistributedDataFrameAccumulator(split_at=100000) | |
input_q = Queue(client=client, maxsize=accumulator.ideal_input_queue_maxsize()) | |
output_q = Queue(client=client) | |
stop_v = Variable(client=client) | |
stop_v.set(False) | |
accumulator_f = client.submit(accumulator.run, input_q, output_q, stop_v) | |
start = time() | |
orig_inputs = [] | |
print('populating the queue') | |
for i in range(100): | |
f = client.submit(random_term_doc_counts_df_chunk, pure=False) | |
input_q.put(f) | |
print('added one to the queue', i) | |
stop_v.set(True) | |
print('waiting for accumulators') | |
total_accums = accumulator_f.result() | |
end = time() | |
print('done') | |
results = [] | |
while output_q.qsize() > 0: | |
results.append(output_q.get()) | |
result_key_set = set() | |
for res in results: | |
result_key_set.update(res.result().index) | |
print('total keys:', len(result_key_set)) | |
print('total accumulators:', total_accums) | |
print('accumulation took:', end - start) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment