Skip to content

Instantly share code, notes, and snippets.

@casualjim
Created February 13, 2012 19:42
Show Gist options
  • Star 24 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save casualjim/1819496 to your computer and use it in GitHub Desktop.
Save casualjim/1819496 to your computer and use it in GitHub Desktop.
A Netty based WebSocket client and server in scala
package mojolly.io
import org.jboss.netty.bootstrap.ClientBootstrap
import org.jboss.netty.channel._
import socket.nio.NioClientSocketChannelFactory
import java.util.concurrent.Executors
import org.jboss.netty.handler.codec.http._
import collection.JavaConversions._
import websocketx._
import java.net.{InetSocketAddress, URI}
import java.nio.charset.Charset
import org.jboss.netty.buffer.ChannelBuffers
import org.jboss.netty.util.CharsetUtil
import akka.actor.ActorRef
import mojolly.LibraryConstants
/**
* Usage of the simple websocket client:
* <pre>
* WebSocketClient(new URI("ws://localhost:8080/thesocket")) {
* case Connected(client) => println("Connection has been established to: " + client.url.toASCIIString)
* case Disconnected(client, _) => println("The websocket to " + client.url.toASCIIString + " disconnected.")
* case TextMessage(client, message) => {
* println("RECV: " + message)
* client send ("ECHO: " + message)
* }
* }
* </pre>
*/
object WebSocketClient {
object Messages {
sealed trait WebSocketClientMessage
case object Connecting extends WebSocketClientMessage
case class ConnectionFailed(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage
case class Connected(client: WebSocketClient) extends WebSocketClientMessage
case class TextMessage(client: WebSocketClient, text: String) extends WebSocketClientMessage
case class WriteFailed(client: WebSocketClient, message: String, reason: Option[Throwable]) extends WebSocketClientMessage
case object Disconnecting extends WebSocketClientMessage
case class Disconnected(client: WebSocketClient, reason: Option[Throwable] = None) extends WebSocketClientMessage
case class Error(client: WebSocketClient, th: Throwable) extends WebSocketClientMessage
}
type Handler = PartialFunction[Messages.WebSocketClientMessage, Unit]
type FrameReader = WebSocketFrame => String
val defaultFrameReader = (_: WebSocketFrame) match {
case f: TextWebSocketFrame => f.getText
case _ => throw new UnsupportedOperationException("Only single text frames are supported for now")
}
def apply(url: URI, version: WebSocketVersion = WebSocketVersion.V13, reader: FrameReader = defaultFrameReader)(handle: Handler): WebSocketClient = {
require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'")
new DefaultWebSocketClient(url, version, handle, reader)
}
def apply(url: URI, handle: ActorRef): WebSocketClient = {
require(url.getScheme.startsWith("ws"), "The scheme of the url should be 'ws' or 'wss'")
WebSocketClient(url) { case x => handle ! x }
}
private class WebSocketClientHandler(handshaker: WebSocketClientHandshaker, client: WebSocketClient) extends SimpleChannelUpstreamHandler {
import Messages._
override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
client.handler(Disconnected(client))
}
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
e.getMessage match {
case resp: HttpResponse if handshaker.isHandshakeComplete =>
throw new WebSocketException("Unexpected HttpResponse (status=" + resp.getStatus + ", content="
+ resp.getContent.toString(CharsetUtil.UTF_8) + ")")
case resp: HttpResponse =>
handshaker.finishHandshake(ctx.getChannel, e.getMessage.asInstanceOf[HttpResponse])
client.handler(Connected(client))
case f: TextWebSocketFrame => client.handler(TextMessage(client, f.getText))
case _: PongWebSocketFrame =>
case _: CloseWebSocketFrame => ctx.getChannel.close()
}
}
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
client.handler(Error(client, e.getCause))
e.getChannel.close()
}
}
private class DefaultWebSocketClient(
val url: URI,
version: WebSocketVersion,
private[this] val _handler: Handler,
val reader: FrameReader = defaultFrameReader) extends WebSocketClient {
val normalized = url.normalize()
val tgt = if (normalized.getPath == null || normalized.getPath.trim().isEmpty) {
new URI(normalized.getScheme, normalized.getAuthority,"/", normalized.getQuery, normalized.getFragment)
} else normalized
val bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool, Executors.newCachedThreadPool))
val handshaker = new WebSocketClientHandshakerFactory().newHandshaker(tgt, version, null, false, Map.empty[String, String])
val self = this
var channel: Channel = _
import Messages._
val handler = _handler orElse defaultHandler
private def defaultHandler: Handler = {
case Error(_, ex) => ex.printStackTrace()
case _: WebSocketClientMessage =>
}
bootstrap.setPipelineFactory(new ChannelPipelineFactory {
def getPipeline = {
val pipeline = Channels.pipeline()
if (version == WebSocketVersion.V00)
pipeline.addLast("decoder", new WebSocketHttpResponseDecoder)
else
pipeline.addLast("decoder", new HttpResponseDecoder)
pipeline.addLast("encoder", new HttpRequestEncoder)
pipeline.addLast("ws-handler", new WebSocketClientHandler(handshaker, self))
pipeline
}
})
import WebSocketClient.Messages._
def connect = {
if (channel == null || !channel.isConnected) {
val listener = futureListener { future =>
if (future.isSuccess) {
synchronized { channel = future.getChannel }
handshaker.handshake(channel)
} else {
handler(ConnectionFailed(this, Option(future.getCause)))
}
}
handler(Connecting)
val fut = bootstrap.connect(new InetSocketAddress(url.getHost, url.getPort))
fut.addListener(listener)
fut.await(5000L)
}
}
def disconnect = {
if (channel != null && channel.isConnected) {
handler(Disconnecting)
channel.write(new CloseWebSocketFrame())
}
}
def send(message: String, charset: Charset = CharsetUtil.UTF_8) = {
channel.write(new TextWebSocketFrame(ChannelBuffers.copiedBuffer(message, charset))).addListener(futureListener { fut =>
if (!fut.isSuccess) {
handler(WriteFailed(this, message, Option(fut.getCause)))
}
})
}
def futureListener(handleWith: ChannelFuture => Unit) = new ChannelFutureListener {
def operationComplete(future: ChannelFuture) {handleWith(future)}
}
}
/**
* Fix bug in standard HttpResponseDecoder for web socket clients. When status 101 is received for Hybi00, there are 16
* bytes of contents expected
*/
class WebSocketHttpResponseDecoder extends HttpResponseDecoder {
val codes = List(101, 200, 204, 205, 304)
protected override def isContentAlwaysEmpty(msg: HttpMessage) = {
msg match {
case res: HttpResponse => codes contains res.getStatus.getCode
case _ => false
}
}
}
/**
* A WebSocket related exception
*
* Copied from https://github.com/cgbystrom/netty-tools
*/
class WebSocketException(s: String, th: Throwable) extends java.io.IOException(s, th) {
def this(s: String) = this(s, null)
}
}
trait WebSocketClient {
def url: URI
def reader: WebSocketClient.FrameReader
def handler: WebSocketClient.Handler
def connect
def disconnect
def send(message: String, charset: Charset = CharsetUtil.UTF_8)
}
package io.backchat.minutes.river
import org.jboss.netty.bootstrap.ServerBootstrap
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory
import java.util.concurrent.{ TimeUnit, Executors }
import java.net.{ InetSocketAddress }
import org.jboss.netty.channel._
import group.{ChannelGroup, DefaultChannelGroup}
import org.elasticsearch.common.logging.{ESLogger, ESLoggerFactory}
import org.jboss.netty.handler.codec.http.{HttpRequest, HttpChunkAggregator, HttpRequestDecoder, HttpResponseEncoder}
import org.jboss.netty.handler.codec.http.websocketx._
import org.jboss.netty.handler.codec.http.HttpHeaders.Values
import org.jboss.netty.handler.codec.http.HttpHeaders.Names
import java.util.Locale.ENGLISH
trait WebSocketServerConfig {
def listenOn: String
def port: Int
}
/**
* Netty based WebSocketServer
* requires netty 3.3.x or later
*
* Usage:
* <pre>
* val conf = new WebSocketServerConfig {
* val port = 14567
* val listenOn = "0.0.0.0"
* }
*
* val server = WebSocketServer(conf) {
* case Connect(_) => println("got a client connection")
* case TextMessage(cl, text) => cl.write(new TextWebSocketFrame("ECHO: " + text))
* case Disconnected(_) => println("client disconnected")
* }
* server.start
* // time passes......
* server.stop
* </pre>
*/
object WebSocketServer {
type WebSocketHandler = PartialFunction[WebSocketMessage, Unit]
sealed trait WebSocketMessage
case class Connect(client: Channel) extends WebSocketMessage
case class TextMessage(client: Channel, content: String) extends WebSocketMessage
case class BinaryMessage(client: Channel, content: Array[Byte]) extends WebSocketMessage
case class Error(client: Channel, cause: Option[Throwable]) extends WebSocketMessage
case class Disconnected(client: Channel) extends WebSocketMessage
def apply(config: WebSocketServerConfig)(handler: WebSocketServer.WebSocketHandler): WebSocketServer =
new WebSocketServer(config, handler)
private class ConnectionTracker(channels: ChannelGroup) extends SimpleChannelUpstreamHandler {
override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
channels remove e.getChannel
ctx.sendUpstream(e)
}
override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
channels add e.getChannel
ctx.sendUpstream(e)
}
override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
channels remove e.getChannel
ctx.sendUpstream(e)
}
}
private class WebSocketPartialFunctionHandler(handler: WebSocketHandler, logger: ESLogger) extends SimpleChannelUpstreamHandler {
private[this] var collectedFrames: Seq[ContinuationWebSocketFrame] = Vector.empty[ContinuationWebSocketFrame]
private[this] var handshaker: WebSocketServerHandshaker = _
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
e.getMessage match {
case httpRequest: HttpRequest if isWebSocketUpgrade(httpRequest) ⇒ handleUpgrade(ctx, httpRequest)
case m: TextWebSocketFrame => handler lift TextMessage(e.getChannel, m.getText)
case m: BinaryWebSocketFrame => handler lift BinaryMessage(e.getChannel, m.getBinaryData.array)
case m: ContinuationWebSocketFrame => {
if (m.isFinalFragment) {
handler lift TextMessage(e.getChannel, collectedFrames map (_.getText) reduce (_ + _))
collectedFrames = Nil
} else {
collectedFrames :+= m
}
}
case f: CloseWebSocketFrame ⇒
if (handshaker != null) handshaker.close(ctx.getChannel, f)
handler lift Disconnected(e.getChannel)
case _: PingWebSocketFrame ⇒ e.getChannel.write(new PongWebSocketFrame)
case _ ⇒ ctx.sendUpstream(e)
}
}
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
handler lift Error(e.getChannel, Option(e.getCause))
}
private def isWebSocketUpgrade(httpRequest: HttpRequest): Boolean = {
val connHdr = httpRequest.getHeader(Names.CONNECTION)
val upgrHdr = httpRequest.getHeader(Names.UPGRADE)
(connHdr != null && connHdr.equalsIgnoreCase(Values.UPGRADE)) &&
(upgrHdr != null && upgrHdr.equalsIgnoreCase(Values.WEBSOCKET))
}
private def handleUpgrade(ctx: ChannelHandlerContext, httpRequest: HttpRequest) {
val handshakerFactory = new WebSocketServerHandshakerFactory(websocketLocation(httpRequest), null, false)
handshaker = handshakerFactory.newHandshaker(httpRequest)
if (handshaker == null) handshakerFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel)
else {
handshaker.handshake(ctx.getChannel, httpRequest)
handler.lift(Connect(ctx.getChannel))
}
}
private def isHttps(req: HttpRequest) = {
val h1 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty)
val h2 = Option(req.getHeader("REQUEST_URI")).filter(_.trim.nonEmpty)
(h1.isDefined && h1.forall(_.toUpperCase(ENGLISH).startsWith("HTTPS"))) ||
(h2.isDefined && h2.forall(_.toUpperCase(ENGLISH) startsWith "HTTPS"))
}
private def websocketLocation(req: HttpRequest) = {
if (isHttps(req))
"wss://" + req.getHeader(Names.HOST) + "/"
else
"ws://" + req.getHeader(Names.HOST) + "/"
}
}
}
class WebSocketServer(val config: WebSocketServerConfig, val handler: WebSocketServer.WebSocketHandler) {
import WebSocketServer._
private[this] val realHandler = handler orElse devNull
private[this] val devNull: WebSocketHandler = {
case WebSocketServer.Error(_, Some(ex)) =>
System.err.println(ex.getMessage)
ex.printStackTrace()
case _ =>
}
protected val logger = ESLoggerFactory.getLogger(getClass.getName)
private[this] val boss = Executors.newCachedThreadPool()
private[this] val worker = Executors.newCachedThreadPool()
private[this] val server = {
val bs = new ServerBootstrap(new NioServerSocketChannelFactory(boss, worker))
bs.setOption("soLinger", 0)
bs.setOption("reuseAddress", true)
bs.setOption("child.tcpNoDelay", true)
bs
}
private[this] val allChannels = new DefaultChannelGroup
protected def getPipeline = {
val pipe = Channels.pipeline()
pipe.addLast("connection-tracker", new ConnectionTracker(allChannels))
pipe.addLast("decoder", new HttpRequestDecoder(4096, 8192, 8192))
pipe.addLast("aggregator", new HttpChunkAggregator(64 * 1024))
pipe.addLast("encoder", new HttpResponseEncoder)
pipe.addLast("websocketmessages", new WebSocketPartialFunctionHandler(realHandler, logger))
pipe
}
private[this] val servName = getClass.getSimpleName
def start = synchronized {
server.setPipeline(getPipeline)
val addr = if (config.listenOn == null || config.listenOn.trim.isEmpty) new InetSocketAddress(config.port)
else new InetSocketAddress(config.listenOn, config.port)
val sc = server.bind(addr)
allChannels add sc
logger info "Started %s on [%s:%d]".format(servName, config.listenOn, config.port)
}
def stop = synchronized {
allChannels.close().awaitUninterruptibly()
val thread = new Thread {
override def run = {
server.releaseExternalResources()
boss.awaitTermination(5, TimeUnit.SECONDS)
worker.awaitTermination(5, TimeUnit.SECONDS)
}
}
thread.setDaemon(false)
thread.start()
thread.join()
logger info "Stopped %s".format(servName)
}
}
@jroper
Copy link

jroper commented Feb 16, 2013

The client implementation creates a new Netty channel factory per WebSocket connection, which means basically you have one thread (actually, more likely 3, 1 boss thread, up to 2 worker threads) per connection, thus defeating the point of using NIO. The channel factory should be passed in to the DefaultWebSocketClient so it can be reused between client connections. If this change is done, then that will also solve the thread leak caused by DefaultWebSocketClient.disconnect not shutting down the thread pools in the channel factory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment