Skip to content

Instantly share code, notes, and snippets.

@bmchild
Last active May 23, 2019 09:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save bmchild/9b53d0f0b648f5db4577 to your computer and use it in GitHub Desktop.
Save bmchild/9b53d0f0b648f5db4577 to your computer and use it in GitHub Desktop.
package com.bmchild.stack.commons.utilities.security;
import java.io.IOException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.httpclient.HttpClient;
import org.apache.commons.httpclient.HttpStatus;
import org.apache.commons.httpclient.NameValuePair;
import org.apache.commons.httpclient.methods.DeleteMethod;
import org.apache.commons.httpclient.methods.PostMethod;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
public final class CasAuthenticationUtils {
private static final Logger LOGGER = Logger.getLogger(CasAuthenticationUtils.class);
private static final int ONE_K = 1024;
private static final String INVALID_RESPONSE = "Invalid response code (%d) from CAS server!";
private static final String RESPONSE_ONE_K = "Response (1k): %s";
private CasAuthenticationUtils() {
throw new UnsupportedOperationException("CasAuthenticationUtils is a utility class and should not be instantiated.");
}
public static String getTicketGrantingTicket(final String casRestTicketUrl, final String securityToken) {
String ticket = null;
if (Base64.isBase64(securityToken.getBytes())){
String decodedToken = new String(Base64.decodeBase64(securityToken.getBytes()));
String[] tokens = StringUtils.split(decodedToken, ":");
if(tokens.length == 2){
ticket = getTicketGrantingTicket(casRestTicketUrl, tokens[0], tokens[1]);
}
}
return ticket;
}
public static String getTicketGrantingTicket(final String casRestTicketUrl, final String username, final String password) {
final HttpClient client = new HttpClient();
final PostMethod post = new PostMethod(casRestTicketUrl);
post.setRequestBody(new NameValuePair[] { new NameValuePair("username", username), new NameValuePair("password", password) });
try {
client.executeMethod(post);
final String response = post.getResponseBodyAsString();
switch (post.getStatusCode()) {
case HttpStatus.SC_CREATED:
final Matcher matcher = Pattern.compile(".*action=\".*/(.*?)\".*").matcher(response);
if (matcher.matches()) {
return matcher.group(1);
}
LOGGER.warn("Successful ticket granting request, but no ticket found!");
LOGGER.info(String.format(RESPONSE_ONE_K, response.substring(0, Math.min(ONE_K, response.length()))));
break;
default:
LOGGER.warn(String.format(INVALID_RESPONSE, post.getStatusCode()));
LOGGER.info(String.format(RESPONSE_ONE_K, response.substring(0, Math.min(ONE_K, response.length()))));
break;
}
} catch (final IOException e) {
LOGGER.warn(e.getMessage());
} finally {
post.releaseConnection();
}
return null;
}
public static String getServiceTicket(final String casRestTicketUrl, final String serviceUrl, final String ticketGrantingTicket) {
if (ticketGrantingTicket == null) {
return null;
}
final HttpClient client = new HttpClient();
final PostMethod post = new PostMethod(casRestTicketUrl + "/" + ticketGrantingTicket);
post.setRequestBody(new NameValuePair[] { new NameValuePair("service", serviceUrl) });
try {
client.executeMethod(post);
final String response = post.getResponseBodyAsString();
switch (post.getStatusCode()) {
case HttpStatus.SC_OK:
return response;
default:
LOGGER.warn(String.format(INVALID_RESPONSE, post.getStatusCode()));
LOGGER.info(String.format(RESPONSE_ONE_K, response.substring(0, Math.min(ONE_K, response.length()))));
break;
}
} catch (final IOException e) {
LOGGER.warn(e.getMessage());
} finally {
post.releaseConnection();
}
return null;
}
public static String deleteTicketGrantingTicket(String casRestTicketUrl, String ticketGrantingTicket) {
if (ticketGrantingTicket == null) {
throw new IllegalStateException("Ticket parameter cannot be null when logging it out");
}
final HttpClient client = new HttpClient();
final DeleteMethod delete = new DeleteMethod(casRestTicketUrl + "/" + ticketGrantingTicket);
try {
client.executeMethod(delete);
final String response = delete.getResponseBodyAsString();
switch (delete.getStatusCode()) {
case HttpStatus.SC_OK:
return response;
default:
LOGGER.warn(String.format(INVALID_RESPONSE, delete.getStatusCode()));
LOGGER.info(String.format(RESPONSE_ONE_K, response.substring(0, Math.min(ONE_K, response.length()))));
break;
}
} catch (final IOException e) {
LOGGER.warn(e.getMessage());
} finally {
delete.releaseConnection();
}
return null;
}
}
package com.bmchild.security.filter;
import java.io.IOException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.httpclient.HttpStatus;
import org.apache.commons.lang.Validate;
import org.apache.commons.lang3.StringUtils;
import org.apache.log4j.Logger;
import org.jasig.cas.client.util.AbstractCasFilter;
import com.bmchild.data.shared.common.util.RequestUtil;
import com.bmchild.stack.commons.utilities.security.CasAuthenticationUtils;
/**
* A filter that will authenticate a request via Basic Authorization. <br />
* It does this by using a REST call to CAS to get a ticket granting ticket.
* It will then get the service ticket and add it to the request
* to be used by CAS's Authentication filter down the chain.
* @author bchild
*
*/
public class CasBasicAuthenticationFilter extends AbstractCasFilter {
private static final Logger LOGGER = Logger.getLogger(CasBasicAuthenticationFilter.class.getName());
/**
* The URL to the CAS REST end point.
*/
private String casRestTicketUrl;
public CasBasicAuthenticationFilter(String casRestTicketUrl) {
this.casRestTicketUrl = casRestTicketUrl;
Validate.notEmpty(casRestTicketUrl, "casRestTicketUrl is required");
if(LOGGER.isDebugEnabled()) {
LOGGER.debug("casRestTicketUrl set to " + casRestTicketUrl);
}
}
public void doFilter(final ServletRequest req, final ServletResponse res, final FilterChain chain) throws IOException, ServletException {
final HttpServletRequest request = (HttpServletRequest) req;
final HttpServletResponse response = (HttpServletResponse) res;
String header = request.getHeader("Authorization");
if (header == null || !header.startsWith("Basic ")) {
chain.doFilter(request, response);
return;
}
final String tgt = CasAuthenticationUtils.getTicketGrantingTicket(casRestTicketUrl, getBase64Token(header));
if(tgt == null) {
response.setStatus(HttpStatus.SC_UNAUTHORIZED);
RequestUtil.sendJsonResponse(response, "error", "Bad Credentials");
return;
}
final String serviceUrl = getServiceUrl(request);
final String serviceTicket = CasAuthenticationUtils.getServiceTicket(casRestTicketUrl, serviceUrl, tgt);
if(serviceTicket == null) {
chain.doFilter(request, response);
return;
}
final String newQueryString = appendServiceTicket(request.getQueryString(), serviceTicket);
HttpServletRequestWrapper wrapper = addServiceTicketToRequest(request, serviceTicket, newQueryString);
chain.doFilter(wrapper, response);
}
private String getBase64Token(String header) {
return header.substring(6);
}
/**
* Reconstruct the requested URL in order to get our service ticket
*
* @param request
* @return
*/
private String getServiceUrl(HttpServletRequest request) {
StringBuilder serviceUrlBuilder = new StringBuilder().append(request.getRequestURL());
if(StringUtils.isNotEmpty(request.getQueryString())) {
serviceUrlBuilder.append("?").append(request.getQueryString());
}
return serviceUrlBuilder.toString();
}
/**
* The CAS {@link org.jasig.cas.client.authentication.AuthenticationFilter} needs the service ticket to do its thing
* @param request
* @param serviceTicket
* @param newQueryString
* @return
*/
private HttpServletRequestWrapper addServiceTicketToRequest(final HttpServletRequest request, final String serviceTicket,
final String newQueryString) {
return new HttpServletRequestWrapper(request){
@Override
public String getParameter(String name) {
if(getArtifactParameterName().equals(name)) {
return serviceTicket;
} else {
return super.getParameter(name);
}
}
@Override
public String getQueryString() {
return newQueryString;
}
};
}
/**
* Add the service ticket to the query string
* @param queryString
* @param serviceTicket
* @return
*/
private String appendServiceTicket(String queryString, String serviceTicket) {
if(StringUtils.isEmpty(queryString)) {
return "ticket=" + serviceTicket;
} else {
return queryString + "&ticket=" + serviceTicket;
}
}
}
<!-- An Example of how it might appear in the application context -->
<!-- CAS Filter Chain - Basic Authentication Filter must be before the org.springframework.security.cas.web.CasAuthenticationFilter -->
<bean id="springSecurityFilterChain" class="org.springframework.security.web.FilterChainProxy">
<sec:filter-chain-map path-type="ant">
<sec:filter-chain pattern="/**/role/user/**" filters="none" />
<sec:filter-chain pattern="/**/nosec/**" filters="none" />
<sec:filter-chain pattern="/" filters="casValidationFilter, wrappingFilter" />
<sec:filter-chain pattern="/j_spring_security_logout" filters="logoutFilter, etf" />
<sec:filter-chain pattern="/**" filters="casBasicAuthenticationFilter, casAuthenticationFilter, casValidationFilter, wrappingFilter, sif, j2eePreAuthFilter, logoutFilter, etf, fsi"/>
</sec:filter-chain-map>
</bean>
<!-- Basic Authentication username:password Base64 encoded in the request headers -->
<bean id="casBasicAuthenticationFilter" class="com.bmchild.security.filter.CasBasicAuthenticationFilter">
<constructor-arg index="0" ref="casTicketServiceUrl" />
</bean>
package com.bmchild.data.shared.common.util;
import java.io.IOException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.log4j.Logger;
import org.springframework.security.web.util.ELRequestMatcher;
import org.springframework.security.web.util.RequestMatcher;
/**
* Utilities related to server requests and responses
*
* @author bchild
*
*/
public final class RequestUtil {
private static final Logger LOGGER = Logger.getLogger(RequestUtil.class.getName());
public static final String JSON_VALUE = "{\"%s\": \"%s\"}";
private RequestUtil() {}
public static void sendJsonResponse(HttpServletResponse response, String key, String message) {
response.setContentType("application/json;charset=UTF-8");
response.setHeader("Cache-Control", "no-cache");
try {
response.getWriter().write(String.format(JSON_VALUE, key, message));
} catch (IOException e) {
LOGGER.error("error writing json to response", e);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment