Skip to content

Instantly share code, notes, and snippets.

@amidvidy
Created June 26, 2012 17:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amidvidy/92e4b1fd8485ab9008f7 to your computer and use it in GitHub Desktop.
Save amidvidy/92e4b1fd8485ab9008f7 to your computer and use it in GitHub Desktop.
load simulator
"""
(C) Copyright 2012, 10gen
"""
from __future__ import print_function
import getopt
import logging
import logging.handlers
import threading
import time
import sys
import math
from os import path
from pymongo import Connection
from pymongo import ReadPreference
from pymongo import ReplicaSetConnection
from pymongo.errors import CollectionInvalid
from pymongo.errors import ConnectionFailure
from pymongo.errors import OperationFailure
# Set up logging
logger = logging.getLogger('load_simulator')
class WorkerThread(threading.Thread):
"""Base class for all worker threads. 'replicaSet', 'host', and 'port'
refer to connection settings for connecting to the monogod. 'dbName' and
'colName' are the database name and collection name to write to,
respectively. 'Wave' is the wave function used to calculate the
oscillation, and 'amplitude' and 'period' are the parameters for the wave
function. 'sleepTime' is the base sleep amount used in the mainloop,
decrease this parameter if you want the thread to clobber the db less
often."""
def __init__(self, replicaSet, host, port, dbName, colName,
junkAmt, waveFunction, sleepTime, *args, **kwargs):
# set up DB stuff
self.dbName = dbName
self.colName = colName
self.host = host
self.port = port
self.replicaSet = replicaSet
self.connection = None
self.collection = None
self.connectToDB()
# set up oscillation stuff
self.junkAmt = junkAmt
self.startTime = time.time()
self.indexGen = indexGenerator(self.junkAmt)
self.oscillator = waveGenerator(waveFunction, self.startTime)
self.sleepTime = sleepTime
# for error handling
self.threadType = self.__class__.__name__
# pass other args to Thread constructor
super(WorkerThread, self).__init__(*args, **kwargs)
def connectToDB(self):
"""Connects this WorkerThread to the mongod it will be acting on. A
pymongo.Connection object will be used for connecting to standalone
instances or to a mongos, and a ReplicaSetConnection will be used for
replica sets. Connection params are those passed when constructing the
WorkerThread instance."""
if self.replicaSet:
self.connection = ReplicaSetConnection(
max_pool_size=1,
replicaSet=self.replicaSet,
host=self.host,
port=self.port)
else:
self.connection = Connection(
max_pool_size=1,
host=self.host,
port=self.port)
self.connection.read_preference = ReadPreference.SECONDARY
self.collection = self.connection[self.dbName][self.colName]
def process(self):
raise NotImplementedError("WorkerThread is abstract"
"and should not be initialized")
def run(self):
"""Main loop for the WorkerThread object. The thread will continually
call time.sleep() for varying lengths determined by oscillator. If
any exception is thrown, the WorkerThread will reconnect to the
database and continue looping. Also, the WorkerThread will
automatically reconnect every hour."""
hourlyReset = False
while True:
try:
time.sleep(self.sleepTime * self.oscillator.next())
self.process()
except NotImplementedError:
# should crash if a base WorkerThread is launched
raise
except Exception:
if not hourlyReset:
logger.error("{0} encountered an unexpected error."
.format(self.threadType),
exc_info=True)
try:
# close connection and reset
if self.connection is not None:
self.connection.close()
self.connectToDB()
hourlyReset = False
except Exception:
pass
finally:
# reset connection every hour
timeElapsed = time.time() - self.startTime
if timeElapsed > 1.0 and timeElapsed % 3600 < 1.0:
time.sleep(1.0)
logger.debug(
"{0} has completed its hourly connection reset."
.format(self.threadType))
hourlyReset = True
self.connection.close()
self.connection = None
self.collection = None
class InsertionThread(WorkerThread):
"""Does a bunch of insertions. Only this "
"thread uses the capped collection."""
def process(self):
insertIndex = self.indexGen.next()
self.collection.insert({"index": insertIndex, "thread": "insert"})
class DeletionThread(WorkerThread):
"""Does a bunch of deletes. This thread uses the uncapped collection. Note
that no deletes are actually performed on the server, but opcounters are
incremented nonetheless."""
def process(self):
deleteIndex = self.indexGen.next()
self.collection.remove({"index": deleteIndex, "thread": "delete"})
class UpdateThread(WorkerThread):
"""Does a bunch of updates. This thread uses the uncapped collection."""
def process(self):
updateIndex = self.indexGen.next()
self.collection.update(
{"thread": "update"},
{"$set": {"index": updateIndex}},
upsert=True)
class RetrieveThread(WorkerThread):
"""Does a bunch of find queries. This thread uses the uncapped
collection."""
def process(self):
# this should hit multiple shards if using a sharded collection
retrieveIndex = self.indexGen.next()
if retrieveIndex % 2 == 0:
self.collection.find_one(
{"index": retrieveIndex,
"thread": "retrieve"})
else:
self.collection.find_one(
{"index": self.junkAmt - retrieveIndex,
"thread": "retrieve"})
def indexGenerator(limit):
"""Loops within the range 0 to limit (exclusive)."""
curIndex = 0
# avoid use of the modulus operator to prevent overflow
while True:
if curIndex >= limit:
curIndex = 0
yield curIndex
curIndex += 1
def waveGenerator(oscillatorFunction, startTime):
"""Generator that given a wave function, will return the value of the
curve at the time elapsed since startTime.
"""
while True:
elapsedTime = time.time() - startTime
yield oscillatorFunction(elapsedTime)
def sawToothWave(amplitude, period):
"""Makes a sawtooth wave function with the given amplitude and period"""
def valueAt(t):
return ((t % period) / period) * amplitude
return valueAt
def sineWave(amplitude, period):
"""Makes a sine wave function with the given amplitude and period"""
def valueAt(t):
return amplitude * math.sin(((2 * math.pi) / period) * t) + amplitude
return valueAt
def launchThreads(replicaSet, host, port, dbName,
capColName, testColName, junkAmt):
"""Launches the worker threads to read/write on the given collections."""
logger.debug("Spawning threads...")
threads = []
# InsertionThread uses capped collection, all others use uncapped
threads.append(InsertionThread(
replicaSet, host, port,
dbName, capColName, junkAmt,
sineWave(0.5, 3600), 0.1))
threads.append(DeletionThread(
replicaSet, host, port,
dbName, testColName, junkAmt,
sawToothWave(60, 3600), 0.002))
threads.append(UpdateThread(
replicaSet, host, port,
dbName, testColName, junkAmt,
sineWave(0.5, 3600), 0.05))
threads.append(RetrieveThread(
replicaSet, host, port,
dbName, testColName, junkAmt,
sawToothWave(100, 3600), 0.002))
threadsStarted = 0
for thread in threads:
try:
# Make it a daemon thread so that it will get killed
# automatically if main thread dies
thread.daemon = True
thread.start()
logger.debug("Started {0}".format(thread.threadType))
threadsStarted += 1
except RuntimeError:
logger.error("Could not start {0}".format(thread.threadType))
logger.debug("{0}/{1} worker threads initialized successfully"
.format(threadsStarted, len(threads)))
def setupCollection(db, collectionName, isCapped, collectionSizeMB):
"""Creates all collections needed for the simulation"""
collection = None
try:
collection = db.create_collection(
collectionName,
capped=isCapped,
size=(1024 * collectionSizeMB))
except CollectionInvalid:
# the collection exists already, clean it up then use it
if not isCapped:
cleanupCollection(db, collectionName)
collection = db[collectionName]
collection.create_index("index")
except OperationFailure:
logger.critical("Collection creation failed; error below:",
exc_info=True)
sys.exit(1)
return collection
def cleanupCollection(db, collectionName):
"""Removes all documents from the test collection on a restart"""
db[collectionName].remove({})
def configureLogger(l):
"""Configure logger to save all logs to stdout and supplied log file"""
logFile = path.abspath(l)
logHdlr = logging.handlers.RotatingFileHandler(
logFile,
maxBytes=(100 * 1024 ** 2),
backupCount=100)
stdoutHdlr = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
logHdlr.setFormatter(formatter)
stdoutHdlr.setFormatter(formatter)
logger.addHandler(logHdlr)
logger.addHandler(stdoutHdlr)
logger.setLevel(logging.INFO)
logger.info("Saving logs to {0}".format(logFile))
def processArguments():
"""Sets default options and uses any user-supplied options"""
settings = {}
# Default options
settings["capped"] = "capped"
settings["uncapped"] = "uncapped"
settings["logfile"] = "load_db.log"
settings["database"] = "load_db"
settings["hostname"] = "localhost"
settings["port"] = 27017
settings["replicaSet"] = None
settings["capSize"] = 10
settings["junkDataSize"] = 1000
# Parse arguments
try:
opts, args = getopt.getopt(sys.argv[1:], "d:c:u:s:l:h:p:r:j:",
["help", "database=", "cappedCollection=", "testCollection=", "size=",
"logFile=", "hostname=", "port=", "replicaSet=", "junkDataSize="])
except getopt.error, msg:
print("\n", msg)
usage()
sys.exit(1)
# Process arguments
for option, arg in opts:
if option in ("--help"):
usage()
elif option in ("-c", "--cappedCollection"):
settings["capped"] = arg
elif option in ("-u", "--testCollection"):
settings["uncapped"] = arg
elif option in ("-l", "--logFile"):
settings["logfile"] = arg
elif option in ("-d", "--database"):
settings["database"] = arg
elif option in ("-h", "--hostname"):
settings["hostname"] = arg
elif option in ("-p", "--port"):
try:
settings["port"] = int(arg)
except ValueError:
usage()
elif option in ("-r", "--replicaSet"):
settings["replicaSet"] = arg
elif option in ("-s", "--size"):
try:
settings["capSize"] = int(arg)
except ValueError:
usage()
elif option in ("-j", "--junkDataSize"):
try:
settings["junkDataSize"] = int(arg)
except ValueError:
usage()
return settings
def usage():
"""Prints out script usage"""
print("\nDESCRIPTION: \n")
print(" python " + sys.argv[0] + " [--help]"
"[-d database | --database database] "
"[-c | --cappedCollection] [-u | --testCollection] [-s | --size] "
"[-l | --logFile] [-p | --port] [-h | --hostname] "
"[-r | --replicaSet] "
"[-j | --junkDataSize]\n")
print(" The following options are available:\n")
print(" --help h is to print this message\n")
print(" -d (str) is to indicate which database the script will use. "
"By default, the script will use 'load_db'\n")
print(" -c (str) is to indicate which capped collection to use. "
"By default, the script will use 'capped'\n")
print(" -u (str) is to indicate which uncapped collection to use. "
"By default, the script will use 'uncapped'\n")
print(" -l (str) is to indicate which log file to use. "
"By default, the script will use load_db.log\n")
print(" -s (int) is to indicate how large the capped collection"
"should be (MB). By default, the script will use 10\n")
print(" -p (int) is to indicate what port to connect to. "
"By default, the script will use 27017\n")
print(" -h (str) is to indicate which host (mongod) to connect to. "
"By default, the script will use localhost\n")
print(" -r (str) is to indicate the replica set to connect to"
"By default, the script does not connect to any "
"replica sets.\n")
print(" -j (int) is to indicate the amount of junk data to insert "
"for RetrieveThread. By default, this parameter is 1000.\n")
print("EXAMPLES:")
print(" python loadSimulator.py -d load_db --capped=cappedCollection "
"-s 10 --logFile=loadSimulator.log\n")
print(" Will cause the script to use 10MB capped collection "
"'capped' in the standalone 'load_db' database and "
"save logs to loadSimulator.log\n")
print(" python loadSimulator.py\n")
print(" Will cause the script to use default values\n")
print("Note that long options should be followed by an equal sign ('=')\n")
print("Written by Adinoyi Omuya and Adam Midvidy")
print("(C) Copyright 2012, 10gen")
sys.exit(0)
def main():
"""Main script entry point"""
# get settings
settings = processArguments()
configureLogger(settings["logfile"])
try:
if settings["replicaSet"]:
connection = Connection(
max_pool_size=1,
host=settings["hostname"],
replicaSet=settings["replicaSet"],
port=settings["port"])
logger.info("Using database {0} on replica set \"{1}\" on port "
"\"{2}\" with capped collection \"{3}\" ({4}MB) and "
"uncapped collection \"{5}\"."
.format(settings["database"], settings["replicaSet"],
settings["port"], settings["capped"],
settings["capSize"], settings["uncapped"]))
else:
connection = Connection(
max_pool_size=1,
host=settings["hostname"],
port=settings["port"])
logger.info("Using database \"{0}\" with capped collection "
"\"{1}\" ({2}MB) and uncapped collection \"{3}\"."
.format(settings["database"], settings["capped"],
settings["capSize"], settings["uncapped"]))
# setup collections
db = connection[settings["database"]]
setupCollection(db, settings["capped"], True, settings["capSize"])
uncapped = setupCollection(db, settings["uncapped"],
False, settings["capSize"])
# add junk data for retrieve thread
logger.debug("Inserting junk data for retrieve thread...")
for index in xrange(0, settings["junkDataSize"]):
uncapped.insert({"index": index, "thread": "retrieve"})
# main thread no longer needs a connection, let it get gc'd
connection.close()
connection = None
except ConnectionFailure:
logger.error("Could not connect to MongoDB server. Exiting...")
sys.exit(1)
# Launch the threads
launchThreads(settings["replicaSet"], settings["hostname"],
settings["port"], settings["database"], settings["capped"],
settings["uncapped"], settings["junkDataSize"])
try:
# do nothing so that we can kill the workers by killing the main thread
while True:
time.sleep(1e7)
except KeyboardInterrupt:
logger.info("Received keyboard interrupt. Exiting...")
sys.exit(0)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment