Skip to content

Instantly share code, notes, and snippets.

@Squiry
Created May 13, 2020 22:10
Show Gist options
  • Save Squiry/3b88bffa0a502e771555d8317a589cc7 to your computer and use it in GitHub Desktop.
Save Squiry/3b88bffa0a502e771555d8317a589cc7 to your computer and use it in GitHub Desktop.
package com.example.r2dbc.example;
import io.r2dbc.postgresql.PostgresqlConnectionConfiguration;
import io.r2dbc.postgresql.PostgresqlConnectionFactory;
import io.r2dbc.postgresql.api.PostgresqlConnection;
import io.r2dbc.postgresql.api.PostgresqlResult;
import io.r2dbc.postgresql.client.ReactorNettyClient;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.Result;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.shaded.com.google.common.util.concurrent.Uninterruptibles;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.retry.Jitter;
import reactor.retry.Repeat;
import reactor.test.StepVerifier;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
import java.lang.reflect.Field;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
class R2DbcExampleApplicationTests {
private static final Logger log = LoggerFactory.getLogger(R2DbcExampleApplicationTests.class);
private static final int MESSAGE_COUNT = 50_000;
private static final PostgreSQLContainer<?> postgresql = new PostgreSQLContainer<>("postgres:9.6-alpine");
@BeforeAll
static void beforeAll() {
postgresql.start();
}
@AfterAll
static void afterAll() {
postgresql.stop();
}
private final PostgresqlConnectionFactory connectionFactory = new PostgresqlConnectionFactory(PostgresqlConnectionConfiguration.builder()
.applicationName("test")
.host(postgresql.getContainerIpAddress())
.port(postgresql.getMappedPort(5432))
.database(postgresql.getDatabaseName())
.username(postgresql.getUsername())
.password(postgresql.getPassword())
.build()
);
private <T> Flux<T> inTx(Function<Tuple2<PostgresqlConnection, String>, Publisher<T>> txCallback) {
var connectionCallback = (Function<PostgresqlConnection, Flux<T>>) con -> {
var tx = Mono.from(con.beginTransaction()).thenReturn(Tuples.of(con, "tx"));
return Flux.usingWhen(tx, txCallback, txp -> con.commitTransaction(),
(txp, e) -> con.rollbackTransaction(),
txp -> con.rollbackTransaction());
};
return withConnection(connectionCallback);
}
private <T> Flux<T> withConnection(Function<PostgresqlConnection, Flux<T>> connectionCallback) {
return Flux.usingWhen(connectionFactory.create(), connectionCallback, Connection::close);
}
private void initDatabase() {
var createsql = "drop table if exists topic_message;\n" +
"create table if not exists topic_message\n" +
"(\n" +
" consensus_timestamp bigint primary key not null,\n" +
" realm_num int not null,\n" +
" topic_num int not null,\n" +
" message bytea null,\n" +
" running_hash bytea null,\n" +
" sequence_number bigint null\n" +
");\n" +
"\n" +
"create index if not exists topic_message__realm_num_timestamp\n" +
" on topic_message (realm_num, topic_num, consensus_timestamp);\n";
Flux.usingWhen(connectionFactory.create(), con -> Flux.from(con.createStatement(createsql).execute()).flatMap(Result::getRowsUpdated), Connection::close).blockLast();
log.info("Database init complete");
}
private void initRows() {
var f = Flux.just(1);
int connections = 10;
for (int j = 0; j < connections; j++) {
var startFrom = j * (MESSAGE_COUNT / connections);
var messages = Flux.range(startFrom, j == connections - 1 ? MESSAGE_COUNT / connections + 1 : MESSAGE_COUNT / connections);
var insert = inTx(tx -> messages
.concatMap(message -> tx.getT1().createStatement("INSERT INTO topic_message(consensus_timestamp, realm_num, topic_num, message, running_hash, sequence_number) VALUES ($1, $2, $3, " +
"$4, $5, " +
"$6)")
.bind("$1", message)
.bind("$2", 0)
.bind("$3", 0)
.bindNull("$4", byte[].class)
.bindNull("$5", byte[].class)
.bind("$6", 0L)
.execute()
.flatMap(PostgresqlResult::getRowsUpdated)
));
f = f.mergeWith(insert);
}
f.blockLast();
log.info("Row insertion complete");
}
@Test
void test() {
initDatabase();
initRows();
var testCases = new ArrayList<TestCase>();
var unverified = new ArrayList<TestCase>();
for (int i = 1; i < 10; i++) {
testCases.add(queryAndCancel(i));
}
for (var testCase : testCases) {
try {
log.info("Verifying test case: {}", testCase);
testCase.stepVerifier.verify();
Assertions.assertIterableEquals(LongStream.range(1, testCase.limit + 1).boxed().collect(Collectors.toList()), testCase.values);
} catch (Throwable e) {
log.warn("Error for {}: {}", testCase, e.getMessage());
unverified.add(testCase);
}
}
log.info("Unverified test cases: {}", unverified.stream().map(TestCase::toString).collect(Collectors.joining("\n")));
Assertions.assertTrue(unverified.isEmpty());
}
private TestCase queryAndCancel(int id) {
int limit = MESSAGE_COUNT;//ThreadLocalRandom.current().nextInt(MESSAGE_COUNT);
var counter = new AtomicLong();
var values = new ConcurrentLinkedQueue<Long>();
var processId = new AtomicLong();
var r2dbc = Flux.defer(() -> withConnection(connection -> {
processId.set(this.getProcessId(connection));
return connection.createStatement("SELECT " + id)
.execute()
.flatMap(PostgresqlResult::getRowsUpdated)
.thenMany(connection.createStatement(
"SELECT * FROM topic_message WHERE realm_num = 0 AND topic_num = 0 AND consensus_timestamp > $1 ORDER " +
"BY consensus_timestamp LIMIT 1000")
.bind("$1", counter.get())
.execute()
.flatMap(r -> r.map((row, meta) -> row.get("consensus_timestamp", Long.class))));
}));
var stepVerifier = r2dbc
.doOnComplete(() -> log.debug("[{}, {}] Complete", processId.get(), id))
.doOnSubscribe(s -> log.debug("[{}, {}] Executing query with {}/{} messages", processId.get(), id, counter, limit))
.repeatWhen(Repeat.times(Long.MAX_VALUE)
.fixedBackoff(Duration.ofSeconds(1))
.jitter(Jitter.random(0.2)))
.as(t -> t.limitRequest(limit))
.doOnSubscribe(s -> log.info("[{}, {}] Executing query with limit {}", processId.get(), id, limit))
.doOnCancel(() -> log.info("[{}, {}] Cancelled query with {}/{} messages", processId.get(), id, counter.get(), limit))
.doOnComplete(() -> log.info("[{}, {}] Completed query with {}/{} messages", processId.get(), id, counter.get(), limit))
.doOnNext(t -> log.trace("[{}, {}] onNext: {}", processId.get(), id, t))
.doOnNext(t -> counter.incrementAndGet())
.doOnNext(t -> Uninterruptibles.sleepUninterruptibly(1, TimeUnit.MILLISECONDS)) // Simulate client backpressure
.doOnNext(values::add)
.timeout(Duration.ofSeconds(5))
.publishOn(Schedulers.boundedElastic())
.subscribeOn(Schedulers.elastic())
.doOnError(e -> log.error("[{}, {}] Error {}/{} messages", processId.get(), id, counter.get(), limit, e))
.as(StepVerifier::create)
.expectNextCount(limit)
.expectComplete()
.verifyLater();
return new TestCase(stepVerifier, id, limit, counter, values);
}
private static final Field clientField;
private static final Field processIdField;
static {
try {
clientField = Class.forName("io.r2dbc.postgresql.PostgresqlConnection").getDeclaredField("client");
processIdField = ReactorNettyClient.class.getDeclaredField("processId");
clientField.setAccessible(true);
processIdField.setAccessible(true);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private int getProcessId(PostgresqlConnection connection) {
try {
var client = clientField.get(connection);
return (Integer) processIdField.get(client);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
private static class TestCase {
private final StepVerifier stepVerifier;
private final int id;
private final int limit;
private final AtomicLong count;
private final Collection<Long> values;
private TestCase(StepVerifier stepVerifier, int id, int limit, AtomicLong count, Collection<Long> values) {
this.stepVerifier = stepVerifier;
this.id = id;
this.limit = limit;
this.count = count;
this.values = values;
}
@Override
public String toString() {
return "TestCase{" +
"stepVerifier=" + stepVerifier +
", id=" + id +
", limit=" + limit +
", count=" + count +
'}';
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment