Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Presto - IamAuthenticator
/*
* Copyright (c) 2018 Schibsted Media Group. All rights reserved
*/
package com.saas.presto.access.authentication
import java.security.Principal
import java.util.concurrent.TimeUnit.MILLISECONDS
import com.facebook.presto.spi.security.{AccessDeniedException, BasicPrincipal, PasswordAuthenticator}
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import com.saas.presto.access.authentication.IamAuthenticator._
import com.saas.presto.helpers.Logging
import com.saas.presto.metrics.TagNames.{Allowed, User}
import com.saas.presto.metrics.{DatadogMetricsClient, DatadogReporter}
import io.airlift.units.Duration
import scala.util.matching.Regex
import scala.util.{Failure, Success, Try}
object IamAuthenticator {
def name: String = "saas-iam-authenticator"
val AwsAccount: String = "awsAccount"
val UserArnPattern: Regex = "arn:aws:iam::(\\d+):user/".r(AwsAccount)
def extractAccountFromUserArn(userArn: String): Option[String] = {
Try(UserArnPattern.findAllIn(userArn).group(AwsAccount)).toOption
}
case class Credentials(user: String, password: String)
}
class IamAuthenticator(val statsd: DatadogMetricsClient, cacheTtl: Duration, iamClient: IamClient, awsAccount: String)
extends PasswordAuthenticator with DatadogReporter with Logging {
val authenticationCache: LoadingCache[Credentials, Principal] = CacheBuilder
.newBuilder()
.expireAfterWrite(cacheTtl.toMillis, MILLISECONDS)
.build(new CacheLoader[Credentials, Principal] {
override def load(key: Credentials): Principal = authenticate(key)
})
override def createAuthenticatedPrincipal(user: String, password: String): Principal = {
Try(authenticationCache.getUnchecked(Credentials(user, password))) match {
case Success(principal) =>
report("authenticate", Map(User -> principal.getName, Allowed -> String.valueOf(true)))
principal
case Failure(exception) =>
logger.error(s"Authentication failed for user [$user]: ", exception)
report("authenticate", Map(User -> user, Allowed -> String.valueOf(false)))
throw exception.getCause
}
}
private def authenticate(credentials: Credentials): Principal = checkUser(credentials.user, credentials.password)
def checkUser(user: String, password: String): Principal = {
Try {
val iamUser = iamClient.requestUser(user, password)
val userArn = iamUser.getArn
extractAccountFromUserArn(userArn).fold(
throw new RuntimeException(s"The Aws Keys provided doesn't match with ${UserArnPattern.pattern}")
)(extractedAccount =>
if (!extractedAccount.equals(awsAccount)) {
throw new RuntimeException(s"Aws Account [$extractedAccount] not allowed")
}
)
userArn
} match {
case Success(userArn) => new BasicPrincipal(transformUserArnToRole(userArn))
case Failure(exception) =>
throw new AccessDeniedException(s"Authentication failed for user [$user]: ${exception.getMessage}")
}
}
/**
* Simple replacemnet of a user arn to a role arn with the same info. This is needed as the Principal Name in presto
* will be the role to use on Authorization for S3.
*
* @param userArn
* @return
*/
private def transformUserArnToRole(userArn: String): String = {
userArn.replace(":user/", ":role/")
}
}
/*
* Copyright (c) 2018 Schibsted Media Group. All rights reserved
*/
package com.saas.presto.access.authentication
import com.amazonaws.auth.{AWSCredentialsProvider, AWSStaticCredentialsProvider, BasicAWSCredentials}
import com.amazonaws.regions.Regions
import com.amazonaws.services.identitymanagement.AmazonIdentityManagementClientBuilder
import com.amazonaws.services.identitymanagement.model.{GetUserRequest, User}
class IamClient(awsRegion: Regions) {
def requestUser(user: String, password: String): User = {
val credentials = new BasicAWSCredentials(user, password)
val credentialsProvider = new AWSStaticCredentialsProvider(credentials)
requestUser(credentialsProvider)
}
def requestUser(credentialsProvider: AWSCredentialsProvider): User = {
val iamClient = AmazonIdentityManagementClientBuilder
.standard()
.withCredentials(credentialsProvider)
.withRegion(awsRegion)
.build()
val request = new GetUserRequest()
iamClient
.getUser(request)
.getUser
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.