Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save arleighdickerson/9997645111ef8a5a4b188f19b7e9b334 to your computer and use it in GitHub Desktop.
Save arleighdickerson/9997645111ef8a5a4b188f19b7e9b334 to your computer and use it in GitHub Desktop.
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.springframework.web.socket.handler.WebSocketSessionDecorator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
@RequiredArgsConstructor
public class SessionWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorFactory {
private final Map<WebSocketSession, HandlerWrapper> handlers = new ConcurrentHashMap<>();
@FunctionalInterface
public interface WebSocketSessionDecoratorFactory {
WebSocketSession decorateWebSocketSession(WebSocketSession webSocketSession);
}
@NonNull
private final WebSocketSessionDecoratorFactory decorator;
@Override
public WebSocketHandler decorate(WebSocketHandler handler) {
return new WebSocketHandlerDecoratorImpl(handler);
}
private HandlerWrapper getWrapper(WebSocketSession session) {
WebSocketSession key = WebSocketSessionDecorator.unwrap(session);
HandlerWrapper handlerWrapper = handlers.get(key);
Assert.state(
handlerWrapper != null,
"handlerWrapper must not be null"
);
Assert.state(
key == WebSocketSessionDecorator.unwrap(handlerWrapper.session),
"key must be unwrapped value"
);
return handlerWrapper;
}
private class WebSocketHandlerDecoratorImpl extends WebSocketHandlerDecorator {
public WebSocketHandlerDecoratorImpl(WebSocketHandler delegate) {
super(delegate);
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
WebSocketSession key = WebSocketSessionDecorator.unwrap(session);
HandlerWrapper wrapper = new HandlerWrapper(getDelegate(), decorator.decorateWebSocketSession(session));
handlers.put(key, wrapper);
wrapper.afterConnectionEstablished();
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
getWrapper(session).handleMessage(message);
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
getWrapper(session).handleTransportError(exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
HandlerWrapper handler = handlers.remove(WebSocketSessionDecorator.unwrap(session));
if (handler != null) {
handler.afterConnectionClosed(closeStatus);
} else {
log.error("could not find handler for session {}", session);
}
}
}
@RequiredArgsConstructor
private static class HandlerWrapper {
final WebSocketHandler handler;
final WebSocketSession session;
void afterConnectionEstablished() throws Exception {
handler.afterConnectionEstablished(session);
}
void handleMessage(WebSocketMessage<?> message) throws Exception {
handler.handleMessage(session, message);
}
void handleTransportError(Throwable exception) throws Exception {
handler.handleTransportError(session, exception);
}
void afterConnectionClosed(CloseStatus closeStatus) throws Exception {
handler.afterConnectionClosed(session, closeStatus);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment