Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
TLS authentication with rethinkdb python client
from rethinkdb.net import ConnectionInstance, SocketWrapper, Connection, decodeUTF
from rethinkdb.errors import *
import socket
import time
import ssl
class TLSConnectionInstance(ConnectionInstance):
def connect(self, timeout):
self._socket = TLSSocketWrapper(self, timeout)
return self._parent
class TLSSocketWrapper(SocketWrapper):
def __init__(self, parent, timeout):
self.host = parent._parent.host
self.port = parent._parent.port
self._read_buffer = None
self._socket = None
self.ssl = parent._parent.ssl
deadline = time.time() + timeout
try:
self._socket = \
socket.create_connection((self.host, self.port), timeout)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if len(self.ssl) > 0:
ssl_context = self._get_ssl_context(self.ssl["ca_certs"])
try:
self._socket = ssl_context.wrap_socket(self._socket,
server_hostname=self.host)
except IOError as exc:
self._socket.close()
raise ReqlDriverError("SSL handshake failed: %s" % (str(exc),))
#why not just use check_hostname like you should?
#try:
# match_hostname(self._socket.getpeercert(), hostname=self.host)
#except CertificateError:
# self._socket.close()
# raise
self.sendall(parent._parent.handshake)
# The response from the server is a null-terminated string
response = b''
while True:
char = self.recvall(1, deadline)
if char == b'\0':
break
response += char
except ReqlAuthError:
raise
except ReqlTimeoutError:
raise
except ReqlDriverError as ex:
self.close()
error = str(ex)\
.replace('receiving from', 'during handshake with')\
.replace('sending to', 'during handshake with')
raise #ReqlDriverError(error)
except socket.timeout as ex:
self.close()
raise ReqlTimeoutError(self.host, self.port)
except Exception as ex:
self.close()
raise ReqlDriverError("Could not connect to %s:%s. Error: %s" %
(self.host, self.port, ex))
if response != b"SUCCESS":
self.close()
message = decodeUTF(response).strip()
if message == "ERROR: Incorrect authorization key.":
raise ReqlAuthError(self.host, self.port)
else:
raise ReqlDriverError("Server dropped connection with message: \"%s\"" %
(message, ))
def _get_ssl_context(self, ca_certs):
#self.ssl #passed from connect
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = False
ctx.load_verify_locations(ca_certs)
if ('certfile' in self.ssl and 'keyfile' in self.ssl):
print("Loading certs:", self.ssl)
ctx.load_cert_chain(certfile=self.ssl['certfile'], keyfile=self.ssl['keyfile'])
return ctx
#ssl = {ca_certs, keyfile, certfile}
def connect(host='localhost', port=28015, db=None, auth_key="", timeout=20, ssl=dict(), **kwargs):
conn = Connection(TLSConnectionInstance, host, port, db, auth_key, timeout, ssl, **kwargs)
return conn.reconnect(timeout=timeout)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment