Last active
April 28, 2017 19:27
-
-
Save BarDweller/224c2235212e955e0ed82252e509786d to your computer and use it in GitHub Desktop.
An attempt at a spring oauth2 / jwt filter that expects a jwt as the access token, and populates the authentication / user details from the jwt info.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private Filter ssoFilter() { | |
OAuth2ClientAuthenticationProcessingFilter appidFilter = new OAuth2ClientAuthenticationProcessingFilter( | |
"/login/appid"); | |
OAuth2RestTemplate appidTemplate = new OAuth2RestTemplate(appid(), oauth2ClientContext); | |
appidFilter.setRestTemplate(appidTemplate); | |
DefaultTokenServices tokenServices = new DebugTokenServices(); | |
try { | |
//obtain json web key for app id. | |
String location = getJwkUri(); | |
SimpleResponse simpleResponse = new Get().get(location); | |
System.out.println("S: "+simpleResponse.getStatusCode()+" B: "+simpleResponse.getBody()); | |
Map<String,Object> parsed = JsonUtil.parseJson(simpleResponse.getBody()); | |
RsaJsonWebKey rsa = new RsaJsonWebKey(parsed); | |
RSAPublicKey key = rsa.getRsaPublicKey(); | |
System.out.println("KEY: "+key); | |
JwtAccessTokenConverter converter = new JwtAccessTokenConverter(){ | |
//override decode, because our jwt didn't work with the default one.. | |
protected Map<String, Object> decode(String token) { | |
try { | |
JwtConsumer jwtConsumer = new JwtConsumerBuilder() | |
.setRequireExpirationTime() | |
.setAllowedClockSkewInSeconds(30) | |
.setVerificationKey(key) //set the key we loaded via the jwks endpoint | |
.setSkipDefaultAudienceValidation() | |
.setJwsAlgorithmConstraints(new AlgorithmConstraints(AlgorithmConstraints.ConstraintType.WHITELIST, | |
AlgorithmIdentifiers.RSA_USING_SHA256)) | |
.build(); | |
JwtClaims jwtClaims = jwtConsumer.processToClaims(token); | |
Map<String, Object> map = jwtClaims.getClaimsMap(); | |
if(map.containsKey("exp") && map.get("exp") instanceof Integer) { | |
Integer intValue = (Integer)map.get("exp"); | |
map.put("exp", new Long((long)intValue.intValue())); | |
} | |
return map; | |
} catch (Exception e) { | |
throw new InvalidTokenException("Cannot convert access token to JSON", e); | |
} | |
} | |
@Override | |
public OAuth2Authentication extractAuthentication(Map<String, ?> map) { | |
Map<String, String> parameters = new HashMap(); | |
Set<String> scope = Collections.emptySet(); | |
//String clientId = (String)map.get("client_id"); | |
//parameters.put("client_id", clientId); | |
String clientId = map.get("amr").toString(); | |
Set<String> resourceIds = Collections.emptySet(); | |
Collection<? extends GrantedAuthority> authorities = null; | |
UserDetails userdetails = new UserDetails(){ | |
@Override | |
public Collection<? extends GrantedAuthority> getAuthorities() { | |
return authorities; | |
} | |
@Override | |
public String getPassword() { | |
return null; | |
} | |
@Override | |
public String getUsername() { | |
return map.get("sub").toString(); | |
} | |
@Override | |
public boolean isAccountNonExpired() { | |
return true; | |
} | |
@Override | |
public boolean isAccountNonLocked() { | |
return true; | |
} | |
@Override | |
public boolean isCredentialsNonExpired() { | |
return true; | |
} | |
@Override | |
public boolean isEnabled() { | |
return true; | |
} | |
}; | |
Authentication user = new UsernamePasswordAuthenticationToken(userdetails, "N/A", authorities); | |
OAuth2Request request = new OAuth2Request(parameters, clientId, authorities, true, scope, resourceIds, (String)null, (Set)null, (Map)null); | |
return new OAuth2Authentication(request, user); | |
} | |
}; | |
JwtTokenStore store = new JwtTokenStore(converter); | |
//JwkTokenStore store = new JwkTokenStore(getJwkUri()); | |
tokenServices.setTokenStore(store); | |
appidFilter.setTokenServices(tokenServices); | |
return appidFilter; | |
}catch(IOException e){ | |
System.out.println("FAIL: "+e.getMessage()); | |
throw new RuntimeException(e); | |
}catch(JoseException e){ | |
System.out.println("FAIL: "+e.getMessage()); | |
throw new RuntimeException(e); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment