Skip to content

Instantly share code, notes, and snippets.

@daschl
Last active August 29, 2015 14:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save daschl/70a82ee72fd3070af917 to your computer and use it in GitHub Desktop.
Save daschl/70a82ee72fd3070af917 to your computer and use it in GitHub Desktop.
final ChannelHandler endpointHandler = new GenericEndpointHandler(this, responseBuffer);
bootstrap = new BootstrapAdapter(new Bootstrap()
.remoteAddress(hostname, port())
.group(environment.ioPool())
.channel(NioSocketChannel.class)
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
.option(ChannelOption.TCP_NODELAY, false)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel channel) throws Exception {
ChannelPipeline pipeline = channel.pipeline();
if (LOGGER.isTraceEnabled()) {
pipeline.addLast(LOGGING_HANDLER_INSTANCE);
}
customEndpointHandlers(pipeline);
pipeline.addLast(endpointHandler);
}
}));
/**
* Copyright (C) 2014 Couchbase, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALING
* IN THE SOFTWARE.
*/
package com.couchbase.client.core.endpoint.binary;
import com.couchbase.client.core.endpoint.AbstractEndpoint;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.memcache.binary.BinaryMemcacheResponseStatus;
import io.netty.handler.codec.memcache.binary.DefaultBinaryMemcacheRequest;
import io.netty.handler.codec.memcache.binary.DefaultFullBinaryMemcacheRequest;
import io.netty.handler.codec.memcache.binary.FullBinaryMemcacheRequest;
import io.netty.handler.codec.memcache.binary.FullBinaryMemcacheResponse;
import io.netty.util.CharsetUtil;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import java.io.IOException;
/**
* A SASL Client which communicates through the memcache binary protocol.
*
* @author Michael Nitschinger
* @since 1.0
*/
public class BinarySaslClient
extends SimpleChannelInboundHandler<FullBinaryMemcacheResponse>
implements CallbackHandler {
/**
* The memcache opcode for the SASL mechs list.
*/
private static final byte SASL_LIST_MECHS_OPCODE = 0x20;
/**
* The memcache opcode for the SASL initial auth.
*/
private static final byte SASL_AUTH_OPCODE = 0x21;
/**
* The memcache opcode for the SASL consecutive steps.
*/
private static final byte SASL_STEP_OPCODE = 0x22;
/**
* The server response indicating SASL auth success.
*/
private static final byte SASL_AUTH_SUCCESS = BinaryMemcacheResponseStatus.SUCCESS;
/**
* The server response indicating SASL auth failure.
*/
private static final byte SASL_AUTH_FAILURE = BinaryMemcacheResponseStatus.AUTH_ERROR;
/**
* The username to auth against.
*/
private final String username;
/**
* The password of the user.
*/
private final String password;
private final AbstractEndpoint endpoint;
/**
* The handler context.
*/
private ChannelHandlerContext ctx;
/**
* The JVM {@link SaslClient} that handles the actual authentication process.
*/
private SaslClient saslClient;
/**
* Contains the selected SASL auth mechanism once decided.
*/
private String selectedMechanism;
/**
* Creates a new {@link BinarySaslClient}.
*
* @param username the name of the user/bucket.
* @param password the password associated with the user/bucket.
*/
public BinarySaslClient(String username, String password, AbstractEndpoint endpoint) {
this.username = username;
this.password = password == null ? "" : password;
this.endpoint = endpoint;
}
/**
* Once the channel is marked as active, the SASL negotiation is started.
*
* @param ctx the handler context.
* @throws Exception if something goes wrong during negotiation.
*/
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
negotiate();
}
/**
* Helper method to kick off the negotiation process.
*
* The first request against the server asks for a list of supported mechanisms.
*/
private void negotiate() {
ctx.writeAndFlush(new DefaultBinaryMemcacheRequest().setOpcode(SASL_LIST_MECHS_OPCODE));
}
/**
* Callback handler needed for the {@link SaslClient} which supplies username and password.
*
* @param callbacks the possible callbacks.
* @throws IOException
* @throws UnsupportedCallbackException
*/
@Override
public void handle(final Callback[] callbacks) throws IOException, UnsupportedCallbackException {
for (Callback callback : callbacks) {
if (callback instanceof NameCallback) {
((NameCallback) callback).setName(username);
} else if (callback instanceof PasswordCallback) {
((PasswordCallback) callback).setPassword(password.toCharArray());
} else {
throw new IllegalStateException("SASLClient requested unsupported callback: " + callback);
}
}
}
/**
* Dispatches incoming SASL responses to the appropriate handler methods.
*
* @param ctx the handler context.
* @param msg the incoming message to investigate.
* @throws Exception
*/
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullBinaryMemcacheResponse msg) throws Exception {
if (msg.getOpcode() == SASL_LIST_MECHS_OPCODE) {
handleListMechsResponse(ctx, msg);
} else if (msg.getOpcode() == SASL_AUTH_OPCODE) {
handleAuthResponse(ctx, msg);
} else if (msg.getOpcode() == SASL_STEP_OPCODE) {
checkIsAuthed(msg);
}
}
/**
* Handles an incoming SASL list mechanisms response and dispatches the next SASL AUTH step.
*
* @param ctx the handler context.
* @param msg the incoming message to investigate.
* @throws Exception
*/
private void handleListMechsResponse(ChannelHandlerContext ctx, FullBinaryMemcacheResponse msg) throws Exception {
String remote = ctx.channel().remoteAddress().toString();
String[] supportedMechanisms = msg.content().toString(CharsetUtil.UTF_8).split(" ");
if (supportedMechanisms.length == 0) {
throw new IllegalStateException("Received empty SASL mechanisms list from server: " + remote);
}
saslClient = Sasl.createSaslClient(supportedMechanisms, null, "couchbase", remote, null, this);
selectedMechanism = saslClient.getMechanismName();
int mechanismLength = selectedMechanism.length();
byte[] bytePayload = saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[]{}) : null;
ByteBuf payload = bytePayload != null ? ctx.alloc().buffer().writeBytes(bytePayload) : Unpooled.EMPTY_BUFFER;
FullBinaryMemcacheRequest initialRequest = new DefaultFullBinaryMemcacheRequest(
selectedMechanism,
Unpooled.EMPTY_BUFFER,
payload
);
initialRequest
.setOpcode(SASL_AUTH_OPCODE)
.setKeyLength((short) mechanismLength)
.setTotalBodyLength(mechanismLength + payload.readableBytes());
ctx.writeAndFlush(initialRequest);
}
/**
* Handles an incoming SASL AUTH response and - if needed - dispatches the SASL STEPs.
*
* @param ctx the handler context.
* @param msg the incoming message to investigate.
* @throws Exception
*/
private void handleAuthResponse(ChannelHandlerContext ctx, FullBinaryMemcacheResponse msg) throws Exception {
if (saslClient.isComplete()) {
checkIsAuthed(msg);
return;
}
byte[] response = new byte[msg.content().readableBytes()];
msg.content().readBytes(response);
byte[] evaluatedBytes = saslClient.evaluateChallenge(response);
if (evaluatedBytes != null) {
String[] evaluated = new String(evaluatedBytes).split(" ");
ByteBuf content = Unpooled.copiedBuffer(username + "\0" + evaluated[1], CharsetUtil.UTF_8);
FullBinaryMemcacheRequest stepRequest = new DefaultFullBinaryMemcacheRequest(
selectedMechanism,
Unpooled.EMPTY_BUFFER,
content
);
stepRequest
.setOpcode(SASL_STEP_OPCODE)
.setKeyLength((short) selectedMechanism.length())
.setTotalBodyLength(content.readableBytes() + selectedMechanism.length());
ctx.writeAndFlush(stepRequest);
} else {
throw new IllegalStateException("SASL Challenge evaluation returned null.");
}
}
/**
* Once authentication is completed, check the response and react appropriately to the upper layers.
*
* @param msg the incoming message to investigate.
*/
private void checkIsAuthed(final FullBinaryMemcacheResponse msg) {
switch (msg.getStatus()) {
case SASL_AUTH_SUCCESS:
endpoint.notifyChannelAuthSuccess();
ctx.pipeline().remove(this);
break;
case SASL_AUTH_FAILURE:
endpoint.notifyChannelAuthFailure();
break;
default:
throw new IllegalStateException("Unhandled SASL auth status: " + msg.getStatus());
}
}
}
package com.couchbase.client.core.endpoint;
import com.couchbase.client.core.cluster.ResponseEvent;
import com.couchbase.client.core.env.Environment;
import com.couchbase.client.core.message.CouchbaseRequest;
import com.couchbase.client.core.message.internal.SignalFlush;
import com.couchbase.client.core.state.AbstractStateMachine;
import com.couchbase.client.core.state.LifecycleState;
import com.couchbase.client.core.state.NotConnectedException;
import com.lmax.disruptor.RingBuffer;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.subjects.AsyncSubject;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
public abstract class AbstractEndpoint extends AbstractStateMachine<LifecycleState> implements Endpoint {
/**
* The logger to use for all endpoints.
*/
private static final Logger LOGGER = LoggerFactory.getLogger(Endpoint.class);
/**
* A shared logging handler for all endpoints.
*/
private static final ChannelHandler LOGGING_HANDLER_INSTANCE = new DebugLoggingHandler(LogLevel.TRACE);
/**
* Precreated not connected exception for performance reasons.
*/
private static final NotConnectedException NOT_CONNECTED_EXCEPTION = new NotConnectedException();
/**
* The netty bootstrap adapter.
*/
private final BootstrapAdapter bootstrap;
/**
* The reconnect delay that increases with backoff.
*/
private final AtomicLong reconnectDelay = new AtomicLong(0);
/**
* The underlying IO (netty) channel.
*/
private volatile Channel channel;
private volatile boolean hasWritten;
/**
* Preset the stack trace for the static exceptions.
*/
static {
NOT_CONNECTED_EXCEPTION.setStackTrace(new StackTraceElement[0]);
}
protected AbstractEndpoint(BootstrapAdapter adapter) {
super(LifecycleState.DISCONNECTED);
bootstrap = adapter;
}
protected AbstractEndpoint(String hostname, Environment environment, final RingBuffer<ResponseEvent> responseBuffer) {
super(LifecycleState.DISCONNECTED);
final ChannelHandler endpointHandler = new GenericEndpointHandler(this, responseBuffer);
bootstrap = new BootstrapAdapter(new Bootstrap()
.remoteAddress(hostname, port())
.group(environment.ioPool())
.channel(NioSocketChannel.class)
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
.option(ChannelOption.TCP_NODELAY, false)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel channel) throws Exception {
ChannelPipeline pipeline = channel.pipeline();
if (LOGGER.isTraceEnabled()) {
pipeline.addLast(LOGGING_HANDLER_INSTANCE);
}
customEndpointHandlers(pipeline);
pipeline.addLast(endpointHandler);
}
}));
}
/**
* Returns the port of the endpoint.
*
* @return the port of the endpoint.
*/
protected abstract int port();
/**
* Add custom endpoint handlers to the {@link ChannelPipeline}.
*
* @param pipeline the pipeline where to add handlers.
*/
protected abstract void customEndpointHandlers(ChannelPipeline pipeline);
@Override
public Observable<LifecycleState> connect() {
if (state() != LifecycleState.DISCONNECTED) {
return Observable.from(state());
}
final AsyncSubject<LifecycleState> observable = AsyncSubject.create();
transitionState(LifecycleState.CONNECTING);
bootstrap.connect().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(final ChannelFuture future) throws Exception {
if (state() == LifecycleState.DISCONNECTING || state() == LifecycleState.DISCONNECTED) {
LOGGER.debug("Endpoint connect completed, but got instructed to disconnect in the meantime.");
transitionState(LifecycleState.DISCONNECTED);
channel = null;
} else {
if (future.isSuccess()) {
channel = future.channel();
LOGGER.debug("Successfully connected to Endpoint " + channel.remoteAddress());
} else {
long delay = reconnectDelay();
LOGGER.warn("Could not connect to endpoint, retrying with delay " + delay + "ms: "
+ future.channel().remoteAddress(), future.cause());
transitionState(LifecycleState.CONNECTING);
future.channel().eventLoop().schedule(new Runnable() {
@Override
public void run() {
connect();
}
}, delay, TimeUnit.MILLISECONDS);
}
}
observable.onNext(state());
observable.onCompleted();
}
});
return observable;
}
@Override
public Observable<LifecycleState> disconnect() {
if (state() == LifecycleState.DISCONNECTED || state() == LifecycleState.DISCONNECTING) {
return Observable.from(state());
}
if (state() == LifecycleState.CONNECTING) {
transitionState(LifecycleState.DISCONNECTED);
return Observable.from(state());
}
transitionState(LifecycleState.DISCONNECTING);
final AsyncSubject<LifecycleState> observable = AsyncSubject.create();
channel.disconnect().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(final ChannelFuture future) throws Exception {
if (future.isSuccess()) {
LOGGER.debug("Successfully disconnected from Endpoint " + channel.remoteAddress());
} else {
LOGGER.warn("Received an error during disconnect.", future.cause());
}
transitionState(LifecycleState.DISCONNECTED);
observable.onNext(state());
observable.onCompleted();
channel = null;
}
});
return observable;
}
@Override
public void send(final CouchbaseRequest request) {
if (state() == LifecycleState.CONNECTED) {
if (request instanceof SignalFlush) {
if (hasWritten) {
channel.flush();
hasWritten = false;
}
} else {
channel.write(request).addListener(new GenericFutureListener<Future<Void>>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
if (!future.isSuccess()) {
request.observable().onError(future.cause());
}
}
});
hasWritten = true;
}
} else {
if (request instanceof SignalFlush) {
return;
}
request.observable().onError(NOT_CONNECTED_EXCEPTION);
}
}
/**
* Helper method that is called inside from the event loop to notify the upper endpoint of a disconnect.
*
* Subsequent reconnect attempts are triggered from here.
*/
public void notifyChannelInactive() {
transitionState(LifecycleState.DISCONNECTED);
connect();
}
public void notifyChannelAuthSuccess() {
transitionState(LifecycleState.CONNECTED);
LOGGER.debug("Successfully authenticated to Endpoint " + channel.remoteAddress());
}
public void notifyChannelAuthFailure() {
transitionState(LifecycleState.DISCONNECTED);
connect();
}
/**
* Returns the reconnect retry delay in Miliseconds.
*
* Currently, it uses linear backoff.
*
* @return the retry delay.
*/
private long reconnectDelay() {
return reconnectDelay.getAndIncrement();
}
}
@daschl
Copy link
Author

daschl commented May 9, 2014

io.netty.channel.ChannelPipelineException: com.couchbase.client.core.endpoint.GenericEndpointHandler is not a @sharable handler, so can't be added or removed multiple times.
at io.netty.channel.DefaultChannelPipeline.checkMultiplicity(DefaultChannelPipeline.java:560) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.DefaultChannelPipeline.addLast0(DefaultChannelPipeline.java:153) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.DefaultChannelPipeline.addLast(DefaultChannelPipeline.java:147) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.DefaultChannelPipeline.addLast(DefaultChannelPipeline.java:329) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.DefaultChannelPipeline.addLast(DefaultChannelPipeline.java:300) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at com.couchbase.client.core.endpoint.AbstractEndpoint$1.initChannel(AbstractEndpoint.java:96) ~[couchbase-jvm-core/:na]
at io.netty.channel.ChannelInitializer.channelRegistered(ChannelInitializer.java:69) ~[netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.ChannelHandlerInvokerUtil.invokeChannelRegisteredNow(ChannelHandlerInvokerUtil.java:32) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.DefaultChannelHandlerInvoker.invokeChannelRegistered(DefaultChannelHandlerInvoker.java:49) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.DefaultChannelHandlerContext.fireChannelRegistered(DefaultChannelHandlerContext.java:143) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.DefaultChannelPipeline.fireChannelRegistered(DefaultChannelPipeline.java:829) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.AbstractChannel$AbstractUnsafe.register0(AbstractChannel.java:458) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.AbstractChannel$AbstractUnsafe.access$100(AbstractChannel.java:384) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.AbstractChannel$AbstractUnsafe$1.run(AbstractChannel.java:434) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.util.concurrent.SingleThreadEventExecutor.runAllTasks(SingleThreadEventExecutor.java:333) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:353) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at io.netty.util.concurrent.SingleThreadEventExecutor$5.run(SingleThreadEventExecutor.java:824) [netty-all-4.1.0.Alpha1-SNAPSHOT.jar:4.1.0.Alpha1-SNAPSHOT]
at java.lang.Thread.run(Thread.java:744) [na:1.7.0_45]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment