Skip to content

Instantly share code, notes, and snippets.

@brimworks
Last active July 2, 2021 14:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brimworks/409e6b847a969896ace387385447c2c6 to your computer and use it in GitHub Desktop.
Save brimworks/409e6b847a969896ace387385447c2c6 to your computer and use it in GitHub Desktop.
Implements the ByteChannel and GatheringByteChannel interfaces in a blocking way.
package io.nats.client.support;
import java.io.IOException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import static java.nio.charset.StandardCharsets.UTF_8;
/**
* Note that a ByteBuffer is always in one of two modes:
*
* <ul>
* <li>Get mode: Valid data is between position and limit. Some documentation may refer to
* this as "read" mode, since the buffer is ready for various <code>.get()</code> method calls,
* however this can be confusing, since you would NEVER call the
* {@link java.nio.channels.ReadableByteChannel#read(ByteBuffer)} method when operating in
* this mode. This may also be referred to as "flush" mode, but nothing in the Buffer
* documentation refers to the term "flush", and thus we use the term "get" mode here.
* <li>Put mode: Valid data is between 0 and position. Some documentation may refer to this as
* "write" mode, since the buffer is ready for varous <code>.put()</code> method calls,
* however this can be confusing since you would NEVER call the
* {@link java.nio.channels.WritableByteChannel#write(ByteBuffer)} method when operating in
* this mode. This may also be referred to as "fill" mode, but the term fill is only used
* in one place in the Buffer.clear() documentation.
* </ul>
*
* All documentation will be using the terms "get mode" or "put mode".
*/
public interface BufferUtils {
static final char[] HEX_ARRAY = "0123456789ABCDEF".toCharArray();
static final char SUBSTITUTE_CHAR = 0x2423;
/**
* It is not to uncommon to buffer data into a temporary buffer
* and then append this temporary buffer into a destination buffer.
* This method makes this operation easy since your temporary buffer
* is typically in "put" mode and your destination is always in "put"
* mode, and you need to take into account that the temporary buffer
* may contain more bytes than your destination buffer, but you don't
* want a BufferOverflowException to occur, instead you just want to
* fullfill as many bytes from your temporary buffer as is possible.
*
* See <code>org.eclipse.jetty.util.BufferUtil.append(ByteBuffer, ByteBuffer)</code>
* for a similar method.
*
* @param src is a buffer in "put" mode which will be flip'ed
* and then "safely" put into dst followed by a compact call.
* @param dst is a buffer in "put" mode which will be populated
* from src.
* @param max is the max bytes to transfer.
* @return min(src.position(), dst.position(), max)
*/
static int append(ByteBuffer src, ByteBuffer dst, int max) {
if (src.position() < max) {
max = src.position();
}
if (dst.remaining() < max) {
max = dst.remaining();
}
src.flip();
try {
ByteBuffer slice = src.slice();
slice.limit(max);
dst.put(slice);
} finally {
src.position(max);
src.compact();
}
return max;
}
/**
* Delegates to {@link #append(ByteBuffer,ByteBuffer,int)}, with
* max set to Integer.MAX_VALUE.
*
* @param src is a buffer in "put" mode which will be flip'ed
* and then "safely" put into dst followed by a compact call.
* @param dst is a buffer in "put" mode which will be populated
* from src.
* @return min(src.position(), dst.position())
*/
static int append(ByteBuffer src, ByteBuffer dst) {
return append(src, dst, Integer.MAX_VALUE);
}
/**
* Throws BufferUnderflowException if there are insufficient capacity in
* buffer to fillfill the request.
*
* @param readBuffer is in "put" mode (0 - position are valid)
* @param reader is a reader used to populate the buffer if insufficient remaining
* bytes exist in buffer. May be null if buffer should not be populated.
* @throws BufferUnderflowException if the buffer has insufficient capacity
* to read a full line.
* @throws IOException if reader.read() throws this exception.
* @return a line without line terminators or null if end of channel.
*/
static String readLine(ByteBuffer readBuffer, ReadableByteChannel reader) throws IOException {
if (null == readBuffer) {
throw new NullPointerException("Expected non-null readBuffer");
}
int end = 0;
boolean foundCR = false;
int newlineLength = 1;
FIND_END:
while (true) {
if (end >= readBuffer.position()) {
if (readBuffer.position() == readBuffer.limit()) {
// Insufficient capacity in ByteBuffer to read a full line!
throw new BufferUnderflowException();
}
if (null == reader || reader.read(readBuffer) < 0) {
if (end > 0) {
if (!foundCR) {
newlineLength = 0;
}
break FIND_END;
}
return null;
}
}
switch (readBuffer.get(end++)) {
case '\r':
if (foundCR) {
--end;
break FIND_END; // Legacy MAC end of line
}
foundCR = true;
break;
case '\n':
if (foundCR) {
newlineLength++;
}
break FIND_END;
default:
if (foundCR) {
--end;
break FIND_END; // Legacy MAC end of line
}
}
}
String result;
readBuffer.flip();
try {
ByteBuffer slice = readBuffer.slice();
slice.limit(end - newlineLength);
result = UTF_8.decode(slice).toString();
} finally {
readBuffer.position(end);
readBuffer.compact();
}
return result;
}
static long remaining(ByteBuffer[] buffers, int offset, int length) {
int total = 0;
int end = offset + length;
while (offset < end) {
total += buffers[offset++].remaining();
}
return total;
}
/**
* Utility method for stringifying a bytebuffer (use ByteBuffer.wrap(byte[])
* if you want to stringify a byte array). Mostly useful for debugging or
* tracing.
*
* @param bytes is the byte buffer.
* @param off is the offset within the byte buffer to begin.
* @param len is the number of bytes to print.
* @return a "hexdump" of the bytes
*/
static String hexdump(ByteBuffer bytes, int off, int len) {
int end = off + len;
StringBuilder sb = new StringBuilder();
for (int i=off; i < end;) {
sb.append(String.format("%04x ", i));
int start = i;
do {
int ch = bytes.get(i) & 0xFF;
sb.append(" ");
if (i % 16 == 8) {
sb.append(" ");
}
sb.append(HEX_ARRAY[ch >>> 4]);
sb.append(HEX_ARRAY[ch & 0x0F]);
} while (++i % 16 != 0 && i < end);
if (i % 16 != 0) {
sb.append(new String(new char[16 - i % 16]).replace("\0", " "));
if (i % 16 < 7) {
sb.append(" ");
}
}
sb.append(" ");
i = start;
do {
char ch = (char)bytes.get(i);
if (ch < 0x21) {
// Control chars:
switch (ch) {
case ' ':
sb.append((char)0x2420);
break;
case '\t':
sb.append((char)0x2409);
break;
case '\r':
sb.append((char)0x240D);
break;
case '\n':
sb.append((char)0x2424);
break;
default:
sb.append(SUBSTITUTE_CHAR);
}
} else if (ch < 0x7F) {
sb.append(ch);
} else {
// control chars:
sb.append(SUBSTITUTE_CHAR);
}
} while (++i % 16 != 0 && i < end);
sb.append("\n");
}
return sb.toString();
}
}
package io.nats.client.channels;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.GatheringByteChannel;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import static io.nats.client.support.BufferUtils.append;
import static io.nats.client.support.BufferUtils.remaining;
/**
* It blows my mind that JDK doesn't provide this functionality by default.
*
* This is an implementation of ByteChannel which uses an SSLEngine to encrypt data sent
* to a ByteChannel that is being wrapped, and then decrypts received data.
*/
public class TLSByteChannel implements ByteChannel, GatheringByteChannel {
private static final ByteBuffer[] EMPTY = new ByteBuffer[]{ByteBuffer.allocate(0)};
private final ByteChannel wrap;
private final SSLEngine engine;
// NOTE: Locks should always be acquired in this order:
// readLock > writeLock > stateLock
private final Lock readLock = new ReentrantLock();
private final Lock writeLock = new ReentrantLock();
// All of the below state is controlled with this lock:
private final Object stateLock = new Object();
private Thread readThread = null;
private Thread writeThread = null;
private State state;
// end state protected by the stateLock
private final ByteBuffer outNetBuffer; // in "get" mode, protected by writeLock
private final ByteBuffer inNetBuffer; // in "put" mode, protected by readLock
private final ByteBuffer inAppBuffer; // in "put" mode, protected by readLock
private enum State {
HANDSHAKING_READ,
HANDSHAKING_WRITE,
HANDSHAKING_TASK,
OPEN,
CLOSING,
CLOSED;
}
public TLSByteChannel(ByteChannel wrap, SSLEngine engine) throws IOException {
this.wrap = wrap;
this.engine = engine;
int netBufferSize = engine.getSession().getPacketBufferSize();
int appBufferSize = engine.getSession().getApplicationBufferSize();
outNetBuffer = ByteBuffer.allocate(netBufferSize);
outNetBuffer.flip();
inNetBuffer = ByteBuffer.allocate(netBufferSize);
inAppBuffer = ByteBuffer.allocate(appBufferSize);
engine.beginHandshake();
state = toState(engine.getHandshakeStatus());
}
/**
* Translate an SSLEngine.HandshakeStatus into internal state.
*/
private static State toState(HandshakeStatus status) {
switch (status) {
case NEED_TASK:
return State.HANDSHAKING_TASK;
case NEED_UNWRAP:
return State.HANDSHAKING_READ;
case NEED_WRAP:
return State.HANDSHAKING_WRITE;
case FINISHED:
case NOT_HANDSHAKING:
return State.OPEN;
default:
throw new IllegalStateException("Unexpected SSLEngine.HandshakeStatus=" + status);
}
}
/**
* Force a TLS handshake to take place if it has not already happened.
* @return false if end of file is observed.
* @throws IOException if any underlying read or write call throws.
*/
public boolean handshake() throws IOException {
while (true) {
boolean needsRead = false;
boolean needsWrite = false;
synchronized (stateLock) {
switch (state) {
case HANDSHAKING_TASK:
executeTasks();
state = toState(engine.getHandshakeStatus());
break;
case HANDSHAKING_READ:
needsRead = true;
break;
case HANDSHAKING_WRITE:
needsWrite = true;
break;
default:
return true;
}
}
if (needsRead) {
if (readImpl(EMPTY[0]) < 0) {
return false;
}
} else if (needsWrite) {
if (writeImpl(EMPTY, 0, 1) < 0) {
return false;
}
}
}
}
/**
* Gracefully close the TLS session and the underlying wrap'ed socket.
*/
@Override
public void close() throws IOException {
if (!wrap.isOpen()) {
return;
}
// [1] Make sure any handshake has happened:
handshake();
// [2] Set state to closing, or return if another thread
// is already closing.
synchronized (stateLock) {
if (State.CLOSED == state || State.CLOSING == state) {
return;
} else {
state = State.CLOSING;
}
// [3] Interrupt any reading/writing threads:
if (null != readThread) {
readThread.interrupt();
}
if (null != writeThread) {
writeThread.interrupt();
}
}
// [4] Try to acquire readLock:
try {
if (!readLock.tryLock(100, TimeUnit.MICROSECONDS)) {
wrap.close();
return;
}
try {
// [5] Try to acquire writeLock:
if (!writeLock.tryLock(100, TimeUnit.MICROSECONDS)) {
wrap.close();
return;
}
try {
// [6] Finally, implement close sequence.
closeImpl();
} finally {
writeLock.unlock();
}
} finally {
readLock.unlock();
}
} catch (InterruptedException ex) {
// Non-graceful close!
Thread.currentThread().interrupt();
wrap.close();
return;
}
}
@Override
public boolean isOpen() {
return wrap.isOpen();
}
/**
* Implement the close procedure.
*
* Precondition: read & write locks are acquired
*
* Postcondition: state is CLOSED
*/
private void closeImpl() throws IOException {
synchronized (stateLock) {
if (State.CLOSED == state) {
return;
}
state = State.CLOSING;
}
try {
// NOTE: unread data may be lost. However, we assume this is desired
// since we are transitioning to closing:
inAppBuffer.clear();
if (outNetBuffer.hasRemaining()) {
wrap.write(outNetBuffer);
}
engine.closeOutbound();
try {
while (!engine.isOutboundDone()) {
if (writeImpl(EMPTY, 0, 1) < 0) {
throw new ClosedChannelException();
}
}
while (!engine.isInboundDone()) {
if (readImpl(EMPTY[0]) < 0) {
throw new ClosedChannelException();
}
}
engine.closeInbound();
} catch (ClosedChannelException ex) {
// already closed, ignore.
}
} finally {
try {
// No matter what happens, we need to close the
// wrapped channel:
wrap.close();
} finally {
// ...and no matter what happens, we need to
// indicate that we are in a CLOSED state:
synchronized (stateLock) {
state = State.CLOSED;
}
}
}
}
/**
* Read plaintext by decrypting the underlying wrap'ed sockets encrypted bytes.
*
* @param dst is the buffer to populate between position and limit.
* @return the number of bytes populated or -1 to indicate end of stream,
* and the dst position will also be incremented appropriately.
*/
@Override
public int read(ByteBuffer dst) throws IOException {
int result = 0;
while (0 == result) {
if (!handshake()) {
return -1;
}
if (!dst.hasRemaining()) {
return 0;
}
result = readImpl(dst);
}
return result;
}
/**
* Precondition: handshake() was called, or this code was called
* by the handshake() implementation.
*/
private int readImpl(ByteBuffer dst) throws IOException {
readLock.lock();
try {
// [1] Check if this is a read for a handshake:
synchronized (stateLock) {
if (isHandshaking(state)) {
if (state != State.HANDSHAKING_READ) {
return 0;
}
dst = EMPTY[0];
}
readThread = Thread.currentThread();
}
// [2] Satisfy read via inAppBuffer:
int count = append(inAppBuffer, dst);
if (count > 0) {
return count;
}
// [3] Read & decrypt loop:
return readAndDecryptLoop(dst);
} finally {
readLock.unlock();
readThread = null;
}
}
/**
* Return true if we are handshaking.
*/
private static boolean isHandshaking(State state) {
switch (state) {
case HANDSHAKING_READ:
case HANDSHAKING_WRITE:
case HANDSHAKING_TASK:
return true;
case CLOSED:
case OPEN:
case CLOSING:
}
return false;
}
/**
* Precondition: readLock acquired
*/
private int readAndDecryptLoop(ByteBuffer dst) throws IOException {
boolean networkRead = inNetBuffer.position() == 0;
while (true) {
// Read from network:
if (networkRead) {
synchronized (stateLock) {
if (State.OPEN == state && !dst.hasRemaining()) {
return 0;
}
}
if (wrap.read(inNetBuffer) < 0) {
return -1;
}
}
SSLEngineResult result;
synchronized(stateLock) {
// Decrypt:
inNetBuffer.flip();
try {
result = engine.unwrap(inNetBuffer, dst);
} finally {
inNetBuffer.compact();
}
State newState = toState(result.getHandshakeStatus());
if (state != State.CLOSING && newState != state) {
state = newState;
}
}
SSLEngineResult.Status status = result.getStatus();
switch (status) {
case BUFFER_OVERFLOW:
if (dst == inAppBuffer) {
throw new IllegalStateException(
"SSLEngine indicated app buffer size=" + inAppBuffer.capacity() +
", but unwrap() returned BUFFER_OVERFLOW with an empty buffer");
}
// Not enough space in dst, so buffer it into inAppBuffer:
readAndDecryptLoop(inAppBuffer);
return append(inAppBuffer, dst);
case BUFFER_UNDERFLOW:
if (!inNetBuffer.hasRemaining()) {
throw new IllegalStateException(
"SSLEngine indicated net buffer size=" + inNetBuffer.capacity() +
", but unwrap() returned BUFFER_UNDERFLOW with a full buffer");
}
networkRead = inNetBuffer.hasRemaining();
break; // retry network read
case CLOSED:
try {
wrap.close();
} finally {
synchronized (stateLock) {
state = State.CLOSED;
}
}
return -1;
case OK:
return result.bytesProduced();
default:
throw new IllegalStateException("Unexpected status=" + status);
}
}
}
/**
* Write plaintext by encrypting and writing this to the underlying wrap'ed socket.
*
* @param srcs are the buffers of plaintext to encrypt.
* @param offset is the offset within the array to begin writing.
* @param length is the number of buffers within the srcs array that should be written.
* @return the number of bytes that got written or -1 to indicate end of
* stream and the src position will also be incremented appropriately.
*/
@Override
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
int result = 0;
while (0 == result) {
if (!handshake()) {
return -1;
}
if (0 == remaining(srcs, offset, length)) {
return 0;
}
result = writeImpl(srcs, offset, length);
}
return result;
}
@Override
public int write(ByteBuffer src) throws IOException {
return Math.toIntExact(write(new ByteBuffer[]{src}, 0, 1));
}
@Override
public long write(ByteBuffer[] srcs) throws IOException {
return write(srcs, 0, srcs.length);
}
/**
* While there are delegatedTasks to run, run them.
*
* Precondition: stateLock acquired
*/
private void executeTasks() {
while (true) {
Runnable runnable = engine.getDelegatedTask();
if (null == runnable) {
break;
}
runnable.run();
}
}
/**
* Implement a write operation.
*
* @param src is the source buffer to write
* @return the number of bytes written or -1 if end of stream.
*
* Precondition: write lock is acquired.
*/
private int writeImpl(ByteBuffer[] srcs, int offset, int length) throws IOException {
writeLock.lock();
try {
// [1] Wait until handshake is complete in other thread.
synchronized (stateLock) {
if (isHandshaking(state)) {
if (state != State.HANDSHAKING_WRITE) {
return 0;
}
srcs = EMPTY;
}
writeThread = Thread.currentThread();
}
// [2] Write & decrypt loop:
return writeAndEncryptLoop(srcs, offset, length);
} finally {
writeLock.unlock();
writeThread = null;
}
}
private int writeAndEncryptLoop(ByteBuffer[] srcs, int offset, int length) throws IOException {
if (offset >= length) {
return 0;
}
int count = 0;
boolean finalNetFlush = false;
int srcsEnd = offset + length;
while (true) {
SSLEngineResult result = null;
synchronized (stateLock) {
// Encrypt:
outNetBuffer.compact();
try {
for (; offset < srcsEnd; offset++, length--) {
ByteBuffer src = srcs[offset];
int startPosition = src.position();
result = engine.wrap(src, outNetBuffer);
count += src.position() - startPosition;
if (result.getStatus() != SSLEngineResult.Status.OK) {
break;
}
}
} finally {
outNetBuffer.flip();
}
State newState = toState(result.getHandshakeStatus());
if (state != State.CLOSING && state != newState) {
state = newState;
}
}
SSLEngineResult.Status status = result.getStatus();
switch (status) {
case BUFFER_OVERFLOW:
if (outNetBuffer.remaining() == outNetBuffer.capacity()) {
throw new IllegalStateException(
"SSLEngine indicated net buffer size=" + outNetBuffer.capacity() +
", but wrap() returned BUFFER_OVERFLOW with a full buffer");
}
break; // retry network write.
case BUFFER_UNDERFLOW:
throw new IllegalStateException("SSLEngine.wrap() should never return BUFFER_UNDERFLOW");
case CLOSED:
finalNetFlush = true;
break;
case OK:
finalNetFlush = offset >= srcsEnd;
break; // perform a final net write.
default:
throw new IllegalStateException("Unexpected status=" + result.getStatus());
}
// Write to network:
if (outNetBuffer.remaining() > 0) {
if (wrap.write(outNetBuffer) < 0) {
return -1;
}
}
if (finalNetFlush || 0 == remaining(srcs, offset, length)) {
break;
}
}
return count;
}
}
package io.nats.client.channels;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.time.Duration;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import org.junit.jupiter.api.Test;
import io.nats.client.NatsTestServer;
import io.nats.client.TestSSLUtils;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException;
public class TLSByteChannelTests {
private static final ByteBuffer EMPTY = ByteBuffer.allocate(0);
@Test
public void testShortAppRead() throws Exception {
// Scenario: Net read TLS frame which is larger than the read buffer.
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) {
URI uri = new URI(ts.getURI());
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2));
ByteBuffer info = ByteBuffer.allocate(1024 * 1024);
socket.read(info);
TLSByteChannel tls = new TLSByteChannel(socket, createSSLEngine(uri));
write(tls, "CONNECT {}\r\n");
ByteBuffer oneByte = ByteBuffer.allocate(1);
assertEquals(1, tls.read(oneByte));
assertEquals(1, oneByte.position());
assertEquals((byte)'+', oneByte.get(0)); // got 0?
oneByte.clear();
assertEquals(1, tls.read(oneByte));
assertEquals((byte)'O', oneByte.get(0));
oneByte.clear();
assertEquals(1, tls.read(oneByte));
assertEquals((byte)'K', oneByte.get(0));
oneByte.clear();
// Follow up with a larger buffer read,
// ...to ensure that we don't block on
// a net read:
info.clear();
int result = tls.read(info);
assertEquals(2, result);
assertEquals(2, info.position());
assertEquals((byte)'\r', info.get(0));
assertEquals((byte)'\n', info.get(1));
oneByte.clear();
assertTrue(tls.isOpen());
assertTrue(socket.isOpen());
tls.close();
assertFalse(tls.isOpen());
assertFalse(socket.isOpen());
}
}
@Test
public void testImmediateClose() throws Exception {
// Scenario: Net read TLS frame which is larger than the read buffer.
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) {
URI uri = new URI(ts.getURI());
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2));
ByteBuffer info = ByteBuffer.allocate(1024 * 1024);
socket.read(info);
TLSByteChannel tls = new TLSByteChannel(socket, createSSLEngine(uri));
assertTrue(tls.isOpen());
assertTrue(socket.isOpen());
tls.close();
assertFalse(tls.isOpen());
assertFalse(socket.isOpen());
}
}
@Test
public void testRenegotiation() throws Exception {
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) {
URI uri = new URI(ts.getURI());
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2));
ByteBuffer readBuffer = ByteBuffer.allocate(1024 * 1024);
socket.read(readBuffer);
SSLEngine sslEngine = createSSLEngine(uri);
TLSByteChannel tls = new TLSByteChannel(socket, sslEngine);
write(tls, "CONNECT {}\r\n");
readBuffer.clear();
tls.read(readBuffer);
readBuffer.flip();
assertEquals(ByteBuffer.wrap("+OK\r\n".getBytes(UTF_8)), readBuffer);
// Now force a renegotiation:
sslEngine.getSession().invalidate();
sslEngine.beginHandshake();
// nats-server doesn't support renegotion, we just get this error:
// javax.net.ssl.SSLException: Received fatal alert: unexpected_message
assertThrows(SSLException.class,
() -> tls.write(new ByteBuffer[]{ByteBuffer.wrap("PING\r\n".getBytes(UTF_8))}));
}
}
@Test
public void testConcurrentHandshake() throws Exception {
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) {
URI uri = new URI(ts.getURI());
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2));
ByteBuffer readBuffer = ByteBuffer.allocate(1024 * 1024);
socket.read(readBuffer);
int numThreads = 10;
SSLEngine sslEngine = createSSLEngine(uri);
TLSByteChannel tls = new TLSByteChannel(socket, sslEngine);
CountDownLatch threadsReady = new CountDownLatch(numThreads);
CountDownLatch startLatch = new CountDownLatch(1);
ExecutorService executor = Executors.newFixedThreadPool(numThreads);
Future<Void>[] futures = new Future[numThreads];
for (int i = 0; i < 10; i++) {
boolean isRead = i % 2 == 0;
futures[i] = executor.submit(() -> {
threadsReady.countDown();
startLatch.await();
if (isRead) {
tls.read(EMPTY);
} else {
tls.write(EMPTY);
}
return null;
});
}
threadsReady.await();
startLatch.countDown();
// Make sure no exception happend on any thread:
for (int i=0; i < 10; i++) {
futures[i].get();
}
write(tls, "CONNECT {}\r\n");
readBuffer.clear();
tls.read(readBuffer);
readBuffer.flip();
assertEquals(ByteBuffer.wrap("+OK\r\n".getBytes(UTF_8)), readBuffer);
tls.close();
}
}
@Test
public void testConcurrentClose() throws Exception {
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) {
URI uri = new URI(ts.getURI());
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2));
ByteBuffer readBuffer = ByteBuffer.allocate(1024 * 1024);
socket.read(readBuffer);
int numThreads = 10;
SSLEngine sslEngine = createSSLEngine(uri);
TLSByteChannel tls = new TLSByteChannel(socket, sslEngine);
tls.handshake();
CountDownLatch threadsReady = new CountDownLatch(numThreads);
CountDownLatch startLatch = new CountDownLatch(1);
ExecutorService executor = Executors.newFixedThreadPool(numThreads);
Future<Void>[] futures = new Future[numThreads];
for (int i = 0; i < 10; i++) {
futures[i] = executor.submit(() -> {
threadsReady.countDown();
startLatch.await();
tls.close();
return null;
});
}
threadsReady.await();
startLatch.countDown();
// Make sure no exception happend on any thread:
for (int i=0; i < 10; i++) {
futures[i].get();
}
}
}
@Test
public void testShortNetRead() throws Exception {
// Scenario: Net read TLS frame which is larger than the read buffer.
try (NatsTestServer ts = new NatsTestServer("src/test/resources/tls.conf", false)) {
URI uri = new URI(ts.getURI());
NatsChannel socket = SocketNatsChannel.factory().connect(uri, Duration.ofSeconds(2));
AtomicBoolean readOneByteAtATime = new AtomicBoolean(true);
NatsChannel wrapper = new AbstractNatsChannel(socket) {
ByteBuffer readBuffer = ByteBuffer.allocate(1);
@Override
public int read(ByteBuffer dst) throws IOException {
if (!readOneByteAtATime.get()) {
return socket.read(dst);
}
readBuffer.clear();
int result = socket.read(readBuffer);
if (result <= 0) {
return result;
}
readBuffer.flip();
dst.put(readBuffer);
return result;
}
@Override
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
return socket.write(srcs, offset, length);
}
@Override
public boolean isSecure() {
return false;
}
@Override
public String transformConnectUrl(String connectUrl) {
return connectUrl;
}
};
ByteBuffer info = ByteBuffer.allocate(1024 * 1024);
socket.read(info);
TLSByteChannel tls = new TLSByteChannel(wrapper, createSSLEngine(uri));
// Peform handshake:
tls.read(ByteBuffer.allocate(0));
// Send connect & ping, but turn off one-byte at a time for readint PONG:
readOneByteAtATime.set(false);
write(tls, "CONNECT {}\r\nPING\r\n");
info.clear();
tls.read(info);
info.flip();
assertEquals(
ByteBuffer.wrap(
"+OK\r\nPONG\r\n"
.getBytes(UTF_8)),
info);
tls.close();
}
}
private static SSLEngine createSSLEngine(URI uri) throws Exception {
SSLContext ctx = TestSSLUtils.createTestSSLContext();
SSLEngine engine = ctx.createSSLEngine(uri.getHost(), uri.getPort());
engine.setUseClientMode(true);
return engine;
}
private static void write(ByteChannel channel, String str) throws IOException {
channel.write(ByteBuffer.wrap(str.getBytes(UTF_8)));
}
}
@brimworks
Copy link
Author

Note that the tests depend on io.nats:jnats-server-runner:1.0.7 and a few other things not shown.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment