Skip to content

Instantly share code, notes, and snippets.

@dkohlsdorf
Last active March 13, 2019 21:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dkohlsdorf/e566ddee790388e8a042 to your computer and use it in GitHub Desktop.
Save dkohlsdorf/e566ddee790388e8a042 to your computer and use it in GitHub Desktop.
Matrix Factorization
package prediction;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
/**
* Item based interpolation
*
* @author Daniel Kohlsdorf
*/
public class MatrixFactorization implements Predictor {
/**
* Learning rate
*/
public static final double ALPHA = 0.000002;
/**
* Bias
*/
public static final double BETA = 0.002;
/**
* Not set field
*/
public static final double LATENT = Double.NaN;
private Array2DRowRealMatrix items, P, Q;
public MatrixFactorization(double items[][]) {
this.items = new Array2DRowRealMatrix(items);
}
public void itemInterpolation(int iterations, int K) {
// initalize smaller matricies
int N = items.getRowDimension();
int M = items.getColumnDimension();
Array2DRowRealMatrix p = new Array2DRowRealMatrix(N, K);
for(int i = 0; i < N; i++) {
for(int k = 0; k < K; k++) {
p.setEntry(i, k, Math.random() * 0.1);
}
}
Array2DRowRealMatrix q = new Array2DRowRealMatrix(K, M);
for(int j = 0; j < M; j++) {
for(int k = 0; k < K; k++) {
q.setEntry(k, j, Math.random() * 0.1);
}
}
// minimize reconstruction error
double last_e = 0;
iter:
for(int iter = 0; iter < iterations; iter++) {
for(int i = 0; i < N; i++) {
for(int j = 0; j < M; j++) {
if(!Double.isNaN(items.getEntry(i, j))) {
ArrayRealVector vec_p = new ArrayRealVector(p.getRow(i));
ArrayRealVector vec_q = new ArrayRealVector(q.getColumn(j));
double euv = items.getEntry(i, j) - vec_p.dotProduct(vec_q);
for(int k = 0; k < K; k++) {
p.setEntry(i, k, p.getEntry(i, k) + ALPHA * (2 * euv * q.getEntry(k, j) - BETA * p.getEntry(i, k)));
q.setEntry(k, j, q.getEntry(k, j) + ALPHA * (2 * euv * p.getEntry(i, k) - BETA * q.getEntry(k, j)));
}
}
}
}
if(iter % 100 == 0) {
double e = 0;
for(int i = 0; i < N; i++) {
for(int j = 0; j < M; j++) {
if(!Double.isNaN(items.getEntry(i, j))) {
ArrayRealVector vec_p = new ArrayRealVector(p.getRow(i));
ArrayRealVector vec_q = new ArrayRealVector(q.getColumn(j));
e += Math.pow(items.getEntry(i, j) - vec_p.dotProduct(vec_q), 2);
for (int k = 0; k < K; k++) {
e += (BETA / 2.0) * Math.pow(p.getEntry(i, k), 2) + Math.pow(q.getEntry(k, j),2);
}
}
}
}
if(Math.abs(last_e - e) < 1e-6) {
break iter;
}
last_e = e;
}
}
P = p;
Q = q;
items = (Array2DRowRealMatrix) p.multiply(q);
}
public Array2DRowRealMatrix getItems() {
return items;
}
public Array2DRowRealMatrix getP() {
return P;
}
public Array2DRowRealMatrix getQ() {
return Q;
}
public static void main(String[] args) {
MatrixFactorization rec = new MatrixFactorization(new double[][]{
{5,3,LATENT,1},
{4,LATENT,LATENT,1},
{1,1,LATENT,5},
{1,LATENT,LATENT,4},
{LATENT,1,5,4},
});
rec.itemInterpolation(5000, 2);
System.out.println(rec.getItems());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment