Skip to content

Instantly share code, notes, and snippets.

@mrsimpson
Created January 10, 2024 09:45
Show Gist options
  • Save mrsimpson/3eaf028fef28930ab2f40b2b7d909344 to your computer and use it in GitHub Desktop.
Save mrsimpson/3eaf028fef28930ab2f40b2b7d909344 to your computer and use it in GitHub Desktop.
Broadcasting in Flink
package de.fermata.flink.app;
import java.time.Instant;
import java.util.Random;
import java.util.stream.StreamSupport;
import org.apache.flink.api.common.state.BroadcastState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReadOnlyBroadcastState;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.BroadcastStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows;
import org.apache.flink.streaming.api.windowing.time.Time;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.util.Collector;
public class BroadcastStreamDemoJob {
static MapStateDescriptor<Void, Integer> getBroadcastStateDescriptor() {
// has to be a MapState, event if we only store a single value
return new MapStateDescriptor<Void, Integer>("threshold", Void.class, Integer.class);
}
public static void main(String[] args) throws Exception {
Configuration config = new Configuration();
config.setInteger("rest.port", 8888);
config.setDouble("taskmanager.memory.network.fraction", 0.2);
config.setString("taskmanager.memory.network.min", "256 mb");
config.setInteger("client.heartbeat.timeout", 60000000);
config.setInteger("heartbeat.timeout", 60000000);
config.setBoolean("rest.flamegraph.enabled", true);
config.setInteger("parallelism.default", 4);
config.setInteger("parallelism.default", 4);
StreamExecutionEnvironment env =
StreamExecutionEnvironment.createLocalEnvironmentWithWebUI(config);
DataStreamSource<Integer> integerSource = env.addSource(new RandomIntSource());
DataStreamSource<Integer> thresholdStream = env.fromElements(0);
// create averages based on event time. This demonstrates that the integerSource stream has
// proper watermarks to close the windows
SingleOutputStreamOperator<Double> averages =
integerSource
.keyBy(c -> 1)
.window(TumblingEventTimeWindows.of(Time.seconds(2)))
.<Double>process(
new ProcessWindowFunction<>() {
@Override
public void process(
Integer integer,
ProcessWindowFunction<Integer, Double, Integer, TimeWindow>.Context context,
Iterable<Integer> elements,
Collector<Double> out)
throws Exception {
Double average =
StreamSupport.stream(elements.spliterator(), false)
.mapToInt(Integer::intValue)
.average()
.orElse(0.0);
out.collect(average);
}
});
// join with a configuration data source
BroadcastStream<Integer> thresholdBroadcast =
thresholdStream.broadcast(getBroadcastStateDescriptor());
SingleOutputStreamOperator<Tuple2<Double, Boolean>> averagesBelowThreshold =
averages
.connect(thresholdBroadcast)
.process(
new BroadcastProcessFunction<Double, Integer, Tuple2<Double, Boolean>>() {
@Override
public void processElement(
Double value,
BroadcastProcessFunction<Double, Integer, Tuple2<Double, Boolean>>
.ReadOnlyContext
ctx,
Collector<Tuple2<Double, Boolean>> out)
throws Exception {
ReadOnlyBroadcastState<Void, Integer> thresholds =
ctx.getBroadcastState(getBroadcastStateDescriptor());
Integer threshold = thresholds.immutableEntries().iterator().next().getValue();
out.collect(new Tuple2<>(value, value < threshold));
}
@Override
public void processBroadcastElement(
Integer value,
BroadcastProcessFunction<Double, Integer, Tuple2<Double, Boolean>>.Context
ctx,
Collector<Tuple2<Double, Boolean>> out)
throws Exception {
BroadcastState<Void, Integer> thresholds =
ctx.getBroadcastState(getBroadcastStateDescriptor());
thresholds.put(null, value);
}
});
averagesBelowThreshold.print("average below thresholdStream");
// now, let's apply a window on it and see whether this is closed too
SingleOutputStreamOperator<Double> minAverages =
averagesBelowThreshold
.keyBy(c -> 1)
.window(TumblingEventTimeWindows.of(Time.seconds(10)))
.process(
new ProcessWindowFunction<>() {
@Override
public void process(
Integer integer,
ProcessWindowFunction<Tuple2<Double, Boolean>, Double, Integer, TimeWindow>
.Context
context,
Iterable<Tuple2<Double, Boolean>> elements,
Collector<Double> out)
throws Exception {
double min =
StreamSupport.stream(elements.spliterator(), false)
.mapToDouble(t -> t.f0)
.min()
.orElse(Double.NEGATIVE_INFINITY);
out.collect(min);
}
});
minAverages.print("Windowed min. of averages");
env.execute();
}
}
class RandomIntSource extends RichParallelSourceFunction<Integer> {
private volatile boolean cancelled = false;
private Random random;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
random = new Random();
}
@Override
public void run(SourceContext<Integer> ctx) throws Exception {
while (!cancelled) {
Integer nextLong = random.nextInt();
synchronized (ctx.getCheckpointLock()) {
long ts = Instant.now().toEpochMilli();
ctx.collectWithTimestamp(nextLong, ts);
ctx.emitWatermark(new Watermark(ts));
}
}
}
@Override
public void cancel() {
cancelled = true;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment