Created
March 11, 2014 09:53
-
-
Save adrianseeley/9482657 to your computer and use it in GitHub Desktop.
GATO.KNNSATURIZATION - very inefficiently explores KNN saturation rates on the MNIST OCR set, draws pretty KNN waveform graphs too.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
using System.Threading.Tasks; | |
using System.IO; | |
using System.Drawing; | |
using System.Diagnostics; | |
namespace GATO.KNNSATURATION | |
{ | |
class Program | |
{ | |
static void Main(string[] args) | |
{ | |
Stopwatch TotalRegressionTime = new Stopwatch(); | |
TotalRegressionTime.Start(); | |
int MaxSamples = 1001; | |
int MaxSaturationPasses = 10; | |
double MaxSaturationSway = 0.2; | |
int MaxKNN = 20; | |
List<TrainingCase> ValidationCases = TrainingCase.ReadFromDisc(); | |
ValidationCases.RemoveRange(0, MaxSamples); | |
List<Result> Results = new List<Result>(); | |
int WaveformID = 0; | |
Directory.CreateDirectory("./img/"); | |
for (int CurrentSamples = 1; CurrentSamples <= MaxSamples; CurrentSamples += 100) | |
{ | |
for (int CurrentSaturationPasses = 0; CurrentSaturationPasses < MaxSaturationPasses; CurrentSaturationPasses++) | |
{ | |
for (double CurrentSaturationSway = 0.00; CurrentSaturationSway < MaxSaturationSway; CurrentSaturationSway += 0.01) | |
{ | |
if (CurrentSaturationPasses == 0 && CurrentSaturationSway > 0) continue; | |
for (int CurrentKNN = 1; CurrentKNN < MaxKNN; CurrentKNN++) | |
{ | |
Stopwatch CurrentRegressionTime = new Stopwatch(); | |
CurrentRegressionTime.Start(); | |
List<TrainingCase> TrainingCases = TrainingCase.ReadFromDisc(CurrentSamples); | |
TrainingCases = Saturate(TrainingCases, CurrentSaturationPasses, CurrentSaturationSway); | |
int Errors = Fitness(TrainingCases, ValidationCases, CurrentKNN); | |
DrawKNNWaveform(TrainingCases, "./img/KNNWaveform" + WaveformID + ".Samples" + CurrentSamples + ".SatPasses" + CurrentSaturationPasses + ".SatSway" + CurrentSaturationSway + ".KNN" + CurrentKNN + ".Errors" + Errors + "of" + (ValidationCases.Count) + ".png", "KNNWaveform " + WaveformID + " on " + CurrentSamples + " Samples, with " + CurrentSaturationPasses + " Saturation Passes at " + CurrentSaturationSway + "% Sway, Scored " + Errors + " Errors out of " + ValidationCases.Count + " Validation Cases (" + (((float)Errors / (float)ValidationCases.Count) * 100) + "%) Using K=" + CurrentKNN); | |
CurrentRegressionTime.Stop(); | |
Results.Add(new Result(WaveformID, CurrentSamples, CurrentSaturationPasses, CurrentSaturationSway, CurrentKNN, Errors, ValidationCases.Count, CurrentRegressionTime.ElapsedMilliseconds)); | |
Console.WriteLine(Results[Results.Count - 1].ToHumanString()); | |
WaveformID++; | |
} | |
} | |
} | |
} | |
TotalRegressionTime.Stop(); | |
Result.Write("Results.txt", Results, TotalRegressionTime.ElapsedMilliseconds); | |
} | |
static List<TrainingCase> Saturate (List<TrainingCase> TrainingCases, int Passes, double Sway) | |
{ | |
for (int t = TrainingCases.Count - 1; t >= 0; t--) | |
{ | |
for (int p = 0; p < Passes; p++) | |
{ | |
TrainingCases.Add(TrainingCases[t].SaturationClone(Sway)); | |
} | |
} | |
return TrainingCases; | |
} | |
static void DrawKNNWaveform(List<TrainingCase> TrainingCases, String Filename, String Annotation) | |
{ | |
double HighestInput = TrainingCases[0].Inputs[0]; | |
double LowestInput = TrainingCases[0].Inputs[0]; | |
List<String> Classes = new List<String>(); | |
for (int t = 0; t < TrainingCases.Count; t++) | |
{ | |
for (int i = 0; i < TrainingCases[t].Inputs.Length; i++) | |
{ | |
if (TrainingCases[t].Inputs[i] > HighestInput) HighestInput = TrainingCases[t].Inputs[i]; | |
if (TrainingCases[t].Inputs[i] < LowestInput) LowestInput = TrainingCases[t].Inputs[i]; | |
if (Classes.IndexOf(TrainingCases[t].Class) == -1) Classes.Add(TrainingCases[t].Class); | |
} | |
} | |
Classes.Sort(); | |
double InputRange = HighestInput - LowestInput; | |
for (int t = 0; t < TrainingCases.Count; t++) | |
{ | |
TrainingCases[t].Normalize(LowestInput, InputRange); | |
} | |
if (InputRange == 0) | |
{ | |
HighestInput = 1; | |
LowestInput = -1; | |
InputRange = 2; | |
} | |
int alpha = 122; | |
List<SolidBrush> BrushSet = new List<SolidBrush>() | |
{ | |
new SolidBrush(Color.FromArgb(alpha, Color.Red)), | |
new SolidBrush(Color.FromArgb(alpha, Color.OrangeRed)), | |
new SolidBrush(Color.FromArgb(alpha, Color.Orange)), | |
new SolidBrush(Color.FromArgb(alpha, Color.Yellow)), | |
new SolidBrush(Color.FromArgb(alpha, Color.GreenYellow)), | |
new SolidBrush(Color.FromArgb(alpha, Color.Green)), | |
new SolidBrush(Color.FromArgb(alpha, Color.LightBlue)), | |
new SolidBrush(Color.FromArgb(alpha, Color.Blue)), | |
new SolidBrush(Color.FromArgb(alpha, Color.Purple)), | |
new SolidBrush(Color.FromArgb(alpha, Color.DarkViolet)) | |
}; | |
float width = 3000; | |
float height = 1000; | |
float border = 60; | |
float mark_width = 10; | |
float mark_height = 10; | |
Bitmap b = new Bitmap((int)width, (int)height); | |
using (Graphics g = Graphics.FromImage(b)) | |
{ | |
g.Clear(Color.Black); | |
for (int t = 0; t < TrainingCases.Count; t++) | |
{ | |
for (int n = 0; n < TrainingCases[t].NormalizedInputs.Length; n++) | |
{ | |
g.FillRectangle(BrushSet[Classes.IndexOf(TrainingCases[t].Class)], | |
(((float)(n) / (float)TrainingCases[t].NormalizedInputs.Length) * (width - (border * 2))) + border, | |
((float)TrainingCases[t].NormalizedInputs[n] * (height - (border * 2))) + border, | |
mark_width, | |
mark_height | |
); | |
} | |
} | |
g.DrawString(Annotation, new Font(FontFamily.GenericMonospace, 24, FontStyle.Bold), Brushes.White, 10, 10); | |
} | |
b.Save(Filename); | |
} | |
static int Fitness(List<TrainingCase> TrainingCases, List<TrainingCase> ValidationCases, int KNN) | |
{ | |
int Errors = 0; | |
for (int v = 0; v < ValidationCases.Count; v++) | |
{ | |
List<KeyValuePair<String, double>> ClassDistances = new List<KeyValuePair<string, double>>(); | |
for (int t = 0; t < TrainingCases.Count; t++) | |
{ | |
double Distance = 0; | |
for (int i = 0; i < TrainingCases[t].Inputs.Length; i++) | |
{ | |
Distance += Math.Pow(TrainingCases[t].Inputs[i] - ValidationCases[v].Inputs[i], 2); | |
} | |
Distance = Math.Sqrt(Distance); | |
ClassDistances.Add(new KeyValuePair<String, double>(TrainingCases[t].Class, Distance)); | |
} | |
ClassDistances.Sort((a, b) => { return a.Value.CompareTo(b.Value); }); | |
Dictionary<String, int> ClassVotes = new Dictionary<string, int>(); | |
for (int k = 0; k < KNN && k < ClassDistances.Count; k++) | |
{ | |
if (!ClassVotes.ContainsKey(ClassDistances[k].Key)) | |
{ | |
ClassVotes.Add(ClassDistances[k].Key, 0); | |
} | |
ClassVotes[ClassDistances[k].Key]++; | |
} | |
List<KeyValuePair<String, int>> MostClassVotes = ClassVotes.ToList(); | |
MostClassVotes.Sort((a, b) => { return -a.Value.CompareTo(b.Value); }); | |
if (MostClassVotes[0].Key != ValidationCases[v].Class) | |
{ | |
Errors++; | |
} | |
} | |
return Errors; | |
} | |
} | |
class TrainingCase | |
{ | |
public static Random R = new Random(); | |
public double[] Inputs; | |
public double[] NormalizedInputs; | |
public String Class; | |
public static List<TrainingCase> BufferedFromDisc; | |
public static List<TrainingCase> ReadFromDisc (int Partial = -1) | |
{ | |
if (BufferedFromDisc == null) | |
{ | |
BufferedFromDisc = new List<TrainingCase>(); | |
using (FileStream FS = File.OpenRead("ocr_train.data")) | |
{ | |
using (StreamReader SR = new StreamReader(FS)) | |
{ | |
SR.ReadLine(); // fucking headers | |
while (!SR.EndOfStream) | |
{ | |
String Line = SR.ReadLine(); | |
String[] Parts = Line.Split(','); | |
List<double> Inputs = new List<double>(); | |
for (int p = 1; p < Parts.Length; p++) | |
{ | |
Inputs.Add(Double.Parse(Parts[p])); | |
} | |
BufferedFromDisc.Add(new TrainingCase(Inputs.ToArray(), Parts[0])); | |
} | |
} | |
} | |
} | |
List<TrainingCase> ret = new List<TrainingCase>(); | |
ret.AddRange(BufferedFromDisc.GetRange(0, Partial == -1 ? BufferedFromDisc.Count: Partial)); | |
return ret; | |
} | |
public TrainingCase(double[] Inputs, String Class) | |
{ | |
this.Inputs = Inputs; | |
this.NormalizedInputs = new double[Inputs.Length]; | |
this.Class = Class; | |
} | |
public TrainingCase SaturationClone(double Sway) | |
{ | |
double[] CloneInputs = new double[Inputs.Length]; | |
Inputs.CopyTo(CloneInputs, 0); | |
for (int i = 0; i < CloneInputs.Length; i++) | |
{ | |
if (Inputs[i] != 0) | |
{ | |
CloneInputs[i] += (CloneInputs[i] * (Sway * 2) * R.NextDouble()) - Sway; | |
} | |
} | |
return new TrainingCase(CloneInputs, Class); | |
} | |
public double[] Normalize(double Min, double Range) | |
{ | |
for (int i = 0; i < Inputs.Length; i++) | |
{ | |
NormalizedInputs[i] = (Inputs[i] - Min) / Range; | |
} | |
return NormalizedInputs; | |
} | |
} | |
class Result | |
{ | |
public static void Write(String Filename, List<Result> Results, long TotalRegressionTime) | |
{ | |
StringBuilder Output = new StringBuilder(); | |
Output.AppendLine("Regression ran in " + TotalRegressionTime + "ms, and generated " + Results.Count + " results"); | |
Output.AppendLine("WaveformID,Samples,SaturationPasses,SaturationSway,KNN,Errors,TotalValidationCases,Error%,RegressionTime(ms)"); | |
for (int r = 0; r < Results.Count; r++) | |
{ | |
Output.AppendLine(Results[r].ToString()); | |
} | |
using (FileStream fs = File.Create(Filename)) | |
{ | |
using (StreamWriter sw = new StreamWriter(fs)) | |
{ | |
sw.Write(Output.ToString()); | |
} | |
} | |
} | |
public int WaveformID; | |
public int Samples; | |
public int SaturationPasses; | |
public double SaturationSway; | |
public int KNN; | |
public int Errors; | |
public int ValidationCases; | |
public long RegressionTime; | |
public Result(int WaveformID, int Samples, int SaturationPasses, double SaturationSway, int KNN, int Errors, int ValidationCases, long RegressionTime) | |
{ | |
this.WaveformID = WaveformID; | |
this.Samples = Samples; | |
this.SaturationPasses = SaturationPasses; | |
this.SaturationSway = SaturationSway; | |
this.KNN = KNN; | |
this.Errors = Errors; | |
this.ValidationCases = ValidationCases; | |
this.RegressionTime = RegressionTime; | |
} | |
public override string ToString() | |
{ | |
return WaveformID + "," + Samples + "," + SaturationPasses + "," + SaturationSway + "," + KNN + "," + Errors + "," + ValidationCases + "," + ((float)Errors / (float)ValidationCases) + "," + RegressionTime; | |
} | |
public string ToHumanString() | |
{ | |
return "WaveformID: " + WaveformID + " Samples: " + Samples + " Passes: " + SaturationPasses + " Sway: " + SaturationSway + " KNN: " + KNN + " Errors: " + Errors + " Of: " + ValidationCases + " %: " + ((float)Errors / (float)ValidationCases) + " MS: " + RegressionTime; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment