Skip to content

Instantly share code, notes, and snippets.

@tomwhoiscontrary
Created October 23, 2020 17:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomwhoiscontrary/ee554d3321e6e4c8439322c2105735f9 to your computer and use it in GitHub Desktop.
Save tomwhoiscontrary/ee554d3321e6e4c8439322c2105735f9 to your computer and use it in GitHub Desktop.
I wasted most of a day trying to copy Clojure's transducers in Java. It works, as far as it goes, and it's more extensible than native streams. No short-circuiting operations though!
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.equalTo;
public class Transducers {
interface Reducer<T, R, A extends Reducer<T, R, A>> extends Collector<T, A, R>, Consumer<T> {
@Override
default Supplier<A> supplier() {
@SuppressWarnings("unchecked")
A self = (A) this;
return () -> self;
}
@Override
default BiConsumer<A, T> accumulator() {
return Reducer::accept;
}
@Override
default BinaryOperator<A> combiner() {
return (a, b) -> { throw new UnsupportedOperationException(); };
}
@Override
default Function<A, R> finisher() {
return Reducer::finish;
}
@Override
default Set<Characteristics> characteristics() {
return Collections.emptySet();
}
@Override
void accept(T value);
R finish();
}
private static <T> Reducer<T, List<T>, ?> toList() {
class ToList implements Reducer<T, List<T>, ToList> {
final List<T> list = new ArrayList<>();
@Override
public void accept(T value) { list.add(value); }
@Override
public List<T> finish() { return list; }
}
return new ToList();
}
private static Reducer<Integer, Integer, ?> sum() {
class Sum implements Reducer<Integer, Integer, Sum> {
int sum = 0;
@Override
public void accept(Integer value) { sum += value; }
@Override
public Integer finish() { return sum; }
}
return new Sum();
}
private static <T> Reducer<T, Void, ?> forEach(Consumer<T> consumer) {
class ForEach implements Reducer<T, Void, ForEach> {
@Override
public void accept(T value) { consumer.accept(value); }
@Override
public Void finish() {
return null;
}
}
return new ForEach();
}
private static <T> Reducer<T, Integer, ?> count() {
class Count implements Reducer<T, Integer, Count> {
int count = 0;
@Override
public void accept(T value) { count += 1; }
@Override
public Integer finish() { return count; }
}
return new Count();
}
private static <T> Reducer<T, Optional<T>, ?> min(Comparator<T> comparator) {
class Min implements Reducer<T, Optional<T>, Min> {
T min;
@Override
public void accept(T value) { if (min == null || comparator.compare(value, min) < 0) min = value; }
@Override
public Optional<T> finish() { return Optional.ofNullable(min); }
}
return new Min();
}
private static <T extends Comparable<T>> Reducer<T, Optional<T>, ?> min() {
return min(Comparator.<T>naturalOrder());
}
private static <T> Reducer<T, Optional<T>, ?> max(Comparator<T> comparator) {
return min(comparator.reversed());
}
private static <T extends Comparable<T>> Reducer<T, Optional<T>, ?> max() {
return max(Comparator.<T>naturalOrder());
}
static class BlockBuffer<T> {
private final IntFunction<T[]> factory;
private int size = 0;
private final List<T[]> blocks = new ArrayList<>();
public BlockBuffer(IntFunction<T[]> factory) {
this.factory = factory;
}
public void add(T element) {
size++;
int blockIndex = Integer.SIZE - Integer.numberOfLeadingZeros(size) - 1;
int blockOffset = size & ~(-1 << blockIndex);
if (blockOffset == 0) blocks.add(factory.apply(1 << blockIndex));
blocks.get(blocks.size() - 1)[blockOffset] = element;
}
public T[] toArray() {
T[] array = factory.apply(size);
int offset = 0;
Iterator<T[]> iterator = blocks.iterator();
while (offset < size) {
T[] block = iterator.next();
int blockSize = Math.min(size - offset, block.length);
System.arraycopy(block, 0, array, offset, blockSize);
offset += blockSize;
}
return array;
}
}
private static <T> Reducer<T, T[], ?> toArray(IntFunction<T[]> factory) {
class ToArray implements Reducer<T, T[], ToArray> {
final BlockBuffer<T> buffer = new BlockBuffer<>(factory);
@Override
public void accept(T value) { buffer.add(value); }
@Override
public T[] finish() {
return buffer.toArray();
}
}
return new ToArray();
}
private static <T, A, R> Reducer<T, R, ?> collect(Collector<T, A, R> collector) {
Supplier<A> supplier = collector.supplier();
BiConsumer<A, T> accumulator = collector.accumulator();
Function<A, R> finisher = collector.finisher();
class Collect implements Reducer<T, R, Collect> {
final A result = supplier.get();
@Override
public void accept(T value) { accumulator.accept(result, value); }
@Override
public R finish() {
return finisher.apply(result);
}
}
return new Collect();
}
private static <T> Reducer<T, T, ?> reduce(T initialValue, BinaryOperator<T> operator) {
class Reduce implements Reducer<T, T, Reduce> {
T reduction = initialValue;
@Override
public void accept(T value) { reduction = reduction == null ? value : operator.apply(reduction, value); }
@Override
public T finish() {
return reduction;
}
}
return new Reduce();
}
private static <T> Reducer<T, T, ?> reduce(BinaryOperator<T> operator) {
return reduce(null, operator);
}
@FunctionalInterface
interface Transducer<T, U> {
<R> Reducer<T, R, ?> then(Reducer<U, R, ?> reducer);
default <V> Transducer<T, V> then(Transducer<U, V> that) {
return compose(this, that);
}
}
private static <T, U, V> Transducer<T, V> compose(Transducer<T, U> first, Transducer<U, V> second) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<V, R, ?> reducer) {
Reducer<U, R, ?> inner = second.then(reducer);
Reducer<T, R, ?> outer = first.then(inner);
return outer;
}
};
}
private static <T> Transducer<T, T> filter(Predicate<T> predicate) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) {
class Filter implements Reducer<T, R, Filter> {
@Override
public void accept(T value) { if (predicate.test(value)) reducer.accept(value); }
@Override
public R finish() { return reducer.finish(); }
}
return new Filter();
}
};
}
private static <T, U> Transducer<T, U> map(Function<T, U> function) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<U, R, ?> reducer) {
class Map implements Reducer<T, R, Map> {
@Override
public void accept(T value) { reducer.accept(function.apply(value)); }
@Override
public R finish() { return reducer.finish(); }
}
return new Map();
}
};
}
private static <T, U> Transducer<T, U> flatMap(Function<T, Stream<U>> function) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<U, R, ?> reducer) {
class FlatMap implements Reducer<T, R, FlatMap> {
@Override
public void accept(T value) {
function.apply(value).forEach(reducer);
}
@Override
public R finish() { return reducer.finish(); }
}
return new FlatMap();
}
};
}
private static <T> Transducer<T, T> skip(int n) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) {
class Skip implements Reducer<T, R, Skip> {
int nRemaining = n;
@Override
public void accept(T value) {
if (nRemaining > 0) nRemaining--;
else reducer.accept(value);
}
@Override
public R finish() { return reducer.finish(); }
}
return new Skip();
}
};
}
private static <T> Transducer<T, T> limit(int n) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) {
class Skip implements Reducer<T, R, Skip> {
int nRemaining = n;
@Override
public void accept(T value) {
if (nRemaining > 0) {
nRemaining--;
reducer.accept(value);
}
}
@Override
public R finish() { return reducer.finish(); }
}
return new Skip();
}
};
}
private static <T> Transducer<T, T> dropWhile(Predicate<T> predicate) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) {
class DropWhile implements Reducer<T, R, DropWhile> {
Predicate<T> prefixPredicate = predicate;
@Override
public void accept(T value) {
if (prefixPredicate == null) {
reducer.accept(value);
} else if (!prefixPredicate.test(value)) {
prefixPredicate = null;
reducer.accept(value);
}
}
@Override
public R finish() { return reducer.finish(); }
}
return new DropWhile();
}
};
}
private static <T> Transducer<T, T> distinct() {
return filter(new HashSet<>()::add);
}
private static <T> Transducer<T, T> peek(Consumer<T> consumer) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) {
class Filter implements Reducer<T, R, Filter> {
@Override
public void accept(T value) {
consumer.accept(value);
reducer.accept(value);
}
@Override
public R finish() { return reducer.finish(); }
}
return new Filter();
}
};
}
private static <T> Transducer<T, T> sorted(Comparator<? super T> comparator) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<T, R, ?> reducer) {
class Sorted implements Reducer<T, R, Sorted> {
final List<T> elements = new ArrayList<>();
@Override
public void accept(T value) { elements.add(value); }
@Override
public R finish() {
elements.sort(comparator);
elements.forEach(reducer);
return reducer.finish();
}
}
return new Sorted();
}
};
}
private static <T extends Comparable<T>> Transducer<T, T> sorted() {
return sorted(Comparator.naturalOrder());
}
private static <T> Transducer<T, List<T>> batch(int size) {
return new Transducer<>() {
@Override
public <R> Reducer<T, R, ?> then(Reducer<List<T>, R, ?> reducer) {
class Batch implements Reducer<T, R, Batch> {
List<T> batch;
@Override
public void accept(T value) {
if (batch == null) batch = new ArrayList<>();
batch.add(value);
if (batch.size() == size) flush();
}
@Override
public R finish() {
if (batch != null) flush();
return reducer.finish();
}
public void flush() {
reducer.accept(batch);
batch = null;
}
}
return new Batch();
}
};
}
@Test
void reduces() {
assertThat(Stream.of(1, 2, 3, 4, 5).collect(toList()),
equalTo(List.of(1, 2, 3, 4, 5)));
assertThat(Stream.of(1, 2, 3, 4, 5).collect(sum()),
equalTo(15));
List<Integer> numbers = new ArrayList<>();
Stream.of(1, 2, 3, 4, 5).collect(forEach(numbers::add));
assertThat(numbers, equalTo(List.of(1, 2, 3, 4, 5)));
assertThat(Stream.of(1, 2, 3, 4, 5).collect(count()),
equalTo(5));
assertThat(Stream.of(1, 2, 3, 4, 5).collect(min()),
equalTo(Optional.of(1)));
assertThat(Stream.of(5, 4, 3, 2, 1).collect(min()),
equalTo(Optional.of(1)));
assertThat(Stream.of(1, 2, 3, 4, 5).collect(max()),
equalTo(Optional.of(5)));
assertThat(Stream.of(5, 4, 3, 2, 1).collect(max()),
equalTo(Optional.of(5)));
assertThat(Stream.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).collect(toArray(Integer[]::new)),
arrayContaining(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
assertThat(Stream.of(1, 2, 3, 4, 5).collect(collect(Collectors.averagingInt((Integer n) -> n))),
equalTo(3.0));
assertThat(Stream.of(1, 2, 3, 4, 5).collect(reduce(Integer::sum)),
equalTo(15));
assertThat(Stream.of(1, 2, 3, 4, 5).collect(reduce(100, Integer::sum)),
equalTo(115));
}
@Test
void reducesWithoutAReturnValue() {
List<Integer> numbers = new ArrayList<>();
Stream.of(1, 2, 3, 4, 5).forEach(forEach(numbers::add)); // this example looks silly, sure
assertThat(numbers, equalTo(List.of(1, 2, 3, 4, 5)));
}
@Test
void transforms() {
AtomicInteger sum = new AtomicInteger();
assertThat(Stream.of(1, 2, 3, 4, 5)
.collect(peek(sum::addAndGet).then(toList())),
equalTo(List.of(1, 2, 3, 4, 5)));
assertThat(sum.get(), equalTo(15));
assertThat(Stream.of(1, 2, 3, 4, 5)
.collect(filter(this::isEven).then(toList())),
equalTo(List.of(2, 4)));
assertThat(Stream.of(1, 2, 3, 4, 5, 1, 2, 3, 4, 5)
.collect(distinct().then(toList())),
equalTo(List.of(1, 2, 3, 4, 5)));
assertThat(Stream.of(1, 2, 3, 4, 5).collect(map(this::toLetter).then(toList())),
equalTo(List.of("A", "B", "C", "D", "E")));
assertThat(Stream.of(1, 2, 3, 4, 5).collect(flatMap(this::repeat).then(toList())),
equalTo(List.of(1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5)));
assertThat(Stream.of(1, 2, 3, 4, 5)
.collect(skip(3).then(toList())),
equalTo(List.of(4, 5)));
assertThat(Stream.of(1, 2, 3, 4, 5)
.collect(limit(3).then(toList())),
equalTo(List.of(1, 2, 3)));
assertThat(Stream.of(1, 2, 3, 4, 5)
.collect(dropWhile((Integer i) -> i <= 3).then(toList())),
equalTo(List.of(4, 5)));
assertThat(Stream.of(1, 5, 2, 4, 3)
.collect(Transducers.<Integer>sorted().then(toList())),
equalTo(List.of(1, 2, 3, 4, 5)));
assertThat(Stream.of(1, 2, 3, 4, 5)
.collect(batch(3).then(toList())),
equalTo(List.of(List.of(1, 2, 3), List.of(4, 5))));
assertThat(Stream.of(1, 2, 3, 4, 5)
.collect(filter(this::isEven)
.then(map(this::toLetter))
.then(toList())),
equalTo(List.of("B", "D")));
}
private boolean isEven(int n) {
return n % 2 == 0;
}
private String toLetter(int n) {
return Character.toString('A' + n - 1);
}
private Stream<Integer> repeat(Integer o) {
return Collections.nCopies(o, o).stream();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment