Skip to content

Instantly share code, notes, and snippets.

@zbyte64
Created March 10, 2016 22:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zbyte64/2378034a4ecfca1d71d8 to your computer and use it in GitHub Desktop.
Save zbyte64/2378034a4ecfca1d71d8 to your computer and use it in GitHub Desktop.
TLS authentication with rethinkdb python client
from rethinkdb.net import ConnectionInstance, SocketWrapper, Connection, decodeUTF
from rethinkdb.errors import *
import socket
import time
import ssl
class TLSConnectionInstance(ConnectionInstance):
def connect(self, timeout):
self._socket = TLSSocketWrapper(self, timeout)
return self._parent
class TLSSocketWrapper(SocketWrapper):
def __init__(self, parent, timeout):
self.host = parent._parent.host
self.port = parent._parent.port
self._read_buffer = None
self._socket = None
self.ssl = parent._parent.ssl
deadline = time.time() + timeout
try:
self._socket = \
socket.create_connection((self.host, self.port), timeout)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if len(self.ssl) > 0:
ssl_context = self._get_ssl_context(self.ssl["ca_certs"])
try:
self._socket = ssl_context.wrap_socket(self._socket,
server_hostname=self.host)
except IOError as exc:
self._socket.close()
raise ReqlDriverError("SSL handshake failed: %s" % (str(exc),))
#why not just use check_hostname like you should?
#try:
# match_hostname(self._socket.getpeercert(), hostname=self.host)
#except CertificateError:
# self._socket.close()
# raise
self.sendall(parent._parent.handshake)
# The response from the server is a null-terminated string
response = b''
while True:
char = self.recvall(1, deadline)
if char == b'\0':
break
response += char
except ReqlAuthError:
raise
except ReqlTimeoutError:
raise
except ReqlDriverError as ex:
self.close()
error = str(ex)\
.replace('receiving from', 'during handshake with')\
.replace('sending to', 'during handshake with')
raise #ReqlDriverError(error)
except socket.timeout as ex:
self.close()
raise ReqlTimeoutError(self.host, self.port)
except Exception as ex:
self.close()
raise ReqlDriverError("Could not connect to %s:%s. Error: %s" %
(self.host, self.port, ex))
if response != b"SUCCESS":
self.close()
message = decodeUTF(response).strip()
if message == "ERROR: Incorrect authorization key.":
raise ReqlAuthError(self.host, self.port)
else:
raise ReqlDriverError("Server dropped connection with message: \"%s\"" %
(message, ))
def _get_ssl_context(self, ca_certs):
#self.ssl #passed from connect
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = False
ctx.load_verify_locations(ca_certs)
if ('certfile' in self.ssl and 'keyfile' in self.ssl):
print("Loading certs:", self.ssl)
ctx.load_cert_chain(certfile=self.ssl['certfile'], keyfile=self.ssl['keyfile'])
return ctx
#ssl = {ca_certs, keyfile, certfile}
def connect(host='localhost', port=28015, db=None, auth_key="", timeout=20, ssl=dict(), **kwargs):
conn = Connection(TLSConnectionInstance, host, port, db, auth_key, timeout, ssl, **kwargs)
return conn.reconnect(timeout=timeout)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment