Skip to content

Instantly share code, notes, and snippets.

@electroCutie
Last active June 16, 2020 11:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save electroCutie/3ee8d2c44637757bff5d227df402afb2 to your computer and use it in GitHub Desktop.
Save electroCutie/3ee8d2c44637757bff5d227df402afb2 to your computer and use it in GitHub Desktop.
package com.macvalves.util.concurrent;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
import java.util.Collection;
import java.util.List;
import java.util.Spliterators;
import java.util.concurrent.CyclicBarrier;
import java.util.function.Function;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import com.google.common.collect.AbstractIterator;
import com.macvalves.util.Exceptions;
public class StagedStream<S, T> {
// when all parties initially arrive that increments the index and populates the Ts
private int idx = -1;
private List<T> currentTs = null;
private final CyclicBarrier barrier;
private final List<S> seeds;
private final Function<S, List<T>> expansionFunction;
private int takenStreams = 0;
public StagedStream(int streams, List<S> seeds, Function<S, List<T>> expansion){
this.seeds = requireNonNull(seeds);
this.expansionFunction = requireNonNull(expansion);
checkArgument(streams > 0);
this.barrier = new CyclicBarrier(streams, () -> {
if(idx < seeds.size())
currentTs = expansionFunction.apply(seeds.get(++idx));
else
currentTs = null;
});
}
public Stream<T> takeStream(){
checkState(takenStreams < barrier.getParties(), "Too many streams taken");
takenStreams++;
return StreamSupport.stream(Spliterators.spliterator(
new ItrWithAcquire(), seeds.size(), 0), false)
.flatMap(Collection::stream);
}
private class ItrWithAcquire extends AbstractIterator<List<T>>{
@Override
protected List<T> computeNext(){
Exceptions.runR(barrier::await);
if(null == currentTs){
assert idx >= seeds.size();
return this.endOfData();
}
assert idx < seeds.size();
return currentTs;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment