Skip to content

Instantly share code, notes, and snippets.

@anuraaga
Last active April 21, 2016 17:21
Show Gist options
  • Save anuraaga/a1ee7f2489718a5462a90c31afbfe01b to your computer and use it in GitHub Desktop.
Save anuraaga/a1ee7f2489718a5462a90c31afbfe01b to your computer and use it in GitHub Desktop.
armeria client-side balancer example
/*
* Copyright 2016 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.line.loadbalancer;
import static com.linecorp.armeria.client.metrics.MetricCollectingClient.newDropwizardDecorator;
import static java.util.Objects.requireNonNull;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Proxy;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.StreamSupport;
import javax.annotation.concurrent.NotThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.codahale.metrics.MetricRegistry;
import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.UnsignedLongs;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import com.linecorp.armeria.client.ClientBuilder;
import com.linecorp.armeria.client.Clients;
import com.linecorp.armeria.client.RemoteInvokerFactory;
import com.linecorp.armeria.client.http.SimpleHttpClient;
import com.linecorp.armeria.client.http.SimpleHttpRequest;
import com.linecorp.armeria.client.http.SimpleHttpRequestBuilder;
import com.linecorp.armeria.client.http.SimpleHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.concurrent.Future;
import jp.skypencil.guava.stream.GuavaCollectors;
import lombok.Value;
/**
* An armeria client that does round-robin load balancing across a set of servers.
*/
// Some write operations happen while old references are read. This is fine since
// it is intended for old state to be valid for several seconds, and the updates
// will be read on the next operation, well before that time.
@NotThreadSafe
public class LoadBalancingClient<T> {
private static final Logger logger = LoggerFactory.getLogger(LoadBalancingClient.class);
@Value
static class ServerConnection<T> {
String address;
T client;
SimpleHttpClient healthCheckClient;
}
private final RemoteInvokerFactory remoteInvokerFactory;
private final Class<T> interfaceClass;
private final String addressScheme;
private final String serviceName;
private final MetricRegistry metricRegistry;
private final String servicePath;
private final SimpleHttpRequest healthCheckRequest;
private final LoadBalancerStatistics loadBalancerStatistics;
volatile List<ServerConnection<T>> allServers = ImmutableList.of();
volatile List<ServerConnection<T>> upServers = ImmutableList.of();
private final AtomicLong counter = new AtomicLong();
private volatile Instant lastHealthCheckTime;
private volatile Instant lastServerListUpdateTime;
/**
* Constructs an armeria client that uses client-side load balancing.
*
* @param remoteInvokerFactory the client event loop to use.
* @param interfaceClass the interface class of the rpc client to construct.
* @param addressScheme the rpc scheme to use (e.g., tbinary+http, none+http).
* @param serviceName the service name to use in monitoring.
* @param metricRegistry the {@link MetricRegistry} to export stats to.
* @param servicePath the url path the service is hosted on (e.g., /thrift).
* @param healthCheckPath the url path health checking is hosted on (e.g., /monitor/l7check).
* @param serverAddresses hosts which this client should load balance over.
*/
public LoadBalancingClient(RemoteInvokerFactory remoteInvokerFactory,
Class<T> interfaceClass,
String addressScheme,
String serviceName,
MetricRegistry metricRegistry,
String servicePath,
String healthCheckPath,
Iterable<String> serverAddresses) {
requireNonNull(healthCheckPath, "healthCheckPath");
requireNonNull(serverAddresses, "serverAddresses");
this.remoteInvokerFactory = requireNonNull(remoteInvokerFactory, "remoteInvokerFactory");
this.interfaceClass = requireNonNull(interfaceClass, "interfaceClass");
this.addressScheme = requireNonNull(addressScheme, "addressScheme");
this.serviceName = requireNonNull(serviceName, "serviceName");
this.metricRegistry = requireNonNull(metricRegistry, "metricRegistry");
this.servicePath = requireNonNull(servicePath, "servicePath");
healthCheckRequest = SimpleHttpRequestBuilder.forGet(healthCheckPath).build();
updateServerList(serverAddresses);
try {
checkAndUpdateUpServers().get();
} catch (InterruptedException | ExecutionException e) {
throw new IllegalStateException("Could not succeed with initial health check for " + serviceName,
e);
}
remoteInvokerFactory.eventLoopGroup().scheduleWithFixedDelay(this::checkAndUpdateUpServers,
/* initialDelay */ 3,
/* delay */ 3,
TimeUnit.SECONDS);
loadBalancerStatistics = new LoadBalancerStatistics(this);
}
/**
* Returns the client interface backed by this load balancer.
*/
public T get() {
InvocationHandler handler = (proxy, method, args) -> {
T client = getNextConnection().getClient();
try {
return method.invoke(client, args);
} catch (InvocationTargetException e) {
throw MoreObjects.firstNonNull(e.getCause(), e);
}
};
@SuppressWarnings("unchecked")
T proxy = (T) Proxy.newProxyInstance(
interfaceClass.getClassLoader(), new Class[] { interfaceClass }, handler);
return proxy;
}
/**
* Update the servers this load balancing client talks to. It is useful to call this from a file
* watcher (e.g., zookeeper watcher) that allows updating the server list without restarting the
* binary.
*/
public void updateServerList(Iterable<String> serverAddresses) {
lastServerListUpdateTime = Instant.now();
Map<String, ServerConnection<T>> allServersByAddress = allServers.stream().collect(
GuavaCollectors.toImmutableMap(ServerConnection::getAddress, Function.identity()));
List<ServerConnection<T>> newAllServers =
StreamSupport.stream(serverAddresses.spliterator(), false)
.map(server -> allServersByAddress.getOrDefault(server, createConnection(server)))
.collect(GuavaCollectors.toImmutableList());
allServers = newAllServers;
}
private ServerConnection<T> getNextConnection() {
List<ServerConnection<T>> up = upServers;
if (up.isEmpty()) {
throw new IllegalStateException("No servers are up, cannot issue request.");
}
long next = counter.incrementAndGet();
int serverIndex = (int) UnsignedLongs.remainder(next, up.size());
return up.get(serverIndex);
}
private ListenableFuture<List<SimpleHttpResponse>> checkAndUpdateUpServers() {
lastHealthCheckTime = Instant.now();
List<ServerConnection<T>> checkedServers = allServers;
ListenableFuture<List<SimpleHttpResponse>> healthCheckResults = Futures.successfulAsList(
checkedServers.stream()
.map(this::checkServerHealth)
.collect(GuavaCollectors.toImmutableList()));
Futures.addCallback(healthCheckResults, new FutureCallback<List<SimpleHttpResponse>>() {
@Override
public void onSuccess(List<SimpleHttpResponse> result) {
ImmutableList.Builder<ServerConnection<T>> healthyServers = ImmutableList.builder();
for (int i = 0; i < result.size(); i++) {
if (result.get(i) != null && result.get(i).status().equals(HttpResponseStatus.OK)) {
healthyServers.add(checkedServers.get(i));
}
}
upServers = healthyServers.build();
}
@Override
public void onFailure(Throwable t) {
// Shouldn't happen since we use Futures.successfulAsList.
throw new IllegalStateException("Exception when checking health.", t);
}
});
return healthCheckResults;
}
private ListenableFuture<SimpleHttpResponse> checkServerHealth(ServerConnection<T> serverConnection) {
return toListenableFuture(serverConnection.getHealthCheckClient().execute(healthCheckRequest));
}
private ServerConnection<T> createConnection(String server) {
T client = new ClientBuilder(addressScheme + "://" + server + servicePath)
.remoteInvokerFactory(remoteInvokerFactory)
.decorator(newDropwizardDecorator(serviceName, metricRegistry))
.build(interfaceClass);
SimpleHttpClient healthCheckClient = client instanceof SimpleHttpClient ? (SimpleHttpClient) client :
Clients.newClient(remoteInvokerFactory,
"none+http://" + server,
SimpleHttpClient.class);
return new ServerConnection<>(server, client, healthCheckClient);
}
private static <T> ListenableFuture<T> toListenableFuture(Future<T> nettyFuture) {
SettableFuture<T> future = SettableFuture.create();
nettyFuture.addListener(f -> {
if (f.isSuccess()) {
future.set((T) f.getNow());
} else {
future.setException(f.cause());
}
});
return future;
}
@Override
public String toString() {
StringBuilder buf = new StringBuilder("LoadBalancingClient<" + serviceName + "> [\n")
.append(" ").append(upServers.size()).append(" servers up out of ").append(allServers.size())
.append(".\n")
.append(" Last health check: ").append(lastHealthCheckTime).append('\n')
.append(" Last server list update: ").append(lastServerListUpdateTime).append('\n')
.append("Server status: ").append(loadBalancerStatistics.getAttributeMap());
return buf.toString();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment