Skip to content

Instantly share code, notes, and snippets.

@NathanHowell
Last active October 31, 2018 16:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NathanHowell/acd01e002103e08979d9517ca79e60f0 to your computer and use it in GitHub Desktop.
Save NathanHowell/acd01e002103e08979d9517ca79e60f0 to your computer and use it in GitHub Desktop.
KinesisIO using Splittable DoFn (SDF) and the V2 Kinesis API (HTTP/2 push)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.kinesisv2;
import org.apache.beam.sdk.io.kinesisv2.ShardPoller.Request;
import org.apache.beam.sdk.io.kinesisv2.KinesisIO.Query;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.joda.time.Duration;
import software.amazon.awssdk.services.kinesis.model.Record;
import java.io.Serializable;
public final class KinesisIO extends PTransform<PCollection<Query>, PCollection<Record>> {
private final int bufferSize;
public KinesisIO(int bufferSize) {
this.bufferSize = bufferSize;
}
public static class Query implements Serializable {
String streamName;
}
@Override
public PCollection<Record> expand(PCollection<Query> input) {
return input
.apply(MapElements
.into(TypeDescriptor.of(ShardPoller.Request.class))
.via(query -> new Request(query.streamName, Duration.standardSeconds(1L))))
.apply(ParDo.of(new ShardPoller()))
.apply(ParDo.of(new RecordReader(bufferSize, "arn:TODO")));
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.kinesisv2;
import com.google.common.base.Strings;
import com.google.common.util.concurrent.Uninterruptibles;
import org.apache.beam.sdk.io.range.ByteKey;
import org.apache.beam.sdk.io.range.ByteKeyRange;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement;
import org.apache.beam.sdk.transforms.splittabledofn.ByteKeyRangeTracker;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.SequenceNumberRange;
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.awssdk.services.kinesis.model.StartingPosition;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler;
import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler.Visitor;
import java.math.BigInteger;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import javax.annotation.Nullable;
@UnboundedPerElement
final class RecordReader extends DoFn<Shard, Record> {
private final String consumerARN;
private final int bufferSize;
private transient KinesisAsyncClient kinesisAsyncClient;
RecordReader(final int bufferSize, final String consumerARN) {
this.bufferSize = bufferSize;
this.consumerARN = consumerARN;
}
@Setup
@SuppressWarnings("unused")
void setup() {
this.kinesisAsyncClient = KinesisAsyncClient.create();
}
@Teardown
@SuppressWarnings("unused")
void teardown() {
if (kinesisAsyncClient != null) {
kinesisAsyncClient.close();
kinesisAsyncClient = null;
}
}
private class SubscriptionEvent<T> {
@Nullable final Throwable throwable;
@Nullable final T value;
SubscriptionEvent(Throwable throwable, T value) {
this.throwable = throwable;
this.value = value;
}
}
private Subscriber<SubscribeToShardEventStream> subscriber(
BlockingQueue<SubscriptionEvent<SubscribeToShardEvent>> queue) {
return new Subscriber<SubscribeToShardEventStream>() {
private Subscription subscription;
@Override
public void onSubscribe(Subscription s) {
this.subscription = s;
s.request(queue.remainingCapacity());
}
@Override
public void onNext(SubscribeToShardEventStream subscribeToShardEventStream) {
subscribeToShardEventStream.accept(new Visitor() {
@Override
public void visit(SubscribeToShardEvent event) {
Uninterruptibles.putUninterruptibly(queue, new SubscriptionEvent<>(null, event));
if (queue.remainingCapacity() > 32) {
subscription.request(32);
}
}
});
}
@Override
public void onError(Throwable t) {
Uninterruptibles.putUninterruptibly(queue, new SubscriptionEvent<>(t, null));
}
@Override
public void onComplete() {
Uninterruptibles.putUninterruptibly(queue, new SubscriptionEvent<>(null, null));
}
};
}
@ProcessElement
@SuppressWarnings("unused")
public ProcessContinuation processElement(
ProcessContext context,
ByteKeyRangeTracker tracker,
@StateId("sequenceNumber") ValueState<String> sequenceNumber) throws Throwable {
final ArrayBlockingQueue<SubscriptionEvent<SubscribeToShardEvent>> queue = new ArrayBlockingQueue<>(Math.addExact(1, bufferSize));
this.kinesisAsyncClient.subscribeToShard(
SubscribeToShardRequest
.builder()
.shardId(context.element().shardId())
.consumerARN(consumerARN)
.startingPosition(StartingPosition
.builder()
.sequenceNumber(sequenceNumber.read())
.build())
.build(),
SubscribeToShardResponseHandler
.builder()
.subscriber(() -> subscriber(queue))
.build());
final SubscriptionEvent<SubscribeToShardEvent> event = Uninterruptibles.takeUninterruptibly(queue);
if (event.throwable != null) {
throw event.throwable;
} else if (event.value != null) {
for (Record record : event.value.records()) {
if (tracker.tryClaim(byteKey(record.sequenceNumber()))) {
final Instant instant = Instant.ofEpochMilli(record.approximateArrivalTimestamp().toEpochMilli());
context.outputWithTimestamp(record, instant);
} else {
return ProcessContinuation.stop();
}
}
context.updateWatermark(Instant.now().minus(Duration.millis(event.value.millisBehindLatest())));
return ProcessContinuation.resume().withResumeDelay(Duration.standardSeconds(1));
} else {
return ProcessContinuation.stop();
}
}
private int bytesNeeded(int decimalLength) {
// Kinesis keys are 128 byte decimals which translates to a 54 byte array
return (int) Math.ceil(decimalLength * (Math.log(10) / Math.log(256)));
}
private ByteKey byteKey(String key) {
if (Strings.isNullOrEmpty(key)) {
return ByteKey.EMPTY;
}
final byte[] byteArray = new BigInteger(key).toByteArray();
final byte[] bytes = new byte[bytesNeeded(128)];
final int offset = bytes.length - byteArray.length;
System.arraycopy(byteArray, 0, bytes, offset, byteArray.length);
return ByteKey.copyFrom(bytes);
}
@GetInitialRestriction
@SuppressWarnings("unused")
public ByteKeyRangeTracker getInitialRestriction(Shard shard) {
final SequenceNumberRange sequenceNumberRange = shard.sequenceNumberRange();
final ByteKey starting = byteKey(sequenceNumberRange.startingSequenceNumber());
final ByteKey ending = byteKey(sequenceNumberRange.endingSequenceNumber());
return ByteKeyRangeTracker.of(ByteKeyRange.of(starting, ending));
}
@NewTracker
@SuppressWarnings("unused")
public ByteKeyRangeTracker newTracker(ByteKeyRange range) {
return ByteKeyRangeTracker.of(range);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.kinesisv2;
import org.apache.beam.sdk.io.kinesisv2.ShardPoller.Request;
import com.google.common.base.Strings;
import com.google.common.util.concurrent.Futures;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.BoundedPerElement;
import org.joda.time.Duration;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.model.ListShardsRequest;
import software.amazon.awssdk.services.kinesis.model.ListShardsResponse;
import software.amazon.awssdk.services.kinesis.model.Shard;
import java.io.Serializable;
import java.util.stream.Stream;
/**
* Poll Kinesis looking for new ShardIds to process
*/
@BoundedPerElement
final class ShardPoller extends DoFn<Request, Shard> {
private transient KinesisAsyncClient kinesisAsyncClient;
static class Request implements Serializable {
final String streamName;
final Duration pollInterval;
Request(String streamName, Duration pollInterval) {
this.streamName = streamName;
this.pollInterval = pollInterval;
}
}
@Setup
@SuppressWarnings("unused")
void setup() {
this.kinesisAsyncClient = KinesisAsyncClient.create();
}
@Teardown
@SuppressWarnings("unused")
void teardown() {
if (kinesisAsyncClient != null) {
kinesisAsyncClient.close();
kinesisAsyncClient = null;
}
}
/**
* Store the last shard seen to allow us to resume listing shards from a reasonable place.
*/
@StateId("lastShard")
@SuppressWarnings("unused")
private final StateSpec<ValueState<String>> lastShardSpec = StateSpecs.value();
private Stream<Shard> getAllShards(ListShardsRequest listShardsRequest) {
return Stream.of(listShardsRequest)
.flatMap(request -> {
final ListShardsResponse response = Futures.getUnchecked(kinesisAsyncClient.listShards(request));
final Stream<Shard> shards = response.shards().stream();
if (Strings.isNullOrEmpty(response.nextToken())) {
return shards;
} else {
final ListShardsRequest nextRequest = ListShardsRequest
.builder()
.nextToken(response.nextToken())
.build();
return Stream.concat(shards, getAllShards(nextRequest));
}
});
}
@ProcessElement
@SuppressWarnings("unused")
ProcessContinuation processElement(
final ProcessContext context,
final ValueState<String> lastShardState) {
final String lastShard = lastShardState.read();
final String streamName = context.element().streamName;
final ListShardsRequest listShardsRequest = Strings.isNullOrEmpty(lastShard)
? ListShardsRequest.builder().streamName(streamName).build()
: ListShardsRequest.builder().streamName(streamName).exclusiveStartShardId(lastShard).build();
getAllShards(listShardsRequest)
.forEach(shard -> {
context.output(shard);
lastShardState.write(shard.shardId());
});
return ProcessContinuation.resume().withResumeDelay(context.element().pollInterval);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment