Skip to content

Instantly share code, notes, and snippets.

@Hornswoggles
Created December 10, 2018 21:05
Show Gist options
  • Save Hornswoggles/a22726720882567631da49b7cd5cc5cb to your computer and use it in GitHub Desktop.
Save Hornswoggles/a22726720882567631da49b7cd5cc5cb to your computer and use it in GitHub Desktop.
Hot reload a cert on update
# pantheon/ssl_context_factory.py
import os
import time
from OpenSSL import SSL
from twisted.internet import ssl
class ReloadingSSLContextFactory(ssl.DefaultOpenSSLContextFactory, object):
"""
L{ReloadingSSLContextFactory} is a factory for server-side SSL context
objects. These objects define certain parameters related to SSL
handshakes and the subsequent connection.
This class behaves similarly to ssl.DefaultOpenSSLContextFactory with the
addition of reloading certificates and keys when their on disk files
are updated. "hot reload"
"""
_check_interval = 10
_last_check = None
_last_reload = None
# cacheContext() is called by the ssl.DefaultOpenSSLContextFactory().__init__() to perform the
# initial cert and key loading
def cacheContext(self):
self._last_reload = os.path.getmtime(self.certificateFileName)
self._last_check = self._last_reload
super(ReloadingSSLContextFactory, self).cacheContext()
def _reload(self):
"""
_reload reloads the certificate and private key if the certificate has been updated
since the last time it was loaded.
This function will stat() the self.certificateFileName periodically to check if it has been
updated and reload it and the self.privateKeyFileName if the file is modified.
In order to avoid excessive stat() invocations the func will only check the self.certificateFileName's
mtime every self._check_interval seconds (10s)
"""
now = time.time()
next_check = (self._last_check + self._check_interval)
if next_check > now:
return
self._last_check = now
mtime = os.path.getmtime(self.certificateFileName)
if mtime > self._last_reload:
self._context.use_certificate_file(self.certificateFileName)
self._context.use_privatekey_file(self.privateKeyFileName)
self._last_reload = mtime
def getContext(self):
self._reload()
return self._context
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment