Skip to content

Instantly share code, notes, and snippets.

@Szer
Created May 5, 2023 13:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save Szer/fda52c4afaaca9c3a4aee35b7870d2ed to your computer and use it in GitHub Desktop.
Save Szer/fda52c4afaaca9c3a4aee35b7870d2ed to your computer and use it in GitHub Desktop.
SAML IDP with custom certificate
package com.thriveglobal.identity.keycloak.idp
import com.thriveglobal.identity.keycloak.utils.Utils.logger
import java.security.PrivateKey
import java.security.PublicKey
import java.util.*
import java.util.stream.Stream
import org.keycloak.broker.saml.SAMLIdentityProvider
import org.keycloak.broker.saml.SAMLIdentityProviderConfig
import org.keycloak.broker.saml.SAMLIdentityProviderFactory
import org.keycloak.crypto.KeyUse
import org.keycloak.crypto.KeyWrapper
import org.keycloak.keys.KeyProvider
import org.keycloak.models.IdentityProviderModel
import org.keycloak.models.KeyManager
import org.keycloak.models.KeycloakSession
import org.keycloak.models.RealmModel
// KeyManager which will return a custom RSA SIG key
class KeyManagerWithKeyOverride(
private val wrappedSession: KeycloakSession,
private val keyToBeUsed: KeyWrapper,
) : KeyManager by wrappedSession.keys() {
override fun getActiveKey(realm: RealmModel?, use: KeyUse?, algorithm: String?): KeyWrapper =
if (use == KeyUse.SIG && algorithm == "RS256") {
keyToBeUsed
} else {
wrappedSession.keys().getActiveKey(realm, use, algorithm)
}
@Deprecated("Deprecated in Java")
override fun getActiveRsaKey(realm: RealmModel?): KeyManager.ActiveRsaKey =
KeyManager.ActiveRsaKey(
keyToBeUsed.kid,
keyToBeUsed.privateKey as PrivateKey,
keyToBeUsed.publicKey as PublicKey,
keyToBeUsed.certificate
)
override fun getKeysStream(realm: RealmModel?, use: KeyUse?, algorithm: String?): Stream<KeyWrapper> =
if (use == keyToBeUsed.use && algorithm == keyToBeUsed.algorithm) {
Stream.of(keyToBeUsed)
} else {
wrappedSession.keys().getKeysStream(realm, use, algorithm)
}
}
// Session which will use a KeyManagerWithKeyOverride with a custom RSA SIG key
class SessionWithKeyOverride(
private val session: KeycloakSession,
private val keyToBeUsed: KeyWrapper,
) : KeycloakSession by session {
override fun keys() = KeyManagerWithKeyOverride(session, keyToBeUsed)
}
// config object which has additional field for kid and validation logic
class SAMLIdentityProviderWithCertConfig : SAMLIdentityProviderConfig {
companion object {
const val CUSTOM_SIGN_CERTIFICATE_KID = "kid"
}
constructor(model: IdentityProviderModel?) : super(model)
constructor()
// kid getter and setter in Kotlin
var kid: String?
get() = config[CUSTOM_SIGN_CERTIFICATE_KID]
set(value) {
config[CUSTOM_SIGN_CERTIFICATE_KID] = value
}
override fun validate(realm: RealmModel) {
super.validate(realm)
val kid = config[CUSTOM_SIGN_CERTIFICATE_KID]
if (kid != null) {
// validate that key with kid exists if kid is specified
val key = realm
.getComponentsStream(realm.id, KeyProvider::class.java.name)
.filter(Objects::nonNull)
.map {
val keyWrapper = it.getNote<KeyWrapper>(KeyWrapper::class.java.name)
if (keyWrapper != null) {
return@map keyWrapper.kid
} else {
return@map it.config.getFirst("kid")
}
}
.filter { it == kid }
.findFirst()
if (key.isEmpty) {
throw IllegalArgumentException("Key with such ID does not exist in the realm!")
}
}
}
}
class SAMLIdentityProviderWithCertFactory : SAMLIdentityProviderFactory() {
companion object {
const val PROVIDER_ID = "saml-with-cert"
private val logger = logger()
// Optionally wrap the session to return a custom RSA SIG key
private fun wrapSession(
originalSession: KeycloakSession,
config: SAMLIdentityProviderWithCertConfig,
): KeycloakSession {
val realm = originalSession.context.realm
// try to find the key with the specified kid
val customKid = config.kid
if (customKid != null) {
// we need a RS256 key for SIGning
val key = originalSession.keys().getKey(realm, customKid, KeyUse.SIG, "RS256")
if (key != null) {
logger.debug("Key with kid $customKid is found for SAMLIdentityProviderWithCert ${config.alias}. Using wrapped session")
return SessionWithKeyOverride(originalSession, key)
} else {
// if the kid is specified, but key not found, throw an error
throw IllegalArgumentException("Key $customKid does not exist in the realm!")
}
}
return originalSession
}
}
override fun createConfig(): SAMLIdentityProviderConfig {
return SAMLIdentityProviderWithCertConfig()
}
// wrapping session into custom one
override fun create(session: KeycloakSession, model: IdentityProviderModel?): SAMLIdentityProvider {
val config = SAMLIdentityProviderWithCertConfig(model)
return super.create(wrapSession(session, config), model)
}
override fun getId() = PROVIDER_ID
override fun getName() = "SAML v2.0 with custom certificate"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment