Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@tnine
Forked from akarnokd/OrderedMerge.java
Last active August 1, 2016 14:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save tnine/11224341 to your computer and use it in GitHub Desktop.
Save tnine/11224341 to your computer and use it in GitHub Desktop.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.usergrid.persistence.graph.serialization.impl.parse;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.Subscriber;
import rx.Subscription;
import rx.observers.SerializedSubscriber;
import rx.subscriptions.CompositeSubscription;
/**
* Produces a single Observable from multiple ordered source observables. The same as the "merge" step in a merge sort.
* Ensure that your comparator matches the ordering of your inputs, or you may get strange results.
*/
public final class OrderedMerge<T> implements Observable.OnSubscribe<T> {
private static Logger log = LoggerFactory.getLogger( OrderedMerge.class );
//the comparator to compare items
private final Comparator<T> comparator;
private final Observable<? extends T>[] observables;
//The max amount to buffer before blowing up
private final int maxBufferSize;
private OrderedMerge( final Comparator<T> comparator, final int maxBufferSize,
Observable<? extends T>... observables ) {
this.comparator = comparator;
this.maxBufferSize = maxBufferSize;
this.observables = observables;
}
@Override
public void call( final Subscriber<? super T> outerOperation ) {
CompositeSubscription csub = new CompositeSubscription();
//when a subscription is received, we need to subscribe on each observable
SubscriberCoordinator coordinator = new SubscriberCoordinator( comparator, outerOperation );
InnerObserver<T>[] innerObservers = new InnerObserver[observables.length];
//we have to do this in 2 steps to get the synchronization correct. We must set up our total inner observers
//before starting subscriptions otherwise our assertions for completion or starting won't work properly
for ( int i = 0; i < observables.length; i++ ) {
//subscribe to each one and add it to the composite
//create a new inner and subscribe
final InnerObserver<T> inner = new InnerObserver<T>( coordinator, maxBufferSize );
coordinator.add( inner );
innerObservers[i] = inner;
}
/**
* Once we're set up, begin the subscription to sub observables
*/
for ( int i = 0; i < observables.length; i++ ) {
//subscribe after setting them up
//add our subscription to the composite for future cancellation
Subscription subscription = observables[i].subscribe( innerObservers[i] );
csub.add( subscription );
//add the internal composite subscription
outerOperation.add( csub );
}
}
/**
* Our coordinator. It coordinates all the
*/
private static final class SubscriberCoordinator<T> {
private final AtomicInteger completedCount = new AtomicInteger();
private volatile boolean readyToProduce = false;
private final Comparator<T> comparator;
private final Subscriber<? super T> subscriber;
private final List<InnerObserver<T>> innerSubscribers;
private SubscriberCoordinator( final Comparator<T> comparator, final Subscriber<? super T> subscriber ) {
//we only want to emit events serially
this.subscriber = new SerializedSubscriber( subscriber );
this.innerSubscribers = new ArrayList<InnerObserver<T>>();
this.comparator = comparator;
}
public void onCompleted() {
final int completed = completedCount.incrementAndGet();
//we're done, just drain the queue since there are no more running producers
if ( completed == innerSubscribers.size() ) {
log.trace( "Completing Observable. Draining elements from the subscribers", innerSubscribers.size() );
//Drain the queues
while ( !subscriber.isUnsubscribed() && !drained() ) {
next();
}
//signal completion
subscriber.onCompleted();
}
}
public void add( InnerObserver<T> inner ) {
this.innerSubscribers.add( inner );
}
public void onError( Throwable e ) {
subscriber.onError( e );
}
public void next() {
//nothing to do, we haven't started emitting values yet
if ( !readyToProduce() ) {
return;
}
//we want to emit items in order, so we synchronize our next
synchronized ( this ) {
//take as many elements as we can until we hit the completed case
while ( true ) {
InnerObserver<T> maxObserver = null;
for ( InnerObserver<T> inner : innerSubscribers ) {
//nothing to do, this inner
//we're done skip it
if ( inner.drained ) {
continue;
}
final T current = inner.peek();
/**
* Our current is null but we're not drained (I.E we haven't finished and completed consuming)
* This means the producer is slow, and we don't have a complete set to compare,
* we can't produce
*/
if ( current == null ) {
return;
}
if ( maxObserver == null || ( current != null
&& comparator.compare( current, maxObserver.peek() ) > 0 ) ) {
maxObserver = inner;
}
}
//No max observer was ever assigned, meaning all our inners are drained, break from loop
if ( maxObserver == null ) {
return;
}
subscriber.onNext( maxObserver.pop() );
}
}
}
/**
* Return true if we're ready to produce
*/
private boolean readyToProduce() {
if ( readyToProduce ) {
return true;
}
//perform an audit
for ( InnerObserver<T> inner : innerSubscribers ) {
if ( !inner.started ) {
readyToProduce = false;
return false;
}
}
readyToProduce = true;
//we'll try again next time
return false;
}
/**
* Return true if every inner observer has been drained
*/
private boolean drained() {
//perform an audit
for ( InnerObserver<T> inner : innerSubscribers ) {
if ( !inner.drained ) {
return false;
}
}
return true;
}
}
private static final class InnerObserver<T> extends Subscriber<T> {
private final SubscriberCoordinator<T> coordinator;
private final Deque<T> items = new LinkedList<T>();
private final int maxQueueSize;
/**
* Flags for synchronization with coordinator. Multiple threads may be used, so volatile is required
*/
private volatile boolean started = false;
private volatile boolean completed = false;
private volatile boolean drained = false;
public InnerObserver( final SubscriberCoordinator<T> coordinator, final int maxQueueSize ) {
this.coordinator = coordinator;
this.maxQueueSize = maxQueueSize;
}
@Override
public void onCompleted() {
started = true;
completed = true;
checkDrained();
coordinator.onCompleted();
}
@Override
public void onError( Throwable e ) {
coordinator.onError( e );
}
@Override
public void onNext( T a ) {
log.trace( "Received {}", a );
if ( items.size() == maxQueueSize ) {
RuntimeException e =
new RuntimeException( "The maximum queue size of " + maxQueueSize + " has been reached" );
onError( e );
}
items.add( a );
started = true;
//for each subscriber, emit to the parent wrapper then evaluate calling on next
coordinator.next();
}
public T peek() {
return items.peekFirst();
}
public T pop() {
T item = items.pollFirst();
checkDrained();
return item;
}
/**
* if we've started and finished, and this is the last element, we want to mark ourselves as completely drained
*/
private void checkDrained() {
drained = started && completed && items.size() == 0;
}
}
/**
* Create our ordered merge
*/
public static <T> Observable<T> orderedMerge( Comparator<T> comparator, int maxBufferSize,
Observable<? extends T>... observables ) {
return Observable.create( new OrderedMerge<T>( comparator, maxBufferSize, observables ) );
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.usergrid.persistence.graph.serialization.impl.parse;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.Subscriber;
import rx.schedulers.Schedulers;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class OrderedMergeTest {
private static final Logger log = LoggerFactory.getLogger( OrderedMergeTest.class );
@Test
public void singleOperator() throws InterruptedException {
List<Integer> expected = Arrays.asList( 0, 1, 2, 3, 4, 5 );
Observable<Integer> ints = Observable.from( expected );
Observable<Integer> ordered = OrderedMerge.orderedMerge( new IntegerComparator(), 10, ints );
final CountDownLatch latch = new CountDownLatch( 1 );
final List<Integer> results = new ArrayList();
ordered.subscribe( new Subscriber<Integer>() {
@Override
public void onCompleted() {
latch.countDown();
}
@Override
public void onError( final Throwable e ) {
e.printStackTrace();
fail( "An error was thrown " );
}
@Override
public void onNext( final Integer integer ) {
log.info( "onNext invoked with {}", integer );
results.add( integer );
}
} );
latch.await();
assertEquals( expected.size(), results.size() );
for ( int i = 0; i < expected.size(); i++ ) {
assertEquals( "Same element expected", expected.get( i ), results.get( i ) );
}
}
@Test
public void multipleOperatorSameThread() throws InterruptedException {
List<Integer> expected1List = Arrays.asList( 5, 3, 2, 0 );
Observable<Integer> expected1 = Observable.from( expected1List );
List<Integer> expected2List = Arrays.asList( 10, 7, 6, 4 );
Observable<Integer> expected2 = Observable.from( expected2List );
List<Integer> expected3List = Arrays.asList( 9, 8, 1 );
Observable<Integer> expected3 = Observable.from( expected3List );
Observable<Integer> ordered =
OrderedMerge.orderedMerge( new IntegerComparator(), 10, expected1, expected2, expected3 );
final CountDownLatch latch = new CountDownLatch( 1 );
final List<Integer> results = new ArrayList();
ordered.subscribe( new Subscriber<Integer>() {
@Override
public void onCompleted() {
latch.countDown();
}
@Override
public void onError( final Throwable e ) {
e.printStackTrace();
fail( "An error was thrown " );
}
@Override
public void onNext( final Integer integer ) {
log.info( "onNext invoked with {}", integer );
results.add( integer );
}
} );
latch.await();
List<Integer> expected = Arrays.asList( 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 );
assertEquals( expected.size(), results.size() );
for ( int i = 0; i < expected.size(); i++ ) {
assertEquals( "Same element expected", expected.get( i ), results.get( i ) );
}
}
@Test
public void multipleOperatorSingleThreadSizeException() throws InterruptedException {
List<Integer> expected1List = Arrays.asList( 5, 3, 2, 0 );
Observable<Integer> expected1 = Observable.from( expected1List );
List<Integer> expected2List = Arrays.asList( 10, 7, 6, 4 );
Observable<Integer> expected2 = Observable.from( expected2List );
List<Integer> expected3List = Arrays.asList( 9, 8, 1 );
Observable<Integer> expected3 = Observable.from( expected3List );
//set our buffer size to 2. We should easily exceed this since every observable has more than 2 elements
Observable<Integer> ordered =
OrderedMerge.orderedMerge( new IntegerComparator(), 2, expected1, expected2, expected3 );
final CountDownLatch latch = new CountDownLatch( 1 );
final List<Integer> results = new ArrayList();
final boolean[] errorThrown = new boolean[1];
ordered.subscribe( new Subscriber<Integer>() {
@Override
public void onCompleted() {
latch.countDown();
}
@Override
public void onError( final Throwable e ) {
log.error( "Expected error thrown", e );
if(e.getMessage().contains( "The maximum queue size of 2 has been reached" )){
errorThrown[0] = true;
}
latch.countDown();
}
@Override
public void onNext( final Integer integer ) {
log.info( "onNext invoked with {}", integer );
results.add( integer );
}
} );
latch.await();
/**
* Since we're on the same thread, we should blow up before we begin producing elements our size
*/
assertEquals( 0, results.size() );
assertTrue("An exception was thrown", errorThrown[0]);
}
@Test
public void multipleOperatorThreads() throws InterruptedException {
List<Integer> expected1List = Arrays.asList( 5, 3, 2, 0 );
Observable<Integer> expected1 = Observable.from( expected1List ).subscribeOn( Schedulers.io() );
List<Integer> expected2List = Arrays.asList( 10, 7, 6, 4 );
Observable<Integer> expected2 = Observable.from( expected2List ).subscribeOn( Schedulers.io() );
List<Integer> expected3List = Arrays.asList( 9, 8, 1 );
Observable<Integer> expected3 = Observable.from( expected3List ).subscribeOn( Schedulers.io() );
Observable<Integer> ordered =
OrderedMerge.orderedMerge( new IntegerComparator(), 10, expected1, expected2, expected3 );
final CountDownLatch latch = new CountDownLatch( 1 );
final List<Integer> results = new ArrayList();
ordered.subscribe( new Subscriber<Integer>() {
@Override
public void onCompleted() {
latch.countDown();
}
@Override
public void onError( final Throwable e ) {
e.printStackTrace();
fail( "An error was thrown " );
}
@Override
public void onNext( final Integer integer ) {
log.info( "onNext invoked with {}", integer );
results.add( integer );
}
} );
latch.await();
List<Integer> expected = Arrays.asList( 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 );
assertEquals( expected.size(), results.size() );
for ( int i = 0; i < expected.size(); i++ ) {
assertEquals( "Same element expected", expected.get( i ), results.get( i ) );
}
}
@Test
public void multipleOperatorMultipleThreadSizeException() throws InterruptedException {
List<Integer> expected1List = Arrays.asList( 10, 4, 3, 2, 1 );
Observable<Integer> expected1 = Observable.from( expected1List ).subscribeOn( Schedulers.io() );
List<Integer> expected2List = Arrays.asList( 9, 8, 7 );
Observable<Integer> expected2 = Observable.from( expected2List ).subscribeOn( Schedulers.io() );
List<Integer> expected3List = Arrays.asList( 6, 5, 0 );
Observable<Integer> expected3 = Observable.from( expected3List ).subscribeOn( Schedulers.io() );
/**
* Fails because our first observable will have to buffer the last 4 elements while waiting for the others to proceed
*/
Observable<Integer> ordered =
OrderedMerge.orderedMerge( new IntegerComparator(), 2, expected1, expected2, expected3 );
final CountDownLatch latch = new CountDownLatch( 1 );
final boolean[] errorThrown = new boolean[1];
ordered.subscribe( new Subscriber<Integer>() {
@Override
public void onCompleted() {
latch.countDown();
}
@Override
public void onError( final Throwable e ) {
log.error( "Expected error thrown", e );
if(e.getMessage().contains( "The maximum queue size of 2 has been reached" )){
errorThrown[0] = true;
}
latch.countDown();
}
@Override
public void onNext( final Integer integer ) {
log.info( "onNext invoked with {}", integer );
}
} );
latch.await();
assertTrue("An exception was thrown", errorThrown[0]);
}
private static class IntegerComparator implements Comparator<Integer> {
@Override
public int compare( final Integer o1, final Integer o2 ) {
return Integer.compare( o1, o2 );
}
}
private static class ReverseIntegerComparator implements Comparator<Integer> {
@Override
public int compare( final Integer o1, final Integer o2 ) {
return Integer.compare( o1, o2 ) * -1;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment