Skip to content

Instantly share code, notes, and snippets.

@forax
Created September 14, 2016 21:43
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save forax/b5257dfac85e74335e02b5a6b95c9182 to your computer and use it in GitHub Desktop.
package fr.umlv.stream;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodHandles.Lookup;
import java.lang.invoke.MethodType;
import java.lang.reflect.UndeclaredThrowableException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.UnaryOperator;
public class Stream<T> {
private final Iterable<?> iterable;
private final UnaryOperator<MethodHandle> op;
Stream(Iterable<?> iterable, UnaryOperator<MethodHandle> op) {
this.iterable = iterable;
this.op = op;
}
private static final MethodHandle FUN, BI_FUN, PRED, IDENTITY;
static {
Lookup lookup = MethodHandles.publicLookup();
try {
FUN = lookup.findVirtual(Function.class, "apply", MethodType.methodType(Object.class, Object.class));
BI_FUN = lookup.findVirtual(BiFunction.class, "apply", MethodType.methodType(Object.class, Object.class, Object.class));
PRED = lookup.findVirtual(Predicate.class, "test", MethodType.methodType(boolean.class, Object.class));
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new AssertionError(e);
}
IDENTITY = MethodHandles.dropArguments(MethodHandles.identity(Object.class), 1, Object.class);
}
public Stream<T> filter(Predicate<? super T> predicate) {
return new Stream<>(iterable, mh -> op.apply(MethodHandles.guardWithTest(
MethodHandles.dropArguments(PRED.bindTo(predicate), 0, Object.class),
mh,
IDENTITY)));
}
public <U> Stream<U> map(Function<? super T, ? extends U> mapper) {
return new Stream<>(iterable, mh -> op.apply(MethodHandles.filterArguments(mh, 1, FUN.bindTo(mapper))));
}
public void forEach(Consumer<? super T> consumer) {
reduce(null, (__, e) -> { consumer.accept(e); return null; });
}
public <U> U reduce(U initial, BiFunction<? super U, ? super T, ? extends U> function) {
MethodHandle body = BI_FUN.bindTo(function);
body = op.apply(body);
body = MethodHandles.dropArguments(body, 2, Iterable.class);
body = MethodHandles.permuteArguments(body, body.type(), 1, 0, 2);
MethodHandle init = MethodHandles.constant(Object.class, initial);
MethodHandle loop = MethodHandles.iteratedLoop(null, init, body);
try {
return (U)loop.invokeExact(iterable);
} catch (Throwable e) {
throw rethrowIfUnchecked(e);
}
}
public <U> U reduce2(U initial, BiFunction<? super U, ? super T, ? extends U> function) {
Iterator<?> iterator = iterable.iterator();
if (!iterator.hasNext()) {
return initial;
}
U value = initial;
MethodHandle body = BI_FUN.bindTo(function);
body = op.apply(body);
body = MethodHandles.dropArguments(body, 2, Iterable.class);
body = MethodHandles.permuteArguments(body, body.type(), 1, 0, 2);
// peel the loop once
try {
value = (U)body.invokeExact(value, iterator.next(), iterable);
} catch (Throwable e) {
throw rethrowIfUnchecked(e);
}
if (!iterator.hasNext()) {
return value;
}
MethodHandle it = MethodHandles.dropArguments(
MethodHandles.constant(Iterator.class, iterator),
0, Iterable.class);
MethodHandle init = MethodHandles.constant(Object.class, value);
MethodHandle loop = MethodHandles.iteratedLoop(it, init, body);
try {
return (U)loop.invokeExact(iterable);
} catch (Throwable e) {
throw rethrowIfUnchecked(e);
}
}
private static UndeclaredThrowableException rethrowIfUnchecked(Throwable e) {
if (e instanceof RuntimeException) {
throw (RuntimeException)e;
}
if (e instanceof Error) {
throw (RuntimeException)e;
}
return new UndeclaredThrowableException(e);
}
public static <T> Stream<T> of(Iterable<T> iterable) {
return new Stream<>(iterable, UnaryOperator.identity());
}
public static void main(String[] args) {
ArrayList<Integer> list = new ArrayList<>();
for(int i = 0; i < 100_000_000; i++) {
list.add(i);
}
long start = System.nanoTime();
int sum = Stream.of(list)
.map(x -> x * 2)
.filter(x -> x % 2 == 0)
.reduce(0, (v, e) -> v + e);
long stop = System.nanoTime();
System.out.println(sum + " " + (stop - start));
long start2 = System.nanoTime();
int sum2 = Stream.of(list)
.map(x -> x * 2)
.filter(x -> x % 2 == 0)
.reduce2(0, (v, e) -> v + e);
long stop2 = System.nanoTime();
System.out.println(sum2 + " " + (stop2 - start2));
long start3 = System.nanoTime();
int sum3 = list.stream()
.map(x -> x * 2)
.filter(x -> x % 2 == 0)
.reduce(0, (v, e) -> v + e);
long stop3 = System.nanoTime();
System.out.println(sum3 + " " + (stop3 - start3));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment