Skip to content

Instantly share code, notes, and snippets.

@tfmorris
Created March 29, 2016 18:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tfmorris/0f3878a6fa1c91dc6787aee06f55cac8 to your computer and use it in GitHub Desktop.
Save tfmorris/0f3878a6fa1c91dc6787aee06f55cac8 to your computer and use it in GitHub Desktop.
Online variation and standard deviation using Welford's algorithm and Java 8 Streams - just a sketch! only lightly tested!!
import java.util.Collections;
import java.util.EnumSet;
import java.util.IntSummaryStatistics;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.ToIntFunction;
import java.util.stream.Collector;
import java.util.stream.Collectors;
/**
* Online accumulator which extends the Java 8 IntSummaryStatistics class to
* also do variance and standard deviation using Welford's algorithm.
*
* @author Tom Morris <tfmorris@gmail.com>
*
*/
public class IntAccumulator extends IntSummaryStatistics {
private double mean = 0.0; // our online mean estimate
private double m2 = 0.0;
@Override
public void accept(int value) {
super.accept(value);
double delta = value - mean;
mean += delta / this.getCount(); // getCount() too inefficient?
m2 += delta * (value - mean);
}
@Override
public void combine(IntSummaryStatistics other) {
// TODO: What's the right answer here? Just throw or attempt to cast?
combine((IntAccumulator) other);
};
public void combine(IntAccumulator other) {
long count = getCount(); // get the old count before we combined
long otherCount = other.getCount();
double totalCount = count + otherCount;
super.combine(other);
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
double delta = other.getMeanEstimate() - mean;
// mean += delta * (otherCount / totalCount);
mean = (mean * count + other.getMeanEstimate() * otherCount) / totalCount;
m2 += other.getSquareSum() + ((delta * delta) * count * otherCount / totalCount);
}
private double getSquareSum() {
return m2;
}
/**
* Returns the online version of the mean which may be less accurate, but
* won't overflow like the version kept by {@link #getAverage()}.
*
* @return
*/
public double getMeanEstimate() {
return mean;
}
public double getSampleVariance() {
long count = getCount();
if (count < 2) {
return 0.0;
} else {
return m2 / (getCount() - 1); // sample variance N-1
}
}
public double getSampleStdDev() {
return Math.sqrt(getSampleVariance());
}
public static <T> Collector<T, ?, IntAccumulator> summarizingIntStdDev(ToIntFunction<? super T> mapper) {
return new CollectorImpl<T, IntAccumulator, IntAccumulator>(IntAccumulator::new,
(r, t) -> r.accept(mapper.applyAsInt(t)), (l, r) -> {
l.combine(r);
return l;
}, CollectorImpl.CH_ID);
}
}
/**
* Private copy of {@link Collectors.CollectorImpl} that we can use to get
* around visibility restrictions.
*
* @param <T>
* @param <A>
* @param <R>
*/
class CollectorImpl<T, A, R> implements Collector<T, A, R> {
static final Set<Collector.Characteristics> CH_ID = Collections
.unmodifiableSet(EnumSet.of(Collector.Characteristics.IDENTITY_FINISH));
private final Supplier<A> supplier;
private final BiConsumer<A, T> accumulator;
private final BinaryOperator<A> combiner;
private final Function<A, R> finisher;
private final Set<Characteristics> characteristics;
CollectorImpl(Supplier<A> supplier, BiConsumer<A, T> accumulator, BinaryOperator<A> combiner,
Function<A, R> finisher, Set<Characteristics> characteristics) {
this.supplier = supplier;
this.accumulator = accumulator;
this.combiner = combiner;
this.finisher = finisher;
this.characteristics = characteristics;
}
CollectorImpl(Supplier<A> supplier, BiConsumer<A, T> accumulator, BinaryOperator<A> combiner,
Set<Characteristics> characteristics) {
this(supplier, accumulator, combiner, castingIdentity(), characteristics);
}
@Override
public BiConsumer<A, T> accumulator() {
return accumulator;
}
@Override
public Supplier<A> supplier() {
return supplier;
}
@Override
public BinaryOperator<A> combiner() {
return combiner;
}
@Override
public Function<A, R> finisher() {
return finisher;
}
@Override
public Set<Characteristics> characteristics() {
return characteristics;
}
@SuppressWarnings("unchecked")
private static <I, R> Function<I, R> castingIdentity() {
return i -> (R) i;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment