Skip to content

Instantly share code, notes, and snippets.

@Alien2150
Last active August 29, 2015 14:14
Show Gist options
  • Save Alien2150/9cb43a28dfee087bdf2e to your computer and use it in GitHub Desktop.
Save Alien2150/9cb43a28dfee087bdf2e to your computer and use it in GitHub Desktop.
TLSCipherActor with fix for IllegalStateException
package actor
import java.nio.ByteBuffer
import actor.TlsRecord._
import akka.util.ByteString
import javax.net.ssl.SSLEngineResult
import akka.actor.ActorLogging
import akka.stream.ssl.SslTlsCipherActor
import javax.net.ssl.SSLEngineResult.Status._
import javax.net.ssl.SSLEngineResult.HandshakeStatus._
object TlsRecord {
val CHANGE_CIPHER_SPEC: Byte = 20
val ALERT: Byte = 21
val HANDSHAKE: Byte = 22
val APPLICATION_DATA: Byte = 23
val HEARTBEAT: Byte = 24
}
class TlsRecord {
var contentType: Byte = _
var majorVersion: Byte = _
var minorVersion: Byte = _
var length: Int = _
}
class SslTlsCipherActor2(val requester2 : akka.actor.ActorRef, val sessionNegotioation2 : akka.stream.ssl.SslTlsCipher.SessionNegotiation, tracing : scala.Boolean)
extends SslTlsCipherActor(requester2, sessionNegotioation2, tracing) with ActorLogging {
var temporaryReceiveBuffer: ByteString = ByteString.empty
val resultOk = new SSLEngineResult(OK, NOT_HANDSHAKING, 0, 0)
val MIN_TLS_RECORD_SIZE = 5
var currentTlsRecord: Option[TlsRecord] = None
/**
* Concat buffer
*/
def concatBuffer() = {
temporaryReceiveBuffer = temporaryReceiveBuffer ++ ByteString(cipherTextInboundBytes)
if (tracing) log.debug(s"Appended ${cipherTextInboundBytes.limit()} bytes. Current: ${temporaryReceiveBuffer.size}")
}
/**
* Parse tls record
*
* @param buffer The buffer to read the record from
*/
def readRecord(buffer: ByteBuffer): Option[TlsRecord] = {
if (buffer.capacity() < MIN_TLS_RECORD_SIZE) {
None
} else {
// Good reference: https://github.com/netty/netty/blob/master/handler/src/main/java/io/netty/handler/ssl/SslHandler.java
val tlsRecord = new TlsRecord
tlsRecord.contentType = buffer.get()
// Also see: http://en.wikipedia.org/wiki/Transport_Layer_Security#TLS_1.2
// 20 = CHANGE_CIPHER_SPEC, 21 = ALERT, 22 = HANDSHAKE, 23 => APPLICATION, 24 = HEARTBEAT
val tls: Boolean = tlsRecord.contentType >= CHANGE_CIPHER_SPEC && tlsRecord.contentType <= HEARTBEAT
if (tls) {
tlsRecord.majorVersion = buffer.get() // Byte 1
tlsRecord.minorVersion = buffer.get() // Byte 2
if (tlsRecord.majorVersion == 3) {
// Include header size in calculation
tlsRecord.length = buffer.getShort + 5.toShort // Byte 3+4
}
}
buffer.rewind ()
Some (tlsRecord)
}
}
/**
* Run ssl-unwrap method with some patches
*
* @param tempBuf The dest-buffer
*/
override def doUnwrap(tempBuf : ByteBuffer) : SSLEngineResult = {
if (currentTlsRecord.equals(None)) {
if (tracing) log.debug("Fetched new tls record")
currentTlsRecord = readRecord(cipherTextInboundBytes)
}
val length = temporaryReceiveBuffer.size + cipherTextInboundBytes.capacity()
if (tracing) log.debug(s"TLS Content Type ${currentTlsRecord.get.contentType} and record-length: ${currentTlsRecord.get.length} vs $length = ${temporaryReceiveBuffer.size} + ${cipherTextInboundBytes.capacity()}")
// Check if we need to buffer any data by comparing the TLSRecordBufferLength with the current temporary buffer
if (currentTlsRecord.get.length > length /*&& currentTlsRecord.get.contentType.equals(APPLICATION_DATA)*/) {
// Buffer data
concatBuffer()
// Return "fake message" (OK, NOT_HANDSHAKING, 0 bytes consumed, 0 bytes produced)
// TODO These fake-result should also deal with handshake message. The problem: How to determine the current "handshake" phase?
resultOk
} else {
// Append data if the receive buffer is not empty
if (temporaryReceiveBuffer.nonEmpty || currentTlsRecord.get.length < length) {
// If we have more data than we need buffer those data and strip data next time we run into the unwrap method
if (currentTlsRecord.get.length < length) {
if (tracing) log.debug(s"More data than we need. Requested: ${currentTlsRecord.get.length} vs available: ${length}")
}
// do it one more time as we only append data in case the length - check fails
concatBuffer()
cipherTextInboundBytes = temporaryReceiveBuffer.take(currentTlsRecord.get.length).toByteBuffer
temporaryReceiveBuffer = temporaryReceiveBuffer.drop(currentTlsRecord.get.length)
currentTlsRecord = readRecord(temporaryReceiveBuffer.toByteBuffer)
if (!currentTlsRecord.eq(None) && tracing) log.debug(s"Fetched next tls record on buffer: with content-type: ${currentTlsRecord.get.contentType} and length:${currentTlsRecord.get.length}")
} else {
currentTlsRecord = None
}
if (tracing) log.debug("Running unwrap")
super.doUnwrap(tempBuf)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment