Skip to content

Instantly share code, notes, and snippets.

@btd
Last active June 5, 2017 10:50
Show Gist options
  • Save btd/695d0aac6fa1b8977fd3de65621d09f3 to your computer and use it in GitHub Desktop.
Save btd/695d0aac6fa1b8977fd3de65621d09f3 to your computer and use it in GitHub Desktop.
Filter to create cookie serialized sessions in servlets container
package ingo.webapp
import javax.servlet._
import javax.servlet.http._
import java.util.Base64
import java.io._
import java.security.SecureRandom
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import ingo.commons.util.KryoSerialization
case class CookieSessionFilterConfig(cookieName: String, cookieMaxAge: Int, signatureKey: String)
object CookieSessionFilter {
val base64decoder = Base64.getDecoder
val base64encoder = Base64.getEncoder.withoutPadding
val ID = "$$$"
private def signatureMethod(algorithm: String)(key: String, data: Array[Byte]): String = {
val mac = Mac.getInstance(algorithm)
val secretKey = new SecretKeySpec(key.getBytes, algorithm)
mac.init(secretKey)
base64encoder.encodeToString(mac.doFinal(data))
}
val signature_HmacSHA256 = signatureMethod("HmacSHA256") _
def deserializeCookieValue(value: String): java.util.HashMap[String, AnyRef] = {
val bytes = base64decoder.decode(value)
KryoSerialization.deserialize[java.util.HashMap[String, AnyRef]](bytes)
}
def serializeCookieValue(value: java.util.HashMap[String, AnyRef]): String = {
val bytes = KryoSerialization.serialize[java.util.HashMap[String, AnyRef]](value)
base64encoder.encodeToString(bytes)
}
def valueCookie(cookieName: String, attributes: java.util.HashMap[String, AnyRef]) = {
val value = serializeCookieValue(attributes)
val cookie = new Cookie(cookieName, value)
cookie.setPath("/")
cookie.setHttpOnly(true)
cookie
}
def valueCookieSignature(key: String, cookie: Cookie) = {
signature_HmacSHA256(key, (cookie.getName + "=" + cookie.getValue).getBytes)
}
def signatureCookie(key: String, _cookie: Cookie) = {
val sig = valueCookieSignature(key, _cookie)
val cookie = new Cookie(_cookie.getName + "_SIG", sig)
cookie.setPath("/")
cookie.setHttpOnly(true)
cookie
}
def cookieSignatureMatch(key: String, value: Cookie, signature: Cookie) = {
signature.getValue == valueCookieSignature(key, value)
}
def getSessionAttributes(req: HttpServletRequest, cookieName: String, key: String) = {
val _cookies = req.getCookies
val attributesOpt = for {
cookies <- Option(_cookies)
valueCookie <- cookies.find(_.getName == cookieName)
signatureCookie <- cookies.find(_.getName == cookieName + "_SIG")
if (cookieSignatureMatch(key, valueCookie, signatureCookie))
} yield {
try {
deserializeCookieValue(valueCookie.getValue)
} catch {
case e: Exception =>
new java.util.HashMap[String, AnyRef]()
}
}
attributesOpt.getOrElse(new java.util.HashMap[String, AnyRef]())
}
}
class InGoHttpSession(servletContext: ServletContext, val attributes: java.util.HashMap[String, AnyRef])
extends HttpSession {
private val _created = System.currentTimeMillis
private var _new = false
private var _id: String = updateId()
def updateId(): String = {
val idRaw = attributes.get(CookieSessionFilter.ID)
if (idRaw == null) {
_id = java.util.UUID.randomUUID.toString
_new = true
attributes.put(CookieSessionFilter.ID, _id)
} else {
_id = idRaw.asInstanceOf[String]
}
_id
}
private var maxAge = 60
def getAttribute(name: String): Object = {
attributes.get(name)
}
def getAttributeNames(): java.util.Enumeration[String] = {
java.util.Collections.enumeration(attributes.keySet())
}
def getId(): String = {
_id
}
def getValue(name: String): Object = getAttribute(name)
def getValueNames(): Array[String] = {
val keys = attributes.keySet
keys.toArray(new Array[String](keys.size))
}
// invalidation creates
def invalidate(): Unit = {
attributes.clear
updateId()
}
def setAttribute(name: String, value: AnyRef): Unit = {
if (value != null) {
attributes.put(name, value)
} else {
attributes.remove(name)
}
}
def putValue(name: String, value: AnyRef): Unit = setAttribute(name, value)
def removeAttribute(name: String): Unit = {
attributes.remove(name)
}
def removeValue(name: String): Unit = removeAttribute(name)
def isNew(): Boolean = _new
def getCreationTime(): Long = _created
def getLastAccessedTime(): Long = _created
def getMaxInactiveInterval(): Int = maxAge
def setMaxInactiveInterval(value: Int): Unit = maxAge = value
def getServletContext(): ServletContext = servletContext
def getSessionContext(): HttpSessionContext = null
}
class SessionCookieRequestWrapper(req: HttpServletRequest, config: CookieSessionFilterConfig)
extends HttpServletRequestWrapper(req) {
var session: InGoHttpSession = null
def fillSessionFromCookie(): Unit = {
val attributes = CookieSessionFilter.getSessionAttributes(req, config.cookieName, config.signatureKey)
session = new InGoHttpSession(req.getServletContext, attributes)
session.setMaxInactiveInterval(config.cookieMaxAge)
}
override def changeSessionId(): String = session.updateId
override def getRequestedSessionId(): String = {
session.getId
}
override def getSession(): HttpSession = getSession(true)
// we do not need to use create parameter because if we pass false it means
// we do not need to create session
// but it is already created, so we still need to read it
override def getSession(create: Boolean): HttpSession = {
if (session == null) {
fillSessionFromCookie()
}
return session
}
override def isRequestedSessionIdFromCookie(): Boolean = false
override def isRequestedSessionIdFromUrl(): Boolean = false
override def isRequestedSessionIdFromURL(): Boolean = false
override def isRequestedSessionIdValid(): Boolean = true
}
class SessionCookieResponseWrapper(req: HttpServletRequest,
res: HttpServletResponse,
config: CookieSessionFilterConfig)
extends HttpServletResponseWrapper(res) {
val output = new WrappedServletOutputStream(res.getOutputStream())
val writer = new PrintWriter(output, true)
override def getOutputStream() = output
override def getWriter() = writer
override def flushBuffer(): Unit = {
val session = req.getSession().asInstanceOf[InGoHttpSession]
val attributes = session.attributes
val valueCookie = CookieSessionFilter.valueCookie(config.cookieName, attributes)
val signatureCookie = CookieSessionFilter.signatureCookie(config.signatureKey, valueCookie)
valueCookie.setMaxAge(session.getMaxInactiveInterval())
signatureCookie.setMaxAge(session.getMaxInactiveInterval())
res.addCookie(valueCookie)
res.addCookie(signatureCookie)
writer.flush()
output.forwardBufferContent()
}
}
class WrappedServletOutputStream(_output: ServletOutputStream) extends ServletOutputStream {
val output = new ByteArrayOutputStream(100 * 1024)
var writeListener: WriteListener = null
def write(n: Int): Unit = {
output.write(n)
if (writeListener != null) writeListener.notify()
}
def forwardBufferContent(): Unit = {
output.writeTo(_output)
output.flush()
}
def setWriteListener(l: WriteListener): Unit = writeListener = l
def isReady(): Boolean = true
}
class CookieSessionFilter extends Filter {
var filterConfig = CookieSessionFilterConfig("SESS", 60, "EH4v4pLHeiQSplsWv3w6")
def destroy(): Unit = {}
def init(config: FilterConfig): Unit = {
for (value <- Option(config.getInitParameter("cookieMaxAge")).map(_.toInt)) {
filterConfig = filterConfig.copy(cookieMaxAge = value)
}
for (value <- Option(config.getInitParameter("cookieName"))) {
filterConfig = filterConfig.copy(cookieName = value)
}
for (value <- Option(config.getInitParameter("signatureKey"))) {
filterConfig = filterConfig.copy(signatureKey = value)
}
}
def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = {
val httpReq = req.asInstanceOf[HttpServletRequest]
val httpRes = res.asInstanceOf[HttpServletResponse]
val wrapperReq = new SessionCookieRequestWrapper(httpReq, filterConfig)
val wrapperRes = new SessionCookieResponseWrapper(wrapperReq, httpRes, filterConfig)
chain.doFilter(wrapperReq, wrapperRes)
wrapperRes.flushBuffer()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment