Skip to content

Instantly share code, notes, and snippets.

@sturgle
Created September 29, 2014 08:29
Show Gist options
  • Save sturgle/991d8b14626033d3bc09 to your computer and use it in GitHub Desktop.
Save sturgle/991d8b14626033d3bc09 to your computer and use it in GitHub Desktop.
A java nio tcp client. Keep it here as notes.
// TCP Client (Java NIO)
// A minimalistic Java NIO TCP client. Stays always connected. Disconnects and reconnects on any exception in the event handlers (onConnect, onDisconnect, onRead).
package net.bobah.nio;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.PostConstruct;
import org.apache.log4j.BasicConfigurator;
import org.apache.log4j.ConsoleAppender;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.PatternLayout;
/**
* A simple NIO TCP client
* Assumptions:
* - the client should always be connected,
* once it gets disconnected it reconnects
* - the exception thrown by onRead means protocol error
* so client disconnects and reconnects
* - the incoming flow is higher than outgoing, so
* direct channel write method is not implemented
*
* @author Vladimir Lysyy (mail@bobah.net)
*
*/
public abstract class TcpClient implements Runnable {
protected static final Logger LOG = Logger.getLogger(TcpClient.class);
private static final long INITIAL_RECONNECT_INTERVAL = 500; // 500 ms.
private static final long MAXIMUM_RECONNECT_INTERVAL = 30000; // 30 sec.
private static final int READ_BUFFER_SIZE = 0x100000;
private static final int WRITE_BUFFER_SIZE = 0x100000;
private long reconnectInterval = INITIAL_RECONNECT_INTERVAL;
private ByteBuffer readBuf = ByteBuffer.allocateDirect(READ_BUFFER_SIZE); // 1Mb
private ByteBuffer writeBuf = ByteBuffer.allocateDirect(WRITE_BUFFER_SIZE); // 1Mb
private final Thread thread = new Thread(this);
private SocketAddress address;
private Selector selector;
private SocketChannel channel;
private final AtomicBoolean connected = new AtomicBoolean(false);
private AtomicLong bytesOut = new AtomicLong(0L);
private AtomicLong bytesIn = new AtomicLong(0L);
public TcpClient() {
}
@PostConstruct
public void init() {
assert address != null: "server address missing";
}
public void start() throws IOException {
LOG.info("starting event loop");
thread.start();
}
public void join() throws InterruptedException {
if (Thread.currentThread().getId() != thread.getId()) thread.join();
}
public void stop() throws IOException, InterruptedException {
LOG.info("stopping event loop");
thread.interrupt();
selector.wakeup();
}
public boolean isConnected() {
return connected.get();
}
/**
* @param buffer data to send, the buffer should be flipped (ready for read)
* @throws InterruptedException
* @throws IOException
*/
public void send(ByteBuffer buffer) throws InterruptedException, IOException {
if (!connected.get()) throw new IOException("not connected");
synchronized (writeBuf) {
// try direct write of what's in the buffer to free up space
if (writeBuf.remaining() < buffer.remaining()) {
writeBuf.flip();
int bytesOp = 0, bytesTotal = 0;
while (writeBuf.hasRemaining() && (bytesOp = channel.write(writeBuf)) > 0) bytesTotal += bytesOp;
writeBuf.compact();
}
// if didn't help, wait till some space appears
if (Thread.currentThread().getId() != thread.getId()) {
while (writeBuf.remaining() < buffer.remaining()) writeBuf.wait();
}
else {
if (writeBuf.remaining() < buffer.remaining()) throw new IOException("send buffer full"); // TODO: add reallocation or buffers chain
}
writeBuf.put(buffer);
// try direct write to decrease the latency
writeBuf.flip();
int bytesOp = 0, bytesTotal = 0;
while (writeBuf.hasRemaining() && (bytesOp = channel.write(writeBuf)) > 0) bytesTotal += bytesOp;
writeBuf.compact();
if (writeBuf.hasRemaining()) {
SelectionKey key = channel.keyFor(selector);
key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
selector.wakeup();
}
}
}
/**
* Override with something meaningful
* @param buf
*/
protected abstract void onRead(ByteBuffer buf) throws Exception;
/**
* Override with something meaningful
* @param buf
*/
protected abstract void onConnected() throws Exception;
/**
* Override with something meaningful
* @param buf
*/
protected abstract void onDisconnected();
private void configureChannel(SocketChannel channel) throws IOException {
channel.configureBlocking(false);
channel.socket().setSendBufferSize(0x100000); // 1Mb
channel.socket().setReceiveBufferSize(0x100000); // 1Mb
channel.socket().setKeepAlive(true);
channel.socket().setReuseAddress(true);
channel.socket().setSoLinger(false, 0);
channel.socket().setSoTimeout(0);
channel.socket().setTcpNoDelay(true);
}
@Override
public void run() {
LOG.info("event loop running");
try {
while(! Thread.interrupted()) { // reconnection loop
try {
selector = Selector.open();
channel = SocketChannel.open();
configureChannel(channel);
channel.connect(address);
channel.register(selector, SelectionKey.OP_CONNECT);
while(!thread.isInterrupted() && channel.isOpen()) { // events multiplexing loop
if (selector.select() > 0) processSelectedKeys(selector.selectedKeys());
}
} catch (Exception e) {
LOG.error("exception", e);
} finally {
connected.set(false);
onDisconnected();
writeBuf.clear();
readBuf.clear();
if (channel != null) channel.close();
if (selector != null) selector.close();
LOG.info("connection closed");
}
try {
Thread.sleep(reconnectInterval);
if (reconnectInterval < MAXIMUM_RECONNECT_INTERVAL) reconnectInterval *= 2;
LOG.info("reconnecting to " + address);
} catch (InterruptedException e) {
break;
}
}
} catch (Exception e) {
LOG.error("unrecoverable error", e);
}
LOG.info("event loop terminated");
}
private void processSelectedKeys(Set keys) throws Exception {
Iterator itr = keys.iterator();
while (itr.hasNext()) {
SelectionKey key = itr.next();
if (key.isReadable()) processRead(key);
if (key.isWritable()) processWrite(key);
if (key.isConnectable()) processConnect(key);
if (key.isAcceptable()) ;
itr.remove();
}
}
private void processConnect(SelectionKey key) throws Exception {
SocketChannel ch = (SocketChannel) key.channel();
if (ch.finishConnect()) {
LOG.info("connected to " + address);
key.interestOps(key.interestOps() ^ SelectionKey.OP_CONNECT);
key.interestOps(key.interestOps() | SelectionKey.OP_READ);
reconnectInterval = INITIAL_RECONNECT_INTERVAL;
connected.set(true);
onConnected();
}
}
private void processRead(SelectionKey key) throws Exception {
ReadableByteChannel ch = (ReadableByteChannel)key.channel();
int bytesOp = 0, bytesTotal = 0;
while (readBuf.hasRemaining() && (bytesOp = ch.read(readBuf)) > 0) bytesTotal += bytesOp;
if (bytesTotal > 0) {
readBuf.flip();
onRead(readBuf);
readBuf.compact();
}
else if (bytesOp == -1) {
LOG.info("peer closed read channel");
ch.close();
}
bytesIn.addAndGet(bytesTotal);
}
private void processWrite(SelectionKey key) throws IOException {
WritableByteChannel ch = (WritableByteChannel)key.channel();
synchronized (writeBuf) {
writeBuf.flip();
int bytesOp = 0, bytesTotal = 0;
while (writeBuf.hasRemaining() && (bytesOp = ch.write(writeBuf)) > 0) bytesTotal += bytesOp;
bytesOut.addAndGet(bytesTotal);
if (writeBuf.remaining() == 0) {
key.interestOps(key.interestOps() ^ SelectionKey.OP_WRITE);
}
if (bytesTotal > 0) writeBuf.notify();
else if (bytesOp == -1) {
LOG.info("peer closed write channel");
ch.close();
}
writeBuf.compact();
}
}
public SocketAddress getAddress() {
return address;
}
public void setAddress(SocketAddress address) {
this.address = address;
}
public long getBytesOut() {
return bytesOut.get();
}
public long getBytesIn() {
return bytesIn.get();
}
/**
* can be used for testing
*/
public static void main(String[] args) throws Exception {
BasicConfigurator.configure(new ConsoleAppender(new PatternLayout("%d{yyyyMMdd-HH:mm:ss} %-10t %-5p %-20C{1} - %m%n")));
Logger.getRootLogger().setLevel(Level.INFO);
final TcpClient client = new TcpClient() {
@Override protected void onRead(ByteBuffer buf) throws Exception { buf.position(buf.limit()); }
@Override protected void onDisconnected() { }
@Override protected void onConnected() throws Exception { }
};
client.setAddress(new InetSocketAddress("127.0.0.1", 20001));
try {
client.start();
} catch (IOException e) {
e.printStackTrace();
}
Timer timer = new Timer();
timer.schedule(new TimerTask() {
@Override
public void run() {
LOG.info("out bytes: " + client.bytesOut.get());
LOG.info("in bytes: " + client.bytesIn.get());
}
}, 5000, 5000);
while(!client.isConnected()) Thread.sleep(500);
LOG.info("starting server flood");
ByteBuffer buf = ByteBuffer.allocate(65535);
Random rnd = new Random();
while (true) {
short len = (short) rnd.nextInt(Short.MAX_VALUE - 2);
byte[] bytes = new byte[len];
rnd.nextBytes(bytes);
buf.putShort((short)len);
buf.put(bytes);
buf.flip();
try {
client.send(buf);
} catch (Exception e) {
LOG.error("exception: " + e.getMessage());
while (!client.isConnected()) Thread.sleep(1000);
}
buf.clear();
Thread.sleep(10);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment