Skip to content

Instantly share code, notes, and snippets.

@ajdavis
Created April 10, 2012 13:15
Show Gist options
  • Save ajdavis/2351293 to your computer and use it in GitHub Desktop.
Save ajdavis/2351293 to your computer and use it in GitHub Desktop.
Hanging Gevent test with PyMongo
import random
import threading
import os
import unittest
from gevent import Greenlet, monkey
from pymongo.connection import Connection
DB = "pymongo-pooling-tests"
N = 50
def get_connection(*args, **kwargs):
host = os.environ.get("DB_IP", "localhost")
port = int(os.environ.get("DB_PORT", 27017))
return Connection(host, port, *args, **kwargs)
class TestThread(object):
def __init__(self, use_greenlets):
self.use_greenlets = use_greenlets
def start(self):
if self.use_greenlets:
self.thread = Greenlet(self.run)
else:
self.thread = threading.Thread(target=self.run)
self.thread.start()
def join(self):
self.thread.join()
self.thread = None
class MongoThread(TestThread):
def __init__(self, test_case):
super(MongoThread, self).__init__(test_case.use_greenlets)
self.connection = test_case.c
self.db = self.connection[DB]
self.ut = test_case
self.passed = False
def run(self):
self.run_mongo_thread()
# No exceptions thrown
self.passed = True
def run_mongo_thread(self):
raise NotImplementedError()
class SaveAndFind(MongoThread):
def run_mongo_thread(self):
for _ in xrange(N):
rand = random.randint(0, N)
id = self.db.sf.save({"x": rand}, safe=True)
self.ut.assertEqual(rand, self.db.sf.find_one(id)["x"])
class Unique(MongoThread):
def run_mongo_thread(self):
for _ in xrange(N):
self.connection.start_request()
self.db.unique.insert({})
self.ut.assertEqual(None, self.db.error())
self.connection.end_request()
class NonUnique(MongoThread):
def run_mongo_thread(self):
for _ in xrange(N):
self.connection.start_request()
self.db.unique.insert({"_id": "mike"})
self.ut.assertNotEqual(None, self.db.error())
self.connection.end_request()
class NoRequest(MongoThread):
def run_mongo_thread(self):
self.connection.start_request()
errors = 0
for _ in xrange(N):
self.db.unique.insert({"_id": "mike"})
if self.db.error() is None:
errors += 1
self.connection.end_request()
self.ut.assertEqual(0, errors)
def run_cases(ut, cases):
threads = []
for case in cases:
for i in range(10):
t = case(ut)
t.start()
threads.append(t)
for t in threads:
t.join()
for t in threads:
assert t.passed, "%s.run_mongo_thread() threw an exception" % repr(t)
class TestPoolingBase(object):
use_greenlets = False
def setUp(self):
if self.use_greenlets:
# Note we don't do patch_thread() or patch_all() - we're
# testing here that patch_thread() is unnecessary for
# the connection pool to work properly.
monkey.patch_socket()
self.c = self.get_connection(auto_start_request=False)
# reset the db
db = self.c[DB]
db.unique.drop()
db.test.drop()
db.unique.insert({"_id": "mike"})
db.test.insert([{} for i in range(1000)])
def get_connection(self, *args, **kwargs):
kwargs = kwargs.copy()
kwargs['use_greenlets'] = self.use_greenlets
return get_connection(*args, **kwargs)
def test_no_disconnect(self):
run_cases(self, [NoRequest, NonUnique, Unique, SaveAndFind])
class TestPoolingGevent(TestPoolingBase, unittest.TestCase):
use_greenlets = True
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment