Created
May 3, 2023 09:39
-
-
Save ndrscodes/0e8105a2600a89615e470f5084823ced to your computer and use it in GitHub Desktop.
A simple java program showing a problem using stream reduction in order to count prime numbers in a list.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package primes; | |
import java.time.Duration; | |
import java.time.Instant; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.concurrent.ForkJoinPool; | |
import java.util.concurrent.RecursiveTask; | |
import java.util.concurrent.ThreadLocalRandom; | |
import java.util.function.Function; | |
public class Primes { | |
public static void main(String[] args) { | |
var l = generateList(10000000, 10, 100); | |
//sum using parallel mapping and consequent reduction | |
var m = measureTime(lst -> lst | |
.stream() | |
.parallel() | |
.map(x -> isPrime(x) ? 1 : 0) | |
.reduce((acc, i) -> acc + i) | |
.orElseThrow(), l); | |
System.out.println("parallel mapping + reduction: " + m); | |
//sum using parallel reduction-only | |
var m2 = measureTime(lst -> lst | |
.stream() | |
.parallel() | |
.reduce((acc, i) -> isPrime(i) ? acc + 1 : acc) | |
.orElseThrow(), l); | |
System.out.println("parallel reduction: " + m2); | |
//sum using a ForkJoinPool | |
var m3 = measureTime(Primes::countPrimes, l); | |
System.out.println("ForkJoinPool: " + m3); | |
//sum using non-parallel mapping and consequent reduction | |
var m4 = measureTime(lst -> lst | |
.stream() | |
.map(x -> isPrime(x) ? 1 : 0) | |
.reduce((acc, i) -> acc + i) | |
.orElseThrow(),l); | |
System.out.println("non-parallel mapping + reduction: " + m4); | |
//sum using non-parallel reduction-only | |
var m5 = measureTime(lst -> lst | |
.stream() | |
.reduce((acc, i) -> isPrime(i) ? acc + 1 : acc) | |
.orElseThrow(), l); | |
System.out.println("non-parallel reduction: " + m5); | |
} | |
private static record Measure<Tin, Tout>(Tout data, Duration time) { | |
}; | |
private static class PrimeTask extends RecursiveTask<Integer> { | |
private static final long serialVersionUID = 1L; | |
private int lo, hi; | |
private List<Integer> data; | |
public PrimeTask(List<Integer> data, int lo, int hi) { | |
this.data = data; | |
this.lo = lo; | |
this.hi = hi; | |
} | |
@Override | |
protected Integer compute() { | |
if (hi - lo < data.size() / 4) { | |
var cnt = 0; | |
for (int i = lo; i < hi; i++) { | |
if (isPrime(data.get(i))) | |
cnt++; | |
} | |
return cnt; | |
} | |
var mid = lo + ((hi - lo) / 2); | |
var l = new PrimeTask(data, lo, mid); | |
var r = new PrimeTask(data, mid, hi); | |
l.fork(); | |
r.fork(); | |
return l.join() + r.join(); | |
} | |
} | |
private static <Tin> int countPrimes(List<Tin> data) { | |
ForkJoinPool pool = new ForkJoinPool(); | |
return pool.invoke(new PrimeTask((List<Integer>) data, 0, data.size())); | |
} | |
public static boolean isPrime(int candidate) { | |
for (int i = 2; i < candidate / 2; i++) { | |
if (candidate % i == 0) { | |
return false; | |
} | |
} | |
return true; | |
} | |
public static <Tin, Tout> Measure<Tin, Tout> measureTime(Function<List<Tin>, Tout> function, List<Tin> data) { | |
var start = Instant.now(); | |
var res = function.apply(data); | |
var end = Instant.now(); | |
return new Measure<Tin, Tout>(res, Duration.between(start, end)); | |
} | |
public static int generateNumber(int min, int max) { | |
return ThreadLocalRandom.current().nextInt(min, max); | |
} | |
public static List<Integer> generateList(int count, int min, int max) { | |
var lst = new ArrayList<Integer>(count); | |
for (int i = 0; i < count; i++) { | |
lst.add(generateNumber(min, max)); | |
} | |
return lst; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment