Skip to content

Instantly share code, notes, and snippets.

@noboomu
Created August 23, 2019 05:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save noboomu/097c2dae7e154e60c310aa19aeab3ab3 to your computer and use it in GitHub Desktop.
Save noboomu/097c2dae7e154e60c310aa19aeab3ab3 to your computer and use it in GitHub Desktop.
Provide HttpRequestInterceptor for ElasticSearch RestClient Using aws-sdk-java-v2
import com.google.inject.Inject;
import com.google.inject.name.Named;
import org.apache.http.*;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.message.BasicHeader;
import org.apache.http.protocol.HttpContext;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthRequest;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.common.Priority;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.regions.Region;
import javax.inject.Singleton;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Clock;
import java.util.*;
import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST;
public class AwsRequestSigningInterceptor implements HttpRequestInterceptor {
private Region region;
private String service;
private Aws4Signer signer;
private StaticCredentialsProvider awsCredentialsProvider;
private final Clock signingOverrideClock = Clock.systemDefaultZone();
/**
* @param service service that we're connecting to
* @param signer particular signer implementation
* @param awsCredentialsProvider source of AWS credentials for signing
*/
public AwsRequestSigningInterceptor(final String service, final Region region,
final Aws4Signer signer,
final StaticCredentialsProvider awsCredentialsProvider) {
this.service = service;
this.region = region;
this.signer = signer;
this.awsCredentialsProvider = awsCredentialsProvider;
}
/**
* {@inheritDoc}
*/
@Override
public void process(final HttpRequest request, final HttpContext context)
throws HttpException, IOException {
URIBuilder uriBuilder;
try {
uriBuilder = new URIBuilder(request.getRequestLine().getUri());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI", e);
}
// Copy Apache HttpRequest to AWS DefaultRequest
SdkHttpFullRequest.Builder sdkHttpFullRequestBuilder = SdkHttpFullRequest.builder();
HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST);
if (host != null) {
sdkHttpFullRequestBuilder.uri(URI.create(host.toURI()));
}
sdkHttpFullRequestBuilder.method(SdkHttpMethod.fromValue(request.getRequestLine().getMethod()));
try {
sdkHttpFullRequestBuilder.encodedPath(uriBuilder.build().getRawPath());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI", e);
}
sdkHttpFullRequestBuilder.contentStreamProvider(() -> new ByteArrayInputStream(new byte[0]));
if (request instanceof HttpEntityEnclosingRequest) {
HttpEntityEnclosingRequest httpEntityEnclosingRequest =
(HttpEntityEnclosingRequest) request;
if (httpEntityEnclosingRequest.getEntity() != null) {
sdkHttpFullRequestBuilder.contentStreamProvider(new ContentStreamProvider() {
@Override
public InputStream newStream() {
try {
return httpEntityEnclosingRequest.getEntity().getContent();
} catch (Exception e) {
logger.error("Error getting content stream", e);
return null;
}
}
});
}
}
sdkHttpFullRequestBuilder.rawQueryParameters(nvpToMapParams(uriBuilder.getQueryParams()));
Arrays.stream(request.getAllHeaders()).forEach(h -> {
sdkHttpFullRequestBuilder.appendHeader(h.getName(), h.getValue());
});
// Sign it
SdkHttpFullRequest fullRequest = sdkHttpFullRequestBuilder.build();
Aws4SignerParams signerParams = Aws4SignerParams.builder()
.signingRegion(region)
.signingName(service)
.signingClockOverride(signingOverrideClock)
.awsCredentials(awsCredentialsProvider.resolveCredentials())
.build();
fullRequest = signer.sign(fullRequest, signerParams);
request.setHeaders(mapToHeaderArray(fullRequest.headers()));
if (request instanceof HttpEntityEnclosingRequest) {
HttpEntityEnclosingRequest httpEntityEnclosingRequest =
(HttpEntityEnclosingRequest) request;
if (httpEntityEnclosingRequest.getEntity() != null) {
BasicHttpEntity basicHttpEntity = new BasicHttpEntity();
basicHttpEntity.setContent(fullRequest.contentStreamProvider().get().newStream());
httpEntityEnclosingRequest.setEntity(basicHttpEntity);
}
}
}
/**
* @param params list of HTTP query params as NameValuePairs
* @return a multimap of HTTP query params
*/
private Map<String, List<String>> nvpToMapParams(final List<NameValuePair> params) {
Map<String, List<String>> parameterMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
for (NameValuePair nvp : params) {
List<String> argsList =
parameterMap.computeIfAbsent(nvp.getName(), k -> new ArrayList<>());
argsList.add(nvp.getValue());
}
return parameterMap;
}
/**
* @param headers modeled Header objects
* @return a Map of header entries
*/
private Map<String, List<String>> headerArrayToMap(final Header[] headers) {
Map<String, List<String>> headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
for (Header header : headers) {
if (!skipHeader(header)) {
headersMap.put(header.getName(), List.of(header.getValue()));
}
}
return headersMap;
}
/**
* @param header header line to check
* @return true if the given header should be excluded when signing
*/
private boolean skipHeader(final Header header) {
return ("content-length".equalsIgnoreCase(header.getName())
&& "0".equals(header.getValue())) // Strip Content-Length: 0
|| "host".equalsIgnoreCase(header.getName()); // Host comes from endpoint
}
/**
* @param mapHeaders Map of header entries
* @return modeled Header objects
*/
private Header[] mapToHeaderArray(final Map<String, List<String>> mapHeaders) {
Header[] headers = new Header[mapHeaders.size()];
int i = 0;
for (Map.Entry<String, List<String>> headerEntry : mapHeaders.entrySet()) {
headers[i++] = new BasicHeader(headerEntry.getKey(), headerEntry.getValue().get(0));
}
return headers;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment