package sample; | |
import org.junit.Test; | |
import rx.Observable; | |
import rx.Scheduler; | |
import rx.schedulers.Schedulers; | |
import java.util.Arrays; | |
import java.util.HashSet; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.Set; | |
import java.util.concurrent.ConcurrentHashMap; | |
import java.util.concurrent.ConcurrentMap; | |
import java.util.concurrent.CopyOnWriteArraySet; | |
import java.util.concurrent.CountDownLatch; | |
import java.util.concurrent.Executors; | |
import java.util.concurrent.TimeUnit; | |
import java.util.function.Function; | |
import java.util.stream.Collectors; | |
import java.util.stream.Stream; | |
import static org.junit.Assert.assertEquals; | |
import static org.junit.Assert.assertFalse; | |
import static org.junit.Assert.assertTrue; | |
/** | |
* Example of selecting schedulers based on current event (which doesn't work). | |
**/ | |
public class StripedObserverTest { | |
@Test | |
public void multiple() throws Exception { | |
for (int i = 0; i < 128; i++) { | |
single(); | |
} | |
} | |
@Test | |
public void single() throws Exception { | |
final List<Scheduler> schedulers = Stream.generate(() -> Schedulers.from(Executors.newSingleThreadExecutor())).limit(10).collect(Collectors.toList()); | |
// given a key selects a scheduler ("partition" or "shard") | |
final Function<String, Scheduler> fn = str -> schedulers.get(Math.abs(str.hashCode()) % schedulers.size()); | |
final List<String> elements = Arrays.asList("one", "two", "three", "one", "two", "four"); | |
final Observable<String> observable = Observable.from(elements); | |
final CountDownLatch latch = new CountDownLatch(elements.size()); | |
final ConcurrentMap<String, Set<String>> state = new ConcurrentHashMap<>(); | |
observable.groupBy(e -> e) | |
.flatMap(o -> o.subscribeOn(fn.apply(o.getKey())).map(i -> i)) | |
.subscribe(e -> { | |
state.computeIfAbsent(e, ignore -> new CopyOnWriteArraySet<>()).add(Thread.currentThread().getName()); | |
latch.countDown(); | |
}); | |
if (!latch.await(500, TimeUnit.MILLISECONDS)) { | |
throw new AssertionError("Timeout"); | |
} | |
assertFalse("no subscriptions observed", state.isEmpty()); | |
assertEquals("all elements processed", new HashSet<>(elements), state.keySet()); | |
Set<String> threads = state.values().stream().flatMap(Set::stream).collect(Collectors.toSet()); | |
assertTrue("Should be more than one thread " + threads, threads.size() > 1); | |
for (Map.Entry<String, Set<String>> entry: state.entrySet()) { | |
assertEquals(String.format("Multiple threads (%s) for key=%s\n\nCurrent state: %s", entry.getValue(), entry.getKey(), state), | |
1, entry.getValue().size()); | |
} | |
} | |
@Test | |
public void example() throws Exception { | |
final List<Scheduler> schedulers = Stream.generate(() -> Schedulers.from(Executors.newSingleThreadExecutor())).limit(10).collect(Collectors.toList()); | |
final Function<String, String> keyFn = s -> s; | |
// select scheduler for each element | |
final Function<String, Scheduler> schedulerFn = key -> schedulers.get(Math.abs(key.hashCode()) % schedulers.size()); | |
Observable.just("one", "two", "three", "one", "two", "four") | |
.groupBy(i -> i) // this is value -> key function | |
.flatMap(g -> g.subscribeOn(schedulerFn.apply(g.getKey()))) | |
.subscribe(e -> System.out.printf("key=%s value=%s thread=%s\n", e, e, Thread.currentThread().getName())); | |
Thread.sleep(1_000); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment