Skip to content

Instantly share code, notes, and snippets.

@RaccoonDev
Created January 31, 2020 20:18
Show Gist options
  • Save RaccoonDev/1624b1bf6691356a5bc941f8a3cbeea9 to your computer and use it in GitHub Desktop.
Save RaccoonDev/1624b1bf6691356a5bc941f8a3cbeea9 to your computer and use it in GitHub Desktop.
Going through unit testing for streaming applications with Apache Beam.
package org.apache.beam.examples.streams;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Objects;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.windowing.AfterFirst;
import org.apache.beam.sdk.transforms.windowing.AfterPane;
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@RunWith(JUnit4.class)
public class TestStreamTest {
private Instant baseTime = new Instant(0);
@Rule public TestPipeline p = TestPipeline.create();
@Test
public void testStream() {
TestStream<String> createEvents =
TestStream.create(StringUtf8Coder.of())
.advanceWatermarkTo(baseTime)
.addElements(getTestJson(1, 1, 200))
.addElements(getTestJson(1, 2, 300))
.addElements(getTestJson(1, 3, 350))
.addElements(getTestJson(2, 4, 350))
.addElements(getTestJson(2, 5, 350))
.advanceWatermarkToInfinity();
PCollection<TestRequest> events = p.apply(createEvents).apply(ParDo.of(new ParseTestJson()));
PAssert.that(events)
.inWindow(GlobalWindow.INSTANCE)
.containsInAnyOrder(
new TestRequest(1, 1, 200),
new TestRequest(1, 2, 300),
new TestRequest(1, 3, 350),
new TestRequest(2, 4, 350),
new TestRequest(2, 5, 350));
p.run().waitUntilFinish();
}
@Test
public void testGroupedByAId() {
TestStream<String> createEvents =
TestStream.create(StringUtf8Coder.of())
.advanceWatermarkTo(baseTime)
.addElements(getTestJson(1, 1, 200))
.addElements(getTestJson(1, 2, 301))
.addElements(getTestJson(1, 3, 352))
.addElements(getTestJson(2, 4, 353))
.addElements(getTestJson(2, 5, 354))
.addElements(getTestJson(1, 1, 205))
.advanceProcessingTime(Duration.standardMinutes(2))
.addElements(getTestJson(1, 2, 306))
.addElements(getTestJson(1, 3, 357))
.addElements(getTestJson(2, 4, 358))
.addElements(getTestJson(2, 5, 359))
.advanceProcessingTime(Duration.standardMinutes(2))
.addElements(getTestJson(1, 1, 2010))
.addElements(getTestJson(1, 2, 3011))
.addElements(getTestJson(1, 3, 3512))
.addElements(getTestJson(2, 4, 3513))
.addElements(getTestJson(2, 5, 3514))
.advanceProcessingTime(Duration.standardMinutes(2))
.addElements(getTestJson(3, 6, 3515))
.advanceProcessingTime(Duration.standardMinutes(1))
.advanceWatermarkToInfinity();
PCollection<KV<Long, Iterable<TestRequest>>> events =
p.apply(createEvents)
.apply(
Window.<String>configure()
.triggering(Repeatedly.forever(
AfterFirst.of(
AfterPane.elementCountAtLeast(2),
AfterProcessingTime.pastFirstElementInPane().plusDelayOf(Duration.standardMinutes(1)))
))
.discardingFiredPanes()
.withAllowedLateness(Duration.ZERO))
.apply(ParDo.of(new ParseTestJson()))
.apply(ParDo.of(new RequestsToKV()))
.apply(GroupByKey.create())
.apply(ParDo.of(new PassThroughLoggingFn<>()));
PAssert.that(events)
.inEarlyGlobalWindowPanes()
.containsInAnyOrder(
KV.of(1L, Arrays.asList(new TestRequest(1, 1, 200), new TestRequest(1, 2, 301))),
KV.of(1L, Arrays.asList(new TestRequest(1, 3, 352), new TestRequest(1, 1, 205))),
KV.of(2L, Arrays.asList(new TestRequest(2, 4, 353), new TestRequest(2, 5, 354)))
);
p.run().waitUntilFinish();
}
public static class ParseTestJson extends DoFn<String, TestRequest> {
private static final Logger logger = LoggerFactory.getLogger(ParseTestJson.class);
@ProcessElement
public void processElement(ProcessContext c) {
String element = c.element();
// logger.info("Parsing request string: {}", element);
ObjectMapper om = new ObjectMapper();
try {
TestRequest parsed = om.readValue(element, TestRequest.class);
c.output(parsed);
} catch (IOException e) {
logger.error("Unable to parse given JSON to TestRequest", e);
}
}
}
public static class RequestsToKV extends DoFn<TestRequest, KV<Long, TestRequest>> {
private static final Logger logger = LoggerFactory.getLogger(RequestsToKV.class);
@ProcessElement
public void processElement(ProcessContext c) {
TestRequest r = c.element();
KV<Long, TestRequest> of = KV.of(r.aId, r);
// logger.info("Converting r to KV: {}", of);
c.output(of);
}
}
private static class PassThroughLoggingFn<T> extends DoFn<T, T> {
private static final Logger logger = LoggerFactory.getLogger(PassThroughLoggingFn.class);
@ProcessElement
public void ProcessElement(ProcessContext c) {
T element = c.element();
logger.info("Logging element: {}", element);
c.output(element);
}
}
public static class TestRequest implements Serializable {
private long aId;
private long kId;
private long value;
public TestRequest() {
}
public TestRequest(long aId, long kId, long value) {
this.aId = aId;
this.kId = kId;
this.value = value;
}
public long getaId() {
return aId;
}
public void setaId(long aId) {
this.aId = aId;
}
public long getkId() {
return kId;
}
public void setkId(long kId) {
this.kId = kId;
}
public long getValue() {
return value;
}
public void setValue(long value) {
this.value = value;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TestRequest that = (TestRequest) o;
return getaId() == that.getaId() &&
getkId() == that.getkId() &&
getValue() == that.getValue();
}
@Override
public int hashCode() {
return Objects.hash(getaId(), getkId(), getValue());
}
@Override
public String toString() {
return "TestRequest{" +
"aId=" + aId +
", kId=" + kId +
", value=" + value +
'}';
}
}
private String getTestJson(long aId, long kId, long value) {
return String.format(
"{ \"aId\": %d, \"kId\": %d, \"value\": %d }", aId, kId, value);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment