Skip to content

Instantly share code, notes, and snippets.

@fasiha
Created November 11, 2018 04:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fasiha/e9c59ceca412c13d6e7e84b5ba806b0b to your computer and use it in GitHub Desktop.
Save fasiha/e9c59ceca412c13d6e7e84b5ba806b0b to your computer and use it in GitHub Desktop.
Very barebones Java implementation of Ebisu https://fasiha.github.io/ebisu/
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9870985778947343\n"
]
}
],
"source": [
"%maven org.apache.commons:commons-math3:3.6.1\n",
"import org.apache.commons.math3.special.Gamma;\n",
"System.out.println(Gamma.logGamma(3.3));"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import java.lang.Math;\n",
"\n",
"public interface EbisuInterface {\n",
" public double getAlpha();\n",
" public double getBeta();\n",
" public double getT();\n",
"}\n",
"public class EbisuModel implements EbisuInterface {\n",
" private double a;\n",
" private double b;\n",
" private double t;\n",
" public EbisuModel(double alpha, double beta, double time) {\n",
" this.a = alpha;\n",
" this.b = beta;\n",
" this.t = time;\n",
" }\n",
" @Override\n",
" public double getAlpha() { return this.a; }\n",
" @Override\n",
" public double getBeta() { return this.b; }\n",
" @Override\n",
" public double getT() { return this.t; }\n",
" @Override\n",
" public String toString() { return \"Model(\" + this.a + \", \" + this.b + \", \" + this.t + \")\"; }\n",
"}\n",
"\n",
"double predictRecall(EbisuInterface prior, double tnow) {\n",
" double alpha = prior.getAlpha();\n",
" double beta = prior.getBeta();\n",
" double dt = tnow / prior.getT();\n",
" return Math.exp(Gamma.logGamma(alpha + dt) - Gamma.logGamma(alpha + beta + dt) - (Gamma.logGamma(alpha) - Gamma.logGamma(alpha + beta)));\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[11.1, 22.2, 33.3]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import java.util.stream.Stream;\n",
"import java.util.stream.StreamSupport;\n",
"import java.util.function.BiFunction;\n",
"\n",
"// Via https://stackoverflow.com/a/32342172/500207\n",
"<T> Stream<T> iteratorToFiniteStream(Iterator<T> iterator, boolean parallel) {\n",
" final Iterable<T> iterable = () -> iterator;\n",
" return StreamSupport.stream(iterable.spliterator(), parallel);\n",
"}\n",
"\n",
"<A, B, C> Stream<C> zip(Stream<A> streamA, Stream<B> streamB, BiFunction<A, B, C> zipper) {\n",
" final Iterator<A> iteratorA = streamA.iterator();\n",
" final Iterator<B> iteratorB = streamB.iterator();\n",
" final Iterator<C> iteratorC = new Iterator<C>() {\n",
" @Override\n",
" public boolean hasNext() {\n",
" return iteratorA.hasNext() && iteratorB.hasNext();\n",
" }\n",
"\n",
" @Override\n",
" public C next() {\n",
" return zipper.apply(iteratorA.next(), iteratorB.next());\n",
" }\n",
" };\n",
" final boolean parallel = streamA.isParallel() || streamB.isParallel();\n",
" return iteratorToFiniteStream(iteratorC, parallel);\n",
"}\n",
"\n",
"Arrays.toString(zip(List.of(1.1, 2.2, 3.3).stream(), List.of(10, 20, 30).stream(), (a, b)->a+b).toArray())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-136.12842467859434\n",
"4.91359873927344\n",
"[4.91359873927344, -1.0]\n"
]
}
],
"source": [
"double[] logSumExp(List<Double> a, List<Double> b) {\n",
" double amax = Collections.max(a);\n",
" double sum = zip(a.stream(), b.stream(), (ai, bi) -> Math.exp(ai - amax) * bi).reduce(0., Double::sum);\n",
" double sign = Math.signum(sum);\n",
" sum *= sign;\n",
" double out = Math.log(sum) + amax;\n",
" double[] ret = {out, sign};\n",
" return ret;\n",
"}\n",
"\n",
"System.out.println(Math.exp(3.3) + Math.exp(4.4) - Math.exp(5.5));\n",
"System.out.println(Math.log(Math.abs(Math.exp(3.3) + Math.exp(4.4) - Math.exp(5.5))));\n",
"System.out.println(Arrays.toString(logSumExp(List.of(3.3, 4.4, 5.5), List.of(1., 1., -1.))));"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-93.81500906943236"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"double subtractexp(double x, double y) {\n",
" var maxval = Math.max(x, y);\n",
" return Math.exp(maxval) * (Math.exp(x - maxval) - Math.exp(y - maxval));\n",
"}\n",
"subtractexp(4,5)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"List<Double> meanVarToBeta(double mean, double v) {\n",
" double tmp = mean * (1 - mean) / v - 1;\n",
" double alpha = mean * tmp;\n",
" double beta = (1 - mean) * tmp;\n",
" return List.of( alpha, beta );\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import java.util.stream.IntStream;\n",
"EbisuInterface updateRecall(EbisuInterface prior, boolean result, double tnow) {\n",
" double alpha = prior.getAlpha();\n",
" double beta = prior.getBeta();\n",
" double dt = tnow / prior.getT();\n",
" double mu = 0;\n",
" double v = 0;\n",
" if (result) {\n",
" double same = Gamma.logGamma(alpha + beta + dt) - Gamma.logGamma(alpha + dt);\n",
" double muln = Gamma.logGamma(alpha + 2 * dt) - Gamma.logGamma(alpha + beta + 2 * dt) + same;\n",
" mu = Math.exp(muln);\n",
" v = subtractexp(same + Gamma.logGamma(alpha + 3 * dt) - Gamma.logGamma(alpha + beta + 3 * dt),\n",
" 2 * muln);\n",
" } else {\n",
" double[] s = IntStream.range(0, 4).mapToDouble(n -> Gamma.logGamma(alpha + n * dt) - Gamma.logGamma(alpha + beta + n * dt)).toArray();\n",
" mu = Math.expm1(s[2] - s[1]) / -Math.expm1(s[0] - s[1]);\n",
" \n",
" double[] n1 = logSumExp(List.of( s[1], s[0] ), List.of( 1., -1. ));\n",
" n1[0] += s[3];\n",
"\n",
" var n2 = logSumExp(List.of(s[0], s[1], s[2]), List.of(1., 1., -1.));\n",
" n2[0] += s[2];\n",
"\n",
" double[] n3 = {s[1] * 2, 1.};\n",
"\n",
" var d = logSumExp(List.of(s[1], s[0]), List.of(1., -1.));\n",
" d[0] *= 2;\n",
"\n",
" var n = logSumExp(List.of(n1[0], n2[0], n3[0]), List.of(n1[1], n2[1], -n3[1]));\n",
"\n",
" v = Math.exp(n[0] - d[0]);\n",
" }\n",
" List<Double> newAlphaBeta = meanVarToBeta(mu, v);\n",
" return new EbisuModel(newAlphaBeta.get(0), newAlphaBeta.get(1), tnow);\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.49999999999999994\n",
"Model(3.0000000000000133, 2.000000000000008, 2.0)\n",
"Model(1.9999999999999585, 2.99999999999994, 2.0)\n"
]
}
],
"source": [
"EbisuModel m = new EbisuModel(2,2,2);\n",
"System.out.println(predictRecall(m, 2));\n",
"System.out.println(updateRecall(m, true, 2e-0));\n",
"System.out.println(updateRecall(m, false, 2e-0));"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Java",
"language": "java",
"name": "java"
},
"language_info": {
"codemirror_mode": "java",
"file_extension": ".java",
"mimetype": "text/x-java-source",
"name": "Java",
"pygments_lexer": "java",
"version": "11.0.1+13-LTS"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment