-
-
Save missingfaktor/664b01b7989e91033657e2e703f14ed6 to your computer and use it in GitHub Desktop.
AWS Signed client for elastic4s (5.4.x) using aws-request-signer (https://github.com/ticofab/aws-request-signer)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.time.{LocalDateTime, ZoneId} | |
import com.amazonaws.auth.{AWSCredentialsProvider, BasicAWSCredentials, InstanceProfileCredentialsProvider} | |
import com.amazonaws.internal.StaticCredentialsProvider | |
import com.amazonaws.util.IOUtils | |
import com.sksamuel.elastic4s.ElasticsearchClientUri | |
import com.sksamuel.elastic4s.http.{HttpClient, NoOpRequestConfigCallback} | |
import io.ticofab.AwsSigner | |
import org.apache.http.client.methods.HttpRequestWrapper | |
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder | |
import org.apache.http.message.BasicHeader | |
import org.apache.http.protocol.HttpContext | |
import org.apache.http.{HttpEntityEnclosingRequest, HttpHost, HttpRequest, HttpRequestInterceptor} | |
import org.elasticsearch.client.RestClient | |
import org.elasticsearch.client.RestClientBuilder.HttpClientConfigCallback | |
class AWSESClientFactory extends EsClientFactory{ | |
override def getClient(elasticsearchUri: String): HttpClient = { | |
val uri = ElasticsearchClientUri(elasticsearchUri) | |
val hosts = uri.hosts.map { case (host, port) => new HttpHost(host, port, "https") } | |
val client = RestClient.builder(hosts: _*) | |
.setRequestConfigCallback(NoOpRequestConfigCallback) | |
.setHttpClientConfigCallback(SignedClientConfig) | |
.build() | |
HttpClient.fromRestClient(client) | |
} | |
} | |
private object SignedClientConfig extends HttpClientConfigCallback { | |
override def customizeHttpClient(httpClientBuilder: HttpAsyncClientBuilder): HttpAsyncClientBuilder = { | |
httpClientBuilder.addInterceptorLast(AWSSigningRequestInterceptor) | |
} | |
} | |
private object AWSSigningRequestInterceptor extends HttpRequestInterceptor { | |
lazy val key = Option("HOWEVER_YOU_GET_YOUR_KEY") // In my real code I get all this values from the config. | |
lazy val region = "HOWEVER_YOU_GET_YOUR_REGION" | |
val service = "es" | |
val awsCredentialProvider = key.fold[AWSCredentialsProvider]({ | |
new InstanceProfileCredentialsProvider(false) // In this case I use the Profile provider as fallback | |
})({ k => | |
val secret = "HOWEVER_YOU_GET_YOUR_SECRET" | |
new StaticCredentialsProvider( | |
new BasicAWSCredentials( | |
k, | |
secret | |
) | |
) | |
}) | |
val signer = AwsSigner(awsCredentialProvider, region, service, clock) | |
override def process(request: HttpRequest, context: HttpContext): Unit = { | |
val rw = request.asInstanceOf[HttpRequestWrapper] | |
val newHeaders = mapHeaders(rw) | |
val headers = signer.getSignedHeaders( | |
rw.getURI.getRawPath, | |
request.getRequestLine.getMethod, | |
params(rw), | |
newHeaders, | |
body(request) | |
) | |
request.setHeaders(headers.map { case (name, value) => new BasicHeader(name, value) }.toArray) | |
} | |
private def clock(): LocalDateTime = LocalDateTime.now(ZoneId.of("UTC")) | |
private def body(request: HttpRequest) = { | |
val original = request.asInstanceOf[HttpRequestWrapper].getOriginal | |
if (!classOf[HttpEntityEnclosingRequest].isAssignableFrom(original.getClass)) None | |
else { | |
Option(original.asInstanceOf[HttpEntityEnclosingRequest].getEntity).flatMap(e => Option(e.getContent)).map(IOUtils.toByteArray) | |
} | |
} | |
private def params(rw: HttpRequestWrapper) = { | |
Option(rw.getURI.getQuery).map(_.split("&").map(_.split("=")).map(p => (p(0), p(1))).toMap).getOrElse(Map.empty) | |
} | |
private def mapHeaders(rw: HttpRequestWrapper) = { | |
Option(rw.getAllHeaders).map(_.map(h => (h.getName, h.getValue)).toMap).getOrElse(Map.empty).map { | |
//Transforming the host as it should not include the port in the headers for the signed request. | |
case ("Host", url) => "Host" -> url.replaceFirst(":[0-9]+", "") // As the ElasticsearchClientUri requires a port but the header "Host" shouldn't have it for AWS I remove it. | |
case t => t | |
} | |
} | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.sksamuel.elastic4s.ElasticsearchClientUri | |
import com.sksamuel.elastic4s.http.HttpClient | |
trait EsClientFactory { | |
def getClient(elasticsearchUri: String): HttpClient | |
} | |
class BaseEsClientFactory extends EsClientFactory{ | |
override def getClient(elasticsearchUri: String): HttpClient = HttpClient(ElasticsearchClientUri(elasticsearchUri)) | |
} | |
object BaseEsClientFactory extends BaseEsClientFactory | |
object EsClientProvider { | |
def client(elasticsearchUri: String)(implicit esClientFactory: EsClientFactory): HttpClient = esClientFactory.getClient(elasticsearchUri) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment