Skip to content

Instantly share code, notes, and snippets.

@missingfaktor
Forked from alex1712/AWSClientFactory.scala
Created August 7, 2017 15:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save missingfaktor/664b01b7989e91033657e2e703f14ed6 to your computer and use it in GitHub Desktop.
Save missingfaktor/664b01b7989e91033657e2e703f14ed6 to your computer and use it in GitHub Desktop.
AWS Signed client for elastic4s (5.4.x) using aws-request-signer (https://github.com/ticofab/aws-request-signer)
import java.time.{LocalDateTime, ZoneId}
import com.amazonaws.auth.{AWSCredentialsProvider, BasicAWSCredentials, InstanceProfileCredentialsProvider}
import com.amazonaws.internal.StaticCredentialsProvider
import com.amazonaws.util.IOUtils
import com.sksamuel.elastic4s.ElasticsearchClientUri
import com.sksamuel.elastic4s.http.{HttpClient, NoOpRequestConfigCallback}
import io.ticofab.AwsSigner
import org.apache.http.client.methods.HttpRequestWrapper
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder
import org.apache.http.message.BasicHeader
import org.apache.http.protocol.HttpContext
import org.apache.http.{HttpEntityEnclosingRequest, HttpHost, HttpRequest, HttpRequestInterceptor}
import org.elasticsearch.client.RestClient
import org.elasticsearch.client.RestClientBuilder.HttpClientConfigCallback
class AWSESClientFactory extends EsClientFactory{
override def getClient(elasticsearchUri: String): HttpClient = {
val uri = ElasticsearchClientUri(elasticsearchUri)
val hosts = uri.hosts.map { case (host, port) => new HttpHost(host, port, "https") }
val client = RestClient.builder(hosts: _*)
.setRequestConfigCallback(NoOpRequestConfigCallback)
.setHttpClientConfigCallback(SignedClientConfig)
.build()
HttpClient.fromRestClient(client)
}
}
private object SignedClientConfig extends HttpClientConfigCallback {
override def customizeHttpClient(httpClientBuilder: HttpAsyncClientBuilder): HttpAsyncClientBuilder = {
httpClientBuilder.addInterceptorLast(AWSSigningRequestInterceptor)
}
}
private object AWSSigningRequestInterceptor extends HttpRequestInterceptor {
lazy val key = Option("HOWEVER_YOU_GET_YOUR_KEY") // In my real code I get all this values from the config.
lazy val region = "HOWEVER_YOU_GET_YOUR_REGION"
val service = "es"
val awsCredentialProvider = key.fold[AWSCredentialsProvider]({
new InstanceProfileCredentialsProvider(false) // In this case I use the Profile provider as fallback
})({ k =>
val secret = "HOWEVER_YOU_GET_YOUR_SECRET"
new StaticCredentialsProvider(
new BasicAWSCredentials(
k,
secret
)
)
})
val signer = AwsSigner(awsCredentialProvider, region, service, clock)
override def process(request: HttpRequest, context: HttpContext): Unit = {
val rw = request.asInstanceOf[HttpRequestWrapper]
val newHeaders = mapHeaders(rw)
val headers = signer.getSignedHeaders(
rw.getURI.getRawPath,
request.getRequestLine.getMethod,
params(rw),
newHeaders,
body(request)
)
request.setHeaders(headers.map { case (name, value) => new BasicHeader(name, value) }.toArray)
}
private def clock(): LocalDateTime = LocalDateTime.now(ZoneId.of("UTC"))
private def body(request: HttpRequest) = {
val original = request.asInstanceOf[HttpRequestWrapper].getOriginal
if (!classOf[HttpEntityEnclosingRequest].isAssignableFrom(original.getClass)) None
else {
Option(original.asInstanceOf[HttpEntityEnclosingRequest].getEntity).flatMap(e => Option(e.getContent)).map(IOUtils.toByteArray)
}
}
private def params(rw: HttpRequestWrapper) = {
Option(rw.getURI.getQuery).map(_.split("&").map(_.split("=")).map(p => (p(0), p(1))).toMap).getOrElse(Map.empty)
}
private def mapHeaders(rw: HttpRequestWrapper) = {
Option(rw.getAllHeaders).map(_.map(h => (h.getName, h.getValue)).toMap).getOrElse(Map.empty).map {
//Transforming the host as it should not include the port in the headers for the signed request.
case ("Host", url) => "Host" -> url.replaceFirst(":[0-9]+", "") // As the ElasticsearchClientUri requires a port but the header "Host" shouldn't have it for AWS I remove it.
case t => t
}
}
}
import com.sksamuel.elastic4s.ElasticsearchClientUri
import com.sksamuel.elastic4s.http.HttpClient
trait EsClientFactory {
def getClient(elasticsearchUri: String): HttpClient
}
class BaseEsClientFactory extends EsClientFactory{
override def getClient(elasticsearchUri: String): HttpClient = HttpClient(ElasticsearchClientUri(elasticsearchUri))
}
object BaseEsClientFactory extends BaseEsClientFactory
object EsClientProvider {
def client(elasticsearchUri: String)(implicit esClientFactory: EsClientFactory): HttpClient = esClientFactory.getClient(elasticsearchUri)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment