|
from multiprocessing import Process, Event |
|
import requests |
|
import unittest |
|
|
|
from BaseHTTPServer import HTTPServer |
|
from dbserver import InMemroryDBRequestHandler |
|
|
|
TEST_SERVER_NAME = 'localhost' |
|
TEST_SERVER_PORT = 4001 |
|
BASE_URL = 'http://' + TEST_SERVER_NAME + ':' + str(TEST_SERVER_PORT) |
|
|
|
def increment_test_server_port(): |
|
"""Workaround for Mac OS X holding on to sockets for too long, |
|
which was causing `socket.error: [Errno 48] Address already in use`.""" |
|
global TEST_SERVER_PORT, BASE_URL |
|
TEST_SERVER_PORT += 1 |
|
BASE_URL = 'http://' + TEST_SERVER_NAME + ':' + str(TEST_SERVER_PORT) |
|
|
|
|
|
class TestServerProcess(Process): |
|
""" |
|
A special process to workaround the `serve_forever` shenanigans. |
|
Usage: |
|
p = TestServerProcess() |
|
p.start() |
|
# ... run your tests... |
|
p.shutdown() |
|
""" |
|
def __init__(self): |
|
Process.__init__(self) |
|
self.exit = Event() |
|
self.server = HTTPServer(('',TEST_SERVER_PORT), InMemroryDBRequestHandler) |
|
self.server.allow_reuse_address = 1 |
|
def silent_logger(self, format, *args): |
|
return |
|
self.server.RequestHandlerClass.log_message = silent_logger |
|
self.daemon = True |
|
|
|
def run(self): |
|
while not self.exit.is_set(): |
|
self.server.handle_request() |
|
|
|
def shutdown(self): |
|
self.exit.set() |
|
self.server.server_close() |
|
increment_test_server_port() |
|
|
|
|
|
|
|
class BasicTestCase(unittest.TestCase): |
|
|
|
@classmethod |
|
def setUpClass(cls): |
|
cls.server_process = TestServerProcess() |
|
cls.server_process.start() |
|
|
|
@classmethod |
|
def tearDownClass(cls): |
|
cls.server_process.shutdown() |
|
|
|
def test_basic_set_get(self): |
|
set_response = requests.get(BASE_URL + '/set?key1=val1') |
|
self.assertEqual(200, set_response.status_code) |
|
# |
|
get_response = requests.get(BASE_URL + '/get?key=key1') |
|
self.assertEqual(200, get_response.status_code) |
|
self.assertEqual('val1', get_response.text) |
|
|
|
def test_404_missing(self): |
|
response = requests.get(BASE_URL + '/get?key=somethingdoesntexit') |
|
self.assertEqual(404, response.status_code) |
|
|
|
def test_500_onmalformed(self): |
|
response = requests.get(BASE_URL + '/malformed') |
|
self.assertEqual(500, response.status_code) |
|
|
|
def test_set_overwrites(self): |
|
set_response = requests.get(BASE_URL + '/set?key1=val1') |
|
self.assertEqual(200, set_response.status_code) |
|
# |
|
set_response = requests.get(BASE_URL + '/set?key1=val2') |
|
self.assertEqual(200, set_response.status_code) |
|
# |
|
get_response = requests.get(BASE_URL + '/get?key=key1') |
|
self.assertEqual(200, get_response.status_code) |
|
self.assertEqual('val2', get_response.text) |
|
|
|
|
|
|
|
class DataPersistenceTest(unittest.TestCase): |
|
|
|
def setUp(self): |
|
self.start_test_server() |
|
|
|
def tearDown(self): |
|
self.stop_test_server() |
|
|
|
def start_test_server(self): |
|
self.server_process = TestServerProcess() |
|
self.server_process.start() |
|
|
|
def stop_test_server(self): |
|
self.server_process.shutdown() |
|
|
|
def test_set_restart_get(self): |
|
set_response = requests.get(BASE_URL + '/set?key1=val1') |
|
self.assertEqual(200, set_response.status_code) |
|
# |
|
self.stop_test_server() |
|
# |
|
self.start_test_server() |
|
# |
|
get_response = requests.get(BASE_URL + '/get?key=key1') |
|
self.assertEqual(200, get_response.status_code) |
|
self.assertEqual('val1', get_response.text) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|