Last active
November 9, 2023 04:34
-
-
Save nickva/e733156c6fefdc94f6bd9051a456d9e5 to your computer and use it in GitHub Desktop.
Test script to generate a lot of CouchDB regular and admin users
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# | |
# make && ./dev/run -n 1 --admin=adm:pass | |
# ./multiuser.py --users-scheme=simple --users-salt="abc" --users-hash-sha1 --tries 100 | |
import copy | |
import sys | |
import time | |
import threading | |
import os | |
import argparse | |
import uuid | |
import random | |
import hashlib | |
import requests | |
from requests.exceptions import HTTPError | |
from multiprocessing.dummy import Pool as ThreadPool | |
BATCH = 100 | |
USER='adm' | |
PASS='pass' | |
TIMEOUT = 60 | |
DB_URLS = [ | |
'http://127.0.0.1:15984' | |
] | |
DB_NAME = 'db' | |
def log(*args): | |
sargs = [] | |
for a in args: | |
try: | |
sargs.append(str(a)) | |
except: | |
sargs.append('?') | |
msg = " ".join(sargs) | |
sys.stderr.write(msg + '\n') | |
sys.stderr.flush() | |
def pick_server(urls): | |
if isinstance(urls, list): | |
return random.choice(urls) | |
return urls | |
def pick_admin_or_user(args): | |
if random.random() > 0.9: | |
return get_admin(args, random.randrange(0, args.admin_count)) | |
else: | |
return get_user(args, random.randrange(0, args.users_count)) | |
class Server: | |
def __init__(self, auth, url, timeout=TIMEOUT): | |
self.sess = requests.Session() | |
self.sess.auth = auth | |
self.url = url.rstrip('/') | |
self.timeout = timeout | |
def _apply_timeout(self, kw): | |
if self.timeout is not None and 'timeout' not in kw: | |
kw['timeout'] = self.timeout | |
return kw | |
def get(self, path = '', **kw): | |
kw = self._apply_timeout(kw) | |
r = self.sess.get(f'{self.url}/{path}', **kw) | |
r.raise_for_status() | |
return r.json() | |
def post(self, path, **kw): | |
kw = self._apply_timeout(kw) | |
r = self.sess.post(f'{self.url}/{path}', **kw) | |
r.raise_for_status() | |
return r.json() | |
def put(self, path, **kw): | |
kw = self._apply_timeout(kw) | |
r = self.sess.put(f'{self.url}/{path}', **kw) | |
r.raise_for_status() | |
return r.json() | |
def delete(self, path, **kw): | |
kw = self._apply_timeout(kw) | |
r = self.sess.delete(f'{self.url}/{path}', **kw) | |
r.raise_for_status() | |
return r.json() | |
def head(self, path, **kw): | |
kw = self._apply_timeout(kw) | |
r = self.sess.head(f'{self.url}/{path}', **kw) | |
return r.status_code | |
def version(self): | |
return self.get()['version'] | |
def membership(self): | |
return self.get('_membership') | |
def cluster_setup(self, req): | |
return self.post('_cluster_setup', json = req) | |
def create_db(self, dbname): | |
if dbname not in self: | |
try: | |
self.put(dbname, timeout=TIMEOUT) | |
except HTTPError as err: | |
response = err.response | |
if not response: | |
Exception(f"{dbname} could not be created") | |
if response.status_code == 412: | |
log(f" -> {dbname} PUT returned a 412. DB is already created") | |
return True | |
raise err | |
if dbname not in self: | |
raise Exception(f"{dbname} could not be created") | |
else: | |
return True | |
def bulk_docs(self, dbname, docs, timeout=TIMEOUT): | |
return self.post(f'{dbname}/_bulk_docs', json = {'docs': docs}) | |
def bulk_get(self, dbname, docs, timeout=TIMEOUT): | |
return self.post(f'{dbname}/_bulk_get', json = {'docs': docs}) | |
def compact(self, dbname, **kw): | |
r = self.sess.post(f'{self.url}/{dbname}/_compact', json = {}, **kw) | |
r.raise_for_status() | |
return r.json() | |
def config_set(self, section, key, val): | |
url = f'_node/_local/_config/{section}/{key}' | |
return self.put(url, data='"'+val+'"') | |
def config_get(self, section, key): | |
url = f'_node/_local/_config/{section}/{key}' | |
return self.get(url) | |
def __iter__(self): | |
dbs = self.get('_all_dbs') | |
return iter(dbs) | |
def __str__(self): | |
return "<Server:%s>" % (self.url, self.auth[0]) | |
def __contains__(self, dbname): | |
res = self.head(dbname) | |
if res == 200: | |
return True | |
if res == 404: | |
return False | |
raise Exception(f"Unexpected head status code {res}") | |
def worker_fun(args, tid, i, url): | |
dbname = get_dbname(args, tid) | |
auth = pick_admin_or_user(args) | |
srv = Server(auth, url) | |
srv.version() | |
num_docs = args.num_docs | |
doc_id = i * num_docs | |
all_docs = {} | |
batches = num_docs // BATCH | |
for b in range(batches): | |
(doc_id, docs) = generate_docs(doc_id, BATCH, all_docs) | |
bulk_docs_with_retry(srv, dbname, docs) | |
log(f" -> batch tid:{tid} i:{i} batch:{b} {auth[0]}") | |
left = num_docs - batches * BATCH | |
(doc_id, docs) = generate_docs(doc_id, left, all_docs) | |
bulk_docs_with_retry(srv, dbname, docs) | |
def bulk_docs_with_retry(srv, dbname, docs): | |
try: | |
return srv.bulk_docs(dbname, docs) | |
except HTTPError as err: | |
response = err.response | |
if response and response.status_code == 500: | |
log(f" -> retrying _bulk_docs due to a 500 error {dbname} {len(docs)} {response}") | |
time.sleep(1.0) | |
return bulk_docs_with_retry(srv, dbname, docs) | |
raise err | |
def generate_docs(doc_id, batch, all_docs): | |
docs = [] | |
for i in range(batch): | |
doc_id += 1 | |
doc = generate_doc(doc_id) | |
docs.append(doc) | |
all_docs[doc['_id']] = doc | |
return (doc_id, docs) | |
def generate_doc(doc_id): | |
doc_id_str = '%012d' % doc_id | |
return { | |
'_id': doc_id_str, | |
'data': doc_id | |
} | |
def thread_worker(args): | |
tid = args.tid | |
url = pick_server(args.urls) | |
srv = Server(args.auth, url) | |
srv.version() | |
dbname = get_dbname(args, tid) | |
srv.create_db(dbname) | |
sec = { | |
'admins': {'names': [], 'roles': ["_admin", args.users_prefix]}, | |
'members': {'names': [], 'roles': ["_admin", args.users_prefix]} | |
} | |
srv.put(f'{dbname}/_security', json = sec) | |
tries = args.tries | |
for i in range(tries): | |
worker_fun(args, tid, i, url) | |
return tid | |
def get_dbname(args, tid): | |
return "%s_%s" % (args.dbname, tid) | |
def set_worker_id(args, tid): | |
args = copy.deepcopy(args) | |
args.tid = tid | |
return args | |
def wait_urls(args): | |
srvs = [] | |
for url in args.urls: | |
srv = wait_url(args, url) | |
log(" >> Server up", url, srv.version()) | |
srvs.append(srv) | |
return srvs | |
def wait_url(args, url): | |
while True: | |
srv = Server(args.auth, url) | |
try: | |
srv.version() | |
return srv | |
except Exception as e: | |
log(">>> Waiting for server", url) | |
time.sleep(1.0) | |
def clear(args): | |
prefix = args.dbname + '_' | |
srv = Server(args.auth, args.urls[0]) | |
dbnames = [srv.delete(db) for db in srv if db.startswith(prefix)] | |
def _get_auth(args): | |
if args.auth: | |
args.auth = tuple(args.auth.split(':')) | |
elif 'AUTH' in os.environ: | |
authstr = os.environ['AUTH'] | |
args.auth = tuple(authstr.split(':')) | |
log(" ! using auth", username," from AUTH env var") | |
else: | |
args.auth = (USER, PASS) | |
return args | |
def add_admins(args): | |
for url in args.urls: | |
srv = Server(args.auth, url) | |
for i in range(args.admin_count): | |
name, pw = get_admin(args, i) | |
srv.config_set("admins", name, pw) | |
def get_admin(args, i): | |
prefix = args.admin_prefix | |
return (f"{prefix}{i}", prefix) | |
def add_users(args): | |
srv = Server(args.auth, args.urls[0]) | |
for i in range(args.users_count): | |
doc = get_user_doc(args, i) | |
doc = maybe_hash(args, doc) | |
_id = doc['name'] | |
doc_path = f'_users/org.couchdb.user:{_id}' | |
status_code = srv.head(doc_path) | |
if status_code == 200: | |
doc['_rev'] = srv.get(doc_path)['_rev'] | |
srv.put(f'_users/org.couchdb.user:{_id}', json = doc) | |
def get_user_doc(args, i): | |
(name, pw) = get_user(args, i) | |
doc = { | |
'name': name, | |
'password': pw, | |
'type': 'user', | |
'roles': [args.users_prefix] | |
} | |
if args.users_salt: | |
doc["salt"] = args.users_salt | |
if args.users_scheme: | |
doc["password_scheme"] = args.users_scheme | |
if args.users_iterations: | |
doc["iterations"] = int(args.users_iterations) | |
if args.users_prf: | |
doc["pbkdf2_prf"] = args.users_prf | |
return doc | |
def maybe_hash(args, doc): | |
if not args.users_hash_sha1: | |
return doc | |
hasher = hashlib.sha1() | |
pw = doc['password'] | |
del doc['password'] | |
hasher.update(pw.encode('utf-8')) | |
salt = doc['salt'] | |
hasher.update(salt.encode('utf-8')) | |
password_sha = hasher.hexdigest() | |
doc['password_sha'] = password_sha | |
doc['password_scheme'] = 'simple' | |
return doc | |
def get_user(args, i): | |
prefix = args.users_prefix | |
return (f"{prefix}{i}", prefix) | |
def _args(): | |
p = argparse.ArgumentParser() | |
# list of couchdb server urls to connect to (usually one, but can have multiple | |
# in that case they are randomly picked) | |
p.add_argument('-u', '--urls', action="append", default=[], help = "Server URL(s)") | |
# db name prefix databases will be created with db_pid_threadid_.... pattern | |
p.add_argument('-d', '--dbname', default=DB_NAME, help = "DB name") | |
# how many regular docs to insert, docs are inserted in a batch | |
# default batch size is BATCH | |
p.add_argument('-n', '--num_docs', type=int, default=1000) | |
# how many worker threads to start per process | |
p.add_argument('-w', '--worker-count', type=int, default=10) | |
# how many times to repeat the operation per-tread (create docs, query) | |
p.add_argument('-t', '--tries', type=int, default=1) | |
# how many processes to start, each process will start `worker-count` threads | |
p.add_argument('-p', '--processes', type=int, default=1) | |
# main admin user user:pass basic auth creds | |
p.add_argument('-a', '--auth', default=None) | |
# extra admins | |
p.add_argument('--admin-count', type=int, default=1) | |
p.add_argument('--admin-prefix', type=str, default="adm") | |
p.add_argument('--users-count', type=int, default=100) | |
p.add_argument('--users-prefix', type=str, default="usr") | |
p.add_argument('--users-scheme') | |
p.add_argument('--users-iterations') | |
p.add_argument('--users-salt') | |
p.add_argument('--users-prf') | |
p.add_argument('--users-hash-sha1', action="store_true", default=False) | |
# clear all the database starting with out DB_NAME prefix (db_) | |
p.add_argument('-c', '--clear', action="store_true", default=False) | |
return p.parse_args() | |
def main(args): | |
if args.urls == []: | |
args.urls = DB_URLS | |
args = _get_auth(args) | |
wait_urls(args) | |
clear(args) | |
add_admins(args) | |
add_users(args) | |
wcount = args.worker_count | |
tpool = ThreadPool(wcount) | |
worker_args = [set_worker_id(args, i) for i in range(wcount)] | |
pending = tpool.imap_unordered(thread_worker, worker_args) | |
log(" ... running") | |
return list(pending) | |
if __name__=='__main__': | |
main(_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment