Created
October 23, 2020 17:27
-
-
Save tomwhoiscontrary/ee554d3321e6e4c8439322c2105735f9 to your computer and use it in GitHub Desktop.
I wasted most of a day trying to copy Clojure's transducers in Java. It works, as far as it goes, and it's more extensible than native streams. No short-circuiting operations though!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.junit.jupiter.api.Test; | |
import java.util.ArrayList; | |
import java.util.Collections; | |
import java.util.Comparator; | |
import java.util.HashSet; | |
import java.util.Iterator; | |
import java.util.List; | |
import java.util.Optional; | |
import java.util.Set; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import java.util.function.BiConsumer; | |
import java.util.function.BinaryOperator; | |
import java.util.function.Consumer; | |
import java.util.function.Function; | |
import java.util.function.IntFunction; | |
import java.util.function.Predicate; | |
import java.util.function.Supplier; | |
import java.util.stream.Collector; | |
import java.util.stream.Collectors; | |
import java.util.stream.Stream; | |
import static org.hamcrest.MatcherAssert.assertThat; | |
import static org.hamcrest.Matchers.arrayContaining; | |
import static org.hamcrest.Matchers.equalTo; | |
public class Transducers { | |
interface Reducer<T, R, A extends Reducer<T, R, A>> extends Collector<T, A, R>, Consumer<T> { | |
@Override | |
default Supplier<A> supplier() { | |
@SuppressWarnings("unchecked") | |
A self = (A) this; | |
return () -> self; | |
} | |
@Override | |
default BiConsumer<A, T> accumulator() { | |
return Reducer::accept; | |
} | |
@Override | |
default BinaryOperator<A> combiner() { | |
return (a, b) -> { throw new UnsupportedOperationException(); }; | |
} | |
@Override | |
default Function<A, R> finisher() { | |
return Reducer::finish; | |
} | |
@Override | |
default Set<Characteristics> characteristics() { | |
return Collections.emptySet(); | |
} | |
@Override | |
void accept(T value); | |
R finish(); | |
} | |
private static <T> Reducer<T, List<T>, ?> toList() { | |
class ToList implements Reducer<T, List<T>, ToList> { | |
final List<T> list = new ArrayList<>(); | |
@Override | |
public void accept(T value) { list.add(value); } | |
@Override | |
public List<T> finish() { return list; } | |
} | |
return new ToList(); | |
} | |
private static Reducer<Integer, Integer, ?> sum() { | |
class Sum implements Reducer<Integer, Integer, Sum> { | |
int sum = 0; | |
@Override | |
public void accept(Integer value) { sum += value; } | |
@Override | |
public Integer finish() { return sum; } | |
} | |
return new Sum(); | |
} | |
private static <T> Reducer<T, Void, ?> forEach(Consumer<T> consumer) { | |
class ForEach implements Reducer<T, Void, ForEach> { | |
@Override | |
public void accept(T value) { consumer.accept(value); } | |
@Override | |
public Void finish() { | |
return null; | |
} | |
} | |
return new ForEach(); | |
} | |
private static <T> Reducer<T, Integer, ?> count() { | |
class Count implements Reducer<T, Integer, Count> { | |
int count = 0; | |
@Override | |
public void accept(T value) { count += 1; } | |
@Override | |
public Integer finish() { return count; } | |
} | |
return new Count(); | |
} | |
private static <T> Reducer<T, Optional<T>, ?> min(Comparator<T> comparator) { | |
class Min implements Reducer<T, Optional<T>, Min> { | |
T min; | |
@Override | |
public void accept(T value) { if (min == null || comparator.compare(value, min) < 0) min = value; } | |
@Override | |
public Optional<T> finish() { return Optional.ofNullable(min); } | |
} | |
return new Min(); | |
} | |
private static <T extends Comparable<T>> Reducer<T, Optional<T>, ?> min() { | |
return min(Comparator.<T>naturalOrder()); | |
} | |
private static <T> Reducer<T, Optional<T>, ?> max(Comparator<T> comparator) { | |
return min(comparator.reversed()); | |
} | |
private static <T extends Comparable<T>> Reducer<T, Optional<T>, ?> max() { | |
return max(Comparator.<T>naturalOrder()); | |
} | |
static class BlockBuffer<T> { | |
private final IntFunction<T[]> factory; | |
private int size = 0; | |
private final List<T[]> blocks = new ArrayList<>(); | |
public BlockBuffer(IntFunction<T[]> factory) { | |
this.factory = factory; | |
} | |
public void add(T element) { | |
size++; | |
int blockIndex = Integer.SIZE - Integer.numberOfLeadingZeros(size) - 1; | |
int blockOffset = size & ~(-1 << blockIndex); | |
if (blockOffset == 0) blocks.add(factory.apply(1 << blockIndex)); | |
blocks.get(blocks.size() - 1)[blockOffset] = element; | |
} | |
public T[] toArray() { | |
T[] array = factory.apply(size); | |
int offset = 0; | |
Iterator<T[]> iterator = blocks.iterator(); | |
while (offset < size) { | |
T[] block = iterator.next(); | |
int blockSize = Math.min(size - offset, block.length); | |
System.arraycopy(block, 0, array, offset, blockSize); | |
offset += blockSize; | |
} | |
return array; | |
} | |
} | |
private static <T> Reducer<T, T[], ?> toArray(IntFunction<T[]> factory) { | |
class ToArray implements Reducer<T, T[], ToArray> { | |
final BlockBuffer<T> buffer = new BlockBuffer<>(factory); | |
@Override | |
public void accept(T value) { buffer.add(value); } | |
@Override | |
public T[] finish() { | |
return buffer.toArray(); | |
} | |
} | |
return new ToArray(); | |
} | |
private static <T, A, R> Reducer<T, R, ?> collect(Collector<T, A, R> collector) { | |
Supplier<A> supplier = collector.supplier(); | |
BiConsumer<A, T> accumulator = collector.accumulator(); | |
Function<A, R> finisher = collector.finisher(); | |
class Collect implements Reducer<T, R, Collect> { | |
final A result = supplier.get(); | |
@Override | |
public void accept(T value) { accumulator.accept(result, value); } | |
@Override | |
public R finish() { | |
return finisher.apply(result); | |
} | |
} | |
return new Collect(); | |
} | |
private static <T> Reducer<T, T, ?> reduce(T initialValue, BinaryOperator<T> operator) { | |
class Reduce implements Reducer<T, T, Reduce> { | |
T reduction = initialValue; | |
@Override | |
public void accept(T value) { reduction = reduction == null ? value : operator.apply(reduction, value); } | |
@Override | |
public T finish() { | |
return reduction; | |
} | |
} | |
return new Reduce(); | |
} | |
private static <T> Reducer<T, T, ?> reduce(BinaryOperator<T> operator) { | |
return reduce(null, operator); | |
} | |
@FunctionalInterface | |
interface Transducer<T, U> { | |
<R> Reducer<T, R, ?> then(Reducer<U, R, ?> reducer); | |
default <V> Transducer<T, V> then(Transducer<U, V> that) { | |
return compose(this, that); | |
} | |
} | |
private static <T, U, V> Transducer<T, V> compose(Transducer<T, U> first, Transducer<U, V> second) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<V, R, ?> reducer) { | |
Reducer<U, R, ?> inner = second.then(reducer); | |
Reducer<T, R, ?> outer = first.then(inner); | |
return outer; | |
} | |
}; | |
} | |
private static <T> Transducer<T, T> filter(Predicate<T> predicate) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) { | |
class Filter implements Reducer<T, R, Filter> { | |
@Override | |
public void accept(T value) { if (predicate.test(value)) reducer.accept(value); } | |
@Override | |
public R finish() { return reducer.finish(); } | |
} | |
return new Filter(); | |
} | |
}; | |
} | |
private static <T, U> Transducer<T, U> map(Function<T, U> function) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<U, R, ?> reducer) { | |
class Map implements Reducer<T, R, Map> { | |
@Override | |
public void accept(T value) { reducer.accept(function.apply(value)); } | |
@Override | |
public R finish() { return reducer.finish(); } | |
} | |
return new Map(); | |
} | |
}; | |
} | |
private static <T, U> Transducer<T, U> flatMap(Function<T, Stream<U>> function) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<U, R, ?> reducer) { | |
class FlatMap implements Reducer<T, R, FlatMap> { | |
@Override | |
public void accept(T value) { | |
function.apply(value).forEach(reducer); | |
} | |
@Override | |
public R finish() { return reducer.finish(); } | |
} | |
return new FlatMap(); | |
} | |
}; | |
} | |
private static <T> Transducer<T, T> skip(int n) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) { | |
class Skip implements Reducer<T, R, Skip> { | |
int nRemaining = n; | |
@Override | |
public void accept(T value) { | |
if (nRemaining > 0) nRemaining--; | |
else reducer.accept(value); | |
} | |
@Override | |
public R finish() { return reducer.finish(); } | |
} | |
return new Skip(); | |
} | |
}; | |
} | |
private static <T> Transducer<T, T> limit(int n) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) { | |
class Skip implements Reducer<T, R, Skip> { | |
int nRemaining = n; | |
@Override | |
public void accept(T value) { | |
if (nRemaining > 0) { | |
nRemaining--; | |
reducer.accept(value); | |
} | |
} | |
@Override | |
public R finish() { return reducer.finish(); } | |
} | |
return new Skip(); | |
} | |
}; | |
} | |
private static <T> Transducer<T, T> dropWhile(Predicate<T> predicate) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) { | |
class DropWhile implements Reducer<T, R, DropWhile> { | |
Predicate<T> prefixPredicate = predicate; | |
@Override | |
public void accept(T value) { | |
if (prefixPredicate == null) { | |
reducer.accept(value); | |
} else if (!prefixPredicate.test(value)) { | |
prefixPredicate = null; | |
reducer.accept(value); | |
} | |
} | |
@Override | |
public R finish() { return reducer.finish(); } | |
} | |
return new DropWhile(); | |
} | |
}; | |
} | |
private static <T> Transducer<T, T> distinct() { | |
return filter(new HashSet<>()::add); | |
} | |
private static <T> Transducer<T, T> peek(Consumer<T> consumer) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) { | |
class Filter implements Reducer<T, R, Filter> { | |
@Override | |
public void accept(T value) { | |
consumer.accept(value); | |
reducer.accept(value); | |
} | |
@Override | |
public R finish() { return reducer.finish(); } | |
} | |
return new Filter(); | |
} | |
}; | |
} | |
private static <T> Transducer<T, T> sorted(Comparator<? super T> comparator) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) { | |
class Sorted implements Reducer<T, R, Sorted> { | |
final List<T> elements = new ArrayList<>(); | |
@Override | |
public void accept(T value) { elements.add(value); } | |
@Override | |
public R finish() { | |
elements.sort(comparator); | |
elements.forEach(reducer); | |
return reducer.finish(); | |
} | |
} | |
return new Sorted(); | |
} | |
}; | |
} | |
private static <T extends Comparable<T>> Transducer<T, T> sorted() { | |
return sorted(Comparator.naturalOrder()); | |
} | |
private static <T> Transducer<T, List<T>> batch(int size) { | |
return new Transducer<>() { | |
@Override | |
public <R> Reducer<T, R, ?> then(Reducer<List<T>, R, ?> reducer) { | |
class Batch implements Reducer<T, R, Batch> { | |
List<T> batch; | |
@Override | |
public void accept(T value) { | |
if (batch == null) batch = new ArrayList<>(); | |
batch.add(value); | |
if (batch.size() == size) flush(); | |
} | |
@Override | |
public R finish() { | |
if (batch != null) flush(); | |
return reducer.finish(); | |
} | |
public void flush() { | |
reducer.accept(batch); | |
batch = null; | |
} | |
} | |
return new Batch(); | |
} | |
}; | |
} | |
@Test | |
void reduces() { | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(toList()), | |
equalTo(List.of(1, 2, 3, 4, 5))); | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(sum()), | |
equalTo(15)); | |
List<Integer> numbers = new ArrayList<>(); | |
Stream.of(1, 2, 3, 4, 5).collect(forEach(numbers::add)); | |
assertThat(numbers, equalTo(List.of(1, 2, 3, 4, 5))); | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(count()), | |
equalTo(5)); | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(min()), | |
equalTo(Optional.of(1))); | |
assertThat(Stream.of(5, 4, 3, 2, 1).collect(min()), | |
equalTo(Optional.of(1))); | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(max()), | |
equalTo(Optional.of(5))); | |
assertThat(Stream.of(5, 4, 3, 2, 1).collect(max()), | |
equalTo(Optional.of(5))); | |
assertThat(Stream.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).collect(toArray(Integer[]::new)), | |
arrayContaining(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(collect(Collectors.averagingInt((Integer n) -> n))), | |
equalTo(3.0)); | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(reduce(Integer::sum)), | |
equalTo(15)); | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(reduce(100, Integer::sum)), | |
equalTo(115)); | |
} | |
@Test | |
void reducesWithoutAReturnValue() { | |
List<Integer> numbers = new ArrayList<>(); | |
Stream.of(1, 2, 3, 4, 5).forEach(forEach(numbers::add)); // this example looks silly, sure | |
assertThat(numbers, equalTo(List.of(1, 2, 3, 4, 5))); | |
} | |
@Test | |
void transforms() { | |
AtomicInteger sum = new AtomicInteger(); | |
assertThat(Stream.of(1, 2, 3, 4, 5) | |
.collect(peek(sum::addAndGet).then(toList())), | |
equalTo(List.of(1, 2, 3, 4, 5))); | |
assertThat(sum.get(), equalTo(15)); | |
assertThat(Stream.of(1, 2, 3, 4, 5) | |
.collect(filter(this::isEven).then(toList())), | |
equalTo(List.of(2, 4))); | |
assertThat(Stream.of(1, 2, 3, 4, 5, 1, 2, 3, 4, 5) | |
.collect(distinct().then(toList())), | |
equalTo(List.of(1, 2, 3, 4, 5))); | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(map(this::toLetter).then(toList())), | |
equalTo(List.of("A", "B", "C", "D", "E"))); | |
assertThat(Stream.of(1, 2, 3, 4, 5).collect(flatMap(this::repeat).then(toList())), | |
equalTo(List.of(1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5))); | |
assertThat(Stream.of(1, 2, 3, 4, 5) | |
.collect(skip(3).then(toList())), | |
equalTo(List.of(4, 5))); | |
assertThat(Stream.of(1, 2, 3, 4, 5) | |
.collect(limit(3).then(toList())), | |
equalTo(List.of(1, 2, 3))); | |
assertThat(Stream.of(1, 2, 3, 4, 5) | |
.collect(dropWhile((Integer i) -> i <= 3).then(toList())), | |
equalTo(List.of(4, 5))); | |
assertThat(Stream.of(1, 5, 2, 4, 3) | |
.collect(Transducers.<Integer>sorted().then(toList())), | |
equalTo(List.of(1, 2, 3, 4, 5))); | |
assertThat(Stream.of(1, 2, 3, 4, 5) | |
.collect(batch(3).then(toList())), | |
equalTo(List.of(List.of(1, 2, 3), List.of(4, 5)))); | |
assertThat(Stream.of(1, 2, 3, 4, 5) | |
.collect(filter(this::isEven) | |
.then(map(this::toLetter)) | |
.then(toList())), | |
equalTo(List.of("B", "D"))); | |
} | |
private boolean isEven(int n) { | |
return n % 2 == 0; | |
} | |
private String toLetter(int n) { | |
return Character.toString('A' + n - 1); | |
} | |
private Stream<Integer> repeat(Integer o) { | |
return Collections.nCopies(o, o).stream(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment