Last active
September 13, 2015 15:17
-
-
Save sakapon/0ec5372e83e72ff24bd3 to your computer and use it in GitHub Desktop.
MLSample / ClusteringConsole
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Diagnostics; | |
using System.Linq; | |
namespace ClusteringConsole | |
{ | |
[DebuggerDisplay(@"\{Clusters: {ClustersNumber}, Iterations: {IterationsNumber}\}")] | |
public class KMeans<T> | |
{ | |
public int ClustersNumber { get; private set; } | |
public int IterationsNumber { get; private set; } | |
public KMeans(int clustersNumber, int iterationsNumber) | |
{ | |
ClustersNumber = clustersNumber; | |
IterationsNumber = iterationsNumber; | |
} | |
public Dictionary<int, Record<T>[]> Train(Record<T>[] records) | |
{ | |
var clusters = InitializeClusters(ClustersNumber, records); | |
for (var i = 0; i < IterationsNumber; i++) | |
TrainOnce(clusters, records); | |
return clusters.ToDictionary(c => c.Id, c => c.Records.ToArray()); | |
} | |
static Cluster<T>[] InitializeClusters(int clustersNumber, Record<T>[] records) | |
{ | |
return RandomUtility.ShuffleRange(records.Length) | |
.Take(clustersNumber) | |
.Select(i => records[i]) | |
.Select((r, i) => new Cluster<T>(i, r.Features)) | |
.ToArray(); | |
} | |
static void TrainOnce(Cluster<T>[] clusters, Record<T>[] records) | |
{ | |
Array.ForEach(clusters, c => c.Records.Clear()); | |
AssignRecords(clusters, records); | |
Array.ForEach(clusters, c => c.TuneCentroid()); | |
} | |
static void AssignRecords(Cluster<T>[] clusters, IEnumerable<Record<T>> records) | |
{ | |
foreach (var record in records) | |
{ | |
var cluster = clusters.FirstToMin(c => FeaturesHelper.GetDistance(c.Centroid, record.Features)); | |
cluster.Records.Add(record); | |
} | |
} | |
} | |
[DebuggerDisplay(@"\{{ToDebugString()}\}")] | |
public struct Record<T> | |
{ | |
public T Element { get; set; } | |
public double[] Features { get; set; } | |
string ToDebugString() | |
{ | |
return string.Format("{0}: {1}", Element, FeaturesHelper.ToString(Features)); | |
} | |
} | |
[DebuggerDisplay(@"\{{ToDebugString()}\}")] | |
class Cluster<T> | |
{ | |
public int Id { get; private set; } | |
public double[] Centroid { get; private set; } | |
public List<Record<T>> Records { get; private set; } | |
public Cluster(int id, double[] centroid) | |
{ | |
Id = id; | |
Centroid = centroid; | |
Records = new List<Record<T>>(); | |
} | |
public void TuneCentroid() | |
{ | |
if (Records.Count == 0) return; | |
Centroid = Enumerable.Range(0, Centroid.Length) | |
.Select(i => Records.Average(r => r.Features[i])) | |
.ToArray(); | |
} | |
string ToDebugString() | |
{ | |
return string.Format("{0}: {1}: {2} records", Id, FeaturesHelper.ToString(Centroid), Records.Count); | |
} | |
} | |
public static class FeaturesHelper | |
{ | |
public static double GetDistance(double[] p1, double[] p2) | |
{ | |
return Math.Sqrt(p1.Zip(p2, (x1, x2) => x1 - x2).Sum(x => x * x)); | |
} | |
public static double GetNorm(double[] p) | |
{ | |
return Math.Sqrt(p.Sum(x => x * x)); | |
} | |
public static string ToString(double[] p) | |
{ | |
return string.Join(", ", p.Select(x => x.ToString("F3"))); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Drawing; | |
using System.Linq; | |
using System.Reflection; | |
using System.Text; | |
using System.Threading.Tasks; | |
namespace ClusteringConsole | |
{ | |
class Program | |
{ | |
static void Main(string[] args) | |
{ | |
ClusterColors(); | |
} | |
static void ClusterColors() | |
{ | |
var colors = typeof(Color).GetProperties(BindingFlags.Public | BindingFlags.Static) | |
.Where(p => p.PropertyType == typeof(Color)) | |
.Select(p => (Color)p.GetValue(null)) | |
.Where(c => c.A == 255) // Exclude Transparent. | |
.ToArray(); | |
var records = colors | |
.Select(c => new Record<Color> { Element = c, Features = new double[] { c.R, c.G, c.B } }) | |
.ToArray(); | |
var clustering = new KMeans<Color>(20, 50); | |
var clusters = clustering.Train(records); | |
foreach (var cluster in clusters) | |
{ | |
Console.WriteLine(cluster.Key); | |
Console.WriteLine(string.Join(", ", cluster.Value.Select(r => r.Element.Name))); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment