Skip to content

Instantly share code, notes, and snippets.

@mstefaniuk
Created February 11, 2016 14:23
Show Gist options
  • Save mstefaniuk/18c31ed661dd46478834 to your computer and use it in GitHub Desktop.
Save mstefaniuk/18c31ed661dd46478834 to your computer and use it in GitHub Desktop.
RabbitMQ STOMP over WebSocket load test
package stomp;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.springframework.http.HttpStatus;
import org.springframework.messaging.converter.StringMessageConverter;
import org.springframework.messaging.simp.stomp.*;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.messaging.WebSocketStompClient;
import org.springframework.web.socket.sockjs.client.SockJsClient;
import org.springframework.web.socket.sockjs.client.Transport;
import org.springframework.web.socket.sockjs.client.WebSocketTransport;
import org.springframework.web.util.UriComponentsBuilder;
import java.lang.reflect.Type;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.fail;
public class LoadTest {
private static Log logger = LogFactory.getLog(LoadTest.class);
private static final int NUMBER_OF_USERS = 20;
private static final int BROADCAST_MESSAGE_COUNT = 2000;
public static void main(String[] args) throws Exception {
String host = "localhost";
int port = 15674;
String homeUrl = "http://{host}:{port}/stomp";
logger.debug("Sending warm-up HTTP request to " + homeUrl);
HttpStatus status = new RestTemplate().getForEntity(homeUrl, Void.class, host, port).getStatusCode();
Assert.state(status == HttpStatus.OK);
final CountDownLatch connectLatch = new CountDownLatch(NUMBER_OF_USERS);
final CountDownLatch subscribeLatch = new CountDownLatch(NUMBER_OF_USERS);
final CountDownLatch messageLatch = new CountDownLatch(NUMBER_OF_USERS);
final CountDownLatch disconnectLatch = new CountDownLatch(NUMBER_OF_USERS);
final AtomicReference<Throwable> failure = new AtomicReference<>();
StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
HttpClient jettyHttpClient = new HttpClient();
jettyHttpClient.setMaxConnectionsPerDestination(1000);
jettyHttpClient.setExecutor(new QueuedThreadPool(1000));
jettyHttpClient.start();
List<Transport> transports = new ArrayList<>();
transports.add(new WebSocketTransport(webSocketClient));
SockJsClient sockJsClient = new SockJsClient(transports);
try {
ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler();
taskScheduler.afterPropertiesSet();
String stompUrl = "ws://{host}:{port}/stomp";
URI uri = UriComponentsBuilder.fromUriString(stompUrl).buildAndExpand(host, port).encode().toUri();
WebSocketStompClient stompClient = new WebSocketStompClient(sockJsClient);
stompClient.setMessageConverter(new StringMessageConverter());
stompClient.setTaskScheduler(taskScheduler);
stompClient.setDefaultHeartbeat(new long[] {0, 0});
StompHeaders headers = new StompHeaders();
headers.add(StompHeaders.HOST, "shotgun");
headers.add(StompHeaders.LOGIN, "shotgun");
headers.add(StompHeaders.PASSCODE, "shotgun");
logger.debug("Connecting and subscribing " + NUMBER_OF_USERS + " users ");
StopWatch stopWatch = new StopWatch("STOMP Broker WebSocket Load Tests");
stopWatch.start();
List<ConsumerStompSessionHandler> consumers = new ArrayList<>();
for (int i=0; i < NUMBER_OF_USERS; i++) {
consumers.add(new ConsumerStompSessionHandler(BROADCAST_MESSAGE_COUNT, connectLatch,
subscribeLatch, messageLatch, disconnectLatch, failure));
stompClient.connect(uri, null, headers, consumers.get(i));
}
if (failure.get() != null) {
throw new AssertionError("Test failed", failure.get());
}
if (!connectLatch.await(50000, TimeUnit.MILLISECONDS)) {
fail("Not all users connected, remaining: " + connectLatch.getCount());
}
if (!subscribeLatch.await(50000, TimeUnit.MILLISECONDS)) {
fail("Not all users subscribed, remaining: " + subscribeLatch.getCount());
}
stopWatch.stop();
logger.debug("Finished: " + stopWatch.getLastTaskTimeMillis() + " millis");
logger.debug("Broadcasting " + BROADCAST_MESSAGE_COUNT + " messages to " + NUMBER_OF_USERS + " users ");
stopWatch.start();
ProducerStompSessionHandler producer = new ProducerStompSessionHandler(BROADCAST_MESSAGE_COUNT, failure);
stompClient.connect(uri, null, headers, producer);
stompClient.setTaskScheduler(taskScheduler);
if (failure.get() != null) {
throw new AssertionError("Test failed", failure.get());
}
if (!messageLatch.await(60 * 1000, TimeUnit.MILLISECONDS)) {
for (ConsumerStompSessionHandler consumer : consumers) {
if (consumer.messageCount.get() < consumer.expectedMessageCount) {
logger.debug(consumer);
}
}
}
if (!messageLatch.await(60 * 1000, TimeUnit.MILLISECONDS)) {
fail("Not all handlers received every message, remaining: " + messageLatch.getCount());
}
producer.session.disconnect();
if (!disconnectLatch.await(5000, TimeUnit.MILLISECONDS)) {
fail("Not all disconnects completed, remaining: " + disconnectLatch.getCount());
}
stopWatch.stop();
logger.debug("Finished: " + stopWatch.getLastTaskTimeMillis() + " millis");
System.out.println("\nPress any key to exit...");
System.in.read();
}
catch (Throwable t) {
t.printStackTrace();
}
finally {
jettyHttpClient.stop();
}
logger.debug("Exiting");
System.exit(0);
}
private static class ConsumerStompSessionHandler extends StompSessionHandlerAdapter {
private final int expectedMessageCount;
private final CountDownLatch connectLatch;
private final CountDownLatch subscribeLatch;
private final CountDownLatch messageLatch;
private final CountDownLatch disconnectLatch;
private final AtomicReference<Throwable> failure;
private AtomicInteger messageCount = new AtomicInteger(0);
public ConsumerStompSessionHandler(int expectedMessageCount, CountDownLatch connectLatch,
CountDownLatch subscribeLatch, CountDownLatch messageLatch, CountDownLatch disconnectLatch,
AtomicReference<Throwable> failure) {
this.expectedMessageCount = expectedMessageCount;
this.connectLatch = connectLatch;
this.subscribeLatch = subscribeLatch;
this.messageLatch = messageLatch;
this.disconnectLatch = disconnectLatch;
this.failure = failure;
}
@Override
public void afterConnected(final StompSession session, StompHeaders connectedHeaders) {
this.connectLatch.countDown();
session.setAutoReceipt(true);
session.subscribe("/topic/greeting", new StompFrameHandler() {
@Override
public Type getPayloadType(StompHeaders headers) {
return String.class;
}
@Override
public void handleFrame(StompHeaders headers, Object payload) {
if (messageCount.incrementAndGet() == expectedMessageCount) {
messageLatch.countDown();
disconnectLatch.countDown();
session.disconnect();
}
}
}).addReceiptTask(new Runnable() {
@Override
public void run() {
subscribeLatch.countDown();
}
});
}
@Override
public void handleTransportError(StompSession session, Throwable exception) {
logger.error("Transport error", exception);
this.failure.set(exception);
if (exception instanceof ConnectionLostException) {
this.disconnectLatch.countDown();
}
}
@Override
public void handleException(StompSession s, StompCommand c, StompHeaders h, byte[] p, Throwable ex) {
logger.error("Handling exception", ex);
this.failure.set(ex);
}
@Override
public void handleFrame(StompHeaders headers, Object payload) {
Exception ex = new Exception(headers.toString());
logger.error("STOMP ERROR frame", ex);
this.failure.set(ex);
}
@Override
public String toString() {
return "ConsumerStompSessionHandler[messageCount=" + this.messageCount + "]";
}
}
private static class ProducerStompSessionHandler extends StompSessionHandlerAdapter {
private final int numberOfMessagesToBroadcast;
private final AtomicReference<Throwable> failure;
private StompSession session;
public ProducerStompSessionHandler(int numberOfMessagesToBroadcast, AtomicReference<Throwable> failure) {
this.numberOfMessagesToBroadcast = numberOfMessagesToBroadcast;
this.failure = failure;
}
@Override
public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
this.session = session;
int i =0;
try {
for ( ; i < this.numberOfMessagesToBroadcast; i++) {
session.send("/topic/greeting", "hello");
logger.debug("Sending " + i);
}
}
catch (Throwable t) {
logger.error("Message sending failed at " + i, t);
failure.set(t);
}
}
@Override
public void handleTransportError(StompSession session, Throwable exception) {
logger.error("Transport error", exception);
this.failure.set(exception);
}
@Override
public void handleException(StompSession s, StompCommand c, StompHeaders h, byte[] p, Throwable ex) {
logger.error("Handling exception", ex);
this.failure.set(ex);
}
@Override
public void handleFrame(StompHeaders headers, Object payload) {
Exception ex = new Exception(headers.toString());
logger.error("STOMP ERROR frame", ex);
this.failure.set(ex);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment