Skip to content

Instantly share code, notes, and snippets.

@nickva
Last active November 9, 2023 04:34
Show Gist options
  • Save nickva/e733156c6fefdc94f6bd9051a456d9e5 to your computer and use it in GitHub Desktop.
Save nickva/e733156c6fefdc94f6bd9051a456d9e5 to your computer and use it in GitHub Desktop.
Test script to generate a lot of CouchDB regular and admin users
#!/usr/bin/env python
#
# make && ./dev/run -n 1 --admin=adm:pass
# ./multiuser.py --users-scheme=simple --users-salt="abc" --users-hash-sha1 --tries 100
import copy
import sys
import time
import threading
import os
import argparse
import uuid
import random
import hashlib
import requests
from requests.exceptions import HTTPError
from multiprocessing.dummy import Pool as ThreadPool
BATCH = 100
USER='adm'
PASS='pass'
TIMEOUT = 60
DB_URLS = [
'http://127.0.0.1:15984'
]
DB_NAME = 'db'
def log(*args):
sargs = []
for a in args:
try:
sargs.append(str(a))
except:
sargs.append('?')
msg = " ".join(sargs)
sys.stderr.write(msg + '\n')
sys.stderr.flush()
def pick_server(urls):
if isinstance(urls, list):
return random.choice(urls)
return urls
def pick_admin_or_user(args):
if random.random() > 0.9:
return get_admin(args, random.randrange(0, args.admin_count))
else:
return get_user(args, random.randrange(0, args.users_count))
class Server:
def __init__(self, auth, url, timeout=TIMEOUT):
self.sess = requests.Session()
self.sess.auth = auth
self.url = url.rstrip('/')
self.timeout = timeout
def _apply_timeout(self, kw):
if self.timeout is not None and 'timeout' not in kw:
kw['timeout'] = self.timeout
return kw
def get(self, path = '', **kw):
kw = self._apply_timeout(kw)
r = self.sess.get(f'{self.url}/{path}', **kw)
r.raise_for_status()
return r.json()
def post(self, path, **kw):
kw = self._apply_timeout(kw)
r = self.sess.post(f'{self.url}/{path}', **kw)
r.raise_for_status()
return r.json()
def put(self, path, **kw):
kw = self._apply_timeout(kw)
r = self.sess.put(f'{self.url}/{path}', **kw)
r.raise_for_status()
return r.json()
def delete(self, path, **kw):
kw = self._apply_timeout(kw)
r = self.sess.delete(f'{self.url}/{path}', **kw)
r.raise_for_status()
return r.json()
def head(self, path, **kw):
kw = self._apply_timeout(kw)
r = self.sess.head(f'{self.url}/{path}', **kw)
return r.status_code
def version(self):
return self.get()['version']
def membership(self):
return self.get('_membership')
def cluster_setup(self, req):
return self.post('_cluster_setup', json = req)
def create_db(self, dbname):
if dbname not in self:
try:
self.put(dbname, timeout=TIMEOUT)
except HTTPError as err:
response = err.response
if not response:
Exception(f"{dbname} could not be created")
if response.status_code == 412:
log(f" -> {dbname} PUT returned a 412. DB is already created")
return True
raise err
if dbname not in self:
raise Exception(f"{dbname} could not be created")
else:
return True
def bulk_docs(self, dbname, docs, timeout=TIMEOUT):
return self.post(f'{dbname}/_bulk_docs', json = {'docs': docs})
def bulk_get(self, dbname, docs, timeout=TIMEOUT):
return self.post(f'{dbname}/_bulk_get', json = {'docs': docs})
def compact(self, dbname, **kw):
r = self.sess.post(f'{self.url}/{dbname}/_compact', json = {}, **kw)
r.raise_for_status()
return r.json()
def config_set(self, section, key, val):
url = f'_node/_local/_config/{section}/{key}'
return self.put(url, data='"'+val+'"')
def config_get(self, section, key):
url = f'_node/_local/_config/{section}/{key}'
return self.get(url)
def __iter__(self):
dbs = self.get('_all_dbs')
return iter(dbs)
def __str__(self):
return "<Server:%s>" % (self.url, self.auth[0])
def __contains__(self, dbname):
res = self.head(dbname)
if res == 200:
return True
if res == 404:
return False
raise Exception(f"Unexpected head status code {res}")
def worker_fun(args, tid, i, url):
dbname = get_dbname(args, tid)
auth = pick_admin_or_user(args)
srv = Server(auth, url)
srv.version()
num_docs = args.num_docs
doc_id = i * num_docs
all_docs = {}
batches = num_docs // BATCH
for b in range(batches):
(doc_id, docs) = generate_docs(doc_id, BATCH, all_docs)
bulk_docs_with_retry(srv, dbname, docs)
log(f" -> batch tid:{tid} i:{i} batch:{b} {auth[0]}")
left = num_docs - batches * BATCH
(doc_id, docs) = generate_docs(doc_id, left, all_docs)
bulk_docs_with_retry(srv, dbname, docs)
def bulk_docs_with_retry(srv, dbname, docs):
try:
return srv.bulk_docs(dbname, docs)
except HTTPError as err:
response = err.response
if response and response.status_code == 500:
log(f" -> retrying _bulk_docs due to a 500 error {dbname} {len(docs)} {response}")
time.sleep(1.0)
return bulk_docs_with_retry(srv, dbname, docs)
raise err
def generate_docs(doc_id, batch, all_docs):
docs = []
for i in range(batch):
doc_id += 1
doc = generate_doc(doc_id)
docs.append(doc)
all_docs[doc['_id']] = doc
return (doc_id, docs)
def generate_doc(doc_id):
doc_id_str = '%012d' % doc_id
return {
'_id': doc_id_str,
'data': doc_id
}
def thread_worker(args):
tid = args.tid
url = pick_server(args.urls)
srv = Server(args.auth, url)
srv.version()
dbname = get_dbname(args, tid)
srv.create_db(dbname)
sec = {
'admins': {'names': [], 'roles': ["_admin", args.users_prefix]},
'members': {'names': [], 'roles': ["_admin", args.users_prefix]}
}
srv.put(f'{dbname}/_security', json = sec)
tries = args.tries
for i in range(tries):
worker_fun(args, tid, i, url)
return tid
def get_dbname(args, tid):
return "%s_%s" % (args.dbname, tid)
def set_worker_id(args, tid):
args = copy.deepcopy(args)
args.tid = tid
return args
def wait_urls(args):
srvs = []
for url in args.urls:
srv = wait_url(args, url)
log(" >> Server up", url, srv.version())
srvs.append(srv)
return srvs
def wait_url(args, url):
while True:
srv = Server(args.auth, url)
try:
srv.version()
return srv
except Exception as e:
log(">>> Waiting for server", url)
time.sleep(1.0)
def clear(args):
prefix = args.dbname + '_'
srv = Server(args.auth, args.urls[0])
dbnames = [srv.delete(db) for db in srv if db.startswith(prefix)]
def _get_auth(args):
if args.auth:
args.auth = tuple(args.auth.split(':'))
elif 'AUTH' in os.environ:
authstr = os.environ['AUTH']
args.auth = tuple(authstr.split(':'))
log(" ! using auth", username," from AUTH env var")
else:
args.auth = (USER, PASS)
return args
def add_admins(args):
for url in args.urls:
srv = Server(args.auth, url)
for i in range(args.admin_count):
name, pw = get_admin(args, i)
srv.config_set("admins", name, pw)
def get_admin(args, i):
prefix = args.admin_prefix
return (f"{prefix}{i}", prefix)
def add_users(args):
srv = Server(args.auth, args.urls[0])
for i in range(args.users_count):
doc = get_user_doc(args, i)
doc = maybe_hash(args, doc)
_id = doc['name']
doc_path = f'_users/org.couchdb.user:{_id}'
status_code = srv.head(doc_path)
if status_code == 200:
doc['_rev'] = srv.get(doc_path)['_rev']
srv.put(f'_users/org.couchdb.user:{_id}', json = doc)
def get_user_doc(args, i):
(name, pw) = get_user(args, i)
doc = {
'name': name,
'password': pw,
'type': 'user',
'roles': [args.users_prefix]
}
if args.users_salt:
doc["salt"] = args.users_salt
if args.users_scheme:
doc["password_scheme"] = args.users_scheme
if args.users_iterations:
doc["iterations"] = int(args.users_iterations)
if args.users_prf:
doc["pbkdf2_prf"] = args.users_prf
return doc
def maybe_hash(args, doc):
if not args.users_hash_sha1:
return doc
hasher = hashlib.sha1()
pw = doc['password']
del doc['password']
hasher.update(pw.encode('utf-8'))
salt = doc['salt']
hasher.update(salt.encode('utf-8'))
password_sha = hasher.hexdigest()
doc['password_sha'] = password_sha
doc['password_scheme'] = 'simple'
return doc
def get_user(args, i):
prefix = args.users_prefix
return (f"{prefix}{i}", prefix)
def _args():
p = argparse.ArgumentParser()
# list of couchdb server urls to connect to (usually one, but can have multiple
# in that case they are randomly picked)
p.add_argument('-u', '--urls', action="append", default=[], help = "Server URL(s)")
# db name prefix databases will be created with db_pid_threadid_.... pattern
p.add_argument('-d', '--dbname', default=DB_NAME, help = "DB name")
# how many regular docs to insert, docs are inserted in a batch
# default batch size is BATCH
p.add_argument('-n', '--num_docs', type=int, default=1000)
# how many worker threads to start per process
p.add_argument('-w', '--worker-count', type=int, default=10)
# how many times to repeat the operation per-tread (create docs, query)
p.add_argument('-t', '--tries', type=int, default=1)
# how many processes to start, each process will start `worker-count` threads
p.add_argument('-p', '--processes', type=int, default=1)
# main admin user user:pass basic auth creds
p.add_argument('-a', '--auth', default=None)
# extra admins
p.add_argument('--admin-count', type=int, default=1)
p.add_argument('--admin-prefix', type=str, default="adm")
p.add_argument('--users-count', type=int, default=100)
p.add_argument('--users-prefix', type=str, default="usr")
p.add_argument('--users-scheme')
p.add_argument('--users-iterations')
p.add_argument('--users-salt')
p.add_argument('--users-prf')
p.add_argument('--users-hash-sha1', action="store_true", default=False)
# clear all the database starting with out DB_NAME prefix (db_)
p.add_argument('-c', '--clear', action="store_true", default=False)
return p.parse_args()
def main(args):
if args.urls == []:
args.urls = DB_URLS
args = _get_auth(args)
wait_urls(args)
clear(args)
add_admins(args)
add_users(args)
wcount = args.worker_count
tpool = ThreadPool(wcount)
worker_args = [set_worker_id(args, i) for i in range(wcount)]
pending = tpool.imap_unordered(thread_worker, worker_args)
log(" ... running")
return list(pending)
if __name__=='__main__':
main(_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment