Skip to content

Instantly share code, notes, and snippets.

@svaponi
Created January 9, 2024 08:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save svaponi/586a252f19a3189cd3dcc1c23cd5c9ac to your computer and use it in GitHub Desktop.
Save svaponi/586a252f19a3189cd3dcc1c23cd5c9ac to your computer and use it in GitHub Desktop.
package io.github.svaponi.auth0
import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.core.type.TypeReference
import com.fasterxml.jackson.databind.DeserializationFeature
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.MapperFeature
import com.fasterxml.jackson.databind.json.JsonMapper
import com.sun.net.httpserver.HttpExchange
import com.sun.net.httpserver.HttpHandler
import com.sun.net.httpserver.HttpServer
import mu.KLogging
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter
import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import java.io.File
import java.io.FileInputStream
import java.io.IOException
import java.math.BigInteger
import java.net.InetSocketAddress
import java.net.ServerSocket
import java.net.URL
import java.nio.charset.StandardCharsets
import java.security.InvalidKeyException
import java.security.KeyFactory
import java.security.KeyPair
import java.security.KeyPairGenerator
import java.security.MessageDigest
import java.security.NoSuchAlgorithmException
import java.security.Signature
import java.security.SignatureException
import java.security.cert.X509Certificate
import java.security.interfaces.RSAPrivateKey
import java.security.interfaces.RSAPublicKey
import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec
import java.time.Duration
import java.time.Instant
import java.util.Base64
import java.util.Date
import java.util.UUID
import java.util.concurrent.atomic.AtomicBoolean
import java.util.function.Function
import java.util.regex.Pattern
private object Utils {
val encoder: Base64.Encoder = Base64.getEncoder()
val urlEncoder: Base64.Encoder = Base64.getUrlEncoder()
val mapper: JsonMapper = JsonMapper.builder()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.build()
}
class Auth0Server(port: Int = -1) {
val tokenCreator: TokenCreator
private val keyPair: KeyPair
private val serverCert: X509Certificate
private val kid: String
private val server: Server
init {
val certDir = File(System.getProperty("java.io.tmpdir"), javaClass.getCanonicalName())
logger.debug("Certificates dir: {}", certDir)
val certProvider = CertProvider(certDir)
keyPair = certProvider.keyPair
serverCert = certProvider.certificate
kid = UUID.randomUUID().toString()
server = Server(port)
.registerResponseProvider(Pattern.compile("/health")) { buildHealthResponse() }
.registerResponseProvider(Pattern.compile("/oauth/token")) { buildTokenResponse(it) }
.registerResponseProvider(Pattern.compile("/.well-known/jwks.json")) { buildJwksResponse() }
.registerResponseProvider(Pattern.compile("/.well-known/openid-configuration")) { buildOpenIdConfigResponse() }
tokenCreator = TokenCreator(keyPair.private as RSAPrivateKey, kid, issuer)
}
fun start(): Auth0Server {
server.start()
return this
}
fun stop(): Auth0Server {
server.stop()
return this
}
val issuer: String
get() = "http://localhost:$port/"
val port: Int
get() = server.port
private var status: String = "UP"
private fun buildHealthResponse(): Server.Response {
val health = mapOf("status" to status)
return Server.Response(
status = 200,
body = Utils.mapper.valueToTree<JsonNode>(health).toString(),
headers = defaultHeaders
)
}
private val jwks: Map<*, *> by lazy {
val publicKey = keyPair.public as RSAPublicKey
val n = Utils.urlEncoder.encodeToString(publicKey.modulus.toByteArray())
val e = Utils.urlEncoder.encodeToString(publicKey.publicExponent.toByteArray())
val x5c = listOf(Utils.encoder.encodeToString(serverCert.encoded))
val md = MessageDigest.getInstance("SHA-1")
md.update(serverCert.encoded)
val x5t = Utils.urlEncoder.encodeToString(md.digest())
val key = mapOf(
"alg" to "RS256",
"kty" to "RSA",
"use" to "sig",
"n" to n,
"e" to e,
"kid" to kid,
"x5t" to x5t,
"x5c" to x5c,
)
mapOf("keys" to listOf(key))
}
private fun buildJwksResponse(): Server.Response {
return Server.Response(
status = 200,
body = Utils.mapper.valueToTree<JsonNode>(jwks).toString(),
headers = defaultHeaders
)
}
private val openIdConfig: Map<*, *> by lazy {
mapOf(
"jwks_uri" to "$issuer.well-known/jwks.json",
"issuer" to issuer
)
}
private fun buildOpenIdConfigResponse(): Server.Response {
return Server.Response(
status = 200,
body = Utils.mapper.valueToTree<JsonNode>(openIdConfig).toString(),
headers = defaultHeaders
)
}
private fun buildTokenResponse(httpExchange: HttpExchange): Server.Response {
return try {
val requestBody: JsonNode = Utils.mapper.readTree(httpExchange.requestBody)
val grantType = requestBody.at("/grant_type").asText(null)
assert("client_credentials".equals(grantType, ignoreCase = true)) { "Unsupported grant_type" }
val clientId = requestBody.at("/client_id").asText(null)
assert(clientId.isNotBlank()) { "Missing client_id" }
val clientSecret = requestBody.at("/client_secret").asText(null)
assert(clientSecret.isNotBlank()) { "Missing client_secret" }
val audience = requestBody.at("/audience").asText(null)
assert(audience.isNotBlank()) { "Missing audience" }
val clientConfig = ClientConfigProvider.getByClientId(clientId)
assert(clientSecret == clientConfig.clientSecret) { "Invalid credentials" }
val tokenConfig: TokenConfig = clientConfig.audiences[audience]
?: throw IllegalArgumentException("Unknown audience '$audience' for client_id '$clientId'")
val claims: MutableMap<String, Any> = LinkedHashMap(tokenConfig.claims)
claims["gty"] = "client-credentials" // necessary for machine to machine tokens
val accessToken = tokenCreator.create()
// if missing subject, use `{clientId}@clients` to simulate Auth0 behaviour
.subject(tokenConfig.subject ?: "$clientId@clients")
.audience(audience)
.claims(claims)
.expiresIn(tokenConfig.expiresIn)
.sign()
val response: MutableMap<String, Any> = LinkedHashMap()
response["access_token"] = accessToken
response["scope"] = claims.getOrDefault("scope", "")
response["expires_in"] = tokenConfig.expiresIn
response["token_type"] = "Bearer"
Server.Response(
status = 200,
body = Utils.mapper.valueToTree<JsonNode>(response).toString(),
headers = defaultHeaders
)
} catch (e: Exception) {
val response: MutableMap<String, Any?> = LinkedHashMap()
response["error"] = e.javaClass.getCanonicalName()
response["error_description"] = e.message
Server.Response(
status = if (e is IllegalArgumentException) 400 else 500,
body = Utils.mapper.valueToTree<JsonNode>(response).toString(),
headers = defaultHeaders
)
}
}
companion object : KLogging() {
private val defaultHeaders: Map<String, List<String>> = mapOf(
"Content-Type" to listOf("application/json"),
"Cache-Control" to listOf("private, no-store, no-cache, must-revalidate, post-check=0, pre-check=0, no-transform")
)
}
}
private object ClientConfigProvider {
private val configs: Collection<ClientConfig> by lazy {
Thread.currentThread().getContextClassLoader().getResource("auth0_server.json")
?.let { obj: URL -> File(obj.file) }
?.takeIf { it.exists() }
?.let { Utils.mapper.readValue(FileInputStream(it), object : TypeReference<Collection<ClientConfig>>() {}) }
?: emptyList()
}
fun getByClientId(clientId: String): ClientConfig {
return configs.firstOrNull { clientId == it.clientId }
?: throw IllegalArgumentException("Unknown client_id '$clientId'")
}
}
private data class ClientConfig(
val clientId: String? = null,
val clientSecret: String? = null,
val audiences: Map<String, TokenConfig> = HashMap()
)
private data class TokenConfig(
val subject: String? = null,
val claims: Map<String, Any>? = null,
val expiresIn: Long = 3600
)
class TokenCreator(val privateKey: RSAPrivateKey, val kid: String, val issuer: String) {
object PublicClaims {
//Header
const val ALGORITHM = "alg"
const val CONTENT_TYPE = "cty"
const val TYPE = "typ"
const val KEY_ID = "kid"
//Payload
const val ISSUER = "iss"
const val SUBJECT = "sub"
const val EXPIRES_AT = "exp"
const val NOT_BEFORE = "nbf"
const val ISSUED_AT = "iat"
const val JWT_ID = "jti"
const val AUDIENCE = "aud"
}
class TokenException(message: String, cause: Throwable) : Exception(message, cause)
fun create(): Token {
return Token()
}
inner class Token {
private val signed: AtomicBoolean = AtomicBoolean(false)
private val payloadClaims: MutableMap<String, Any> = HashMap()
private val headerClaims: MutableMap<String, Any> = HashMap()
/**
* @param subject the Subject ("sub") claim.
*/
fun subject(subject: String): Token {
payloadClaims[PublicClaims.SUBJECT] = subject
return this
}
/**
* @param audience the Audience ("aud") claim.
*/
fun audience(audience: String): Token {
payloadClaims[PublicClaims.AUDIENCE] = audience
return this
}
/**
* @param audience the Audience ("aud") claim.
*/
fun audience(audience: Collection<String?>): Token {
payloadClaims[PublicClaims.AUDIENCE] = audience
return this
}
/**
* @param expiresAt the Expires At ("exp") claim.
*/
fun expiresAt(expiresAt: Instant): Token {
payloadClaims[PublicClaims.EXPIRES_AT] = expiresAt.epochSecond
return this
}
/**
* @param expiresIn time to live in seconds.
*/
fun expiresIn(expiresIn: Long): Token {
return expiresAt(Instant.now().plusSeconds(expiresIn))
}
/**
* @param claims additional custom JWT claims to the Payload.
*/
fun claims(claims: Map<String, Any>?): Token {
if (claims != null) {
payloadClaims.putAll(claims)
}
return this
}
/**
* Creates a new JWT and signs is with the given algorithm
*
* @return a new JWT token
* @throws IllegalArgumentException if the provided algorithm is null.
*/
@Throws(TokenException::class)
fun sign(): String {
if (signed.compareAndSet(false, true)) {
return try {
val s = Signature.getInstance(SIGNATURE) // validates algorithm before proceeding, fast fail
headerClaims.putIfAbsent(PublicClaims.ALGORITHM, ALGORITHM)
headerClaims.putIfAbsent(PublicClaims.TYPE, "JWT")
headerClaims.putIfAbsent(PublicClaims.KEY_ID, kid)
val headerJsonBytes = JSON_MAPPER.writeValueAsBytes(headerClaims)
payloadClaims.putIfAbsent(PublicClaims.ISSUER, issuer)
payloadClaims.putIfAbsent(PublicClaims.JWT_ID, UUID.randomUUID().toString())
payloadClaims.putIfAbsent(PublicClaims.ISSUED_AT, Instant.now().epochSecond)
val payloadJsonBytes = JSON_MAPPER.writeValueAsBytes(payloadClaims)
val header = ENCODER.encodeToString(headerJsonBytes)
val payload = ENCODER.encodeToString(payloadJsonBytes)
val contentBytes = String.format("%s.%s", header, payload).toByteArray(StandardCharsets.UTF_8)
s.initSign(privateKey)
s.update(contentBytes)
val signatureBytes = s.sign()
val signature = ENCODER.encodeToString(signatureBytes)
String.format("%s.%s.%s", header, payload, signature)
} catch (e: NoSuchAlgorithmException) {
throw TokenException("Algorithm $SIGNATURE not found", e)
} catch (e: SignatureException) {
throw TokenException("Signing error", e)
} catch (e: InvalidKeyException) {
throw TokenException("Invalid key pair", e)
} catch (e: JsonProcessingException) {
throw TokenException("Some of the Claims couldn't be converted to a valid JSON format.", e)
} finally {
headerClaims.clear()
payloadClaims.clear()
}
}
throw IllegalStateException("Token already signed")
}
}
companion object {
private const val SIGNATURE = "SHA256withRSA"
private const val ALGORITHM = "RS256"
private val ENCODER: Base64.Encoder = Base64.getUrlEncoder().withoutPadding()
private val JSON_MAPPER: JsonMapper = JsonMapper.builder()
.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true)
.build()
}
}
/**
* Starts a simple web application, just register responses via methods:
*
* @see Server.registerResponse
* @see Server.registerResponseProvider
*/
private class Server(port: Int = -1) {
private val pathPatternToResponseProvider: MutableMap<Pattern, Function<HttpExchange, Response>> = HashMap()
private val address: InetSocketAddress by lazy { InetSocketAddress(if (port > 0) port else findFreePort()) }
private var httpServer: HttpServer? = null
data class Response(
val status: Int = 0,
val body: String? = null,
val headers: Map<String, List<String>>? = null
)
fun registerResponse(
pathPattern: Pattern,
status: Int,
body: String?,
headers: Map<String, List<String>>?
): Server {
pathPatternToResponseProvider[pathPattern] =
Function<HttpExchange, Response> { Response(status, body, headers) }
return this
}
fun registerResponseProvider(pathPattern: Pattern, responseProvider: Function<HttpExchange, Response>): Server {
pathPatternToResponseProvider[pathPattern] = responseProvider
return this
}
val port: Int
get() = address.port
/**
* Recreates an instance of HttpServer. From Javadoc: "Once stopped, a HttpServer cannot be re-used."
*
* @see HttpServer.stop
*/
private fun createHttpServer(): HttpServer {
return try {
System.setProperty("sun.net.httpserver.maxReqTime", "1000")
System.setProperty("sun.net.httpserver.maxRspTime", "1000")
val httpServer = HttpServer.create(address, 0)
httpServer.createContext("/", HttpHandlerImpl()) // handles any request /**
httpServer
} catch (e: IOException) {
throw IllegalStateException("impossible to start server: " + e.message, e)
}
}
private inner class HttpHandlerImpl : HttpHandler {
@Throws(IOException::class)
override fun handle(httpExchange: HttpExchange) {
val path: String = httpExchange.requestURI.getRawSchemeSpecificPart()
try {
val responseProvider: Function<HttpExchange, Response>? =
pathPatternToResponseProvider.keys.stream()
.filter { pathPattern: Pattern -> pathPattern.matcher(path).matches() }
.findFirst()
.map<Function<HttpExchange, Response>?> { pathPattern: Pattern -> pathPatternToResponseProvider[pathPattern] }
.orElse(null)
if (responseProvider != null) {
val response: Response = responseProvider.apply(httpExchange)
if (response.headers != null) {
httpExchange.responseHeaders.putAll(response.headers)
}
if (response.body == null || response.status == 204) {
httpExchange.sendResponseHeaders(response.status, -1)
} else {
val body = response.body
httpExchange.sendResponseHeaders(response.status, body.length.toLong())
httpExchange.responseBody.use { os -> os.write(body.toByteArray()) }
logger.info("${httpExchange.requestMethod} $path >>> ${response.status} $body")
return
}
}
httpExchange.sendResponseHeaders(404, -1)
logger.warn("${httpExchange.requestMethod} $path >>> 404 NOT_FOUND")
} catch (e: Exception) {
httpExchange.sendResponseHeaders(500, -1)
logger.error("${httpExchange.requestMethod} $path >>> 500 SERVER_ERROR ${e.javaClass.getSimpleName()} ${e.message}")
}
}
}
private fun findFreePort(): Int {
while (true) {
try {
ServerSocket(0).use { it.getLocalPort() }
} catch (ignore: IOException) {
}
}
}
fun start(): Server {
if (httpServer == null) {
logger.debug("Starting on http://localhost:{}", this.port)
httpServer = createHttpServer()
httpServer!!.start()
logger.info("Started on http://localhost:{}", this.port)
for (pathPattern in pathPatternToResponseProvider.keys) {
logger.debug("Active url http://localhost:{}{}", this.port, pathPattern.toString())
}
} else {
logger.debug("Already running on http://localhost:{}", this.port)
}
return this
}
fun stop() {
if (httpServer != null) {
logger.debug("Stopping on http://localhost:{}", this.port)
httpServer!!.stop(0)
httpServer = null
logger.info("Stopped on http://localhost:{}", this.port)
} else {
logger.debug("Already stopped on http://localhost:{}", this.port)
}
}
companion object : KLogging()
}
/**
* Generates a valid KeyPair. If any, re-uses previously created key.
*/
private class CertProvider(private val certDir: File, private val algorithm: String = "RSA") {
val keyPair: KeyPair by lazy {
val privateKeyFile = File(certDir, "$algorithm.key")
val publicKeyFile = File(certDir, "$algorithm.pub")
if (privateKeyFile.exists() && publicKeyFile.exists()) {
val kf = KeyFactory.getInstance(algorithm)
val privateKeyBytes = privateKeyFile.readBytes()
val publicKeyBytes = publicKeyFile.readBytes()
val pvt = kf.generatePrivate(PKCS8EncodedKeySpec(privateKeyBytes))
val pub = kf.generatePublic(X509EncodedKeySpec(publicKeyBytes))
KeyPair(pub, pvt)
} else {
// create parent dirs before writing!!
privateKeyFile.parentFile.mkdirs()
publicKeyFile.parentFile.mkdirs()
val kpg = KeyPairGenerator.getInstance(algorithm)
kpg.initialize(4096)
val kp = kpg.generateKeyPair()
privateKeyFile.writeBytes(kp.private.encoded)
publicKeyFile.writeBytes(kp.public.encoded)
kp
}
}
val certificate: X509Certificate by lazy {
val now = Instant.now()
val notBefore = Date.from(now)
val notAfter = Date.from(now.plus(Duration.ofDays(1)))
val contentSigner = JcaContentSignerBuilder("SHA256WithRSA").build(keyPair.private)
val x500Name = X500Name("CN=" + javaClass.getCanonicalName())
val certificate = JcaX509v3CertificateBuilder(
/* issuer = */ x500Name,
/* serial = */ BigInteger.valueOf(now.toEpochMilli()),
/* notBefore = */ notBefore,
/* notAfter = */ notAfter,
/* subject = */ x500Name,
/* publicKey = */ keyPair.public
)
.addExtension(Extension.basicConstraints, true, BasicConstraints(true))
.build(contentSigner)
JcaX509CertificateConverter()
.setProvider(BouncyCastleProvider())
.getCertificate(certificate)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment