Skip to content

Instantly share code, notes, and snippets.

@tschlegel
Last active November 5, 2021 12:22
Show Gist options
  • Save tschlegel/2c96146cdb57e57b6acad26615868ecf to your computer and use it in GitHub Desktop.
Save tschlegel/2c96146cdb57e57b6acad26615868ecf to your computer and use it in GitHub Desktop.
package org.springframework.security.oauth2.client.web.reactive.function.client;
import java.net.URI;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.reactive.ClientHttpConnector;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.ExchangeFunctions;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
/**
* Token Relay Gateway Filter with automatic access token refresh. This can be removed when issue {@see https://github.com/spring-cloud/spring-cloud-security/issues/175} is closed.
* Implementation based on {@link ServerOAuth2AuthorizedClientExchangeFilterFunction}
*/
@Component
public class TokenRelayWithRefreshGatewayFilterFactory extends AbstractGatewayFilterFactory<Object> {
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private final OAuth2AuthorizedClientResolver authorizedClientResolver;
private final ClientHttpConnector connector;
private Clock clock = Clock.systemUTC();
private Duration accessTokenExpiresSkew = Duration.ofSeconds(3);
public TokenRelayWithRefreshGatewayFilterFactory(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, ReactiveClientRegistrationRepository clientRegistrationRepository, ClientHttpConnector connector) {
super(Object.class);
this.authorizedClientRepository = authorizedClientRepository;
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
this.connector = connector;
}
public GatewayFilter apply() {
return apply((Object) null);
}
@Override
public GatewayFilter apply(Object config) {
return (exchange, chain) -> exchange.getPrincipal()
// .log("token-relay-filter")
.filter(principal -> principal instanceof OAuth2AuthenticationToken)
.cast(OAuth2AuthenticationToken.class)
.flatMap(authentication -> loadAuthorizedClient(exchange, authentication))
.map(OAuth2AuthorizedClient::getAccessToken)
.map(token -> withBearerAuth(exchange, token))
// TODO: adjustable behavior if empty
.defaultIfEmpty(exchange).flatMap(chain::filter);
}
private ServerWebExchange withBearerAuth(ServerWebExchange exchange,
OAuth2AccessToken accessToken) {
return exchange.mutate()
.request(r -> r.headers(
headers -> headers.setBearerAuth(accessToken.getTokenValue())))
.build();
}
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ServerWebExchange exchange, String clientRegistrationId, OAuth2AuthorizedClient authorizedClient) {
if (shouldRefresh(authorizedClient)) {
return createRequest(exchange, clientRegistrationId)
.flatMap(r -> refreshAuthorizedClient(authorizedClient, r));
}
return Mono.just(authorizedClient);
}
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(ServerWebExchange exchange, OAuth2AuthenticationToken oAuth2AuthenticationToken) {
String clientRegistrationId = oAuth2AuthenticationToken.getAuthorizedClientRegistrationId();
return createRequest(exchange, clientRegistrationId)
.flatMap(authorizedClientResolver::loadAuthorizedClient)
.flatMap(authClient -> refreshIfNecessary(exchange, clientRegistrationId, authClient));
}
private Mono<OAuth2AuthorizedClientResolver.Request> createRequest(ServerWebExchange exchange, String clientRegistrationId) {
Authentication authentication = null;
return authorizedClientResolver.createDefaultedRequest(clientRegistrationId, authentication, exchange);
}
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
if (this.authorizedClientRepository == null) {
return false;
}
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
if (refreshToken == null) {
return false;
}
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
if (expiresAt != null && now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {
return true;
}
return false;
}
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
ServerWebExchange exchange = r.getExchange();
Authentication authentication = r.getAuthentication();
ClientRegistration clientRegistration = authorizedClient
.getClientRegistration();
String tokenUri = clientRegistration
.getProviderDetails().getTokenUri();
ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
.body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue()))
.build();
ExchangeFunction next = ExchangeFunctions.create(connector) ;
return next.exchange(refreshRequest)
.flatMap(refreshResponse -> refreshResponse.body(oauth2AccessTokenResponse()))
.map(accessTokenResponse -> {
OAuth2RefreshToken refreshToken = Optional.ofNullable(accessTokenResponse.getRefreshToken())
.orElse(authorizedClient.getRefreshToken());
return new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), refreshToken);
})
.flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
.thenReturn(result));
}
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
return BodyInserters
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
.with("refresh_token", refreshToken);
}
}
@gongchao8888
Copy link

Thank you very much. I created a package org.springframework.security.oauth2.client.web.reactive.function.client; it can work。
Thank you very much。
my english is poor。

@gongchao8888
Copy link

Thank you very much。 my code can refresh token

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment