Last active
April 21, 2016 17:21
-
-
Save anuraaga/a1ee7f2489718a5462a90c31afbfe01b to your computer and use it in GitHub Desktop.
armeria client-side balancer example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2016 LINE Corporation | |
* | |
* LINE Corporation licenses this file to you under the Apache License, | |
* version 2.0 (the "License"); you may not use this file except in compliance | |
* with the License. You may obtain a copy of the License at: | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
* License for the specific language governing permissions and limitations | |
* under the License. | |
*/ | |
package com.linecorp.armeria.line.loadbalancer; | |
import static com.linecorp.armeria.client.metrics.MetricCollectingClient.newDropwizardDecorator; | |
import static java.util.Objects.requireNonNull; | |
import java.lang.reflect.InvocationHandler; | |
import java.lang.reflect.InvocationTargetException; | |
import java.lang.reflect.Proxy; | |
import java.time.Instant; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.concurrent.ExecutionException; | |
import java.util.concurrent.TimeUnit; | |
import java.util.concurrent.atomic.AtomicLong; | |
import java.util.function.Function; | |
import java.util.stream.StreamSupport; | |
import javax.annotation.concurrent.NotThreadSafe; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import com.codahale.metrics.MetricRegistry; | |
import com.google.common.base.MoreObjects; | |
import com.google.common.collect.ImmutableList; | |
import com.google.common.primitives.UnsignedLongs; | |
import com.google.common.util.concurrent.FutureCallback; | |
import com.google.common.util.concurrent.Futures; | |
import com.google.common.util.concurrent.ListenableFuture; | |
import com.google.common.util.concurrent.SettableFuture; | |
import com.linecorp.armeria.client.ClientBuilder; | |
import com.linecorp.armeria.client.Clients; | |
import com.linecorp.armeria.client.RemoteInvokerFactory; | |
import com.linecorp.armeria.client.http.SimpleHttpClient; | |
import com.linecorp.armeria.client.http.SimpleHttpRequest; | |
import com.linecorp.armeria.client.http.SimpleHttpRequestBuilder; | |
import com.linecorp.armeria.client.http.SimpleHttpResponse; | |
import io.netty.handler.codec.http.HttpResponseStatus; | |
import io.netty.util.concurrent.Future; | |
import jp.skypencil.guava.stream.GuavaCollectors; | |
import lombok.Value; | |
/** | |
* An armeria client that does round-robin load balancing across a set of servers. | |
*/ | |
// Some write operations happen while old references are read. This is fine since | |
// it is intended for old state to be valid for several seconds, and the updates | |
// will be read on the next operation, well before that time. | |
@NotThreadSafe | |
public class LoadBalancingClient<T> { | |
private static final Logger logger = LoggerFactory.getLogger(LoadBalancingClient.class); | |
@Value | |
static class ServerConnection<T> { | |
String address; | |
T client; | |
SimpleHttpClient healthCheckClient; | |
} | |
private final RemoteInvokerFactory remoteInvokerFactory; | |
private final Class<T> interfaceClass; | |
private final String addressScheme; | |
private final String serviceName; | |
private final MetricRegistry metricRegistry; | |
private final String servicePath; | |
private final SimpleHttpRequest healthCheckRequest; | |
private final LoadBalancerStatistics loadBalancerStatistics; | |
volatile List<ServerConnection<T>> allServers = ImmutableList.of(); | |
volatile List<ServerConnection<T>> upServers = ImmutableList.of(); | |
private final AtomicLong counter = new AtomicLong(); | |
private volatile Instant lastHealthCheckTime; | |
private volatile Instant lastServerListUpdateTime; | |
/** | |
* Constructs an armeria client that uses client-side load balancing. | |
* | |
* @param remoteInvokerFactory the client event loop to use. | |
* @param interfaceClass the interface class of the rpc client to construct. | |
* @param addressScheme the rpc scheme to use (e.g., tbinary+http, none+http). | |
* @param serviceName the service name to use in monitoring. | |
* @param metricRegistry the {@link MetricRegistry} to export stats to. | |
* @param servicePath the url path the service is hosted on (e.g., /thrift). | |
* @param healthCheckPath the url path health checking is hosted on (e.g., /monitor/l7check). | |
* @param serverAddresses hosts which this client should load balance over. | |
*/ | |
public LoadBalancingClient(RemoteInvokerFactory remoteInvokerFactory, | |
Class<T> interfaceClass, | |
String addressScheme, | |
String serviceName, | |
MetricRegistry metricRegistry, | |
String servicePath, | |
String healthCheckPath, | |
Iterable<String> serverAddresses) { | |
requireNonNull(healthCheckPath, "healthCheckPath"); | |
requireNonNull(serverAddresses, "serverAddresses"); | |
this.remoteInvokerFactory = requireNonNull(remoteInvokerFactory, "remoteInvokerFactory"); | |
this.interfaceClass = requireNonNull(interfaceClass, "interfaceClass"); | |
this.addressScheme = requireNonNull(addressScheme, "addressScheme"); | |
this.serviceName = requireNonNull(serviceName, "serviceName"); | |
this.metricRegistry = requireNonNull(metricRegistry, "metricRegistry"); | |
this.servicePath = requireNonNull(servicePath, "servicePath"); | |
healthCheckRequest = SimpleHttpRequestBuilder.forGet(healthCheckPath).build(); | |
updateServerList(serverAddresses); | |
try { | |
checkAndUpdateUpServers().get(); | |
} catch (InterruptedException | ExecutionException e) { | |
throw new IllegalStateException("Could not succeed with initial health check for " + serviceName, | |
e); | |
} | |
remoteInvokerFactory.eventLoopGroup().scheduleWithFixedDelay(this::checkAndUpdateUpServers, | |
/* initialDelay */ 3, | |
/* delay */ 3, | |
TimeUnit.SECONDS); | |
loadBalancerStatistics = new LoadBalancerStatistics(this); | |
} | |
/** | |
* Returns the client interface backed by this load balancer. | |
*/ | |
public T get() { | |
InvocationHandler handler = (proxy, method, args) -> { | |
T client = getNextConnection().getClient(); | |
try { | |
return method.invoke(client, args); | |
} catch (InvocationTargetException e) { | |
throw MoreObjects.firstNonNull(e.getCause(), e); | |
} | |
}; | |
@SuppressWarnings("unchecked") | |
T proxy = (T) Proxy.newProxyInstance( | |
interfaceClass.getClassLoader(), new Class[] { interfaceClass }, handler); | |
return proxy; | |
} | |
/** | |
* Update the servers this load balancing client talks to. It is useful to call this from a file | |
* watcher (e.g., zookeeper watcher) that allows updating the server list without restarting the | |
* binary. | |
*/ | |
public void updateServerList(Iterable<String> serverAddresses) { | |
lastServerListUpdateTime = Instant.now(); | |
Map<String, ServerConnection<T>> allServersByAddress = allServers.stream().collect( | |
GuavaCollectors.toImmutableMap(ServerConnection::getAddress, Function.identity())); | |
List<ServerConnection<T>> newAllServers = | |
StreamSupport.stream(serverAddresses.spliterator(), false) | |
.map(server -> allServersByAddress.getOrDefault(server, createConnection(server))) | |
.collect(GuavaCollectors.toImmutableList()); | |
allServers = newAllServers; | |
} | |
private ServerConnection<T> getNextConnection() { | |
List<ServerConnection<T>> up = upServers; | |
if (up.isEmpty()) { | |
throw new IllegalStateException("No servers are up, cannot issue request."); | |
} | |
long next = counter.incrementAndGet(); | |
int serverIndex = (int) UnsignedLongs.remainder(next, up.size()); | |
return up.get(serverIndex); | |
} | |
private ListenableFuture<List<SimpleHttpResponse>> checkAndUpdateUpServers() { | |
lastHealthCheckTime = Instant.now(); | |
List<ServerConnection<T>> checkedServers = allServers; | |
ListenableFuture<List<SimpleHttpResponse>> healthCheckResults = Futures.successfulAsList( | |
checkedServers.stream() | |
.map(this::checkServerHealth) | |
.collect(GuavaCollectors.toImmutableList())); | |
Futures.addCallback(healthCheckResults, new FutureCallback<List<SimpleHttpResponse>>() { | |
@Override | |
public void onSuccess(List<SimpleHttpResponse> result) { | |
ImmutableList.Builder<ServerConnection<T>> healthyServers = ImmutableList.builder(); | |
for (int i = 0; i < result.size(); i++) { | |
if (result.get(i) != null && result.get(i).status().equals(HttpResponseStatus.OK)) { | |
healthyServers.add(checkedServers.get(i)); | |
} | |
} | |
upServers = healthyServers.build(); | |
} | |
@Override | |
public void onFailure(Throwable t) { | |
// Shouldn't happen since we use Futures.successfulAsList. | |
throw new IllegalStateException("Exception when checking health.", t); | |
} | |
}); | |
return healthCheckResults; | |
} | |
private ListenableFuture<SimpleHttpResponse> checkServerHealth(ServerConnection<T> serverConnection) { | |
return toListenableFuture(serverConnection.getHealthCheckClient().execute(healthCheckRequest)); | |
} | |
private ServerConnection<T> createConnection(String server) { | |
T client = new ClientBuilder(addressScheme + "://" + server + servicePath) | |
.remoteInvokerFactory(remoteInvokerFactory) | |
.decorator(newDropwizardDecorator(serviceName, metricRegistry)) | |
.build(interfaceClass); | |
SimpleHttpClient healthCheckClient = client instanceof SimpleHttpClient ? (SimpleHttpClient) client : | |
Clients.newClient(remoteInvokerFactory, | |
"none+http://" + server, | |
SimpleHttpClient.class); | |
return new ServerConnection<>(server, client, healthCheckClient); | |
} | |
private static <T> ListenableFuture<T> toListenableFuture(Future<T> nettyFuture) { | |
SettableFuture<T> future = SettableFuture.create(); | |
nettyFuture.addListener(f -> { | |
if (f.isSuccess()) { | |
future.set((T) f.getNow()); | |
} else { | |
future.setException(f.cause()); | |
} | |
}); | |
return future; | |
} | |
@Override | |
public String toString() { | |
StringBuilder buf = new StringBuilder("LoadBalancingClient<" + serviceName + "> [\n") | |
.append(" ").append(upServers.size()).append(" servers up out of ").append(allServers.size()) | |
.append(".\n") | |
.append(" Last health check: ").append(lastHealthCheckTime).append('\n') | |
.append(" Last server list update: ").append(lastServerListUpdateTime).append('\n') | |
.append("Server status: ").append(loadBalancerStatistics.getAttributeMap()); | |
return buf.toString(); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment