Skip to content

Instantly share code, notes, and snippets.

@michaelplaing
Last active August 29, 2015 14:02
Show Gist options
  • Save michaelplaing/37d89c8f5f09ae779e47 to your computer and use it in GitHub Desktop.
Save michaelplaing/37d89c8f5f09ae779e47 to your computer and use it in GitHub Desktop.
bm_copy.py: pre-release fast copy app for Cassandra tables
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
bm_copy
Author: Michael P Laing, michael.laing@nytimes.com
Date: 2014-06-08
A preview version from the nyt⨍aбrik 'rabbit_helpers' framework.
Apache License Version 2.0, January 2004 http://www.apache.org/licenses/
"""
import os
import sys
import json
import uuid
import logging
import argparse
import traceback
import multiprocessing
from time import time, sleep
from datetime import datetime
from collections import deque
from threading import Event, Lock
from multiprocessing import Manager, Process
from cassandra.cluster import Cluster
from cassandra import ConsistencyLevel
from cassandra.policies import (
RetryPolicy,
RoundRobinPolicy,
TokenAwarePolicy,
DCAwareRoundRobinPolicy,
DowngradingConsistencyRetryPolicy
)
LOG_FORMAT = (
'%(levelname) -10s %(asctime)s %(name) -30s %(process)d '
'%(funcName) -35s %(lineno) -5d: %(message)s'
)
JSON_DUMPS_ARGS = {'ensure_ascii': False, 'indent': 4}
class CopyService(object):
def __init__(self, args, token_range, logger):
self._logger = logger
self._logger.name = __name__
self._logger.debug("")
self._args = {
"source": {
k[7:]: v for k, v in args.__dict__.items()
if k[:7] == "source_"
},
"dest": {
k[5:]: v for k, v in args.__dict__.items()
if k[:5] == "dest_"
}
}
for source_or_dest in ['source', 'dest']:
self._args[source_or_dest]['consistency_level'] = (
ConsistencyLevel.name_to_value[
self._args[source_or_dest]['consistency_level']
]
)
self._dest_token_aware = args.dest_token_aware
self._fetch_size = args.fetch_size
self._concurrency = args.concurrency
self._throttle_rate = args.throttle_rate
self._worker_count = args.worker_count
self._token_range = token_range
self._page_rate = (
float(self._throttle_rate)
/ (self._fetch_size * self._worker_count)
)
self._cql_cluster = {}
self._cql_session = {}
self._future = None # source future
self._finished_event = Event()
self._lock = Lock()
self._concurrent_updates = 0
self._stopped = False
self._page = 0
self._row_count = 0
self._rows = deque()
self._start_time = 0
self._stop_time = 0
self._os_times_start = None
self._os_times_stop = None
self._query = {}
u"""
-- my source_data and source_data_copy tables look like this
--
-- use your own tables and modify:
-- . source stmt
-- . dest stmt
-- . map_fields method
-- . map_routing_key method (if using token_aware option)
--
CREATE TABLE source_data (
hash_key text,
message_id timeuuid,
body blob,
metadata text,
PRIMARY KEY (hash_key, message_id)
);
"""
self._stmt = {
'source': u""" -- these ?'s will be filled in automatically
SELECT * -- Note: token range is inclusive
FROM benchmark.source_data
WHERE TOKEN(hash_key) >= ? AND TOKEN(hash_key) <= ?
""",
'dest': u""" -- these ?'s must be mapped by 'map_fields' below
UPDATE benchmark.source_data_copy
SET metadata = ?, body = ?
WHERE hash_key = ? AND message_id = ?
"""
}
def map_fields(self, source_row):
self._logger.debug("")
return ( # return a tuple in the order of the dest ?'s above
source_row.metadata,
source_row.body,
source_row.hash_key,
source_row.message_id
)
def map_routing_key(self, source_row): # required for token_aware
self._logger.debug("")
return (source_row.hash_key,) # return a tuple of the partition key
def fetch_now_or_later(self):
actual = time() - self._start_time
target = self._page / self._page_rate
self._logger.debug(
"actual: {0}; target: {1}; diff: {2}".format(
actual, target, target - actual
)
)
if target > actual: # fetching faster than target?
sleep(target - actual) # sleep until actual == target (overshoot)
self._future.start_fetching_next_page() # now fetch
def update_or_finish(self, _):
self._logger.debug(
"len(self._rows): {0}; "
"self._concurrent_updates: {1}; ".format(
len(self._rows),
self._concurrent_updates
)
)
try:
self._row_count += 1
self._concurrent_updates -= 1
# is work available or queued?
if self._future.has_more_pages or self._rows:
# is the current row within the latest page?
if (self._row_count - 1) / self._fetch_size + 1 > self._page:
# if so, maybe prefetch another page
self._logger.info(
'page: {0}; row_count: {1}'.format(
self._page, self._row_count
)
)
# is more work available?
if self._future.has_more_pages:
if self._throttle_rate: # is throttling in effect?
self.fetch_now_or_later() # maybe delay fetching
else: # fetch now while processing continues async
self._future.start_fetching_next_page()
self._page += 1
if self._rows: # is work queued?
self.update_dest_table() # process it
elif self._concurrent_updates: # is work in progress?
pass # wait for callbacks
else: # work is all done
self.finish()
except Exception as exc:
self.stop_and_raise(exc)
def update_dest_table(self):
self._logger.debug("")
with self._lock: # called from multiple threads, so use lock
while self._rows: # is work queued?
self._logger.debug(
"len(self._rows): {0}; "
"self._concurrent_updates: {1}; "
"self._concurrency: {2}".format(
len(self._rows),
self._concurrent_updates,
self._concurrency
)
)
# are we at the limit of concurrency?
if self._concurrent_updates >= self._concurrency:
break # enough work is in progress
else: # submit more work
self._concurrent_updates += 1
row = self._rows.pop()
subvars = self.map_fields(row)
if self._dest_token_aware:
self._query[
'dest'
].routing_key = self.map_routing_key(row)
future = self._cql_session[
'dest'
].execute_async(self._query['dest'], subvars)
future.add_callback(self.update_or_finish)
future.add_errback(self.stop_and_raise)
def stop_and_raise(self, exc):
self._logger.debug("")
error_msg = 'traceback: {}'.format(traceback.format_exc(exc))
self._logger.error(error_msg)
self.stop()
raise exc
def more_rows_or_finish(self, new_rows):
self._logger.debug(
"len(new_rows): {0}; "
"len(self._rows): {1}; "
"self._concurrent_updates: {2}; ".format(
len(new_rows), len(self._rows), self._concurrent_updates
)
)
try:
if new_rows: # is there new work?
self._rows.extend(new_rows) # extend the work queue
self.update_dest_table() # process it
elif self._rows: # is work queued?
self.update_dest_table() # process it
elif self._concurrent_updates: # is work in progress?
pass # wait for callbacks
else: # work is all done
self.finish()
except Exception as exc:
self.stop_and_raise(exc)
def select_from_source_table(self):
self._logger.debug("")
self._start_time = time()
self._os_times_start = os.times()
try:
self._logger.info(
"self._token_range: {}".format(self._token_range)
)
self._future = self._cql_session[
'source'
].execute_async(self._query['source'], self._token_range)
self._future.add_callback(self.more_rows_or_finish)
self._future.add_errback(self.stop_and_raise)
except Exception as exc:
self.stop_and_raise(exc)
def stop(self):
self._logger.debug("")
if self._stopped:
return
self._stop_time = time()
self._os_times_stop = os.times()
self._logger.info("Stopping service.")
for source_or_dest in ['source', 'dest']:
if self._cql_cluster[source_or_dest]:
try:
self._cql_cluster[source_or_dest].shutdown()
except Exception as exc:
error_msg = traceback.format_exc(exc)
self._logger.info(
"Exception on cql_cluster.shutdown(): {0}".format(
error_msg
)
)
raise
def connection(self, source_or_dest):
params = self._args[source_or_dest]
self._logger.debug(
"source_or_dest: {0}; params: {1}".format(
source_or_dest, params
)
)
try:
# defaults
load_balancing_policy = RoundRobinPolicy()
retry_policy = RetryPolicy()
if params['dc_aware']:
load_balancing_policy = DCAwareRoundRobinPolicy(
params['local_dc'],
used_hosts_per_remote_dc=params['remote_dc_hosts']
)
if params['token_aware']:
load_balancing_policy = TokenAwarePolicy(
load_balancing_policy
)
if params['retry']:
retry_policy = DowngradingConsistencyRetryPolicy()
self._cql_cluster[source_or_dest] = Cluster(
params['cql_host_list'],
load_balancing_policy=load_balancing_policy,
default_retry_policy=retry_policy
)
self._cql_session[source_or_dest] = self._cql_cluster[
source_or_dest
].connect()
if source_or_dest == 'source':
self._cql_session[
'source'
].default_fetch_size = self._fetch_size
self._query[source_or_dest] = self._cql_session[
source_or_dest
].prepare(self._stmt[source_or_dest])
self._query[source_or_dest].consistency_level = params[
'consistency_level'
]
except Exception as exc:
error_msg = 'Cassandra init error; traceback: {}'.format(
traceback.format_exc(exc)
)
self._logger.error(error_msg)
self.stop()
raise
self._logger.info('Connected to Cassandra - {}'.format(source_or_dest))
def finish(self):
self._logger.info("Finished")
self._finished_event.set()
def run(self):
self._logger.info("Starting service")
self.connection('source')
self.connection('dest')
self.select_from_source_table()
self._finished_event.wait()
self._logger.info("Stopping service")
def main(args, worker_index, tokens, rate, results, logger):
logger.info("Initializing...")
exitcode = 0
try:
service = CopyService(args, tokens, logger)
service.run()
except KeyboardInterrupt:
logger.info("Service terminated by SIGINT.")
except Exception as exc:
error_msg = traceback.format_exc(exc)
logger.error("Runtime exception: {0}".format(error_msg))
exitcode = 1
service.stop()
logger.info("Terminated.")
if service._row_count:
elapsed = service._stop_time - service._start_time
rate = int(service._row_count // elapsed)
else:
elapsed = 0
rate = 0
os_times_diff = map(
lambda stop, start: stop - start,
service._os_times_stop,
service._os_times_start
)
results[worker_index] = { # add results to the shared special dict
'row_count': service._row_count,
'start_time': service._start_time,
'stop_time': service._stop_time,
'elapsed': elapsed,
'rate': rate,
'os_times_stop': service._os_times_stop,
'os_times_start': service._os_times_start,
'os_times_diff': os_times_diff
}
sys.exit(exitcode)
def print_results(results):
print("\nWorker Rows Elapsed Rows/sec")
for worker_index in sorted(results.keys()):
r = results[worker_index]
print(
"{0:6d}{1:10d}{2:9.3f}{3:11d}".format(
worker_index, r['row_count'], r['elapsed'], r['rate']
)
)
print("CPU: {:4.0f}%".format(results[99]['cpu'] * 100))
def print_arguments(args):
print("arguments:")
for k, v in sorted(args.__dict__.items()):
print(" {0}: {1}".format(k, v))
def print_json(args, results):
output = {
"uuid": str(uuid.uuid1()),
"timestamp": datetime.utcnow().isoformat() + 'Z',
"args": args.__dict__,
"results": results
}
print(json.dumps(output, **JSON_DUMPS_ARGS))
def analyze_results(results, os_times_stop, os_times_start):
worker_indices = [worker_index for worker_index in results.keys()]
row_count = sum([
results[worker_index]['row_count']
for worker_index in worker_indices
])
start_time = min([
results[worker_index]['start_time']
for worker_index in worker_indices
if results[worker_index]['start_time'] != 0
])
stop_time = max([
results[worker_index]['stop_time']
for worker_index in worker_indices
if results[worker_index]['start_time'] != 0
])
if row_count:
elapsed = stop_time - start_time
rate = int(row_count // elapsed)
else:
elapsed = 0
rate = 0
os_times_diff = map(
lambda stop, start: stop - start, os_times_stop, os_times_start
)
cpu = (os_times_diff[2] + os_times_diff[3]) / os_times_diff[4]
results[99] = { # add summary results
'row_count': row_count,
'start_time': start_time,
'stop_time': stop_time,
'elapsed': elapsed,
'rate': rate,
'cpu': cpu,
'os_times_stop': os_times_stop,
'os_times_start': os_times_start,
'os_times_diff': os_times_diff
}
return {k: v for k, v in results.items()} # return a standard dict
def multiprocess(args):
step = ((args.token_hi - args.token_lo) / args.worker_count) + 1
tr1 = range(args.token_lo, args.token_hi, step) # intermediate points
tr2 = [(t, t + 1) for t in tr1[1:]] # pairs of adjacent points
tr3 = [t for st in tr2 for t in st] # flatten
tr4 = [args.token_lo] + tr3 + [args.token_hi] # add end points
token_ranges = [tr4[i:i + 2] for i in range(0, len(tr4), 2)] # make pairs
rate = args.throttle_rate / args.worker_count
formatter = logging.Formatter(LOG_FORMAT)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = multiprocessing.get_logger()
logger.addHandler(handler)
level = logging.getLevelName(args.logging_level) # undocumented feature
logger.setLevel(level)
manager = Manager()
results = manager.dict() # create a special shared dict to gather results
workers = [
Process(
target=main,
args=(
args, worker_index, token_ranges[worker_index], rate, results,
logger
)
)
for worker_index in range(args.worker_count)
]
os_times_start = os.times()
for worker in workers:
worker.start()
for worker in workers:
worker.join()
os_times_stop = os.times()
exitcode = 0
for worker in workers:
if worker.exitcode:
exitcode = worker.exitcode
break
if results:
results_dict = analyze_results(results, os_times_stop, os_times_start)
if args.json_output:
print_json(args, results_dict)
else:
print_arguments(args)
print_results(results_dict)
return(exitcode)
if __name__ == "__main__":
description = """
Copy/transform rows from one Cassandra table to another in the same or
different clusters.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--source-cql-host-list",
default=['localhost'],
dest="source_cql_host_list",
nargs='*',
metavar='CQL_HOST',
help=(
"source: the initial cql hosts to contact "
"(default = ['localhost'])"
)
)
parser.add_argument(
"--source-dc-aware",
dest="source_dc_aware",
action="store_true",
help="source: favor hosts in the local datacenter (default = False)"
)
parser.add_argument(
"--source-local-dc",
dest="source_local_dc",
default='datacenter1',
help=(
"source: if dc_aware, the local datacenter "
"(default = 'datacenter1')"
)
)
parser.add_argument(
"--source-remote-dc-hosts",
type=int,
default=0,
dest="source_remote_dc_hosts",
help=(
"source: if dc_aware, the number of hosts to connect to as "
"remote hosts (default = 0)"
)
)
parser.add_argument(
"--source-token-aware",
dest="source_token_aware",
action="store_true",
help=(
"source: route queries to known replicas by the tokens of their "
"partition keys (default = False)"
)
)
parser.add_argument(
"--source-consistency-level",
dest="source_consistency_level",
default='ONE',
choices=[
'ONE', 'TWO', 'THREE', 'QUORUM', 'ALL', 'LOCAL_QUORUM',
'EACH_QUORUM', 'SERIAL', 'LOCAL_SERIAL', 'LOCAL_ONE'
],
help="dest: consistency level (default = 'ONE')"
)
parser.add_argument(
"--source-retry",
dest="source_retry",
action="store_true",
help="source: downgrade consistency level and retry (default = False)"
)
parser.add_argument(
"--dest-cql-host-list",
default=['localhost'],
dest="dest_cql_host_list",
nargs='*',
metavar='CQL_HOST',
help=(
"dest: the initial cql hosts to contact "
"(default = ['localhost'])"
)
)
parser.add_argument(
"--dest-dc-aware",
dest="dest_dc_aware",
action="store_true",
help="dest: favor hosts in the local datacenter (default = False)"
)
parser.add_argument(
"--dest-local-dc",
default='datacenter1',
dest="dest_local_dc",
help=(
"dest: if dc_aware, the local datacenter "
"(default = 'datacenter1')"
)
)
parser.add_argument(
"--dest-remote-dc-hosts",
type=int,
default=0,
dest="dest_remote_dc_hosts",
help=(
"dest: if dc_aware, the number of hosts to be connected to as "
"remote hosts (default = 0)"
)
)
parser.add_argument(
"--dest-token-aware",
dest="dest_token_aware",
action="store_true",
help=(
"dest: route queries to known replicas by the tokens of their "
"partition keys (default = False)"
)
)
parser.add_argument(
"--dest-consistency-level",
dest="dest_consistency_level",
default='ONE',
choices=[
'ANY', 'ONE', 'TWO', 'THREE', 'QUORUM', 'ALL', 'LOCAL_QUORUM',
'EACH_QUORUM', 'SERIAL', 'LOCAL_SERIAL', 'LOCAL_ONE'
],
help="dest: consistency level (default = 'ONE')"
)
parser.add_argument(
"--dest-retry",
dest="dest_retry",
action="store_true",
help="dest: downgrade consistency level and retry (default = False)"
)
parser.add_argument(
"--token-hi",
type=int,
default=2 ** 63 - 1,
dest="token_hi",
help=(
"the high token in the inclusive partition key token range to "
"split among worker processes (default = 2 ** 63 - 1)"
)
)
parser.add_argument(
"--token-lo",
type=int,
default=-2 ** 63,
dest="token_lo",
help=(
"the low token in the inclusive partition key token range to "
"split among worker processes (default = -2 ** 63)"
)
)
parser.add_argument(
"-c",
"--concurrency",
type=int,
default=50,
dest="concurrency",
help=(
"the number of updates to launch concurrently in each process "
"using callback chaining (default = 50)"
)
)
parser.add_argument(
"-w",
"--worker-count",
type=int,
default=2,
dest="worker_count",
help=(
"the number of asynchronous worker processes to spawn - each "
"will handle an equal range of partition key tokens (default = 2)"
)
)
parser.add_argument(
"-f",
"--fetch-size",
type=int,
default=1000,
dest="fetch_size",
help="the number of rows to fetch in each page (default = 1000)"
)
parser.add_argument(
"-t",
"--throttle-rate",
type=int,
default=1000,
dest="throttle_rate",
help=(
"the aggregate rows per second to target across all workers, "
"for unlimited use 0 (default = 1000)"
)
)
parser.add_argument(
"-j",
"--json-output",
dest="json_output",
action="store_true",
help=(
"suppress formatted printing; output args and results as json "
"(default = False)"
)
)
parser.add_argument(
"-l",
"--logging-level",
dest="logging_level",
default='INFO',
choices=['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'],
help="logging level (default = 'INFO')"
)
parser.set_defaults(source_dc_aware=False)
parser.set_defaults(source_token_aware=False)
parser.set_defaults(source_retry=False)
parser.set_defaults(dest_dc_aware=False)
parser.set_defaults(dest_token_aware=False)
parser.set_defaults(dest_retry=False)
parser.set_defaults(json_output=False)
args = parser.parse_args(args=sys.argv[1:])
sys.exit(multiprocess(args))
@michaelplaing
Copy link
Author

Revision 2 fixes bugs, reworks command-line args somewhat, and adds throttling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment