Created
May 13, 2020 22:10
-
-
Save Squiry/3b88bffa0a502e771555d8317a589cc7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.example.r2dbc.example; | |
import io.r2dbc.postgresql.PostgresqlConnectionConfiguration; | |
import io.r2dbc.postgresql.PostgresqlConnectionFactory; | |
import io.r2dbc.postgresql.api.PostgresqlConnection; | |
import io.r2dbc.postgresql.api.PostgresqlResult; | |
import io.r2dbc.postgresql.client.ReactorNettyClient; | |
import io.r2dbc.spi.Connection; | |
import io.r2dbc.spi.Result; | |
import org.junit.jupiter.api.AfterAll; | |
import org.junit.jupiter.api.Assertions; | |
import org.junit.jupiter.api.BeforeAll; | |
import org.junit.jupiter.api.Test; | |
import org.reactivestreams.Publisher; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import org.testcontainers.containers.PostgreSQLContainer; | |
import org.testcontainers.shaded.com.google.common.util.concurrent.Uninterruptibles; | |
import reactor.core.publisher.Flux; | |
import reactor.core.publisher.Mono; | |
import reactor.core.scheduler.Schedulers; | |
import reactor.retry.Jitter; | |
import reactor.retry.Repeat; | |
import reactor.test.StepVerifier; | |
import reactor.util.function.Tuple2; | |
import reactor.util.function.Tuples; | |
import java.lang.reflect.Field; | |
import java.time.Duration; | |
import java.util.ArrayList; | |
import java.util.Collection; | |
import java.util.concurrent.ConcurrentLinkedQueue; | |
import java.util.concurrent.TimeUnit; | |
import java.util.concurrent.atomic.AtomicLong; | |
import java.util.function.Function; | |
import java.util.stream.Collectors; | |
import java.util.stream.LongStream; | |
class R2DbcExampleApplicationTests { | |
private static final Logger log = LoggerFactory.getLogger(R2DbcExampleApplicationTests.class); | |
private static final int MESSAGE_COUNT = 50_000; | |
private static final PostgreSQLContainer<?> postgresql = new PostgreSQLContainer<>("postgres:9.6-alpine"); | |
@BeforeAll | |
static void beforeAll() { | |
postgresql.start(); | |
} | |
@AfterAll | |
static void afterAll() { | |
postgresql.stop(); | |
} | |
private final PostgresqlConnectionFactory connectionFactory = new PostgresqlConnectionFactory(PostgresqlConnectionConfiguration.builder() | |
.applicationName("test") | |
.host(postgresql.getContainerIpAddress()) | |
.port(postgresql.getMappedPort(5432)) | |
.database(postgresql.getDatabaseName()) | |
.username(postgresql.getUsername()) | |
.password(postgresql.getPassword()) | |
.build() | |
); | |
private <T> Flux<T> inTx(Function<Tuple2<PostgresqlConnection, String>, Publisher<T>> txCallback) { | |
var connectionCallback = (Function<PostgresqlConnection, Flux<T>>) con -> { | |
var tx = Mono.from(con.beginTransaction()).thenReturn(Tuples.of(con, "tx")); | |
return Flux.usingWhen(tx, txCallback, txp -> con.commitTransaction(), | |
(txp, e) -> con.rollbackTransaction(), | |
txp -> con.rollbackTransaction()); | |
}; | |
return withConnection(connectionCallback); | |
} | |
private <T> Flux<T> withConnection(Function<PostgresqlConnection, Flux<T>> connectionCallback) { | |
return Flux.usingWhen(connectionFactory.create(), connectionCallback, Connection::close); | |
} | |
private void initDatabase() { | |
var createsql = "drop table if exists topic_message;\n" + | |
"create table if not exists topic_message\n" + | |
"(\n" + | |
" consensus_timestamp bigint primary key not null,\n" + | |
" realm_num int not null,\n" + | |
" topic_num int not null,\n" + | |
" message bytea null,\n" + | |
" running_hash bytea null,\n" + | |
" sequence_number bigint null\n" + | |
");\n" + | |
"\n" + | |
"create index if not exists topic_message__realm_num_timestamp\n" + | |
" on topic_message (realm_num, topic_num, consensus_timestamp);\n"; | |
Flux.usingWhen(connectionFactory.create(), con -> Flux.from(con.createStatement(createsql).execute()).flatMap(Result::getRowsUpdated), Connection::close).blockLast(); | |
log.info("Database init complete"); | |
} | |
private void initRows() { | |
var f = Flux.just(1); | |
int connections = 10; | |
for (int j = 0; j < connections; j++) { | |
var startFrom = j * (MESSAGE_COUNT / connections); | |
var messages = Flux.range(startFrom, j == connections - 1 ? MESSAGE_COUNT / connections + 1 : MESSAGE_COUNT / connections); | |
var insert = inTx(tx -> messages | |
.concatMap(message -> tx.getT1().createStatement("INSERT INTO topic_message(consensus_timestamp, realm_num, topic_num, message, running_hash, sequence_number) VALUES ($1, $2, $3, " + | |
"$4, $5, " + | |
"$6)") | |
.bind("$1", message) | |
.bind("$2", 0) | |
.bind("$3", 0) | |
.bindNull("$4", byte[].class) | |
.bindNull("$5", byte[].class) | |
.bind("$6", 0L) | |
.execute() | |
.flatMap(PostgresqlResult::getRowsUpdated) | |
)); | |
f = f.mergeWith(insert); | |
} | |
f.blockLast(); | |
log.info("Row insertion complete"); | |
} | |
@Test | |
void test() { | |
initDatabase(); | |
initRows(); | |
var testCases = new ArrayList<TestCase>(); | |
var unverified = new ArrayList<TestCase>(); | |
for (int i = 1; i < 10; i++) { | |
testCases.add(queryAndCancel(i)); | |
} | |
for (var testCase : testCases) { | |
try { | |
log.info("Verifying test case: {}", testCase); | |
testCase.stepVerifier.verify(); | |
Assertions.assertIterableEquals(LongStream.range(1, testCase.limit + 1).boxed().collect(Collectors.toList()), testCase.values); | |
} catch (Throwable e) { | |
log.warn("Error for {}: {}", testCase, e.getMessage()); | |
unverified.add(testCase); | |
} | |
} | |
log.info("Unverified test cases: {}", unverified.stream().map(TestCase::toString).collect(Collectors.joining("\n"))); | |
Assertions.assertTrue(unverified.isEmpty()); | |
} | |
private TestCase queryAndCancel(int id) { | |
int limit = MESSAGE_COUNT;//ThreadLocalRandom.current().nextInt(MESSAGE_COUNT); | |
var counter = new AtomicLong(); | |
var values = new ConcurrentLinkedQueue<Long>(); | |
var processId = new AtomicLong(); | |
var r2dbc = Flux.defer(() -> withConnection(connection -> { | |
processId.set(this.getProcessId(connection)); | |
return connection.createStatement("SELECT " + id) | |
.execute() | |
.flatMap(PostgresqlResult::getRowsUpdated) | |
.thenMany(connection.createStatement( | |
"SELECT * FROM topic_message WHERE realm_num = 0 AND topic_num = 0 AND consensus_timestamp > $1 ORDER " + | |
"BY consensus_timestamp LIMIT 1000") | |
.bind("$1", counter.get()) | |
.execute() | |
.flatMap(r -> r.map((row, meta) -> row.get("consensus_timestamp", Long.class)))); | |
})); | |
var stepVerifier = r2dbc | |
.doOnComplete(() -> log.debug("[{}, {}] Complete", processId.get(), id)) | |
.doOnSubscribe(s -> log.debug("[{}, {}] Executing query with {}/{} messages", processId.get(), id, counter, limit)) | |
.repeatWhen(Repeat.times(Long.MAX_VALUE) | |
.fixedBackoff(Duration.ofSeconds(1)) | |
.jitter(Jitter.random(0.2))) | |
.as(t -> t.limitRequest(limit)) | |
.doOnSubscribe(s -> log.info("[{}, {}] Executing query with limit {}", processId.get(), id, limit)) | |
.doOnCancel(() -> log.info("[{}, {}] Cancelled query with {}/{} messages", processId.get(), id, counter.get(), limit)) | |
.doOnComplete(() -> log.info("[{}, {}] Completed query with {}/{} messages", processId.get(), id, counter.get(), limit)) | |
.doOnNext(t -> log.trace("[{}, {}] onNext: {}", processId.get(), id, t)) | |
.doOnNext(t -> counter.incrementAndGet()) | |
.doOnNext(t -> Uninterruptibles.sleepUninterruptibly(1, TimeUnit.MILLISECONDS)) // Simulate client backpressure | |
.doOnNext(values::add) | |
.timeout(Duration.ofSeconds(5)) | |
.publishOn(Schedulers.boundedElastic()) | |
.subscribeOn(Schedulers.elastic()) | |
.doOnError(e -> log.error("[{}, {}] Error {}/{} messages", processId.get(), id, counter.get(), limit, e)) | |
.as(StepVerifier::create) | |
.expectNextCount(limit) | |
.expectComplete() | |
.verifyLater(); | |
return new TestCase(stepVerifier, id, limit, counter, values); | |
} | |
private static final Field clientField; | |
private static final Field processIdField; | |
static { | |
try { | |
clientField = Class.forName("io.r2dbc.postgresql.PostgresqlConnection").getDeclaredField("client"); | |
processIdField = ReactorNettyClient.class.getDeclaredField("processId"); | |
clientField.setAccessible(true); | |
processIdField.setAccessible(true); | |
} catch (Exception e) { | |
throw new RuntimeException(e); | |
} | |
} | |
private int getProcessId(PostgresqlConnection connection) { | |
try { | |
var client = clientField.get(connection); | |
return (Integer) processIdField.get(client); | |
} catch (IllegalAccessException e) { | |
throw new RuntimeException(e); | |
} | |
} | |
private static class TestCase { | |
private final StepVerifier stepVerifier; | |
private final int id; | |
private final int limit; | |
private final AtomicLong count; | |
private final Collection<Long> values; | |
private TestCase(StepVerifier stepVerifier, int id, int limit, AtomicLong count, Collection<Long> values) { | |
this.stepVerifier = stepVerifier; | |
this.id = id; | |
this.limit = limit; | |
this.count = count; | |
this.values = values; | |
} | |
@Override | |
public String toString() { | |
return "TestCase{" + | |
"stepVerifier=" + stepVerifier + | |
", id=" + id + | |
", limit=" + limit + | |
", count=" + count + | |
'}'; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment