Skip to content

Instantly share code, notes, and snippets.

@timothyklim
Last active March 21, 2021 19:18
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save timothyklim/ec5889aa23400529fd5e to your computer and use it in GitHub Desktop.
Save timothyklim/ec5889aa23400529fd5e to your computer and use it in GitHub Desktop.
AES encryption/decryption with akka stream
package utils
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.stage._
import akka.util.{ByteString, ByteStringBuilder}
import scala.annotation.tailrec
import java.security.SecureRandom
import java.security.{Key, KeyFactory, PublicKey, PrivateKey}
import java.security.spec.{PKCS8EncodedKeySpec, X509EncodedKeySpec}
import javax.crypto._
import javax.crypto.spec.{SecretKeySpec, IvParameterSpec}
import java.io.{File, FileInputStream, FileOutputStream, InputStream, OutputStream}
private[this] class AesStage(cipher: Cipher) extends GraphStage[FlowShape[ByteString, ByteString]] {
val in = Inlet[ByteString]("in")
val out = Outlet[ByteString]("out")
override val shape = FlowShape.of(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) {
setHandler(in, new InHandler {
override def onPush(): Unit = {
val bs = grab(in)
if (bs.isEmpty) push(out, bs)
else push(out, ByteString(cipher.update(bs.toArray)))
}
override def onUpstreamFinish(): Unit = {
val bs = ByteString(cipher.doFinal)
if (bs.nonEmpty) emit(out, bs)
complete(out)
}
})
setHandler(out, new OutHandler {
override def onPull(): Unit = {
pull(in)
}
})
}
}
object Crypto {
val aesKeySize = 128
def generateAesKey() = {
val gen = KeyGenerator.getInstance("AES")
gen.init(aesKeySize)
val key = gen.generateKey()
val aesKey = key.getEncoded()
aesKeySpec(aesKey)
}
def aesKeySpec(key: Array[Byte]) =
new SecretKeySpec(key, "AES")
val rand = new SecureRandom()
def generateIv() = rand.generateSeed(16)
private def aesCipher(mode: Int, keySpec: SecretKeySpec, ivBytes: Array[Byte]) = {
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
val ivSpec = new IvParameterSpec(ivBytes)
cipher.init(mode, keySpec, ivSpec)
cipher
}
def encryptAes(
source: Source[ByteString, Any],
keySpec: SecretKeySpec,
ivBytes: Array[Byte]
): Source[ByteString, Any] = {
val cipher = aesCipher(Cipher.ENCRYPT_MODE, keySpec, ivBytes)
source.via(new AesStage(cipher))
}
def decryptAes(
source: Source[ByteString, Any],
keySpec: SecretKeySpec,
ivBytes: Array[Byte]
): Source[ByteString, Any] = {
val cipher = aesCipher(Cipher.DECRYPT_MODE, keySpec, ivBytes)
source.via(new AesStage(cipher))
}
def getRsaKeyFactory() =
KeyFactory.getInstance("RSA")
def loadRsaPrivateKey(key: Array[Byte]) = {
val spec = new PKCS8EncodedKeySpec(key)
getRsaKeyFactory.generatePrivate(spec)
}
def loadRsaPublicKey(key: Array[Byte]) = {
val spec = new X509EncodedKeySpec(key)
getRsaKeyFactory.generatePublic(spec)
}
private def rsaCipher(mode: Int, key: Key) = {
val cipher = Cipher.getInstance("RSA")
cipher.init(mode, key)
cipher
}
def encryptRsa(bytes: Array[Byte], key: PublicKey): Array[Byte] =
rsaCipher(Cipher.ENCRYPT_MODE, key).doFinal(bytes)
def decryptRsa(bytes: Array[Byte], key: PrivateKey): Array[Byte] =
rsaCipher(Cipher.DECRYPT_MODE, key).doFinal(bytes)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment