Skip to content

Instantly share code, notes, and snippets.

@dsc
Created July 19, 2010 20:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dsc/481966 to your computer and use it in GitHub Desktop.
Save dsc/481966 to your computer and use it in GitHub Desktop.
Modified stress.py for testing Cassandra
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# expects a Cassandra server to be running and listening on port 9160.
# (read tests expect insert tests to have run first too.)
have_multiproc = False
try:
from multiprocessing import Array as array, Process as Thread
from uuid import uuid1 as get_ident
Thread.isAlive = Thread.is_alive
have_multiproc = True
except ImportError:
from threading import Thread
from thread import get_ident
from array import array
from hashlib import md5
import time, random, sys, os
from random import randint, gauss
from optparse import OptionParser
from thrift.transport import TTransport
from thrift.transport import TSocket
from thrift.transport import THttpClient
from thrift.protocol import TBinaryProtocol
try:
from cassandra import Cassandra
from cassandra.ttypes import *
except ImportError:
# add cassandra directory to sys.path
L = os.path.abspath(__file__).split(os.path.sep)[:-3]
root = os.path.sep.join(L)
_ipath = os.path.join(root, 'interface', 'thrift', 'gen-py')
sys.path.append(os.path.join(_ipath, 'cassandra'))
import Cassandra
from ttypes import *
except ImportError:
print "Cassandra thrift bindings not found, please run 'ant gen-thrift-py'"
sys.exit(2)
try:
from thrift.protocol import fastbinary
except ImportError:
print "WARNING: thrift binary extension not found, benchmark will not be accurate!"
parser = OptionParser()
parser.add_option('-n', '--num-keys', type="int", dest="numkeys",
help="Number of keys", default=1000**2)
parser.add_option("-a", "--start-key", type="int", dest="startkey",
help="Key at which to start", default=0)
parser.add_option('-t', '--threads', type="int", dest="threads",
help="Number of threads/procs to use", default=50)
parser.add_option('-c', '--columns', type="int", dest="columns",
help="Number of columns per key", default=5)
parser.add_option('-d', '--nodes', type="string", dest="nodes",
help="Host nodes (comma separated)", default="localhost")
parser.add_option('-s', '--stdev', type="float", dest="stdev", default=0.1,
help="standard deviation factor")
parser.add_option('-r', '--random', action="store_true", dest="random",
help="use random key generator (stdev will have no effect)")
parser.add_option('-f', '--file', type="string", dest="file",
help="write output to file")
parser.add_option('-p', '--port', type="int", default=9160, dest="port",
help="thrift port")
parser.add_option('-m', '--framed', action="store_true", dest="framed",
help="use framed transport")
parser.add_option('-o', '--operation', type="choice", dest="operation",
default="insert", choices=('insert', 'read', 'rangeslice'),
help="operation to perform")
parser.add_option('-u', '--supercolumns', type="int", dest="supers", default=1,
help="number of super columns per key")
parser.add_option('-y', '--family-type', type="choice", dest="cftype",
choices=('regular','super'), default='regular',
help="column family type")
parser.add_option('-k', '--keep-going', action="store_true", dest="ignore",
help="ignore errors inserting or reading")
parser.add_option('-i', '--progress-interval', type="int", default=10,
dest="interval", help="progress report interval (seconds)")
parser.add_option('-g', '--get-range-slice-count', type="int", default=1000,
dest="rangecount",
help="amount of keys to get_range_slice per call")
(options, args) = parser.parse_args()
total_keys = options.numkeys
start_key = options.startkey
stop_key = options.startkey + options.numkeys
n_threads = options.threads
keys_per_thread = total_keys / n_threads
columns_per_key = options.columns
supers_per_key = options.supers
# this allows client to round robin requests directly for
# simple request load-balancing
nodes = options.nodes.split(',')
# a generator that generates all keys according to a bell curve centered
# around the middle of the keys generated (0..total_keys). Remember that
# about 68% of keys will be within stdev away from the mean and
# about 95% within 2*stdev.
stdev = total_keys * options.stdev
mean = start_key + (total_keys / 2)
def key_generator_gauss():
fmt = '%0' + str(len(str(total_keys))) + 'd'
while True:
guess = gauss(mean, stdev)
if 0 <= guess < total_keys:
return fmt % int(guess)
# a generator that will generate all keys w/ equal probability. this is the
# worst case for caching.
def key_generator_random():
fmt = '%0' + str(len(str(total_keys))) + 'd'
return fmt % randint(0, total_keys - 1)
key_generator = key_generator_gauss
if options.random:
key_generator = key_generator_random
def get_client(host='127.0.0.1', port=9160, framed=False):
socket = TSocket.TSocket(host, port)
if framed:
transport = TTransport.TFramedTransport(socket)
else:
transport = TTransport.TBufferedTransport(socket)
protocol = TBinaryProtocol.TBinaryProtocolAccelerated(transport)
client = Cassandra.Client(protocol)
client.transport = transport
return client
class Operation(Thread):
def __init__(self, i, counts, latencies, errors):
Thread.__init__(self)
# generator of the keys to be used
self.range = xrange(start_key + keys_per_thread * i, start_key + keys_per_thread * (i + 1))
# we can't use a local counter, since that won't be visible to the parent
# under multiprocessing. instead, the parent passes a "counts" array
# and an index that is our assigned counter.
self.idx = i
self.counts = counts
self.errors = errors
# similarly, a shared array for latency totals
self.latencies = latencies
# random host for pseudo-load-balancing
[hostname] = random.sample(nodes, 1)
self.hostname = hostname
# open client
self.cclient = get_client(hostname, options.port, options.framed)
self.cclient.transport.open()
class Inserter(Operation):
def run(self):
data = md5(str(get_ident())).hexdigest()
columns = [Column(chr(ord('A') + j), data, 0) for j in xrange(columns_per_key)]
fmt = '%0' + str(len(str(total_keys))) + 'd'
if 'super' == options.cftype:
supers = [SuperColumn(chr(ord('A') + j), columns) for j in xrange(supers_per_key)]
for i in self.range:
key = fmt % i
if 'super' == options.cftype:
cfmap= {'Super1': [ColumnOrSuperColumn(super_column=s) for s in supers]}
else:
cfmap = {'Standard1': [ColumnOrSuperColumn(column=c) for c in columns]}
while True:
start = time.time()
try:
self.cclient.batch_insert('Keyspace1', key, cfmap, ConsistencyLevel.ONE)
break
except KeyboardInterrupt:
# print self.idx, "Interrupted!"
sys.exit(1)
except Exception, e:
if options.ignore:
self.errors[self.idx] += 1.0
print self.hostname, self.idx, e
continue
else:
raise
self.latencies[self.idx] += time.time() - start
self.counts[self.idx] += 1
class Reader(Operation):
def run(self):
p = SlicePredicate(slice_range=SliceRange('', '', False, columns_per_key))
if 'super' == options.cftype:
for i in xrange(start_key, start_key + keys_per_thread):
key = key_generator()
for j in xrange(supers_per_key):
parent = ColumnParent('Super1', chr(ord('A') + j))
start = time.time()
try:
r = self.cclient.get_slice('Keyspace1', key, parent, p, ConsistencyLevel.ONE)
if not r:
if options.ignore:
self.errors[self.idx] += 1.0
else:
raise RuntimeError("Key %s not found" % key)
except KeyboardInterrupt:
# print self.idx, "Interrupted!"
sys.exit(1)
except Exception, e:
if options.ignore:
self.errors[self.idx] += 1.0
print self.hostname, self.idx, e
else:
raise
self.latencies[self.idx] += time.time() - start
self.counts[self.idx] += 1
else:
parent = ColumnParent('Standard1')
for i in xrange(start_key, start_key + keys_per_thread):
key = key_generator()
start = time.time()
try:
r = self.cclient.get_slice('Keyspace1', key, parent, p, ConsistencyLevel.ONE)
if not r:
if options.ignore:
self.errors[self.idx] += 1.0
else:
raise RuntimeError("Key %s not found" % key)
except KeyboardInterrupt:
# print self.idx, "Interrupted!"
sys.exit(1)
except Exception, e:
if options.ignore:
self.errors[self.idx] += 1.0
print e
else:
raise
self.latencies[self.idx] += time.time() - start
self.counts[self.idx] += 1
class RangeSlicer(Operation):
def run(self):
begin = self.range[0]
end = self.range[-1]
current = begin
last = current + options.rangecount
fmt = '%0' + str(len(str(total_keys))) + 'd'
p = SlicePredicate(slice_range=SliceRange('', '', False, columns_per_key))
if 'super' == options.cftype:
while current < end:
start = fmt % current
finish = fmt % last
res = []
for j in xrange(supers_per_key):
parent = ColumnParent('Super1', chr(ord('A') + j))
begin = time.time()
try:
res = self.cclient.get_range_slice('Keyspace1', parent, p, start,finish, options.rangecount, ConsistencyLevel.ONE)
if not res: raise RuntimeError("Key %s not found" % key)
except KeyboardInterrupt:
# print self.idx, "Interrupted!"
sys.exit(1)
except Exception, e:
if options.ignore:
self.errors[self.idx] += 1.0
print self.hostname, self.idx, e
else:
raise
self.latencies[self.idx] += time.time() - begin
self.counts[self.idx] += 1
current += len(r) + 1
last += len(r)
else:
parent = ColumnParent('Standard1')
while current < end:
start = fmt % current
finish = fmt % last
begin = time.time()
try:
r = self.cclient.get_range_slice('Keyspace1', parent, p, start, finish, options.rangecount, ConsistencyLevel.ONE)
if not r: raise RuntimeError("Range not found:", start, finish)
except KeyboardInterrupt:
# print self.idx, "Interrupted!"
sys.exit(1)
except Exception, e:
if options.ignore:
self.errors[self.idx] += 1.0
print e
else:
raise
current += len(r) + 1
last += len(r)
self.latencies[self.idx] += time.time() - begin
self.counts[self.idx] += 1
class OperationFactory:
@staticmethod
def create(type, i, counts, latencies, errors):
if type == 'read':
return Reader(i, counts, latencies, errors)
elif type == 'insert':
return Inserter(i, counts, latencies, errors)
elif type == 'rangeslice':
return RangeSlicer(i, counts, latencies, errors)
else:
raise RuntimeError, 'Unsupported op!'
class Stress(object):
def __init__(self):
self.counts = array('i', [0] * n_threads)
self.latencies = array('d', [0] * n_threads)
self.errors = array('f', [0.0] * n_threads)
def create_threads(self,type):
threads = []
for i in xrange(n_threads):
th = OperationFactory.create(type, i, self.counts, self.latencies, self.errors)
threads.append(th)
th.start()
return threads
def run_test(self,filename,threads):
start_t = time.time()
if filename:
outf = open(filename,'w')
else:
outf = sys.stdout
outf.write('secs\ttotal\top_rate\terr%\tavg_latency\n')
total = old_total = latency = old_latency = errors = old_errors = 0
try:
while True:
time.sleep(options.interval)
old_total, old_latency, old_errors = total, latency, errors
total = sum(self.counts[th.idx] for th in threads)
latency = sum(self.latencies[th.idx] for th in threads)
errors = sum(self.errors[th.idx] for th in threads)
delta = total - old_total
delta_latency = latency - old_latency
delta_errors = errors - old_errors
delta_errors_percent = (delta_errors / delta) if delta > 0 else 'NAN'
delta_formatted = (delta_latency / delta) if delta > 0 else 'NAN'
elapsed_t = int(time.time() - start_t)
outf.write('%d\t%d\t%d\t%.4f\t%.6f\n'
% (elapsed_t, total, delta / options.interval, delta_errors_percent, delta_formatted))
if not [th for th in threads if th.isAlive()]:
break
except KeyboardInterrupt:
print "Interrupted!"
sys.exit(1)
def insert(self):
threads = self.create_threads('insert')
self.run_test(options.file,threads);
def read(self):
threads = self.create_threads('read')
self.run_test(options.file,threads);
def rangeslice(self):
threads = self.create_threads('rangeslice')
self.run_test(options.file,threads);
stresser = Stress()
benchmark = getattr(stresser, options.operation, None)
if not have_multiproc:
print """WARNING: multiprocessing not present, threading will be used.
Benchmark may not be accurate!"""
benchmark()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment