Skip to content

Instantly share code, notes, and snippets.

@angelmartz
Forked from ttddyy/AsyncResourceRetriever.java
Created February 22, 2024 22:14
Show Gist options
  • Save angelmartz/2fa74dd0a7e955b5dba78a2c45bc1d74 to your computer and use it in GitHub Desktop.
Save angelmartz/2fa74dd0a7e955b5dba78a2c45bc1d74 to your computer and use it in GitHub Desktop.
Async JWKS retriever
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import com.nimbusds.jose.util.Resource;
import com.nimbusds.jose.util.ResourceRetriever;
import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* {@link ResourceRetriever} implementation to retrieve JWKS asynchronously.
*
* <p>
* Actual JWKS retrieval is delegated to the {@link ReactorJwkSetRetriever}.
*
* <h3>Initial Request
* <p>
* By default, making a request to obtain the JWKS occurs at the first request. If you
* want to pre-populate the JWKS, call {@link #updateJwkSet(URL, boolean)} method from
* outside. For example, call the method in
* {@link org.springframework.boot.ApplicationRunner} bean. So that, the JWKS is retrieved
* at start-up time but before the readiness becomes ready.
*
* <h3>Lifetime of retrieved JWKS
* <p>
* When a JWKS is retrieved, clock starts ticking. When it has passed the renewal time,
* any calls to {@link #retrieveResource(URL)} trigger an asynchronous JWKS
* retrieval.(Concurrency is controlled, so it only issue one request at the same time.)
* While performing asynchronous JWKS retrieval, any requests coming in still serve the
* existing JWKS. For some reason, the renew request fails, still existing JWKS is served
* until it passes expiration time. The expiration is the lifetime to serve the existing
* JWKS. Once the existing JWKS has reached to the expiration time, the JWKS will not be
* served anymore. The next call to {@link #retrieveResource(URL)} will perform
* synchronous JWKS retrieval and other requests are blocked until the retrieval finishes.
*
* <h3>Async JWKS retrieval
* <p>
* It is expected that the underlying {@link ReactorJwkSetRetriever} retrieves JWKS in
* different thread. For example, if {@code WebClient} is used, the retrieval operation
* happens in a reactor-netty thread. If retrieval does not happen in a separate thread,
* specify a {@link Scheduler} via {@link #setScheduler(Scheduler)}. The scheduler is used
* for subscribing the {@link ReactorJwkSetRetriever#retrieve(URL)} operation for async
* JWKS retrieval.
*
* @author Tadaya Tsuyukubo
*/
@Slf4j
public class AsyncResourceRetriever implements ResourceRetriever {
private final ReactorJwkSetRetriever reactorJwkSetRetriever;
private Duration expirationDuration = Duration.ofMinutes(15);
private Duration renewalDuration = Duration.ofMinutes(5);
private Clock clock = Clock.systemDefaultZone();
// a scheduler to subscribe async jwks retrieval
@Nullable
private Scheduler scheduler;
@Nullable
private JwkSetContext jwkSetContext;
public AsyncResourceRetriever(ReactorJwkSetRetriever reactorJwkSetRetriever) {
this.reactorJwkSetRetriever = reactorJwkSetRetriever;
}
@Override
public Resource retrieveResource(URL url) throws IOException {
Instant now = this.clock.instant();
// Retrieve JWKS synchronously(blocking) if:
// - JWKS is empty(first time deferred initialization), or
// - Failed to renew the jwks and reached to the expiration
BooleanSupplier obtainCondition = () -> this.jwkSetContext == null || this.jwkSetContext.isExpired(now);
// Trigger JWKS renewal if it has passed the renewal time
BooleanSupplier renewCondition = () -> {
Assert.notNull(this.jwkSetContext, "jwkSetContext can not be null");
return this.jwkSetContext.requireRenew(now);
};
// first time or failed to renew and reached to the expiration
if (obtainCondition.getAsBoolean()) {
// blocking call to populate resource
updateJwkSet(obtainCondition, false, now, url);
}
// check whether renew is needed
else if (renewCondition.getAsBoolean()) {
// async call to get JWKS
updateJwkSet(renewCondition, true, now, url);
}
Assert.notNull(this.jwkSetContext, "jwkSetContext can not be null");
return this.jwkSetContext.getResource();
}
/**
* Force retrieve JWKS.
* <p>
* When {@code async} parameter is set to {@code true}, this method immediately
* returns and JWKS is retrieved asynchronously.
* <p>
* Provided {@code url} parameter may or may not be used based on the underlying
* {@link ReactorJwkSetRetriever} implementation.
* @param url JWKS endpoint. may or may not be used.
* @param async {@code true} to retrieve JWKS asynchronously.
* @throws RuntimeException when JWKS retrieval failed
*/
public void updateJwkSet(URL url, boolean async) {
try {
updateJwkSet(() -> true, async, this.clock.instant(), url);
}
catch (IOException ex) {
throw new RuntimeException(ex);
}
}
/**
* Perform JWKS update under the lock(synchronized).
* @param condition condition to proceed(make a call)
* @param async asynchronous or not
* @param now current time
* @param url JWKS endpoint. may or may not be used.
* @throws IOException when retrieved jwks is empty(null)
*/
private synchronized void updateJwkSet(BooleanSupplier condition, boolean async, Instant now, URL url)
throws IOException {
// double checked lock
if (!condition.getAsBoolean()) {
return;
}
// upon successful JWKS retrieval, populate JwkSetContext
Consumer<String> consumer = (jwkSet) -> {
Instant nextRenew = now.plus(this.renewalDuration);
Instant expiration = now.plus(this.expirationDuration);
this.jwkSetContext = new JwkSetContext(jwkSet, nextRenew, expiration);
log.debug("Updated the current JWKS. time={}, renew={}, expiration={}", now, nextRenew, expiration);
};
if (async) {
log.debug("Making an asynchronous call to retrieve JWKS");
Mono<String> mono = this.reactorJwkSetRetriever.retrieve(url);
if (this.scheduler != null) {
mono = mono.subscribeOn(this.scheduler);
}
mono.subscribe(consumer);
}
else {
// blocking call to retrieve JWKS
log.debug("Making a blocking call to retrieve JWKS");
String jwkSet = this.reactorJwkSetRetriever.retrieve(url).block();
if (jwkSet == null) {
throw new IOException("Retrieved JWKS is null");
}
consumer.accept(jwkSet);
}
}
/**
* Max duration to use the current JWKS.
*
* <p>
* When a new JWKS is retrieved, calculate the max time(lifetime) by adding this
* duration to the current time(retrieval time).
*
* <p>
* In normal case, the JWKS renewal comes before this expiration time. However, for
* any reason if renewal failed, the current JWKS is served until it reaches to this
* max time.
* @param expirationDuration duration to expire the token
*/
public void setExpirationDuration(Duration expirationDuration) {
this.expirationDuration = expirationDuration;
}
/**
* Time to start triggering JWKS renewal asynchronously.
*
* <p>
* When a new JWKS is retrieved, calculate the renewal time by adding this duration to
* the current time(retrieval time).
*
* <p>
* When a request comes in and the current time has passed the renewal time, the
* request uses existing JWKS but also triggers the JWKS renewal asynchronously.
* @param renewalDuration duration to renew the token
*/
public void setRenewalDuration(Duration renewalDuration) {
this.renewalDuration = renewalDuration;
}
public void setClock(Clock clock) {
this.clock = clock;
}
/**
* {@link Scheduler} to perform the async
* {@link AsyncResourceRetriever#retrieveResource(URL)} operation.
* <p>
* When the {@link AsyncResourceRetriever} implementation performs JWKS retrieval in
* separate thread (e.g. {@code WebClient} uses reactor-netty thread), it is not
* required to set this scheduler. On the other hand, if JWKS retrieval implementation
* does not use a different thread, you need to specify a {@link Scheduler} to perform
* the JWKS retrieval on it. Otherwise, the retrieval operation happens on the
* caller's thread due to the invocation of {@link Mono#subscribe()}.
* @param scheduler a scheduler to use
*/
public void setScheduler(Scheduler scheduler) {
this.scheduler = scheduler;
}
/**
* Data object + convenient methods.
*/
private static class JwkSetContext {
private final Instant expiration;
private final Instant nextRenew;
private final Resource resource;
public JwkSetContext(String jwkSet, Instant nextRenew, Instant expiration) {
this.nextRenew = nextRenew;
this.expiration = expiration;
this.resource = new Resource(jwkSet, StandardCharsets.UTF_8.name());
}
public boolean isExpired(Instant now) {
return now.isAfter(this.expiration);
}
public boolean requireRenew(Instant now) {
return now.isAfter(this.nextRenew);
}
public Resource getResource() {
return this.resource;
}
}
}
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.net.URL;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import com.nimbusds.jose.util.Resource;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
/**
* Test for {@link AsyncResourceRetriever}.
*
* @author Tadaya Tsuyukubo
*/
class AsyncResourceRetrieverTest {
@Test
void deferredInitialization() throws Exception {
ReactorJwkSetRetriever jwksRetriever = mock(ReactorJwkSetRetriever.class);
Mono<String> mono = Mono.just("jwks");
URL url = new URL("http://example.com");
when(jwksRetriever.retrieve(url)).thenReturn(mono);
AsyncResourceRetriever retriever = new AsyncResourceRetriever(jwksRetriever);
Resource resource = retriever.retrieveResource(url);
assertThat(resource.getContent()).isEqualTo("jwks");
// first call invokes jwksRetriever
verify(jwksRetriever).retrieve(url);
resource = retriever.retrieveResource(url);
assertThat(resource.getContent()).isEqualTo("jwks");
// verify only called once
verify(jwksRetriever).retrieve(url);
}
@Test
void initializedBeforeAccess() throws Exception {
ReactorJwkSetRetriever jwksRetriever = mock(ReactorJwkSetRetriever.class);
Mono<String> mono = Mono.just("jwks");
URL url = new URL("http://example.com");
when(jwksRetriever.retrieve(url)).thenReturn(mono);
AsyncResourceRetriever retriever = new AsyncResourceRetriever(jwksRetriever);
// initialize
retriever.updateJwkSet(url, false);
verify(jwksRetriever).retrieve(url);
Resource resource = retriever.retrieveResource(url);
assertThat(resource.getContent()).isEqualTo("jwks");
// verify "retrieveResource" does not retrieve JWKS
verify(jwksRetriever).retrieve(url);
}
@Test
void renew() throws Exception {
ReactorJwkSetRetriever jwksRetriever = mock(ReactorJwkSetRetriever.class);
Mono<String> mono = Mono.just("jwks");
URL url = new URL("http://example.com");
when(jwksRetriever.retrieve(url)).thenReturn(mono);
AsyncResourceRetriever retriever = new AsyncResourceRetriever(jwksRetriever);
retriever.setExpirationDuration(Duration.ofDays(1));
retriever.setRenewalDuration(Duration.ofMinutes(30));
// Since jwks retrieval for this test doesn't run on a separate thread, specify
// a scheduler to perform the retrieval.
retriever.setScheduler(Schedulers.boundedElastic());
// initialize
retriever.updateJwkSet(url, false);
reset(jwksRetriever);
// mimic slow JWKS response (specify a scheduler to make the op run async)
AtomicBoolean asyncCallTriggered = new AtomicBoolean();
AtomicBoolean latch = new AtomicBoolean();
Mono<String> monoRenewed = Mono.fromSupplier(() -> {
asyncCallTriggered.set(true);
await().untilTrue(latch);
return "jwks-renewed";
});
when(jwksRetriever.retrieve(url)).thenReturn(monoRenewed);
// set the clock to mimic it has passed the renewal but not expiration
Clock clock = mock(Clock.class);
when(clock.instant()).thenReturn(Instant.now().plus(Duration.ofHours(1)));
retriever.setClock(clock);
// should serve the old one
Resource resource = retriever.retrieveResource(url);
assertThat(resource.getContent()).isEqualTo("jwks");
// verify an async call has made
await().untilTrue(asyncCallTriggered);
// release the JWKS call
latch.set(true);
// verify "retrieveResource" has triggered the call
verify(jwksRetriever).retrieve(url);
reset(jwksRetriever);
when(jwksRetriever.retrieve(url)).thenThrow(new RuntimeException("should not be called"));
// since renew happens asynchronously
await().until(() -> {
String content = retriever.retrieveResource(url).getContent();
return "jwks-renewed".equals(content);
});
verifyNoInteractions(jwksRetriever);
}
@Test
void failedRenewal() throws Exception {
ReactorJwkSetRetriever jwksRetriever = mock(ReactorJwkSetRetriever.class);
Mono<String> mono = Mono.just("jwks");
URL url = new URL("http://example.com");
when(jwksRetriever.retrieve(url)).thenReturn(mono);
AsyncResourceRetriever retriever = new AsyncResourceRetriever(jwksRetriever);
retriever.setExpirationDuration(Duration.ofDays(1));
retriever.setRenewalDuration(Duration.ofMinutes(30));
// Since jwks retrieval for this test doesn't run on a separate thread, specify
// a scheduler to perform the retrieval.
retriever.setScheduler(Schedulers.boundedElastic());
// initialize
retriever.updateJwkSet(url, false);
reset(jwksRetriever);
// mimic JWKS call failure
AtomicBoolean firstCall = new AtomicBoolean();
Mono<String> firstRenewMono = Mono.error(() -> {
firstCall.set(true);
return new RuntimeException("failure first");
});
when(jwksRetriever.retrieve(url)).thenReturn(firstRenewMono);
// set the clock to mimic it has passed the renewal but not expiration
Clock clock = mock(Clock.class);
when(clock.instant()).thenReturn(Instant.now().plus(Duration.ofHours(1)));
retriever.setClock(clock);
// should serve the old one
Resource resource = retriever.retrieveResource(url);
assertThat(resource.getContent()).isEqualTo("jwks");
// verify an async call has made
await().untilTrue(firstCall);
// verify "retrieveResource" triggered the call
verify(jwksRetriever).retrieve(url);
reset(jwksRetriever);
// again mock the async call failure
AtomicBoolean secondCall = new AtomicBoolean();
Mono<String> secondRenewMono = Mono.error(() -> {
secondCall.set(true);
return new RuntimeException("failure second");
});
when(jwksRetriever.retrieve(url)).thenReturn(secondRenewMono);
// second call should again serve the old one
resource = retriever.retrieveResource(url);
assertThat(resource.getContent()).isEqualTo("jwks");
// verify async call made again
verify(jwksRetriever).retrieve(url);
}
@Test
void expired() throws Exception {
ReactorJwkSetRetriever jwksRetriever = mock(ReactorJwkSetRetriever.class);
Mono<String> mono = Mono.just("jwks");
URL url = new URL("http://example.com");
when(jwksRetriever.retrieve(url)).thenReturn(mono);
AsyncResourceRetriever retriever = new AsyncResourceRetriever(jwksRetriever);
retriever.setExpirationDuration(Duration.ofDays(1));
retriever.setRenewalDuration(Duration.ofMinutes(30));
// initialize
retriever.updateJwkSet(url, false);
reset(jwksRetriever);
AtomicBoolean monoCalled = new AtomicBoolean();
AtomicBoolean latch = new AtomicBoolean();
Mono<String> monoRenewed = Mono.fromSupplier(() -> {
monoCalled.set(true);
await().untilTrue(latch);
return "jwks-renewed";
});
when(jwksRetriever.retrieve(url)).thenReturn(monoRenewed);
// set the clock to mimic it has passed the expiration
Clock clock = mock(Clock.class);
when(clock.instant()).thenReturn(Instant.now().plus(Duration.ofDays(2)));
retriever.setClock(clock);
// make a call
AtomicBoolean callFinished = new AtomicBoolean();
AtomicReference<String> callResult = new AtomicReference<>();
new Thread(() -> {
try {
String result = retriever.retrieveResource(url).getContent();
callResult.set(result);
callFinished.set(true);
}
catch (IOException ex) {
callResult.set("FAILED");
}
}).start();
// verify jwks call has made
await().untilTrue(monoCalled);
// synchronous call has not finished
assertThat(callFinished).isFalse();
latch.set(true);
// synchronous call finished
await().untilTrue(callFinished);
assertThat(callResult).hasValue("jwks-renewed");
}
@Test
void initialCallBlocksOthers() throws Exception {
ReactorJwkSetRetriever jwksRetriever = mock(ReactorJwkSetRetriever.class);
URL url = new URL("http://example.com");
AtomicBoolean asyncCallTriggered = new AtomicBoolean();
AtomicInteger counter = new AtomicInteger();
AtomicBoolean latch = new AtomicBoolean();
Mono<String> mono = Mono.fromSupplier(() -> {
asyncCallTriggered.set(true);
counter.incrementAndGet();
await().untilTrue(latch);
return "jwks";
});
when(jwksRetriever.retrieve(url)).thenReturn(mono);
AsyncResourceRetriever retriever = new AsyncResourceRetriever(jwksRetriever);
// make initial call
AtomicBoolean firstCallFinished = new AtomicBoolean();
AtomicReference<String> firstCallResult = new AtomicReference<>();
new Thread(() -> {
try {
String result = retriever.retrieveResource(url).getContent();
firstCallResult.set(result);
firstCallFinished.set(true);
}
catch (IOException ex) {
firstCallResult.set("FAILED");
}
}).start();
// make initial call
AtomicBoolean secondCallFinished = new AtomicBoolean();
AtomicReference<String> secondCallResult = new AtomicReference<>();
new Thread(() -> {
try {
String result = retriever.retrieveResource(url).getContent();
secondCallResult.set(result);
secondCallFinished.set(true);
}
catch (IOException ex) {
secondCallResult.set("FAILED");
}
}).start();
await().untilTrue(asyncCallTriggered);
assertThat(firstCallFinished).isFalse();
assertThat(secondCallFinished).isFalse();
latch.set(true);
await().untilTrue(firstCallFinished);
await().untilTrue(secondCallFinished);
assertThat(counter).hasValue(1);
assertThat(firstCallResult).hasValue("jwks");
assertThat(secondCallResult).hasValue("jwks");
verify(jwksRetriever).retrieve(url);
}
}
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.net.URI;
import java.net.URL;
import reactor.core.publisher.Mono;
/**
* Asynchronously retrieve JWKS.
*
* @author Tadaya Tsuyukubo
*/
@FunctionalInterface
public interface ReactorJwkSetRetriever {
Mono<String> retrieve(URL url);
}
// Configuration to integrate the AsyncResourceRetriever bean
@Bean
public JwtDecoder myJwtDecoder(OAuth2TokenValidator<Jwt> jwtValidator, AsyncResourceRetriever asyncResourceRetriever) throws Exception {
String jwksEndpoint = "...";
// use custom async JWKS retriever
JWSKeySelector<SecurityContext> keySelector = createJWSKeySelector(jwksEndpoint, asyncResourceRetriever);
JwkSetUriJwtDecoderBuilder builder = NimbusJwtDecoder.withJwkSetUri(jwksEndpoint)
.jwtProcessorCustomizer(processor -> processor.setJWSKeySelector(keySelector));
NimbusJwtDecoder jwtDecoder = builder.build();
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
}
// Similar to "jwkSource()" and "jwsKeySelector()" in "NimbusJwtDecoder"
private JWSKeySelector<SecurityContext> createJWSKeySelector(String jwksEndpoint,
AsyncResourceRetriever asyncResourceRetriever) throws MalformedURLException {
URL jwksUrl = new URL(jwksEndpoint);
JWKSource<SecurityContext> jwkSource = new RemoteJWKSet<>(jwksUrl, asyncResourceRetriever,
new NoOpJwkSetCache());
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
}
// Then, specify the JwtDecoder bean
http.oauth2ResourceServer()
.jwt()
.decoder(myJwtDecoder)
.....
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment