Skip to content

Instantly share code, notes, and snippets.

@paulo-raca
Last active July 6, 2018 10:52
Show Gist options
  • Save paulo-raca/d121cf27905cfb1fafc3 to your computer and use it in GitHub Desktop.
Save paulo-raca/d121cf27905cfb1fafc3 to your computer and use it in GitHub Desktop.
Spark is awesome for synchronous work (Usually CPU-bond). But once in a while, you need to throw an asynchronous operation in the middle (Usually IO-bound, e.g., calling a web server)
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import com.google.common.collect.AbstractIterator;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
public class AsyncMapSample {
/**
* Spark is awesome for synchronous work (Usually CPU-bond). But once in a while, you need to throw an asynchronous operation in the middle (Usually IO-bound, e.g., calling a web server)
*
* asyncMap allows Spark to execute several asynchronous operations concurrently, increasing throughput
*
* @param rdd Input Rdd
* @param transform Async transformation to be applied to the data
* @param maxConcurrency Maximum number of concurrent operations in execution
* @param lookAhead How many elements from the input should be pre-fetched and ready for processing
* @param preserveOrder If true, the elements are returned in the same order as the input. Otherwise, they are returned in the order they finish processing.
*/
public static <A,B> JavaRDD<B> asyncMap(JavaRDD<A> rdd, final Function<A, ListenableFuture<B>> transform, final int maxConcurrency, final int lookAhead, boolean preserveOrder) {
return rdd.mapPartitions((inIterator) -> {
return () -> new AbstractIterator<B>() {
Object lock = new Object();
int concurrency = 0;
Queue<A> inputQueue = new LinkedList<>();
Queue<ListenableFuture<B>> outputQueue = new LinkedList<>();
boolean withinExecuteAsync = false; //prevents reentrant calls to schedule()
private ListenableFuture<B> transformChecked(A v) {
try {
return transform.call(v);
} catch (Throwable t) {
return Futures.immediateFailedFuture(t);
}
}
private void scheduleAsync() {
withinExecuteAsync = true;
try {
while (!inputQueue.isEmpty() && concurrency < maxConcurrency) {
ListenableFuture<B> out = transformChecked(inputQueue.poll());
concurrency++;
if (preserveOrder) {
outputQueue.add(out);
lock.notify();
}
out.addListener(() -> {
synchronized (lock) {
concurrency--;
if (!preserveOrder) {
outputQueue.add(out);
lock.notify();
}
if (!withinExecuteAsync) {
scheduleAsync();
}
}
}, MoreExecutors.sameThreadExecutor());
}
} finally {
withinExecuteAsync = false;
}
}
private void fetchQueue() {
while (inputQueue.size() + outputQueue.size() + (preserveOrder ? 0 : concurrency) < lookAhead && inIterator.hasNext()) {
A in = inIterator.next();
System.out.println("Fetched " + in);
inputQueue.add(in);
//Fetching the next element might be sloooow, so it's better if we get started with each element ASAP
scheduleAsync();
}
}
@Override
protected B computeNext() {
ListenableFuture<B> out = null;
synchronized (lock) {
fetchQueue();
if (inputQueue.size() == 0 && outputQueue.size() == 0 && concurrency == 0) {
return endOfData();
}
while (outputQueue.isEmpty()) {
try {
lock.wait();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
out = outputQueue.poll();
}
return Futures.getUnchecked(out);
}
};
});
}
// ============================ Quick and Dirty test code ============================
private static ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(5);
public static <T> ListenableFuture<T> lazy(T data) {
System.out.println("Started processing " + data);
SettableFuture<T> ret = SettableFuture.create();
scheduler.schedule(() -> {
System.out.println("Completed processing " + data);
ret.set(data);
}, (int)(Math.random()*1000), TimeUnit.MILLISECONDS);
return ret;
}
static int N = 0;
public static void main(String[] csvFiles) {
SparkConf sparkConfig = new SparkConf()
.setAppName("Teste do Spark")
.setMaster("local[*]");
try (JavaSparkContext sparkContext = new JavaSparkContext(sparkConfig)) {
List<Integer> numberList = new ArrayList<>();
for (int i=0; i<1000; i++) {
numberList.add(i);
}
JavaRDD<Integer> numbers = sparkContext
.parallelize(numberList, 1)
.setName("Numbers");
numbers = asyncMap(numbers, AsyncMapSample::lazy, 50, 100, false);
numbers.foreach((v) -> System.out.println(N++ + ": " + v));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment