Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import static org.apache.kafka.clients.consumer.ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor;
import org.apache.kafka.clients.consumer.RangeAssignor;
import org.apache.kafka.common.Cluster;
import org.apache.kafka.common.Configurable;
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.utils.ByteBufferInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.instana.backend.common.config.KafkaConfig;
public class AzAwarePartitionAssignor implements ConsumerPartitionAssignor, Configurable {
private static final Logger LOGGER = LoggerFactory.getLogger(AzAwarePartitionAssignor.class);
public static final String CONFIG_KEY_AZ = "instana.az";
private static final ObjectMapper MAPPER = new ObjectMapper();
private String az;
public static KafkaConfig createAzAwareKafkaConfig(KafkaConfig kafkaConfig) {
if (kafkaConfig == null) {
return kafkaConfig;
}
final Map<String, String> kafkaConsumerConfig = new HashMap<>(kafkaConfig.getConsumerConfig());
final String instanaAZ = System.getenv().get("INSTANA_AZ");
if (instanaAZ != null) {
kafkaConsumerConfig.put(CONFIG_KEY_AZ, instanaAZ);
}
kafkaConsumerConfig.put(PARTITION_ASSIGNMENT_STRATEGY_CONFIG,
AzAwarePartitionAssignor.class.getName() + "," + RangeAssignor.class.getName());
return new KafkaConfig(kafkaConfig.getTopics(), kafkaConsumerConfig, kafkaConfig.getProducerConfig(),
kafkaConfig.getConsecutiveFailedSendOperationsHealthCheckThreshold());
}
@Override
public ByteBuffer subscriptionUserData(Set<String> topics) {
if (az == null) {
return null;
}
try {
return ByteBuffer.wrap(MAPPER.writeValueAsBytes(Collections.singletonMap(CONFIG_KEY_AZ, az)));
} catch (JsonProcessingException e) {
return null;
}
}
@Override
public GroupAssignment assign(Cluster metadata, GroupSubscription groupSubscription) {
LOGGER.info("Starting az aware partition assignment...");
Map<String, Subscription> subscriptions = groupSubscription.groupSubscription();
Map<String, Set<MemberInfo>> subscribersPerTopic = new HashMap<>();
for (Map.Entry<String, Subscription> subscriptionEntry : subscriptions.entrySet()) {
final MemberInfo memberInfo = new MemberInfo(subscriptionEntry);
for (String topic : subscriptionEntry.getValue().topics()) {
subscribersPerTopic.computeIfAbsent(topic, k -> new HashSet<>()).add(memberInfo);
}
}
LOGGER.info("Calculating az aware assignment for partitions {}", subscribersPerTopic.keySet());
Map<String, List<TopicPartition>> assignments = new HashMap<>();
for (Map.Entry<String, Set<MemberInfo>> topicEntry : subscribersPerTopic.entrySet()) {
final List<PartitionInfo> partitionInfos = metadata.partitionsForTopic(topicEntry.getKey());
if (partitionInfos != null && partitionInfos.size() > 0) {
final Map<String, List<TopicPartition>> topicAssignments = assignForTopic(topicEntry.getKey(),
topicEntry.getValue(), partitionInfos);
for (Map.Entry<String, List<TopicPartition>> assignmentEntry : topicAssignments.entrySet()) {
assignments.computeIfAbsent(assignmentEntry.getKey(), k -> new ArrayList<>())
.addAll(assignmentEntry.getValue());
}
} else {
LOGGER.error("Skipping assignment for topic {} since no metadata is available", topicEntry.getKey());
}
}
final HashMap<String, Assignment> result = new HashMap<>();
for (Map.Entry<String, List<TopicPartition>> entry : assignments.entrySet()) {
result.put(entry.getKey(), new Assignment(entry.getValue()));
}
LOGGER.info("Done with az aware partition assignment");
return new GroupAssignment(result);
}
private Map<String, List<TopicPartition>> assignForTopic(String topic, Collection<MemberInfo> memberInfos,
Collection<PartitionInfo> partitions) {
final ArrayList<SubscriberAssignment> sortedSubscribers = new ArrayList<>();
for (MemberInfo memberInfo : memberInfos) {
sortedSubscribers.add(new SubscriberAssignment(memberInfo));
}
Collections.sort(sortedSubscribers, Comparator.comparing(SubscriberAssignment::getMemberId));
final ArrayList<PartitionInfo> sortedPartitions = new ArrayList<>(partitions);
Collections.sort(sortedPartitions, Comparator.comparing(PartitionInfo::partition));
// Calculate partition count for each subscriber
final int numPartitions = partitions.size();
final int numSubscribers = memberInfos.size();
final int numPartitionsPerSubscriber = numPartitions / numSubscribers;
final int numSubscribersWithExtraPartition = numPartitions % numSubscribers;
for (int i = 0; i < sortedSubscribers.size(); i++) {
if (i < numSubscribersWithExtraPartition) {
sortedSubscribers.get(i).setNumPartitions(numPartitionsPerSubscriber + 1);
} else {
sortedSubscribers.get(i).setNumPartitions(numPartitionsPerSubscriber);
}
}
// Assign only subscribers matching AZ
{
final Iterator<PartitionInfo> iterator = sortedPartitions.iterator();
int numAssigned = 0;
while (iterator.hasNext()) {
final PartitionInfo partition = iterator.next();
// in our case leader and followers are in the same AZ anyway
final String rack = partition.leader().rack();
if (rack == null) {
continue;
}
for (SubscriberAssignment sortedSubscriber : sortedSubscribers) {
if (sortedSubscriber.canAcceptMorePartitions() && rack.equals(sortedSubscriber.getMemberAz())) {
sortedSubscriber.assignPartition(partition);
iterator.remove();
numAssigned++;
break;
}
}
}
LOGGER.info("Assigned {} partitions for topic {} with matching AZ", numAssigned, topic);
}
// Assign all remaining partitions to subscribers similar to RangeAssignor
{
final Iterator<PartitionInfo> iterator = sortedPartitions.iterator();
int numAssigned = 0;
while (iterator.hasNext()) {
final PartitionInfo partition = iterator.next();
for (SubscriberAssignment sortedSubscriber : sortedSubscribers) {
if (sortedSubscriber.canAcceptMorePartitions()) {
sortedSubscriber.assignPartition(partition);
iterator.remove();
numAssigned++;
break;
}
}
}
LOGGER.info("Assigned {} partitions for topic {} in range-order", numAssigned, topic);
}
Map<String, List<TopicPartition>> assignment = new HashMap<>();
for (SubscriberAssignment subscriber : sortedSubscribers) {
final ArrayList<TopicPartition> assignedPartitions = new ArrayList<>();
for (PartitionInfo partitionInfo : subscriber.getAssignedPartitions()) {
assignedPartitions.add(new TopicPartition(partitionInfo.topic(), partitionInfo.partition()));
}
assignment.put(subscriber.getMemberId(), assignedPartitions);
}
return assignment;
}
@Override
public String name() {
return "az-aware";
}
@Override
public void configure(Map<String, ?> configs) {
az = (String) configs.get(CONFIG_KEY_AZ);
}
public static class SubscriberAssignment {
private final MemberInfo memberInfo;
private int numPartitions;
private List<PartitionInfo> assignedPartitions = new ArrayList<>();
public SubscriberAssignment(MemberInfo memberInfo) {
this.memberInfo = memberInfo;
}
public void setNumPartitions(int numPartitions) {
this.numPartitions = numPartitions;
}
public boolean canAcceptMorePartitions() {
return assignedPartitions.size() < numPartitions;
}
public void assignPartition(PartitionInfo partitionInfo) {
this.assignedPartitions.add(partitionInfo);
Collections.sort(this.assignedPartitions, Comparator.comparing(PartitionInfo::partition));
}
public List<PartitionInfo> getAssignedPartitions() {
return assignedPartitions;
}
public String getMemberId() {
return memberInfo.getMemberId();
}
public String getMemberAz() {
return memberInfo.getAz();
}
}
public static class MemberInfo {
private final String memberId;
private final Map<String, Object> userData;
public MemberInfo(Map.Entry<String, Subscription> subscriptionEntry) {
this.memberId = subscriptionEntry.getKey();
this.userData = deserializeUserData(subscriptionEntry.getValue().userData());
}
private Map<String, Object> deserializeUserData(ByteBuffer userData) {
if (userData != null) {
try (InputStream is = new ByteBufferInputStream(userData)) {
return MAPPER.readValue(is, Map.class);
} catch (Exception e) {
LOGGER.error("Unable to deserialize user data", e);
}
}
return null;
}
public String getMemberId() {
return memberId;
}
public Map<String, Object> getUserData() {
return userData;
}
public String getAz() {
if (userData != null) {
return (String) userData.getOrDefault(CONFIG_KEY_AZ, "");
} else {
return "";
}
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
MemberInfo that = (MemberInfo) o;
return Objects.equals(memberId, that.memberId) && Objects.equals(userData, that.userData);
}
@Override
public int hashCode() {
return Objects.hash(memberId, userData);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment