-
-
Save amidvidy/92e4b1fd8485ab9008f7 to your computer and use it in GitHub Desktop.
load simulator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
(C) Copyright 2012, 10gen | |
""" | |
from __future__ import print_function | |
import getopt | |
import logging | |
import logging.handlers | |
import threading | |
import time | |
import sys | |
import math | |
from os import path | |
from pymongo import Connection | |
from pymongo import ReadPreference | |
from pymongo import ReplicaSetConnection | |
from pymongo.errors import CollectionInvalid | |
from pymongo.errors import ConnectionFailure | |
from pymongo.errors import OperationFailure | |
# Set up logging | |
logger = logging.getLogger('load_simulator') | |
class WorkerThread(threading.Thread): | |
"""Base class for all worker threads. 'replicaSet', 'host', and 'port' | |
refer to connection settings for connecting to the monogod. 'dbName' and | |
'colName' are the database name and collection name to write to, | |
respectively. 'Wave' is the wave function used to calculate the | |
oscillation, and 'amplitude' and 'period' are the parameters for the wave | |
function. 'sleepTime' is the base sleep amount used in the mainloop, | |
decrease this parameter if you want the thread to clobber the db less | |
often.""" | |
def __init__(self, replicaSet, host, port, dbName, colName, | |
junkAmt, waveFunction, sleepTime, *args, **kwargs): | |
# set up DB stuff | |
self.dbName = dbName | |
self.colName = colName | |
self.host = host | |
self.port = port | |
self.replicaSet = replicaSet | |
self.connection = None | |
self.collection = None | |
self.connectToDB() | |
# set up oscillation stuff | |
self.junkAmt = junkAmt | |
self.startTime = time.time() | |
self.indexGen = indexGenerator(self.junkAmt) | |
self.oscillator = waveGenerator(waveFunction, self.startTime) | |
self.sleepTime = sleepTime | |
# for error handling | |
self.threadType = self.__class__.__name__ | |
# pass other args to Thread constructor | |
super(WorkerThread, self).__init__(*args, **kwargs) | |
def connectToDB(self): | |
"""Connects this WorkerThread to the mongod it will be acting on. A | |
pymongo.Connection object will be used for connecting to standalone | |
instances or to a mongos, and a ReplicaSetConnection will be used for | |
replica sets. Connection params are those passed when constructing the | |
WorkerThread instance.""" | |
if self.replicaSet: | |
self.connection = ReplicaSetConnection( | |
max_pool_size=1, | |
replicaSet=self.replicaSet, | |
host=self.host, | |
port=self.port) | |
else: | |
self.connection = Connection( | |
max_pool_size=1, | |
host=self.host, | |
port=self.port) | |
self.connection.read_preference = ReadPreference.SECONDARY | |
self.collection = self.connection[self.dbName][self.colName] | |
def process(self): | |
raise NotImplementedError("WorkerThread is abstract" | |
"and should not be initialized") | |
def run(self): | |
"""Main loop for the WorkerThread object. The thread will continually | |
call time.sleep() for varying lengths determined by oscillator. If | |
any exception is thrown, the WorkerThread will reconnect to the | |
database and continue looping. Also, the WorkerThread will | |
automatically reconnect every hour.""" | |
hourlyReset = False | |
while True: | |
try: | |
time.sleep(self.sleepTime * self.oscillator.next()) | |
self.process() | |
except NotImplementedError: | |
# should crash if a base WorkerThread is launched | |
raise | |
except Exception: | |
if not hourlyReset: | |
logger.error("{0} encountered an unexpected error." | |
.format(self.threadType), | |
exc_info=True) | |
try: | |
# close connection and reset | |
if self.connection is not None: | |
self.connection.close() | |
self.connectToDB() | |
hourlyReset = False | |
except Exception: | |
pass | |
finally: | |
# reset connection every hour | |
timeElapsed = time.time() - self.startTime | |
if timeElapsed > 1.0 and timeElapsed % 3600 < 1.0: | |
time.sleep(1.0) | |
logger.debug( | |
"{0} has completed its hourly connection reset." | |
.format(self.threadType)) | |
hourlyReset = True | |
self.connection.close() | |
self.connection = None | |
self.collection = None | |
class InsertionThread(WorkerThread): | |
"""Does a bunch of insertions. Only this " | |
"thread uses the capped collection.""" | |
def process(self): | |
insertIndex = self.indexGen.next() | |
self.collection.insert({"index": insertIndex, "thread": "insert"}) | |
class DeletionThread(WorkerThread): | |
"""Does a bunch of deletes. This thread uses the uncapped collection. Note | |
that no deletes are actually performed on the server, but opcounters are | |
incremented nonetheless.""" | |
def process(self): | |
deleteIndex = self.indexGen.next() | |
self.collection.remove({"index": deleteIndex, "thread": "delete"}) | |
class UpdateThread(WorkerThread): | |
"""Does a bunch of updates. This thread uses the uncapped collection.""" | |
def process(self): | |
updateIndex = self.indexGen.next() | |
self.collection.update( | |
{"thread": "update"}, | |
{"$set": {"index": updateIndex}}, | |
upsert=True) | |
class RetrieveThread(WorkerThread): | |
"""Does a bunch of find queries. This thread uses the uncapped | |
collection.""" | |
def process(self): | |
# this should hit multiple shards if using a sharded collection | |
retrieveIndex = self.indexGen.next() | |
if retrieveIndex % 2 == 0: | |
self.collection.find_one( | |
{"index": retrieveIndex, | |
"thread": "retrieve"}) | |
else: | |
self.collection.find_one( | |
{"index": self.junkAmt - retrieveIndex, | |
"thread": "retrieve"}) | |
def indexGenerator(limit): | |
"""Loops within the range 0 to limit (exclusive).""" | |
curIndex = 0 | |
# avoid use of the modulus operator to prevent overflow | |
while True: | |
if curIndex >= limit: | |
curIndex = 0 | |
yield curIndex | |
curIndex += 1 | |
def waveGenerator(oscillatorFunction, startTime): | |
"""Generator that given a wave function, will return the value of the | |
curve at the time elapsed since startTime. | |
""" | |
while True: | |
elapsedTime = time.time() - startTime | |
yield oscillatorFunction(elapsedTime) | |
def sawToothWave(amplitude, period): | |
"""Makes a sawtooth wave function with the given amplitude and period""" | |
def valueAt(t): | |
return ((t % period) / period) * amplitude | |
return valueAt | |
def sineWave(amplitude, period): | |
"""Makes a sine wave function with the given amplitude and period""" | |
def valueAt(t): | |
return amplitude * math.sin(((2 * math.pi) / period) * t) + amplitude | |
return valueAt | |
def launchThreads(replicaSet, host, port, dbName, | |
capColName, testColName, junkAmt): | |
"""Launches the worker threads to read/write on the given collections.""" | |
logger.debug("Spawning threads...") | |
threads = [] | |
# InsertionThread uses capped collection, all others use uncapped | |
threads.append(InsertionThread( | |
replicaSet, host, port, | |
dbName, capColName, junkAmt, | |
sineWave(0.5, 3600), 0.1)) | |
threads.append(DeletionThread( | |
replicaSet, host, port, | |
dbName, testColName, junkAmt, | |
sawToothWave(60, 3600), 0.002)) | |
threads.append(UpdateThread( | |
replicaSet, host, port, | |
dbName, testColName, junkAmt, | |
sineWave(0.5, 3600), 0.05)) | |
threads.append(RetrieveThread( | |
replicaSet, host, port, | |
dbName, testColName, junkAmt, | |
sawToothWave(100, 3600), 0.002)) | |
threadsStarted = 0 | |
for thread in threads: | |
try: | |
# Make it a daemon thread so that it will get killed | |
# automatically if main thread dies | |
thread.daemon = True | |
thread.start() | |
logger.debug("Started {0}".format(thread.threadType)) | |
threadsStarted += 1 | |
except RuntimeError: | |
logger.error("Could not start {0}".format(thread.threadType)) | |
logger.debug("{0}/{1} worker threads initialized successfully" | |
.format(threadsStarted, len(threads))) | |
def setupCollection(db, collectionName, isCapped, collectionSizeMB): | |
"""Creates all collections needed for the simulation""" | |
collection = None | |
try: | |
collection = db.create_collection( | |
collectionName, | |
capped=isCapped, | |
size=(1024 * collectionSizeMB)) | |
except CollectionInvalid: | |
# the collection exists already, clean it up then use it | |
if not isCapped: | |
cleanupCollection(db, collectionName) | |
collection = db[collectionName] | |
collection.create_index("index") | |
except OperationFailure: | |
logger.critical("Collection creation failed; error below:", | |
exc_info=True) | |
sys.exit(1) | |
return collection | |
def cleanupCollection(db, collectionName): | |
"""Removes all documents from the test collection on a restart""" | |
db[collectionName].remove({}) | |
def configureLogger(l): | |
"""Configure logger to save all logs to stdout and supplied log file""" | |
logFile = path.abspath(l) | |
logHdlr = logging.handlers.RotatingFileHandler( | |
logFile, | |
maxBytes=(100 * 1024 ** 2), | |
backupCount=100) | |
stdoutHdlr = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') | |
logHdlr.setFormatter(formatter) | |
stdoutHdlr.setFormatter(formatter) | |
logger.addHandler(logHdlr) | |
logger.addHandler(stdoutHdlr) | |
logger.setLevel(logging.INFO) | |
logger.info("Saving logs to {0}".format(logFile)) | |
def processArguments(): | |
"""Sets default options and uses any user-supplied options""" | |
settings = {} | |
# Default options | |
settings["capped"] = "capped" | |
settings["uncapped"] = "uncapped" | |
settings["logfile"] = "load_db.log" | |
settings["database"] = "load_db" | |
settings["hostname"] = "localhost" | |
settings["port"] = 27017 | |
settings["replicaSet"] = None | |
settings["capSize"] = 10 | |
settings["junkDataSize"] = 1000 | |
# Parse arguments | |
try: | |
opts, args = getopt.getopt(sys.argv[1:], "d:c:u:s:l:h:p:r:j:", | |
["help", "database=", "cappedCollection=", "testCollection=", "size=", | |
"logFile=", "hostname=", "port=", "replicaSet=", "junkDataSize="]) | |
except getopt.error, msg: | |
print("\n", msg) | |
usage() | |
sys.exit(1) | |
# Process arguments | |
for option, arg in opts: | |
if option in ("--help"): | |
usage() | |
elif option in ("-c", "--cappedCollection"): | |
settings["capped"] = arg | |
elif option in ("-u", "--testCollection"): | |
settings["uncapped"] = arg | |
elif option in ("-l", "--logFile"): | |
settings["logfile"] = arg | |
elif option in ("-d", "--database"): | |
settings["database"] = arg | |
elif option in ("-h", "--hostname"): | |
settings["hostname"] = arg | |
elif option in ("-p", "--port"): | |
try: | |
settings["port"] = int(arg) | |
except ValueError: | |
usage() | |
elif option in ("-r", "--replicaSet"): | |
settings["replicaSet"] = arg | |
elif option in ("-s", "--size"): | |
try: | |
settings["capSize"] = int(arg) | |
except ValueError: | |
usage() | |
elif option in ("-j", "--junkDataSize"): | |
try: | |
settings["junkDataSize"] = int(arg) | |
except ValueError: | |
usage() | |
return settings | |
def usage(): | |
"""Prints out script usage""" | |
print("\nDESCRIPTION: \n") | |
print(" python " + sys.argv[0] + " [--help]" | |
"[-d database | --database database] " | |
"[-c | --cappedCollection] [-u | --testCollection] [-s | --size] " | |
"[-l | --logFile] [-p | --port] [-h | --hostname] " | |
"[-r | --replicaSet] " | |
"[-j | --junkDataSize]\n") | |
print(" The following options are available:\n") | |
print(" --help h is to print this message\n") | |
print(" -d (str) is to indicate which database the script will use. " | |
"By default, the script will use 'load_db'\n") | |
print(" -c (str) is to indicate which capped collection to use. " | |
"By default, the script will use 'capped'\n") | |
print(" -u (str) is to indicate which uncapped collection to use. " | |
"By default, the script will use 'uncapped'\n") | |
print(" -l (str) is to indicate which log file to use. " | |
"By default, the script will use load_db.log\n") | |
print(" -s (int) is to indicate how large the capped collection" | |
"should be (MB). By default, the script will use 10\n") | |
print(" -p (int) is to indicate what port to connect to. " | |
"By default, the script will use 27017\n") | |
print(" -h (str) is to indicate which host (mongod) to connect to. " | |
"By default, the script will use localhost\n") | |
print(" -r (str) is to indicate the replica set to connect to" | |
"By default, the script does not connect to any " | |
"replica sets.\n") | |
print(" -j (int) is to indicate the amount of junk data to insert " | |
"for RetrieveThread. By default, this parameter is 1000.\n") | |
print("EXAMPLES:") | |
print(" python loadSimulator.py -d load_db --capped=cappedCollection " | |
"-s 10 --logFile=loadSimulator.log\n") | |
print(" Will cause the script to use 10MB capped collection " | |
"'capped' in the standalone 'load_db' database and " | |
"save logs to loadSimulator.log\n") | |
print(" python loadSimulator.py\n") | |
print(" Will cause the script to use default values\n") | |
print("Note that long options should be followed by an equal sign ('=')\n") | |
print("Written by Adinoyi Omuya and Adam Midvidy") | |
print("(C) Copyright 2012, 10gen") | |
sys.exit(0) | |
def main(): | |
"""Main script entry point""" | |
# get settings | |
settings = processArguments() | |
configureLogger(settings["logfile"]) | |
try: | |
if settings["replicaSet"]: | |
connection = Connection( | |
max_pool_size=1, | |
host=settings["hostname"], | |
replicaSet=settings["replicaSet"], | |
port=settings["port"]) | |
logger.info("Using database {0} on replica set \"{1}\" on port " | |
"\"{2}\" with capped collection \"{3}\" ({4}MB) and " | |
"uncapped collection \"{5}\"." | |
.format(settings["database"], settings["replicaSet"], | |
settings["port"], settings["capped"], | |
settings["capSize"], settings["uncapped"])) | |
else: | |
connection = Connection( | |
max_pool_size=1, | |
host=settings["hostname"], | |
port=settings["port"]) | |
logger.info("Using database \"{0}\" with capped collection " | |
"\"{1}\" ({2}MB) and uncapped collection \"{3}\"." | |
.format(settings["database"], settings["capped"], | |
settings["capSize"], settings["uncapped"])) | |
# setup collections | |
db = connection[settings["database"]] | |
setupCollection(db, settings["capped"], True, settings["capSize"]) | |
uncapped = setupCollection(db, settings["uncapped"], | |
False, settings["capSize"]) | |
# add junk data for retrieve thread | |
logger.debug("Inserting junk data for retrieve thread...") | |
for index in xrange(0, settings["junkDataSize"]): | |
uncapped.insert({"index": index, "thread": "retrieve"}) | |
# main thread no longer needs a connection, let it get gc'd | |
connection.close() | |
connection = None | |
except ConnectionFailure: | |
logger.error("Could not connect to MongoDB server. Exiting...") | |
sys.exit(1) | |
# Launch the threads | |
launchThreads(settings["replicaSet"], settings["hostname"], | |
settings["port"], settings["database"], settings["capped"], | |
settings["uncapped"], settings["junkDataSize"]) | |
try: | |
# do nothing so that we can kill the workers by killing the main thread | |
while True: | |
time.sleep(1e7) | |
except KeyboardInterrupt: | |
logger.info("Received keyboard interrupt. Exiting...") | |
sys.exit(0) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment