Skip to content

Instantly share code, notes, and snippets.

@ndrscodes
Created May 3, 2023 09:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ndrscodes/0e8105a2600a89615e470f5084823ced to your computer and use it in GitHub Desktop.
Save ndrscodes/0e8105a2600a89615e470f5084823ced to your computer and use it in GitHub Desktop.
A simple java program showing a problem using stream reduction in order to count prime numbers in a list.
package primes;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;
public class Primes {
public static void main(String[] args) {
var l = generateList(10000000, 10, 100);
//sum using parallel mapping and consequent reduction
var m = measureTime(lst -> lst
.stream()
.parallel()
.map(x -> isPrime(x) ? 1 : 0)
.reduce((acc, i) -> acc + i)
.orElseThrow(), l);
System.out.println("parallel mapping + reduction: " + m);
//sum using parallel reduction-only
var m2 = measureTime(lst -> lst
.stream()
.parallel()
.reduce((acc, i) -> isPrime(i) ? acc + 1 : acc)
.orElseThrow(), l);
System.out.println("parallel reduction: " + m2);
//sum using a ForkJoinPool
var m3 = measureTime(Primes::countPrimes, l);
System.out.println("ForkJoinPool: " + m3);
//sum using non-parallel mapping and consequent reduction
var m4 = measureTime(lst -> lst
.stream()
.map(x -> isPrime(x) ? 1 : 0)
.reduce((acc, i) -> acc + i)
.orElseThrow(),l);
System.out.println("non-parallel mapping + reduction: " + m4);
//sum using non-parallel reduction-only
var m5 = measureTime(lst -> lst
.stream()
.reduce((acc, i) -> isPrime(i) ? acc + 1 : acc)
.orElseThrow(), l);
System.out.println("non-parallel reduction: " + m5);
}
private static record Measure<Tin, Tout>(Tout data, Duration time) {
};
private static class PrimeTask extends RecursiveTask<Integer> {
private static final long serialVersionUID = 1L;
private int lo, hi;
private List<Integer> data;
public PrimeTask(List<Integer> data, int lo, int hi) {
this.data = data;
this.lo = lo;
this.hi = hi;
}
@Override
protected Integer compute() {
if (hi - lo < data.size() / 4) {
var cnt = 0;
for (int i = lo; i < hi; i++) {
if (isPrime(data.get(i)))
cnt++;
}
return cnt;
}
var mid = lo + ((hi - lo) / 2);
var l = new PrimeTask(data, lo, mid);
var r = new PrimeTask(data, mid, hi);
l.fork();
r.fork();
return l.join() + r.join();
}
}
private static <Tin> int countPrimes(List<Tin> data) {
ForkJoinPool pool = new ForkJoinPool();
return pool.invoke(new PrimeTask((List<Integer>) data, 0, data.size()));
}
public static boolean isPrime(int candidate) {
for (int i = 2; i < candidate / 2; i++) {
if (candidate % i == 0) {
return false;
}
}
return true;
}
public static <Tin, Tout> Measure<Tin, Tout> measureTime(Function<List<Tin>, Tout> function, List<Tin> data) {
var start = Instant.now();
var res = function.apply(data);
var end = Instant.now();
return new Measure<Tin, Tout>(res, Duration.between(start, end));
}
public static int generateNumber(int min, int max) {
return ThreadLocalRandom.current().nextInt(min, max);
}
public static List<Integer> generateList(int count, int min, int max) {
var lst = new ArrayList<Integer>(count);
for (int i = 0; i < count; i++) {
lst.add(generateNumber(min, max));
}
return lst;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment