Created
November 11, 2018 04:21
-
-
Save fasiha/e9c59ceca412c13d6e7e84b5ba806b0b to your computer and use it in GitHub Desktop.
Very barebones Java implementation of Ebisu https://fasiha.github.io/ebisu/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.9870985778947343\n" | |
] | |
} | |
], | |
"source": [ | |
"%maven org.apache.commons:commons-math3:3.6.1\n", | |
"import org.apache.commons.math3.special.Gamma;\n", | |
"System.out.println(Gamma.logGamma(3.3));" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import java.lang.Math;\n", | |
"\n", | |
"public interface EbisuInterface {\n", | |
" public double getAlpha();\n", | |
" public double getBeta();\n", | |
" public double getT();\n", | |
"}\n", | |
"public class EbisuModel implements EbisuInterface {\n", | |
" private double a;\n", | |
" private double b;\n", | |
" private double t;\n", | |
" public EbisuModel(double alpha, double beta, double time) {\n", | |
" this.a = alpha;\n", | |
" this.b = beta;\n", | |
" this.t = time;\n", | |
" }\n", | |
" @Override\n", | |
" public double getAlpha() { return this.a; }\n", | |
" @Override\n", | |
" public double getBeta() { return this.b; }\n", | |
" @Override\n", | |
" public double getT() { return this.t; }\n", | |
" @Override\n", | |
" public String toString() { return \"Model(\" + this.a + \", \" + this.b + \", \" + this.t + \")\"; }\n", | |
"}\n", | |
"\n", | |
"double predictRecall(EbisuInterface prior, double tnow) {\n", | |
" double alpha = prior.getAlpha();\n", | |
" double beta = prior.getBeta();\n", | |
" double dt = tnow / prior.getT();\n", | |
" return Math.exp(Gamma.logGamma(alpha + dt) - Gamma.logGamma(alpha + beta + dt) - (Gamma.logGamma(alpha) - Gamma.logGamma(alpha + beta)));\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[11.1, 22.2, 33.3]" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import java.util.stream.Stream;\n", | |
"import java.util.stream.StreamSupport;\n", | |
"import java.util.function.BiFunction;\n", | |
"\n", | |
"// Via https://stackoverflow.com/a/32342172/500207\n", | |
"<T> Stream<T> iteratorToFiniteStream(Iterator<T> iterator, boolean parallel) {\n", | |
" final Iterable<T> iterable = () -> iterator;\n", | |
" return StreamSupport.stream(iterable.spliterator(), parallel);\n", | |
"}\n", | |
"\n", | |
"<A, B, C> Stream<C> zip(Stream<A> streamA, Stream<B> streamB, BiFunction<A, B, C> zipper) {\n", | |
" final Iterator<A> iteratorA = streamA.iterator();\n", | |
" final Iterator<B> iteratorB = streamB.iterator();\n", | |
" final Iterator<C> iteratorC = new Iterator<C>() {\n", | |
" @Override\n", | |
" public boolean hasNext() {\n", | |
" return iteratorA.hasNext() && iteratorB.hasNext();\n", | |
" }\n", | |
"\n", | |
" @Override\n", | |
" public C next() {\n", | |
" return zipper.apply(iteratorA.next(), iteratorB.next());\n", | |
" }\n", | |
" };\n", | |
" final boolean parallel = streamA.isParallel() || streamB.isParallel();\n", | |
" return iteratorToFiniteStream(iteratorC, parallel);\n", | |
"}\n", | |
"\n", | |
"Arrays.toString(zip(List.of(1.1, 2.2, 3.3).stream(), List.of(10, 20, 30).stream(), (a, b)->a+b).toArray())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"-136.12842467859434\n", | |
"4.91359873927344\n", | |
"[4.91359873927344, -1.0]\n" | |
] | |
} | |
], | |
"source": [ | |
"double[] logSumExp(List<Double> a, List<Double> b) {\n", | |
" double amax = Collections.max(a);\n", | |
" double sum = zip(a.stream(), b.stream(), (ai, bi) -> Math.exp(ai - amax) * bi).reduce(0., Double::sum);\n", | |
" double sign = Math.signum(sum);\n", | |
" sum *= sign;\n", | |
" double out = Math.log(sum) + amax;\n", | |
" double[] ret = {out, sign};\n", | |
" return ret;\n", | |
"}\n", | |
"\n", | |
"System.out.println(Math.exp(3.3) + Math.exp(4.4) - Math.exp(5.5));\n", | |
"System.out.println(Math.log(Math.abs(Math.exp(3.3) + Math.exp(4.4) - Math.exp(5.5))));\n", | |
"System.out.println(Arrays.toString(logSumExp(List.of(3.3, 4.4, 5.5), List.of(1., 1., -1.))));" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-93.81500906943236" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"double subtractexp(double x, double y) {\n", | |
" var maxval = Math.max(x, y);\n", | |
" return Math.exp(maxval) * (Math.exp(x - maxval) - Math.exp(y - maxval));\n", | |
"}\n", | |
"subtractexp(4,5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"List<Double> meanVarToBeta(double mean, double v) {\n", | |
" double tmp = mean * (1 - mean) / v - 1;\n", | |
" double alpha = mean * tmp;\n", | |
" double beta = (1 - mean) * tmp;\n", | |
" return List.of( alpha, beta );\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import java.util.stream.IntStream;\n", | |
"EbisuInterface updateRecall(EbisuInterface prior, boolean result, double tnow) {\n", | |
" double alpha = prior.getAlpha();\n", | |
" double beta = prior.getBeta();\n", | |
" double dt = tnow / prior.getT();\n", | |
" double mu = 0;\n", | |
" double v = 0;\n", | |
" if (result) {\n", | |
" double same = Gamma.logGamma(alpha + beta + dt) - Gamma.logGamma(alpha + dt);\n", | |
" double muln = Gamma.logGamma(alpha + 2 * dt) - Gamma.logGamma(alpha + beta + 2 * dt) + same;\n", | |
" mu = Math.exp(muln);\n", | |
" v = subtractexp(same + Gamma.logGamma(alpha + 3 * dt) - Gamma.logGamma(alpha + beta + 3 * dt),\n", | |
" 2 * muln);\n", | |
" } else {\n", | |
" double[] s = IntStream.range(0, 4).mapToDouble(n -> Gamma.logGamma(alpha + n * dt) - Gamma.logGamma(alpha + beta + n * dt)).toArray();\n", | |
" mu = Math.expm1(s[2] - s[1]) / -Math.expm1(s[0] - s[1]);\n", | |
" \n", | |
" double[] n1 = logSumExp(List.of( s[1], s[0] ), List.of( 1., -1. ));\n", | |
" n1[0] += s[3];\n", | |
"\n", | |
" var n2 = logSumExp(List.of(s[0], s[1], s[2]), List.of(1., 1., -1.));\n", | |
" n2[0] += s[2];\n", | |
"\n", | |
" double[] n3 = {s[1] * 2, 1.};\n", | |
"\n", | |
" var d = logSumExp(List.of(s[1], s[0]), List.of(1., -1.));\n", | |
" d[0] *= 2;\n", | |
"\n", | |
" var n = logSumExp(List.of(n1[0], n2[0], n3[0]), List.of(n1[1], n2[1], -n3[1]));\n", | |
"\n", | |
" v = Math.exp(n[0] - d[0]);\n", | |
" }\n", | |
" List<Double> newAlphaBeta = meanVarToBeta(mu, v);\n", | |
" return new EbisuModel(newAlphaBeta.get(0), newAlphaBeta.get(1), tnow);\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.49999999999999994\n", | |
"Model(3.0000000000000133, 2.000000000000008, 2.0)\n", | |
"Model(1.9999999999999585, 2.99999999999994, 2.0)\n" | |
] | |
} | |
], | |
"source": [ | |
"EbisuModel m = new EbisuModel(2,2,2);\n", | |
"System.out.println(predictRecall(m, 2));\n", | |
"System.out.println(updateRecall(m, true, 2e-0));\n", | |
"System.out.println(updateRecall(m, false, 2e-0));" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Java", | |
"language": "java", | |
"name": "java" | |
}, | |
"language_info": { | |
"codemirror_mode": "java", | |
"file_extension": ".java", | |
"mimetype": "text/x-java-source", | |
"name": "Java", | |
"pygments_lexer": "java", | |
"version": "11.0.1+13-LTS" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment