Created
October 3, 2017 20:13
-
-
Save tiltec/a57c8024e05e93051693f50f58828860 to your computer and use it in GitHub Desktop.
concurrent test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import multiprocessing | |
import requests | |
from rest_framework import status | |
def do_requests_process(iq, oq, cookies=None, headers=None, pickup_url=None): | |
client = CSRFSession() | |
client.cookies = cookies | |
client.headers.update(headers) | |
for task in iter(iq.get, 'STOP'): | |
response = client.post(pickup_url + 'add/') | |
oq.put((task, response)) | |
client.close() | |
class CSRFSession(requests.Session): | |
def request(self, *args, **kwargs): | |
response = super().request(*args, **kwargs) | |
csrftoken = self.cookies['csrftoken'] | |
self.headers.update({'X-CSRFToken': csrftoken}) | |
return response | |
class TestPickupDatesAPIConcurrently(unittest.TestCase): | |
def test_join_pickup_as_member(self): | |
mp = multiprocessing.get_context('fork') | |
credentials = {'email': '357462christopher64@roberson.biz', 'password': '123'} | |
pickup_url = 'http://localhost:8000/api/pickup-dates/127/' | |
client = CSRFSession() | |
client.get('http://localhost:8000/api/auth/status/') | |
r = client.post('http://localhost:8000/api/auth/', json=credentials) | |
self.assertEqual(r.status_code, status.HTTP_201_CREATED) | |
client.post(pickup_url + 'remove/') | |
taskq = mp.Queue(30) | |
responseq = mp.Queue(30) | |
n = 4 | |
for id in range(n): | |
mp.Process( | |
target=do_requests_process, | |
kwargs={ | |
'iq': taskq, | |
'oq': responseq, | |
'cookies': client.cookies, | |
'headers': client.headers, | |
'pickup_url': pickup_url | |
} | |
).start() | |
client.close() | |
workload = 40 | |
for id in range(workload): | |
taskq.put('task{}'.format(id)) | |
responses = [] | |
for _ in range(workload): | |
(task, r) = responseq.get() | |
responses.append(r) | |
for _ in range(n): | |
taskq.put('STOP') | |
for i in responses: | |
if i.status_code not in (status.HTTP_200_OK, status.HTTP_403_FORBIDDEN): | |
print(i, i.text) | |
for i in responses: | |
print(i.status_code) | |
self.assertEqual(1, sum(1 for r in responses if r.status_code == status.HTTP_200_OK)) | |
self.assertEqual(workload - 1, sum(1 for r in responses if r.status_code == status.HTTP_403_FORBIDDEN)) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment