Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Spring Framework servlet request parsing bug
/**
* Prepare a builder by copying the scheme, host, port, path, and
* query string of an HttpServletRequest.
*/
public static CorrectedServletUriComponentsBuilder fromRequest(HttpServletRequest request) {
String scheme = scheme(request);
String host = host(request);
int port = port(request);
CorrectedServletUriComponentsBuilder builder = new CorrectedServletUriComponentsBuilder();
builder.scheme(scheme);
builder.host(host);
if (!isDefaultSchemePort(scheme, port)) {
builder.port(port(request));
}
builder.pathFromRequest(request);
builder.query(request.getQueryString());
return builder;
}
private static final int HTTP_PORT_DEFAULT = 80;
private static final int HTTPS_PORT_DEFAULT = 443;
private static final String HTTP_SCHEME = "http";
private static final String HTTPS_SCHEME = "https";
static boolean isDefaultSchemePort(String scheme, int port) {
if (port == HTTP_PORT_DEFAULT) return scheme.equalsIgnoreCase(HTTP_SCHEME);
if (port == HTTPS_PORT_DEFAULT) return scheme.equalsIgnoreCase(HTTPS_SCHEME);
return false;
}
static int defaultSchemePort(String scheme) {
if (scheme.equalsIgnoreCase(HTTP_SCHEME)) return HTTP_PORT_DEFAULT;
if (scheme.equalsIgnoreCase(HTTPS_SCHEME)) return HTTPS_PORT_DEFAULT;
return -1;
}
static boolean isForwarded(HttpServletRequest request) {
String hostHeader = request.getHeader("X-Forwarded-Host");
return StringUtils.hasText(hostHeader);
}
static String scheme(HttpServletRequest request) {
if (isForwarded(request)) {
String protocolHeader = request.getHeader("X-Forwarded-Proto");
if (StringUtils.hasText(protocolHeader)) {
return protocolHeader;
}
}
return request.getScheme();
}
static String host(HttpServletRequest request) {
if (isForwarded(request)) {
String hostHeader = request.getHeader("X-Forwarded-Host");
String[] hosts = StringUtils.commaDelimitedListToStringArray(hostHeader);
String host = hosts[0];
if (host.contains(":")) {
return StringUtils.split(host, ":")[0];
} else {
return host;
}
}
// The request was not forwarded.
return request.getServerName();
}
static int port(HttpServletRequest request) {
int port = -1;
if (isForwarded(request)) {
String portHeader = request.getHeader("X-Forwarded-Port");
if (StringUtils.hasText(portHeader)) {
return Integer.parseInt(portHeader);
}
String hostHeader = request.getHeader("X-Forwarded-Host");
String[] hosts = StringUtils.commaDelimitedListToStringArray(hostHeader);
String host = hosts[0];
if (host.contains(":")) {
port = Integer.valueOf(StringUtils.split(host, ":")[1]);
}
return port;
}
port = request.getServerPort();
if (port == -1) {
return defaultSchemePort(scheme(request));
}
return port;
}
@Test
public void fromRequestWithForwardedHostWithDefaultPort() {
request.setServerPort(10080);
request.addHeader("X-Forwarded-Host", "webtest.foo.bar.com");
request.setRequestURI("/mvc-showcase/data/param");
request.setQueryString("foo=123");
UriComponents result = CorrectedServletUriComponentsBuilder.fromRequest(request).build();
assertEquals("webtest.foo.bar.com", result.getHost());
assertEquals("Should have used the default port of the forwarded request.",
-1, result.getPort());
}
@Test
public void fromRequestWithForwardedHostWithForwardedScheme() {
request.setScheme("http");
request.setServerPort(10080);
request.addHeader("X-Forwarded-Proto", "https");
request.addHeader("X-Forwarded-Host", "webtest.foo.bar.com");
request.setRequestURI("/mvc-showcase/data/param");
request.setQueryString("foo=123");
UriComponents result = CorrectedServletUriComponentsBuilder.fromRequest(request).build();
assertEquals("webtest.foo.bar.com", result.getHost());
assertEquals("Should have derived scheme from header.",
"https", result.getScheme());
assertEquals("Should have used the default port of the forwarded request.",
-1, result.getPort());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment