Skip to content

Instantly share code, notes, and snippets.

@tanopwan
Created December 2, 2020 12:23
Show Gist options
  • Save tanopwan/b756e15c80a7b4524d741effb98dce1d to your computer and use it in GitHub Desktop.
Save tanopwan/b756e15c80a7b4524d741effb98dce1d to your computer and use it in GitHub Desktop.
class AwsKeystore(
private val kmsClient: AWSKMS,
private val awsKmsKeyId: String,
private val awsKmsKeySpec: String
): KeyProvider {
private val keys: TreeMap<String, KeyVersion> = TreeMap()
private val currentVersion: Map<String, Int> = mutableMapOf()
override fun getKeyNames(): MutableList<String> {
return ArrayList(currentVersion.keys)
}
override fun getCurrentKeyVersion(keyName: String): HadoopShims.KeyMetadata {
val versionName = buildVersionName(keyName, currentVersion[keyName]
?: throw IllegalArgumentException("Unknown key $keyName"))
return if (this.keys.containsKey(versionName)) {
this.keys[versionName] as HadoopShims.KeyMetadata
} else {
throw IllegalArgumentException("Unknown key $keyName")
}
}
override fun createLocalKey(key: HadoopShims.KeyMetadata): LocalKey {
val keyVersion = buildVersionName(key.keyName, key.version)
return if (!keys.containsKey(keyVersion)) {
throw IllegalArgumentException("Unknown key $key")
} else {
val secret = keys[keyVersion]!!
val algorithm = secret.algorithm
try {
val dataKeyRequest = GenerateDataKeyRequest()
dataKeyRequest.keyId = awsKmsKeyId
dataKeyRequest.keySpec = awsKmsKeySpec
val dataKeyResult: GenerateDataKeyResult = kmsClient.generateDataKey(dataKeyRequest)
val decryptedKeyBuffer: ByteBuffer = dataKeyResult.plaintext
val encryptedKeyBuffer: ByteBuffer = dataKeyResult.ciphertextBlob
val decryptedKey = ByteArray(decryptedKeyBuffer.remaining())
decryptedKeyBuffer.get(decryptedKey)
val encryptedKey = ByteArray(encryptedKeyBuffer.remaining())
encryptedKeyBuffer.get(encryptedKey)
LocalKey(algorithm, decryptedKey, encryptedKey)
} catch (ex: Exception) {
throw IllegalStateException("DecryptRequest error: ${ex.message}")
}
}
}
override fun decryptLocalKey(key: HadoopShims.KeyMetadata, encryptedKey: ByteArray): Key? {
val keyVersion = buildVersionName(key.keyName, key.version)
return if (!keys.containsKey(keyVersion)) {
null
} else {
val secret = keys[keyVersion]!!
val algorithm = secret.algorithm
try {
val ciphertextBlob = ByteBuffer.wrap(encryptedKey)
val req = DecryptRequest().withCiphertextBlob(ciphertextBlob)
val plainText = kmsClient.decrypt(req).plaintext
val decryptedKey = ByteArray(plainText.remaining())
plainText.get(decryptedKey)
SecretKeySpec(decryptedKey, algorithm.algorithm)
} catch (ex: Exception) {
throw IllegalStateException("DecryptRequest error: ${ex.message}")
}
}
}
override fun getKind(): HadoopShims.KeyProviderKind {
return HadoopShims.KeyProviderKind.AWS
}
private fun buildVersionName(name: String, version: Int): String {
return "$name@$version"
}
internal class KeyVersion(keyName: String, version: Int, algorithm: EncryptionAlgorithm) : HadoopShims.KeyMetadata(keyName, version, algorithm) {
}
@Throws(IOException::class)
fun addKey(keyName: String, algorithm: EncryptionAlgorithm): AwsKeystore {
return addKey(keyName, 0, algorithm)
}
@Throws(IOException::class)
fun addKey(keyName: String, version: Int, algorithm: EncryptionAlgorithm): AwsKeystore {
val key = KeyVersion(keyName, version, algorithm)
return if (currentVersion[keyName] != null && currentVersion[keyName]!! >= version) {
throw IllegalArgumentException(String.format("Key %s with equal or higher version %d already exists", keyName, version))
} else {
keys[buildVersionName(keyName, version)] = key
(currentVersion as LinkedHashMap)[keyName] = version
this
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment