Created
May 5, 2023 13:45
-
-
Save Szer/fda52c4afaaca9c3a4aee35b7870d2ed to your computer and use it in GitHub Desktop.
SAML IDP with custom certificate
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.thriveglobal.identity.keycloak.idp | |
import com.thriveglobal.identity.keycloak.utils.Utils.logger | |
import java.security.PrivateKey | |
import java.security.PublicKey | |
import java.util.* | |
import java.util.stream.Stream | |
import org.keycloak.broker.saml.SAMLIdentityProvider | |
import org.keycloak.broker.saml.SAMLIdentityProviderConfig | |
import org.keycloak.broker.saml.SAMLIdentityProviderFactory | |
import org.keycloak.crypto.KeyUse | |
import org.keycloak.crypto.KeyWrapper | |
import org.keycloak.keys.KeyProvider | |
import org.keycloak.models.IdentityProviderModel | |
import org.keycloak.models.KeyManager | |
import org.keycloak.models.KeycloakSession | |
import org.keycloak.models.RealmModel | |
// KeyManager which will return a custom RSA SIG key | |
class KeyManagerWithKeyOverride( | |
private val wrappedSession: KeycloakSession, | |
private val keyToBeUsed: KeyWrapper, | |
) : KeyManager by wrappedSession.keys() { | |
override fun getActiveKey(realm: RealmModel?, use: KeyUse?, algorithm: String?): KeyWrapper = | |
if (use == KeyUse.SIG && algorithm == "RS256") { | |
keyToBeUsed | |
} else { | |
wrappedSession.keys().getActiveKey(realm, use, algorithm) | |
} | |
@Deprecated("Deprecated in Java") | |
override fun getActiveRsaKey(realm: RealmModel?): KeyManager.ActiveRsaKey = | |
KeyManager.ActiveRsaKey( | |
keyToBeUsed.kid, | |
keyToBeUsed.privateKey as PrivateKey, | |
keyToBeUsed.publicKey as PublicKey, | |
keyToBeUsed.certificate | |
) | |
override fun getKeysStream(realm: RealmModel?, use: KeyUse?, algorithm: String?): Stream<KeyWrapper> = | |
if (use == keyToBeUsed.use && algorithm == keyToBeUsed.algorithm) { | |
Stream.of(keyToBeUsed) | |
} else { | |
wrappedSession.keys().getKeysStream(realm, use, algorithm) | |
} | |
} | |
// Session which will use a KeyManagerWithKeyOverride with a custom RSA SIG key | |
class SessionWithKeyOverride( | |
private val session: KeycloakSession, | |
private val keyToBeUsed: KeyWrapper, | |
) : KeycloakSession by session { | |
override fun keys() = KeyManagerWithKeyOverride(session, keyToBeUsed) | |
} | |
// config object which has additional field for kid and validation logic | |
class SAMLIdentityProviderWithCertConfig : SAMLIdentityProviderConfig { | |
companion object { | |
const val CUSTOM_SIGN_CERTIFICATE_KID = "kid" | |
} | |
constructor(model: IdentityProviderModel?) : super(model) | |
constructor() | |
// kid getter and setter in Kotlin | |
var kid: String? | |
get() = config[CUSTOM_SIGN_CERTIFICATE_KID] | |
set(value) { | |
config[CUSTOM_SIGN_CERTIFICATE_KID] = value | |
} | |
override fun validate(realm: RealmModel) { | |
super.validate(realm) | |
val kid = config[CUSTOM_SIGN_CERTIFICATE_KID] | |
if (kid != null) { | |
// validate that key with kid exists if kid is specified | |
val key = realm | |
.getComponentsStream(realm.id, KeyProvider::class.java.name) | |
.filter(Objects::nonNull) | |
.map { | |
val keyWrapper = it.getNote<KeyWrapper>(KeyWrapper::class.java.name) | |
if (keyWrapper != null) { | |
return@map keyWrapper.kid | |
} else { | |
return@map it.config.getFirst("kid") | |
} | |
} | |
.filter { it == kid } | |
.findFirst() | |
if (key.isEmpty) { | |
throw IllegalArgumentException("Key with such ID does not exist in the realm!") | |
} | |
} | |
} | |
} | |
class SAMLIdentityProviderWithCertFactory : SAMLIdentityProviderFactory() { | |
companion object { | |
const val PROVIDER_ID = "saml-with-cert" | |
private val logger = logger() | |
// Optionally wrap the session to return a custom RSA SIG key | |
private fun wrapSession( | |
originalSession: KeycloakSession, | |
config: SAMLIdentityProviderWithCertConfig, | |
): KeycloakSession { | |
val realm = originalSession.context.realm | |
// try to find the key with the specified kid | |
val customKid = config.kid | |
if (customKid != null) { | |
// we need a RS256 key for SIGning | |
val key = originalSession.keys().getKey(realm, customKid, KeyUse.SIG, "RS256") | |
if (key != null) { | |
logger.debug("Key with kid $customKid is found for SAMLIdentityProviderWithCert ${config.alias}. Using wrapped session") | |
return SessionWithKeyOverride(originalSession, key) | |
} else { | |
// if the kid is specified, but key not found, throw an error | |
throw IllegalArgumentException("Key $customKid does not exist in the realm!") | |
} | |
} | |
return originalSession | |
} | |
} | |
override fun createConfig(): SAMLIdentityProviderConfig { | |
return SAMLIdentityProviderWithCertConfig() | |
} | |
// wrapping session into custom one | |
override fun create(session: KeycloakSession, model: IdentityProviderModel?): SAMLIdentityProvider { | |
val config = SAMLIdentityProviderWithCertConfig(model) | |
return super.create(wrapSession(session, config), model) | |
} | |
override fun getId() = PROVIDER_ID | |
override fun getName() = "SAML v2.0 with custom certificate" | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment