Skip to content

Instantly share code, notes, and snippets.

@mfagerlund
Last active December 13, 2017 19:11
Show Gist options
  • Save mfagerlund/30218d5a07b7fa34e1719688698a5014 to your computer and use it in GitHub Desktop.
Save mfagerlund/30218d5a07b7fa34e1719688698a5014 to your computer and use it in GitHub Desktop.
C# Logistic regression + Factorization machines + SGD
using System;
using System.Globalization;
using System.Linq;
using NUnit.Framework;
namespace Boltzman.Tests.FactorizationMachines
{
[TestFixture]
public class FactorizationMachineTests
{
// This is fairly a straight translation of this:
// https://gist.github.com/kalaidin/9ea737ad771fcf073e57
/*
SGD for logistic loss + factorization machines
The code follows this paper:
[1] http://www.ics.uci.edu/~smyth/courses/cs277/papers/factorization_machines_with_libFM.pdf
*/
[Test]
public void Predict()
{
CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
float[][] trainingSamples =
{
new float[] {2, 1, 1, 0, 1, 1, 0, 0, 0},
new float[] {2, 1, 1, 1, 0, 0, 1, 0, 0},
new float[] {2, 1, 1, 0, 0, 0, 0, 1, 0},
new float[] {2, 1, 0, 0, 0, 0, 0, 0, 1},
new float[] {4, 2, 0, 0, 0, 0, 0, 0, 1},
new float[] {4, 2, 0, 0, 0, 0, 0, 0, 1},
new float[] {4, 2, 0, 0, 0, 0, 0, 0, 1}
};
float[] results = { 1, 1, -1, -1, 1, 1, 1 };
SimpleFactorizationMachine fm =
new SimpleFactorizationMachine(
epochs: 2000,
factorCount: 3);
fm.OnEpoch =
() =>
{
//Console.WriteLine($"{fm.CurrentEpoch}: {SimpleFactorizationMachine.LogLoss(fm.Predict(xTrain), yTrain).SJoin()}");
//Console.WriteLine(fm.W0);
if (fm.CurrentEpoch % 100 == 0 || fm.IsLastEpoch)
{
Console.WriteLine($"{fm.CurrentEpoch,4}: {fm.Rmse:0.00000}");
}
};
fm.Fit(trainingSamples, results);
float[] scaled = SimpleFactorizationMachine.Scale(fm.Predict(trainingSamples));
Console.WriteLine($"Pred: {string.Join(",", scaled)}");
float sum = 0;
for (int index = 0; index < scaled.Length; index++)
{
float x = scaled[index];
float y = results[index];
float v = x - y;
sum += v * v;
}
Console.WriteLine($"RMSE: {Math.Sqrt(sum / results.Length):0.00000}");
}
public class SimpleFactorizationMachine
{
private readonly Random _random;
public SimpleFactorizationMachine(
int epochs = 1000,
float learningRate = 0.01f,
int factorCount = 3,
float initialWeightSpread = 0.001f)
{
_random = new Random(99);
LearningRate = learningRate;
L2Regularization = 0.01f;
InitialWeightSpread = initialWeightSpread;
Epochs = epochs;
FactorCount = factorCount;
}
// Number of features
public int FeatureCount { get; set; }
// Number of factors
public int FactorCount { get; set; }
public int Epochs { get; set; }
public float InitialWeightSpread { get; set; }
public float L2Regularization { get; set; }
public float LearningRate { get; set; }
public float W0 { get; set; }
public float[] W { get; set; }
public float[][] V { get; set; }
public Action OnEpoch { get; set; }
public int CurrentEpoch { get; set; }
public float Rmse { get; set; }
public bool IsLastEpoch => CurrentEpoch >= Epochs - 1;
public void Fit(float[][] samples, float[] values)
{
FeatureCount = samples.First().Length;
W = new float[FeatureCount];
V = new float[FeatureCount][];
for (int p = 0; p < FeatureCount; p++)
{
float[] arr = new float[FactorCount];
for (int k = 0; k < FactorCount; k++)
{
arr[k] = InitialWeightSpread * ((float)_random.NextDouble() - 0.5f) * 2;
}
V[p] = arr;
}
for (int epoch = 0; epoch < Epochs; epoch++)
{
CurrentEpoch = epoch;
Rmse = 0;
for (int sampleIndex = 0; sampleIndex < samples.Length; sampleIndex++)
{
float[] x = samples[sampleIndex];
float p = Predict(x);
float y = values[sampleIndex];
// Where does the delta come from?
float delta = y * (Sigmoid(y * p) - 1);
float error = Sigmoid(p) * 2 - 1 - y;
Rmse += error * error;
W0 -= LearningRate * (delta + 2 * L2Regularization * W0);
for (int featureIndex = 0; featureIndex < FeatureCount; featureIndex++)
{
W[featureIndex] -= LearningRate * (delta * x[featureIndex] + 2 * L2Regularization * W[featureIndex]);
var v = V[featureIndex];
for (int j = 0; j < FactorCount; j++)
{
float dot = 0;
for (int pi = 0; pi < FeatureCount; pi++)
{
dot += V[pi][j] * x[pi];
}
float h = x[featureIndex] * (dot - x[featureIndex] * v[j]);
v[j] -= LearningRate * (delta * h + 2 * L2Regularization * v[j]);
}
}
}
Rmse = (float)Math.Sqrt(Rmse / samples.Length);
OnEpoch?.Invoke();
}
}
public float[] Predict(float[][] x)
{
return x.Select(Predict).ToArray();
}
public float Predict(float[] x)
{
float res = W0 + DotProduct(W, x);
float fRes = 0;
for (int f = 0; f < FactorCount; f++)
{
float s = 0;
float sSquared = 0;
for (int j = 0; j < FeatureCount; j++)
{
float el = V[j][f] * x[j];
s += el;
sSquared += el * el;
}
fRes += 0.5f * (s * s - sSquared);
}
res += fRes;
return res;
}
public static float[] Scale(float[] x)
{
return x.Select(y => Sigmoid(y) * 2 - 1).ToArray();
}
public static float[] Sigmoid(float[] x)
{
return x.Select(Sigmoid).ToArray();
}
public static float[] LogLoss(float[] x, float[] y)
{
float[] results = new float[y.Length];
for (int index = 0; index < y.Length; index++)
{
results[index] = LogLoss(x[index], y[index]);
}
return results;
}
public static float Sigmoid(float x)
{
if (x > 50)
{
return 1;
}
if (x < -50)
{
return 0;
}
return 1 / (1 + (float)Math.Exp(-x));
}
public static float LogLoss(float y, float p)
{
float z = y * p;
if (z > 18)
{
return (float)Math.Exp(-z);
}
if (z < -18)
{
return -z;
}
return -(float)Math.Log(Sigmoid(z));
}
private float DotProduct(float[] a, float[] b)
{
if (a.Length != b.Length)
{
throw new InvalidOperationException("Arrays are of different sizes");
}
float total = 0;
for (int index = 0; index < a.Length; index++)
{
total += a[index] * b[index];
}
return total;
}
}
}
}
@mfagerlund
Copy link
Author

mfagerlund commented Jun 17, 2016

This is a minimal implementation of a Factorization Machine in C#. It's an almost straight translation of this: https://gist.github.com/kalaidin/9ea737ad771fcf073e57 .

0: 0.98412
100: 0.72338
200: 0.62091
300: 0.41326
400: 0.23960
500: 0.16100
600: 0.12571
700: 0.10831
800: 0.09896
900: 0.09357
1000: 0.09028
1100: 0.08816
1200: 0.08673
1300: 0.08572
1400: 0.08497
1500: 0.08440
1600: 0.08395
1700: 0.08358
1800: 0.08328
1900: 0.08303
1999: 0.08282
Pred: 0.9796363,0.9795824,-0.9317688,-0.7985968,0.9767648,0.9767648,0.9767648
RMSE: 0.08252

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment