public
Last active

A Netty based WebSocket client and server in scala

  • Download Gist
WebSocketClient.scala
Scala
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
package mojolly.io
 
import org.jboss.netty.bootstrap.ClientBootstrap
import org.jboss.netty.channel._
import socket.nio.NioClientSocketChannelFactory
import java.util.concurrent.Executors
import org.jboss.netty.handler.codec.http._
import collection.JavaConversions._
import websocketx._
import java.net.{InetSocketAddress, URI}
import java.nio.charset.Charset
import org.jboss.netty.buffer.ChannelBuffers
import org.jboss.netty.util.CharsetUtil
import akka.actor.ActorRef
import mojolly.LibraryConstants
 
 
/**
* Usage of the simple websocket client:
* <pre>
* WebSocketClient(new URI("ws://localhost:8080/thesocket")) {
* case Connected(client) => println("Connection has been established to: " + client.url.toASCIIString)
* case Disconnected(client, _) => println("The websocket to " + client.url.toASCIIString + " disconnected.")
* case TextMessage(client, message) => {
* println("RECV: " + message)
* client send ("ECHO: " + message)
* }
* }
* </pre>
*/
object WebSocketClient {
 
object Messages {
sealed trait WebSocketClientMessage
case object Connecting extends WebSocketClientMessage
case class ConnectionFailed(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage
case class Connected(client: WebSocketClient) extends WebSocketClientMessage
case class TextMessage(client: WebSocketClient, text: String) extends WebSocketClientMessage
case class WriteFailed(client: WebSocketClient, message: String, reason: Option[Throwable]) extends WebSocketClientMessage
case object Disconnecting extends WebSocketClientMessage
case class Disconnected(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage
case class Error(client: WebSocketClient, th: Throwable) extends WebSocketClientMessage
}
 
type Handler = PartialFunction[Messages.WebSocketClientMessage, Unit]
type FrameReader = WebSocketFrame => String
val defaultFrameReader = (_: WebSocketFrame) match {
case f: TextWebSocketFrame => f.getText
case _ => throw new UnsupportedOperationException("Only single text frames are supported for now")
}
def apply(url: URI, version: WebSocketVersion = WebSocketVersion.V13, reader: FrameReader = defaultFrameReader)(handle: Handler): WebSocketClient = {
require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'")
new DefaultWebSocketClient(url, version, handle, reader)
}
def apply(url: URI, handle: ActorRef): WebSocketClient = {
require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'")
WebSocketClient(url) { case x => handle ! x }
}
private class WebSocketClientHandler(handshaker: WebSocketClientHandshaker, client: WebSocketClient) extends SimpleChannelUpstreamHandler {
 
import Messages._
override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
client.handler(Disconnected(client))
}
 
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
e.getMessage match {
case resp: HttpResponse if handshaker.isHandshakeComplete =>
throw new WebSocketException("Unexpected HttpResponse (status=" + resp.getStatus + ", content="
+ resp.getContent.toString(CharsetUtil.UTF_8) + ")")
case resp: HttpResponse =>
handshaker.finishHandshake(ctx.getChannel, e.getMessage.asInstanceOf[HttpResponse])
client.handler(Connected(client))
 
case f: TextWebSocketFrame => client.handler(TextMessage(client, f.getText))
case _: PongWebSocketFrame =>
case _: CloseWebSocketFrame => ctx.getChannel.close()
}
}
 
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
client.handler(Error(client, e.getCause))
e.getChannel.close()
}
 
}
private class DefaultWebSocketClient(
val url: URI,
version: WebSocketVersion,
private[this] val _handler: Handler,
val reader: FrameReader = defaultFrameReader) extends WebSocketClient {
val normalized = url.normalize()
val tgt = if (normalized.getPath == null || normalized.getPath.trim().isEmpty) {
new URI(normalized.getScheme, normalized.getAuthority,"/", normalized.getQuery, normalized.getFragment)
} else normalized
val bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool))
val handshaker = new WebSocketClientHandshakerFactory().newHandshaker(tgt, version, null, false, Map.empty[String, String])
val self = this
var channel: Channel = _
 
import Messages._
val handler = _handler orElse defaultHandler
 
private def defaultHandler: Handler = {
case Error(_, ex) => ex.printStackTrace()
case _: WebSocketClientMessage =>
}
 
 
bootstrap.setPipelineFactory(new ChannelPipelineFactory {
def getPipeline = {
val pipeline = Channels.pipeline()
if (version == WebSocketVersion.V00)
pipeline.addLast("decoder", new WebSocketHttpResponseDecoder)
else
pipeline.addLast("decoder", new HttpResponseDecoder)
pipeline.addLast("encoder", new HttpRequestEncoder)
pipeline.addLast("ws-handler", new WebSocketClientHandler(handshaker, self))
pipeline
}
})
 
import WebSocketClient.Messages._
def connect = {
if (channel == null || !channel.isConnected) {
val listener = futureListener { future =>
if (future.isSuccess) {
synchronized { channel = future.getChannel }
handshaker.handshake(channel)
} else {
handler(ConnectionFailed(this, Option(future.getCause)))
}
}
handler(Connecting)
val fut = bootstrap.connect(new InetSocketAddress(url.getHost, url.getPort))
fut.addListener(listener)
fut.await(5000L)
}
}
 
def disconnect = {
if (channel != null && channel.isConnected) {
handler(Disconnecting)
channel.write(new CloseWebSocketFrame())
}
}
 
def send(message: String, charset: Charset = CharsetUtil.UTF_8) = {
channel.write(new TextWebSocketFrame(ChannelBuffers.copiedBuffer(message, charset))).addListener(futureListener { fut =>
if (!fut.isSuccess) {
handler(WriteFailed(this, message, Option(fut.getCause)))
}
})
}
def futureListener(handleWith: ChannelFuture => Unit) = new ChannelFutureListener {
def operationComplete(future: ChannelFuture) {handleWith(future)}
}
}
/**
* Fix bug in standard HttpResponseDecoder for web socket clients. When status 101 is received for Hybi00, there are 16
* bytes of contents expected
*/
class WebSocketHttpResponseDecoder extends HttpResponseDecoder {
 
val codes = List(101, 200, 204, 205, 304)
 
protected override def isContentAlwaysEmpty(msg: HttpMessage) = {
msg match {
case res: HttpResponse => codes contains res.getStatus.getCode
case _ => false
}
}
}
 
/**
* A WebSocket related exception
*
* Copied from https://github.com/cgbystrom/netty-tools
*/
class WebSocketException(s: String, th: Throwable) extends java.io.IOException(s, th) {
def this(s: String) = this(s, null)
}
}
trait WebSocketClient {
 
def url: URI
def reader: WebSocketClient.FrameReader
def handler: WebSocketClient.Handler
 
def connect
def disconnect
def send(message: String, charset: Charset = CharsetUtil.UTF_8)
}
WebSocketServer.scala
Scala
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
package io.backchat.minutes.river
 
import org.jboss.netty.bootstrap.ServerBootstrap
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory
import java.util.concurrent.{ TimeUnit, Executors }
import java.net.{ InetSocketAddress }
import org.jboss.netty.channel._
import group.{ChannelGroup, DefaultChannelGroup}
import org.elasticsearch.common.logging.{ESLogger, ESLoggerFactory}
import org.jboss.netty.handler.codec.http.{HttpRequest, HttpChunkAggregator, HttpRequestDecoder, HttpResponseEncoder}
import org.jboss.netty.handler.codec.http.websocketx._
import org.jboss.netty.handler.codec.http.HttpHeaders.Values
import org.jboss.netty.handler.codec.http.HttpHeaders.Names
import java.util.Locale.ENGLISH
 
trait WebSocketServerConfig {
def listenOn: String
def port: Int
}
 
 
/**
* Netty based WebSocketServer
* requires netty 3.3.x or later
*
* Usage:
* <pre>
* val conf = new WebSocketServerConfig {
* val port = 14567
* val listenOn = "0.0.0.0"
* }
*
* val server = WebSocketServer(conf) {
* case Connect(_) => println("got a client connection")
* case TextMessage(cl, text) => cl.write(new TextWebSocketFrame("ECHO: " + text))
* case Disconnected(_) => println("client disconnected")
* }
* server.start
* // time passes......
* server.stop
* </pre>
*/
object WebSocketServer {
type WebSocketHandler = PartialFunction[WebSocketMessage, Unit]
sealed trait WebSocketMessage
case class Connect(client: Channel) extends WebSocketMessage
case class TextMessage(client: Channel, content: String) extends WebSocketMessage
case class BinaryMessage(client: Channel, content: Array[Byte]) extends WebSocketMessage
case class Error(client: Channel, cause: Option[Throwable]) extends WebSocketMessage
case class Disconnected(client: Channel) extends WebSocketMessage
 
 
def apply(config: WebSocketServerConfig)(handler: WebSocketServer.WebSocketHandler): WebSocketServer =
new WebSocketServer(config, handler)
 
private class ConnectionTracker(channels: ChannelGroup) extends SimpleChannelUpstreamHandler {
override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
channels remove e.getChannel
ctx.sendUpstream(e)
}
 
override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
channels add e.getChannel
ctx.sendUpstream(e)
}
 
override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
channels remove e.getChannel
ctx.sendUpstream(e)
}
 
}
 
private class WebSocketPartialFunctionHandler(handler: WebSocketHandler, logger: ESLogger) extends SimpleChannelUpstreamHandler {
 
private[this] var collectedFrames: Seq[ContinuationWebSocketFrame] = Vector.empty[ContinuationWebSocketFrame]
 
private[this] var handshaker: WebSocketServerHandshaker = _
 
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
e.getMessage match {
case httpRequest: HttpRequest if isWebSocketUpgrade(httpRequest) handleUpgrade(ctx, httpRequest)
case m: TextWebSocketFrame => handler lift TextMessage(e.getChannel, m.getText)
case m: BinaryWebSocketFrame => handler lift BinaryMessage(e.getChannel, m.getBinaryData.array)
case m: ContinuationWebSocketFrame => {
if (m.isFinalFragment) {
handler lift TextMessage(e.getChannel, collectedFrames map (_.getText) reduce (_ + _))
collectedFrames = Nil
} else {
collectedFrames :+= m
}
}
case f: CloseWebSocketFrame
if (handshaker != null) handshaker.close(ctx.getChannel, f)
handler lift Disconnected(e.getChannel)
case _: PingWebSocketFrame e.getChannel.write(new PongWebSocketFrame)
case _ ctx.sendUpstream(e)
}
}
 
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
handler lift Error(e.getChannel, Option(e.getCause))
}
 
private def isWebSocketUpgrade(httpRequest: HttpRequest): Boolean = {
val connHdr = httpRequest.getHeader(Names.CONNECTION)
val upgrHdr = httpRequest.getHeader(Names.UPGRADE)
(connHdr != null && connHdr.equalsIgnoreCase(Values.UPGRADE)) &&
(upgrHdr != null && upgrHdr.equalsIgnoreCase(Values.WEBSOCKET))
}
 
private def handleUpgrade(ctx: ChannelHandlerContext, httpRequest: HttpRequest) {
val handshakerFactory = new WebSocketServerHandshakerFactory(websocketLocation(httpRequest), null, false)
handshaker = handshakerFactory.newHandshaker(httpRequest)
if (handshaker == null) handshakerFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel)
else {
handshaker.handshake(ctx.getChannel, httpRequest)
handler.lift(Connect(ctx.getChannel))
}
}
 
private def isHttps(req: HttpRequest) = {
val h1 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty)
val h2 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty)
(h1.isDefined && h1.forall(_.toUpperCase(ENGLISH).startsWith("HTTPS"))) ||
(h2.isDefined && h2.forall(_.toUpperCase(ENGLISH) startsWith "HTTPS"))
}
 
private def websocketLocation(req: HttpRequest) = {
if (isHttps(req))
"wss://" + req.getHeader(Names.HOST) + "/"
else
"ws://" + req.getHeader(Names.HOST) + "/"
}
}
 
}
 
class WebSocketServer(val config: WebSocketServerConfig, val handler: WebSocketServer.WebSocketHandler) {
 
import WebSocketServer._
private[this] val realHandler = handler orElse devNull
private[this] val devNull: WebSocketHandler = {
case WebSocketServer.Error(_, Some(ex)) =>
System.err.println(ex.getMessage)
ex.printStackTrace()
case _ =>
}
protected val logger = ESLoggerFactory.getLogger(getClass.getName)
private[this] val boss = Executors.newCachedThreadPool()
private[this] val worker = Executors.newCachedThreadPool()
private[this] val server = {
val bs = new ServerBootstrap(new NioServerSocketChannelFactory(boss, worker))
bs.setOption("soLinger", 0)
bs.setOption("reuseAddress", true)
bs.setOption("child.tcpNoDelay", true)
bs
}
 
private[this] val allChannels = new DefaultChannelGroup
 
protected def getPipeline = {
val pipe = Channels.pipeline()
pipe.addLast("connection-tracker", new ConnectionTracker(allChannels))
pipe.addLast("decoder", new HttpRequestDecoder(4096, 8192, 8192))
pipe.addLast("aggregator", new HttpChunkAggregator(64 * 1024))
pipe.addLast("encoder", new HttpResponseEncoder)
pipe.addLast("websocketmessages", new WebSocketPartialFunctionHandler(realHandler, logger))
pipe
}
 
private[this] val servName = getClass.getSimpleName
 
def start = synchronized {
server.setPipeline(getPipeline)
val addr = if (config.listenOn == null || config.listenOn.trim.isEmpty) new InetSocketAddress(config.port)
else new InetSocketAddress(config.listenOn, config.port)
val sc = server.bind(addr)
allChannels add sc
logger info "Started %s on [%s:%d]".format(servName, config.listenOn, config.port)
}
 
def stop = synchronized {
allChannels.close().awaitUninterruptibly()
val thread = new Thread {
override def run = {
server.releaseExternalResources()
boss.awaitTermination(5, TimeUnit.SECONDS)
worker.awaitTermination(5, TimeUnit.SECONDS)
}
}
thread.setDaemon(false)
thread.start()
thread.join()
logger info "Stopped %s".format(servName)
}
}

The client implementation creates a new Netty channel factory per WebSocket connection, which means basically you have one thread (actually, more likely 3, 1 boss thread, up to 2 worker threads) per connection, thus defeating the point of using NIO. The channel factory should be passed in to the DefaultWebSocketClient so it can be reused between client connections. If this change is done, then that will also solve the thread leak caused by DefaultWebSocketClient.disconnect not shutting down the thread pools in the channel factory.

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.