Last active
August 29, 2015 14:02
-
-
Save michaelplaing/37d89c8f5f09ae779e47 to your computer and use it in GitHub Desktop.
bm_copy.py: pre-release fast copy app for Cassandra tables
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
bm_copy | |
Author: Michael P Laing, michael.laing@nytimes.com | |
Date: 2014-06-08 | |
A preview version from the nyt⨍aбrik 'rabbit_helpers' framework. | |
Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ | |
""" | |
import os | |
import sys | |
import json | |
import uuid | |
import logging | |
import argparse | |
import traceback | |
import multiprocessing | |
from time import time, sleep | |
from datetime import datetime | |
from collections import deque | |
from threading import Event, Lock | |
from multiprocessing import Manager, Process | |
from cassandra.cluster import Cluster | |
from cassandra import ConsistencyLevel | |
from cassandra.policies import ( | |
RetryPolicy, | |
RoundRobinPolicy, | |
TokenAwarePolicy, | |
DCAwareRoundRobinPolicy, | |
DowngradingConsistencyRetryPolicy | |
) | |
LOG_FORMAT = ( | |
'%(levelname) -10s %(asctime)s %(name) -30s %(process)d ' | |
'%(funcName) -35s %(lineno) -5d: %(message)s' | |
) | |
JSON_DUMPS_ARGS = {'ensure_ascii': False, 'indent': 4} | |
class CopyService(object): | |
def __init__(self, args, token_range, logger): | |
self._logger = logger | |
self._logger.name = __name__ | |
self._logger.debug("") | |
self._args = { | |
"source": { | |
k[7:]: v for k, v in args.__dict__.items() | |
if k[:7] == "source_" | |
}, | |
"dest": { | |
k[5:]: v for k, v in args.__dict__.items() | |
if k[:5] == "dest_" | |
} | |
} | |
for source_or_dest in ['source', 'dest']: | |
self._args[source_or_dest]['consistency_level'] = ( | |
ConsistencyLevel.name_to_value[ | |
self._args[source_or_dest]['consistency_level'] | |
] | |
) | |
self._dest_token_aware = args.dest_token_aware | |
self._fetch_size = args.fetch_size | |
self._concurrency = args.concurrency | |
self._throttle_rate = args.throttle_rate | |
self._worker_count = args.worker_count | |
self._token_range = token_range | |
self._page_rate = ( | |
float(self._throttle_rate) | |
/ (self._fetch_size * self._worker_count) | |
) | |
self._cql_cluster = {} | |
self._cql_session = {} | |
self._future = None # source future | |
self._finished_event = Event() | |
self._lock = Lock() | |
self._concurrent_updates = 0 | |
self._stopped = False | |
self._page = 0 | |
self._row_count = 0 | |
self._rows = deque() | |
self._start_time = 0 | |
self._stop_time = 0 | |
self._os_times_start = None | |
self._os_times_stop = None | |
self._query = {} | |
u""" | |
-- my source_data and source_data_copy tables look like this | |
-- | |
-- use your own tables and modify: | |
-- . source stmt | |
-- . dest stmt | |
-- . map_fields method | |
-- . map_routing_key method (if using token_aware option) | |
-- | |
CREATE TABLE source_data ( | |
hash_key text, | |
message_id timeuuid, | |
body blob, | |
metadata text, | |
PRIMARY KEY (hash_key, message_id) | |
); | |
""" | |
self._stmt = { | |
'source': u""" -- these ?'s will be filled in automatically | |
SELECT * -- Note: token range is inclusive | |
FROM benchmark.source_data | |
WHERE TOKEN(hash_key) >= ? AND TOKEN(hash_key) <= ? | |
""", | |
'dest': u""" -- these ?'s must be mapped by 'map_fields' below | |
UPDATE benchmark.source_data_copy | |
SET metadata = ?, body = ? | |
WHERE hash_key = ? AND message_id = ? | |
""" | |
} | |
def map_fields(self, source_row): | |
self._logger.debug("") | |
return ( # return a tuple in the order of the dest ?'s above | |
source_row.metadata, | |
source_row.body, | |
source_row.hash_key, | |
source_row.message_id | |
) | |
def map_routing_key(self, source_row): # required for token_aware | |
self._logger.debug("") | |
return (source_row.hash_key,) # return a tuple of the partition key | |
def fetch_now_or_later(self): | |
actual = time() - self._start_time | |
target = self._page / self._page_rate | |
self._logger.debug( | |
"actual: {0}; target: {1}; diff: {2}".format( | |
actual, target, target - actual | |
) | |
) | |
if target > actual: # fetching faster than target? | |
sleep(target - actual) # sleep until actual == target (overshoot) | |
self._future.start_fetching_next_page() # now fetch | |
def update_or_finish(self, _): | |
self._logger.debug( | |
"len(self._rows): {0}; " | |
"self._concurrent_updates: {1}; ".format( | |
len(self._rows), | |
self._concurrent_updates | |
) | |
) | |
try: | |
self._row_count += 1 | |
self._concurrent_updates -= 1 | |
# is work available or queued? | |
if self._future.has_more_pages or self._rows: | |
# is the current row within the latest page? | |
if (self._row_count - 1) / self._fetch_size + 1 > self._page: | |
# if so, maybe prefetch another page | |
self._logger.info( | |
'page: {0}; row_count: {1}'.format( | |
self._page, self._row_count | |
) | |
) | |
# is more work available? | |
if self._future.has_more_pages: | |
if self._throttle_rate: # is throttling in effect? | |
self.fetch_now_or_later() # maybe delay fetching | |
else: # fetch now while processing continues async | |
self._future.start_fetching_next_page() | |
self._page += 1 | |
if self._rows: # is work queued? | |
self.update_dest_table() # process it | |
elif self._concurrent_updates: # is work in progress? | |
pass # wait for callbacks | |
else: # work is all done | |
self.finish() | |
except Exception as exc: | |
self.stop_and_raise(exc) | |
def update_dest_table(self): | |
self._logger.debug("") | |
with self._lock: # called from multiple threads, so use lock | |
while self._rows: # is work queued? | |
self._logger.debug( | |
"len(self._rows): {0}; " | |
"self._concurrent_updates: {1}; " | |
"self._concurrency: {2}".format( | |
len(self._rows), | |
self._concurrent_updates, | |
self._concurrency | |
) | |
) | |
# are we at the limit of concurrency? | |
if self._concurrent_updates >= self._concurrency: | |
break # enough work is in progress | |
else: # submit more work | |
self._concurrent_updates += 1 | |
row = self._rows.pop() | |
subvars = self.map_fields(row) | |
if self._dest_token_aware: | |
self._query[ | |
'dest' | |
].routing_key = self.map_routing_key(row) | |
future = self._cql_session[ | |
'dest' | |
].execute_async(self._query['dest'], subvars) | |
future.add_callback(self.update_or_finish) | |
future.add_errback(self.stop_and_raise) | |
def stop_and_raise(self, exc): | |
self._logger.debug("") | |
error_msg = 'traceback: {}'.format(traceback.format_exc(exc)) | |
self._logger.error(error_msg) | |
self.stop() | |
raise exc | |
def more_rows_or_finish(self, new_rows): | |
self._logger.debug( | |
"len(new_rows): {0}; " | |
"len(self._rows): {1}; " | |
"self._concurrent_updates: {2}; ".format( | |
len(new_rows), len(self._rows), self._concurrent_updates | |
) | |
) | |
try: | |
if new_rows: # is there new work? | |
self._rows.extend(new_rows) # extend the work queue | |
self.update_dest_table() # process it | |
elif self._rows: # is work queued? | |
self.update_dest_table() # process it | |
elif self._concurrent_updates: # is work in progress? | |
pass # wait for callbacks | |
else: # work is all done | |
self.finish() | |
except Exception as exc: | |
self.stop_and_raise(exc) | |
def select_from_source_table(self): | |
self._logger.debug("") | |
self._start_time = time() | |
self._os_times_start = os.times() | |
try: | |
self._logger.info( | |
"self._token_range: {}".format(self._token_range) | |
) | |
self._future = self._cql_session[ | |
'source' | |
].execute_async(self._query['source'], self._token_range) | |
self._future.add_callback(self.more_rows_or_finish) | |
self._future.add_errback(self.stop_and_raise) | |
except Exception as exc: | |
self.stop_and_raise(exc) | |
def stop(self): | |
self._logger.debug("") | |
if self._stopped: | |
return | |
self._stop_time = time() | |
self._os_times_stop = os.times() | |
self._logger.info("Stopping service.") | |
for source_or_dest in ['source', 'dest']: | |
if self._cql_cluster[source_or_dest]: | |
try: | |
self._cql_cluster[source_or_dest].shutdown() | |
except Exception as exc: | |
error_msg = traceback.format_exc(exc) | |
self._logger.info( | |
"Exception on cql_cluster.shutdown(): {0}".format( | |
error_msg | |
) | |
) | |
raise | |
def connection(self, source_or_dest): | |
params = self._args[source_or_dest] | |
self._logger.debug( | |
"source_or_dest: {0}; params: {1}".format( | |
source_or_dest, params | |
) | |
) | |
try: | |
# defaults | |
load_balancing_policy = RoundRobinPolicy() | |
retry_policy = RetryPolicy() | |
if params['dc_aware']: | |
load_balancing_policy = DCAwareRoundRobinPolicy( | |
params['local_dc'], | |
used_hosts_per_remote_dc=params['remote_dc_hosts'] | |
) | |
if params['token_aware']: | |
load_balancing_policy = TokenAwarePolicy( | |
load_balancing_policy | |
) | |
if params['retry']: | |
retry_policy = DowngradingConsistencyRetryPolicy() | |
self._cql_cluster[source_or_dest] = Cluster( | |
params['cql_host_list'], | |
load_balancing_policy=load_balancing_policy, | |
default_retry_policy=retry_policy | |
) | |
self._cql_session[source_or_dest] = self._cql_cluster[ | |
source_or_dest | |
].connect() | |
if source_or_dest == 'source': | |
self._cql_session[ | |
'source' | |
].default_fetch_size = self._fetch_size | |
self._query[source_or_dest] = self._cql_session[ | |
source_or_dest | |
].prepare(self._stmt[source_or_dest]) | |
self._query[source_or_dest].consistency_level = params[ | |
'consistency_level' | |
] | |
except Exception as exc: | |
error_msg = 'Cassandra init error; traceback: {}'.format( | |
traceback.format_exc(exc) | |
) | |
self._logger.error(error_msg) | |
self.stop() | |
raise | |
self._logger.info('Connected to Cassandra - {}'.format(source_or_dest)) | |
def finish(self): | |
self._logger.info("Finished") | |
self._finished_event.set() | |
def run(self): | |
self._logger.info("Starting service") | |
self.connection('source') | |
self.connection('dest') | |
self.select_from_source_table() | |
self._finished_event.wait() | |
self._logger.info("Stopping service") | |
def main(args, worker_index, tokens, rate, results, logger): | |
logger.info("Initializing...") | |
exitcode = 0 | |
try: | |
service = CopyService(args, tokens, logger) | |
service.run() | |
except KeyboardInterrupt: | |
logger.info("Service terminated by SIGINT.") | |
except Exception as exc: | |
error_msg = traceback.format_exc(exc) | |
logger.error("Runtime exception: {0}".format(error_msg)) | |
exitcode = 1 | |
service.stop() | |
logger.info("Terminated.") | |
if service._row_count: | |
elapsed = service._stop_time - service._start_time | |
rate = int(service._row_count // elapsed) | |
else: | |
elapsed = 0 | |
rate = 0 | |
os_times_diff = map( | |
lambda stop, start: stop - start, | |
service._os_times_stop, | |
service._os_times_start | |
) | |
results[worker_index] = { # add results to the shared special dict | |
'row_count': service._row_count, | |
'start_time': service._start_time, | |
'stop_time': service._stop_time, | |
'elapsed': elapsed, | |
'rate': rate, | |
'os_times_stop': service._os_times_stop, | |
'os_times_start': service._os_times_start, | |
'os_times_diff': os_times_diff | |
} | |
sys.exit(exitcode) | |
def print_results(results): | |
print("\nWorker Rows Elapsed Rows/sec") | |
for worker_index in sorted(results.keys()): | |
r = results[worker_index] | |
print( | |
"{0:6d}{1:10d}{2:9.3f}{3:11d}".format( | |
worker_index, r['row_count'], r['elapsed'], r['rate'] | |
) | |
) | |
print("CPU: {:4.0f}%".format(results[99]['cpu'] * 100)) | |
def print_arguments(args): | |
print("arguments:") | |
for k, v in sorted(args.__dict__.items()): | |
print(" {0}: {1}".format(k, v)) | |
def print_json(args, results): | |
output = { | |
"uuid": str(uuid.uuid1()), | |
"timestamp": datetime.utcnow().isoformat() + 'Z', | |
"args": args.__dict__, | |
"results": results | |
} | |
print(json.dumps(output, **JSON_DUMPS_ARGS)) | |
def analyze_results(results, os_times_stop, os_times_start): | |
worker_indices = [worker_index for worker_index in results.keys()] | |
row_count = sum([ | |
results[worker_index]['row_count'] | |
for worker_index in worker_indices | |
]) | |
start_time = min([ | |
results[worker_index]['start_time'] | |
for worker_index in worker_indices | |
if results[worker_index]['start_time'] != 0 | |
]) | |
stop_time = max([ | |
results[worker_index]['stop_time'] | |
for worker_index in worker_indices | |
if results[worker_index]['start_time'] != 0 | |
]) | |
if row_count: | |
elapsed = stop_time - start_time | |
rate = int(row_count // elapsed) | |
else: | |
elapsed = 0 | |
rate = 0 | |
os_times_diff = map( | |
lambda stop, start: stop - start, os_times_stop, os_times_start | |
) | |
cpu = (os_times_diff[2] + os_times_diff[3]) / os_times_diff[4] | |
results[99] = { # add summary results | |
'row_count': row_count, | |
'start_time': start_time, | |
'stop_time': stop_time, | |
'elapsed': elapsed, | |
'rate': rate, | |
'cpu': cpu, | |
'os_times_stop': os_times_stop, | |
'os_times_start': os_times_start, | |
'os_times_diff': os_times_diff | |
} | |
return {k: v for k, v in results.items()} # return a standard dict | |
def multiprocess(args): | |
step = ((args.token_hi - args.token_lo) / args.worker_count) + 1 | |
tr1 = range(args.token_lo, args.token_hi, step) # intermediate points | |
tr2 = [(t, t + 1) for t in tr1[1:]] # pairs of adjacent points | |
tr3 = [t for st in tr2 for t in st] # flatten | |
tr4 = [args.token_lo] + tr3 + [args.token_hi] # add end points | |
token_ranges = [tr4[i:i + 2] for i in range(0, len(tr4), 2)] # make pairs | |
rate = args.throttle_rate / args.worker_count | |
formatter = logging.Formatter(LOG_FORMAT) | |
handler = logging.StreamHandler() | |
handler.setFormatter(formatter) | |
logger = multiprocessing.get_logger() | |
logger.addHandler(handler) | |
level = logging.getLevelName(args.logging_level) # undocumented feature | |
logger.setLevel(level) | |
manager = Manager() | |
results = manager.dict() # create a special shared dict to gather results | |
workers = [ | |
Process( | |
target=main, | |
args=( | |
args, worker_index, token_ranges[worker_index], rate, results, | |
logger | |
) | |
) | |
for worker_index in range(args.worker_count) | |
] | |
os_times_start = os.times() | |
for worker in workers: | |
worker.start() | |
for worker in workers: | |
worker.join() | |
os_times_stop = os.times() | |
exitcode = 0 | |
for worker in workers: | |
if worker.exitcode: | |
exitcode = worker.exitcode | |
break | |
if results: | |
results_dict = analyze_results(results, os_times_stop, os_times_start) | |
if args.json_output: | |
print_json(args, results_dict) | |
else: | |
print_arguments(args) | |
print_results(results_dict) | |
return(exitcode) | |
if __name__ == "__main__": | |
description = """ | |
Copy/transform rows from one Cassandra table to another in the same or | |
different clusters. | |
""" | |
parser = argparse.ArgumentParser(description=description) | |
parser.add_argument( | |
"--source-cql-host-list", | |
default=['localhost'], | |
dest="source_cql_host_list", | |
nargs='*', | |
metavar='CQL_HOST', | |
help=( | |
"source: the initial cql hosts to contact " | |
"(default = ['localhost'])" | |
) | |
) | |
parser.add_argument( | |
"--source-dc-aware", | |
dest="source_dc_aware", | |
action="store_true", | |
help="source: favor hosts in the local datacenter (default = False)" | |
) | |
parser.add_argument( | |
"--source-local-dc", | |
dest="source_local_dc", | |
default='datacenter1', | |
help=( | |
"source: if dc_aware, the local datacenter " | |
"(default = 'datacenter1')" | |
) | |
) | |
parser.add_argument( | |
"--source-remote-dc-hosts", | |
type=int, | |
default=0, | |
dest="source_remote_dc_hosts", | |
help=( | |
"source: if dc_aware, the number of hosts to connect to as " | |
"remote hosts (default = 0)" | |
) | |
) | |
parser.add_argument( | |
"--source-token-aware", | |
dest="source_token_aware", | |
action="store_true", | |
help=( | |
"source: route queries to known replicas by the tokens of their " | |
"partition keys (default = False)" | |
) | |
) | |
parser.add_argument( | |
"--source-consistency-level", | |
dest="source_consistency_level", | |
default='ONE', | |
choices=[ | |
'ONE', 'TWO', 'THREE', 'QUORUM', 'ALL', 'LOCAL_QUORUM', | |
'EACH_QUORUM', 'SERIAL', 'LOCAL_SERIAL', 'LOCAL_ONE' | |
], | |
help="dest: consistency level (default = 'ONE')" | |
) | |
parser.add_argument( | |
"--source-retry", | |
dest="source_retry", | |
action="store_true", | |
help="source: downgrade consistency level and retry (default = False)" | |
) | |
parser.add_argument( | |
"--dest-cql-host-list", | |
default=['localhost'], | |
dest="dest_cql_host_list", | |
nargs='*', | |
metavar='CQL_HOST', | |
help=( | |
"dest: the initial cql hosts to contact " | |
"(default = ['localhost'])" | |
) | |
) | |
parser.add_argument( | |
"--dest-dc-aware", | |
dest="dest_dc_aware", | |
action="store_true", | |
help="dest: favor hosts in the local datacenter (default = False)" | |
) | |
parser.add_argument( | |
"--dest-local-dc", | |
default='datacenter1', | |
dest="dest_local_dc", | |
help=( | |
"dest: if dc_aware, the local datacenter " | |
"(default = 'datacenter1')" | |
) | |
) | |
parser.add_argument( | |
"--dest-remote-dc-hosts", | |
type=int, | |
default=0, | |
dest="dest_remote_dc_hosts", | |
help=( | |
"dest: if dc_aware, the number of hosts to be connected to as " | |
"remote hosts (default = 0)" | |
) | |
) | |
parser.add_argument( | |
"--dest-token-aware", | |
dest="dest_token_aware", | |
action="store_true", | |
help=( | |
"dest: route queries to known replicas by the tokens of their " | |
"partition keys (default = False)" | |
) | |
) | |
parser.add_argument( | |
"--dest-consistency-level", | |
dest="dest_consistency_level", | |
default='ONE', | |
choices=[ | |
'ANY', 'ONE', 'TWO', 'THREE', 'QUORUM', 'ALL', 'LOCAL_QUORUM', | |
'EACH_QUORUM', 'SERIAL', 'LOCAL_SERIAL', 'LOCAL_ONE' | |
], | |
help="dest: consistency level (default = 'ONE')" | |
) | |
parser.add_argument( | |
"--dest-retry", | |
dest="dest_retry", | |
action="store_true", | |
help="dest: downgrade consistency level and retry (default = False)" | |
) | |
parser.add_argument( | |
"--token-hi", | |
type=int, | |
default=2 ** 63 - 1, | |
dest="token_hi", | |
help=( | |
"the high token in the inclusive partition key token range to " | |
"split among worker processes (default = 2 ** 63 - 1)" | |
) | |
) | |
parser.add_argument( | |
"--token-lo", | |
type=int, | |
default=-2 ** 63, | |
dest="token_lo", | |
help=( | |
"the low token in the inclusive partition key token range to " | |
"split among worker processes (default = -2 ** 63)" | |
) | |
) | |
parser.add_argument( | |
"-c", | |
"--concurrency", | |
type=int, | |
default=50, | |
dest="concurrency", | |
help=( | |
"the number of updates to launch concurrently in each process " | |
"using callback chaining (default = 50)" | |
) | |
) | |
parser.add_argument( | |
"-w", | |
"--worker-count", | |
type=int, | |
default=2, | |
dest="worker_count", | |
help=( | |
"the number of asynchronous worker processes to spawn - each " | |
"will handle an equal range of partition key tokens (default = 2)" | |
) | |
) | |
parser.add_argument( | |
"-f", | |
"--fetch-size", | |
type=int, | |
default=1000, | |
dest="fetch_size", | |
help="the number of rows to fetch in each page (default = 1000)" | |
) | |
parser.add_argument( | |
"-t", | |
"--throttle-rate", | |
type=int, | |
default=1000, | |
dest="throttle_rate", | |
help=( | |
"the aggregate rows per second to target across all workers, " | |
"for unlimited use 0 (default = 1000)" | |
) | |
) | |
parser.add_argument( | |
"-j", | |
"--json-output", | |
dest="json_output", | |
action="store_true", | |
help=( | |
"suppress formatted printing; output args and results as json " | |
"(default = False)" | |
) | |
) | |
parser.add_argument( | |
"-l", | |
"--logging-level", | |
dest="logging_level", | |
default='INFO', | |
choices=['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], | |
help="logging level (default = 'INFO')" | |
) | |
parser.set_defaults(source_dc_aware=False) | |
parser.set_defaults(source_token_aware=False) | |
parser.set_defaults(source_retry=False) | |
parser.set_defaults(dest_dc_aware=False) | |
parser.set_defaults(dest_token_aware=False) | |
parser.set_defaults(dest_retry=False) | |
parser.set_defaults(json_output=False) | |
args = parser.parse_args(args=sys.argv[1:]) | |
sys.exit(multiprocess(args)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Revision 2 fixes bugs, reworks command-line args somewhat, and adds throttling.