-
-
Save stevenschlansker/0a1a8e2e6d773efd9a681ffb599eeb6d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import static javax.servlet.http.HttpServletResponse.SC_BAD_GATEWAY; | |
import static javax.servlet.http.HttpServletResponse.SC_GATEWAY_TIMEOUT; | |
import static javax.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR; | |
import static javax.servlet.http.HttpServletResponse.SC_SERVICE_UNAVAILABLE; | |
import java.io.IOException; | |
import java.net.SocketTimeoutException; | |
import java.net.URI; | |
import java.util.Collection; | |
import java.util.Collections; | |
import java.util.HashSet; | |
import java.util.List; | |
import java.util.Locale; | |
import java.util.Map; | |
import java.util.Map.Entry; | |
import java.util.Set; | |
import java.util.concurrent.CompletableFuture; | |
import java.util.concurrent.RejectedExecutionException; | |
import java.util.function.Supplier; | |
import javax.servlet.http.HttpServletRequest; | |
import javax.servlet.http.HttpServletResponse; | |
import com.google.common.base.Joiner; | |
import com.google.common.base.MoreObjects; | |
import com.google.common.collect.ImmutableMap; | |
import com.google.common.collect.Iterators; | |
import com.google.common.net.HttpHeaders; | |
import org.apache.commons.lang3.StringUtils; | |
import org.eclipse.jetty.client.HttpClient; | |
import org.eclipse.jetty.client.api.Request; | |
import org.eclipse.jetty.client.api.Response; | |
import org.eclipse.jetty.proxy.AsyncMiddleManServlet; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; | |
@SuppressFBWarnings("SE_BAD_FIELD") | |
public class RealFrontdoorProxyServlet extends AsyncMiddleManServlet { | |
private static final Logger LOG = LoggerFactory.getLogger(RealFrontdoorProxyServlet.class); | |
private static final long serialVersionUID = 1L; | |
private HttpClient httpClient; | |
public void setHttpClient(HttpClient httpClient) { | |
this.httpClient = httpClient; | |
} | |
@Override | |
protected HttpClient getHttpClient() { | |
return httpClient; | |
} | |
@SuppressWarnings("unchecked") | |
@Override | |
protected String rewriteTarget(HttpServletRequest clientRequest) { | |
final String nextUri = ((Supplier<URI>) clientRequest.getAttribute("UriSupplier")).get().toString(); | |
clientRequest.setAttribute(HttpClientConnectionProxy.DESTINATION_URI_KEY, nextUri); | |
return nextUri; | |
} | |
@Override | |
protected void addProxyHeaders(HttpServletRequest clientRequest, Request proxyRequest) { | |
final ProxyAction action = getAction(clientRequest); | |
action.getRequestHeaders().forEach(proxyRequest::header); | |
final String host; | |
if (action.getHost().isPresent()) { | |
host = action.getHost().get(); | |
} else { | |
host = clientRequest.getHeader(HttpHeaders.HOST); | |
} | |
if (host != null) { | |
proxyRequest.header(HttpHeaders.HOST, host); | |
} | |
addXForwardedHeaders(clientRequest, proxyRequest); | |
} | |
@Override | |
protected Set<String> findConnectionHeaders(HttpServletRequest clientRequest) { | |
final Set<String> removed = new HashSet<>(); | |
final Set<String> superConnHeaders = super.findConnectionHeaders(clientRequest); | |
if (superConnHeaders != null) { | |
removed.addAll(superConnHeaders); | |
} | |
final Set<String> ignorePrefix = new HashSet<String>((Collection<String>) clientRequest.getAttribute("IgnoreHeaderPrefix")); | |
ignorePrefix.add("X-Forwarded-"); | |
Iterators.forEnumeration(clientRequest.getHeaderNames()) | |
.forEachRemaining(h -> { | |
for (String prefix : ignorePrefix) { | |
if (StringUtils.startsWithIgnoreCase(h, prefix)) { | |
removed.add(h.toLowerCase(Locale.ROOT)); | |
} | |
} | |
}); | |
return removed; | |
} | |
@Override | |
protected void addXForwardedHeaders(HttpServletRequest clientRequest, Request proxyRequest) { | |
final List<String> forwardFor = Collections.list(clientRequest.getHeaders(HttpHeaders.X_FORWARDED_FOR)); | |
forwardFor.add(clientRequest.getRemoteAddr()); | |
proxyRequest.header(HttpHeaders.X_FORWARDED_FOR, Joiner.on(", ").join(forwardFor)); | |
proxyRequest.header(HttpHeaders.X_FORWARDED_PROTO, MoreObjects.firstNonNull( | |
clientRequest.getHeader(HttpHeaders.X_FORWARDED_PROTO), | |
clientRequest.isSecure() ? "https" : "http")); | |
} | |
@Override | |
protected void onServerResponseHeaders(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) { | |
super.onServerResponseHeaders(clientRequest, proxyResponse, serverResponse); | |
getAction(clientRequest).getResponseHeaders().forEach(proxyResponse::addHeader); | |
} | |
@Override | |
protected void onClientRequestFailure(HttpServletRequest clientRequest, Request proxyRequest, HttpServletResponse proxyResponse, Throwable failure) { | |
LOG.info("client request fail: {}", clientRequest, failure); | |
sendError(proxyResponse, failure); | |
finished(clientRequest).complete(true); | |
} | |
@SuppressWarnings("unchecked") | |
private CompletableFuture<Boolean> finished(HttpServletRequest clientRequest) { | |
return (CompletableFuture<Boolean>) clientRequest.getAttribute("Finished"); | |
} | |
private ProxyAction getAction(HttpServletRequest clientRequest) { | |
return (ProxyAction) clientRequest.getAttribute("ProxyAction"); | |
} | |
@Override | |
protected void onProxyResponseFailure(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse, Throwable failure) { | |
LOG.info("proxy response fail: {}", serverResponse, failure); | |
sendError(proxyResponse, failure); | |
finished(clientRequest).complete(true); | |
} | |
@Override | |
protected void onProxyResponseSuccess(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) { | |
LOG.info("proxy response success! {}", serverResponse); | |
finished(clientRequest).complete(true); | |
} | |
private void sendError(HttpServletResponse proxyResponse, Throwable failure) { | |
int status = SC_INTERNAL_SERVER_ERROR; | |
for (Entry<Class<? extends Throwable>, Integer> e : statusByCause.entrySet()) { | |
if (e.getKey().isInstance(failure)) { | |
status = e.getValue(); | |
} | |
} | |
try { | |
proxyResponse.sendError(status); | |
} catch (IOException e1) { | |
LOG.error("while sending error", e1); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment