Skip to content

Instantly share code, notes, and snippets.

@hamnis
Created May 9, 2022 19:47
Show Gist options
  • Save hamnis/3a2a30d136f7016b71ca16a5e443a179 to your computer and use it in GitHub Desktop.
Save hamnis/3a2a30d136f7016b71ca16a5e443a179 to your computer and use it in GitHub Desktop.
package reloadable
import cats.effect._
import cats.syntax.all._
import org.typelevel.log4cats.LoggerFactory
import java.io.ByteArrayInputStream
import java.net.Socket
import java.nio.charset.StandardCharsets
import java.security.KeyStore
import java.security.cert.{CertificateFactory, X509Certificate}
import java.util.UUID
import java.util.concurrent.atomic.AtomicReference
import javax.net.ssl.{SSLContext, SSLEngine, TrustManager, TrustManagerFactory, X509ExtendedTrustManager}
import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.util.control.NonFatal
final class ReloadableX509TrustManager private (
defaultTrustManager: X509ExtendedTrustManager,
trustManager: AtomicReference[Option[X509ExtendedTrustManager]]
) extends X509ExtendedTrustManager {
private[this] val logger = org.log4s.getLogger
override def checkClientTrusted(chain: Array[X509Certificate], authType: String): Unit =
runOp(_.checkClientTrusted(chain, authType), "no client trust cert found")
override def checkClientTrusted(chain: Array[X509Certificate], authType: String, socket: Socket): Unit =
runOp(_.checkClientTrusted(chain, authType, socket), "no client trust cert found")
override def checkClientTrusted(chain: Array[X509Certificate], authType: String, engine: SSLEngine): Unit =
runOp(_.checkClientTrusted(chain, authType, engine), "no client trust cert found")
override def checkServerTrusted(chain: Array[X509Certificate], authType: String): Unit =
runOp(_.checkServerTrusted(chain, authType), "no server trust cert found")
override def checkServerTrusted(chain: Array[X509Certificate], authType: String, socket: Socket): Unit =
runOp(_.checkServerTrusted(chain, authType, socket), "no server trust cert found")
override def checkServerTrusted(chain: Array[X509Certificate], authType: String, engine: SSLEngine): Unit =
runOp(_.checkServerTrusted(chain, authType, engine), "no server trust cert found")
override def getAcceptedIssuers: Array[X509Certificate] =
trustManager.get() match {
case Some(custom) => custom.getAcceptedIssuers ++ defaultTrustManager.getAcceptedIssuers
case None => defaultTrustManager.getAcceptedIssuers
}
private def runOp(f: (X509ExtendedTrustManager) => Unit, onError: String): Unit = {
trustManager.get() match {
case Some(custom) =>
try {
f(custom)
} catch {
case e: Exception =>
logger.warn(e)(s"$onError, trying default trust manager")
f(defaultTrustManager)
}
case None => f(defaultTrustManager)
}
}
}
object ReloadableX509TrustManager {
def SSLContextForResource[F[_]](certs: F[List[String]], duration: FiniteDuration = 1.minute)(implicit
A: Async[F],
S: Spawn[F],
loggerFactory: LoggerFactory[F]
) =
Resource.make(SSLContextFor(certs, duration))(_._2.cancel).map(_._1)
def SSLContextFor[F[_]](certs: F[List[String]], duration: FiniteDuration)(implicit
A: Async[F],
S: Spawn[F],
loggerFactory: LoggerFactory[F]
): F[(SSLContext, Fiber[F, Throwable, Unit])] = {
val trustManager = new AtomicReference[Option[X509ExtendedTrustManager]](None)
def make(newCerts: List[String], currentCerts: Ref[F, List[String]]) = {
currentCerts
.set(newCerts)
.flatMap(_ =>
if (newCerts.isEmpty) {
loggerFactory.getLogger
.debug("No custom certificates, disabling custom trust manager")
.map(_ => trustManager.set(None))
} else
makeTrustManager(newCerts).flatMap(cer => Sync[F].delay(trustManager.set(Some(cer))))
)
}.recoverWith {
case NonFatal(e) =>
loggerFactory.getLogger
.warn(e)("Exception raised while creating trustmanager")
}
def reload(currentCerts: Ref[F, List[String]]): F[Unit] = {
val op = for {
_ <- A.sleep(duration)
current <- currentCerts.get
newCerts <- certs
shouldReload = newCerts.isEmpty || current.isEmpty || current != newCerts
_ <- if (shouldReload) make(newCerts, currentCerts) else A.unit
_ <- reload(currentCerts)
} yield ()
op.recoverWith {
case NonFatal(e) =>
loggerFactory.getLogger(getClass).warn(e)("Exception raised while reloading") *> op
}
}
for {
ref <- Ref.of(List.empty[String])
_ <- certs.flatMap(newCerts => make(newCerts, ref))
fiber <- S.start(reload(ref))
default <- getTrustManager(null)
ctx <- Sync[F].blocking {
val ctx = SSLContext.getInstance("TLS")
ctx.init(
null,
Array[TrustManager](
new ReloadableX509TrustManager(default, trustManager)
),
null
)
ctx
}
} yield (ctx, fiber)
}
private def makeTrustManager[F[_]: Sync](additionalCerts: List[String]): F[X509ExtendedTrustManager] = {
for {
_ <-
LoggerFactory
.getLogger[F]
.info(s"Making new trustmanager from ${additionalCerts.size} custom certs")
ks <- keystoreFor(additionalCerts)
manager <- getTrustManager(ks)
} yield manager
}
private def getTrustManager[F[_]: Sync](keyStore: KeyStore) = {
Sync[F].blocking {
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
tmf.init(keyStore)
tmf.getTrustManagers
.collectFirst {
case x: X509ExtendedTrustManager => x
}
.getOrElse(throw new IllegalStateException("No X509TrustManager in TrustManagerFactory"))
}
}
private def keystoreFor[F[_]: Sync](certificates: List[String]) =
Sync[F].blocking {
val ks = KeyStore.getInstance(KeyStore.getDefaultType)
ks.load(null)
val cf = CertificateFactory.getInstance("X.509")
certificates.foreach { (certificate) =>
val cert =
cf.generateCertificate(new ByteArrayInputStream(certificate.getBytes(StandardCharsets.UTF_8)))
ks.setCertificateEntry(UUID.randomUUID.toString, cert)
}
ks
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment