Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
package sample;
import org.junit.Test;
import rx.Observable;
import rx.Scheduler;
import rx.schedulers.Schedulers;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
* Example of selecting schedulers based on current event (which doesn't work).
public class StripedObserverTest {
public void multiple() throws Exception {
for (int i = 0; i < 128; i++) {
public void single() throws Exception {
final List<Scheduler> schedulers = Stream.generate(() -> Schedulers.from(Executors.newSingleThreadExecutor())).limit(10).collect(Collectors.toList());
// given a key selects a scheduler ("partition" or "shard")
final Function<String, Scheduler> fn = str -> schedulers.get(Math.abs(str.hashCode()) % schedulers.size());
final List<String> elements = Arrays.asList("one", "two", "three", "one", "two", "four");
final Observable<String> observable = Observable.from(elements);
final CountDownLatch latch = new CountDownLatch(elements.size());
final ConcurrentMap<String, Set<String>> state = new ConcurrentHashMap<>();
observable.groupBy(e -> e)
.flatMap(o -> o.subscribeOn(fn.apply(o.getKey())).map(i -> i))
.subscribe(e -> {
state.computeIfAbsent(e, ignore -> new CopyOnWriteArraySet<>()).add(Thread.currentThread().getName());
if (!latch.await(500, TimeUnit.MILLISECONDS)) {
throw new AssertionError("Timeout");
assertFalse("no subscriptions observed", state.isEmpty());
assertEquals("all elements processed", new HashSet<>(elements), state.keySet());
Set<String> threads = state.values().stream().flatMap(Set::stream).collect(Collectors.toSet());
assertTrue("Should be more than one thread " + threads, threads.size() > 1);
for (Map.Entry<String, Set<String>> entry: state.entrySet()) {
assertEquals(String.format("Multiple threads (%s) for key=%s\n\nCurrent state: %s", entry.getValue(), entry.getKey(), state),
1, entry.getValue().size());
public void example() throws Exception {
final List<Scheduler> schedulers = Stream.generate(() -> Schedulers.from(Executors.newSingleThreadExecutor())).limit(10).collect(Collectors.toList());
final Function<String, String> keyFn = s -> s;
// select scheduler for each element
final Function<String, Scheduler> schedulerFn = key -> schedulers.get(Math.abs(key.hashCode()) % schedulers.size());
Observable.just("one", "two", "three", "one", "two", "four")
.groupBy(i -> i) // this is value -> key function
.flatMap(g -> g.subscribeOn(schedulerFn.apply(g.getKey())))
.subscribe(e -> System.out.printf("key=%s value=%s thread=%s\n", e, e, Thread.currentThread().getName()));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment