Skip to content

Instantly share code, notes, and snippets.

@bamthomas
Created January 13, 2020 11:53
Show Gist options
  • Save bamthomas/069f563b2d5216ee9ee016d3f8443d8b to your computer and use it in GitHub Desktop.
Save bamthomas/069f563b2d5216ee9ee016d3f8443d8b to your computer and use it in GitHub Desktop.
A tiny multithreaded memory databus
package org.icij.datashare;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;
import static java.util.Optional.ofNullable;
public class MemoryDataBus<T> {
public enum Channel {CHANNEL1, CHANNEL2}
private final Map<Consumer<T>, Listener<T>> subscribers = new ConcurrentHashMap<>();
public void publish(final Channel channel, final T message) {
T nonNullMessage = requireNonNull(message, "cannot publish a null message");
subscribers.values().stream().filter(l -> l.hasSubscribedTo(channel)).forEach(l -> l.accept(nonNullMessage));
}
public int subscribe(final Consumer<T> subscriber, final Runnable subscriptionCallback, Supplier<T> shutdownSupplier, final Channel... channels) throws InterruptedException {
Listener<T> listener = new Listener<>(subscriber, shutdownSupplier, channels);
subscribers.put(subscriber, listener);
subscriptionCallback.run();
return listener.loopUntilShutdown();
}
public void unsubscribe(final Consumer<T> subscriber) {
ofNullable(subscribers.remove(subscriber)).ifPresent(Listener::shutdown);
}
private static class Listener<T> implements Consumer<T> {
private final Consumer<T> subscriber;
private final Supplier<T> shutdownSupplier;
private final LinkedHashSet<Channel> channels;
final AtomicReference<T> message = new AtomicReference<>();
final AtomicInteger nbMessage = new AtomicInteger();
public Listener(Consumer<T> subscriber, Supplier<T> shutdownSupplier, Channel[] channels) {
this.subscriber = subscriber;
this.shutdownSupplier = shutdownSupplier;
this.channels = new LinkedHashSet<>(asList(channels));
}
boolean hasSubscribedTo(Channel channel) {
return channels.contains(channel);
}
@Override
public void accept(T message) {
subscriber.accept(message);
synchronized (this.message) {
this.message.set(message);
this.message.notify();
}
nbMessage.getAndIncrement();
}
void shutdown() {
accept(shutdownSupplier.get());
}
boolean shutdownAsked() {
T message = this.message.get();
return message != null && shutdownSupplier.get().equals(message);
}
int loopUntilShutdown() throws InterruptedException {
synchronized (message) {
while (!shutdownAsked()) {
message.wait();
}
}
return nbMessage.get();
}
}
public static void main(String[] args) throws Exception {
ExecutorService executor = Executors.newFixedThreadPool(2);
MemoryDataBus<String> dataBus = new MemoryDataBus<>();
CountDownLatch subscribed = new CountDownLatch(2);
Future<Integer> future1 = executor.submit(() -> dataBus.subscribe(System.out::println, subscribed::countDown, () -> "", Channel.CHANNEL1));
Future<Integer> future2 = executor.submit(() -> dataBus.subscribe(System.out::println, subscribed::countDown, () -> "", Channel.CHANNEL1, Channel.CHANNEL2));
subscribed.await(1, TimeUnit.SECONDS);
dataBus.publish(Channel.CHANNEL1, "this is a message for channel 1");
dataBus.publish(Channel.CHANNEL2, "this is a message for channel 2");
dataBus.publish(Channel.CHANNEL1, "");
executor.shutdown();
executor.awaitTermination(1, TimeUnit.SECONDS);
System.out.println("subscriber 1 received " + future1.get() + " message(s)") ;
System.out.println("subscriber 2 received " + future2.get() + " message(s)") ;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment