Skip to content

Instantly share code, notes, and snippets.

@adrianseeley
Created March 11, 2014 09:53
Show Gist options
  • Save adrianseeley/9482657 to your computer and use it in GitHub Desktop.
Save adrianseeley/9482657 to your computer and use it in GitHub Desktop.
GATO.KNNSATURIZATION - very inefficiently explores KNN saturation rates on the MNIST OCR set, draws pretty KNN waveform graphs too.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.IO;
using System.Drawing;
using System.Diagnostics;
namespace GATO.KNNSATURATION
{
class Program
{
static void Main(string[] args)
{
Stopwatch TotalRegressionTime = new Stopwatch();
TotalRegressionTime.Start();
int MaxSamples = 1001;
int MaxSaturationPasses = 10;
double MaxSaturationSway = 0.2;
int MaxKNN = 20;
List<TrainingCase> ValidationCases = TrainingCase.ReadFromDisc();
ValidationCases.RemoveRange(0, MaxSamples);
List<Result> Results = new List<Result>();
int WaveformID = 0;
Directory.CreateDirectory("./img/");
for (int CurrentSamples = 1; CurrentSamples <= MaxSamples; CurrentSamples += 100)
{
for (int CurrentSaturationPasses = 0; CurrentSaturationPasses < MaxSaturationPasses; CurrentSaturationPasses++)
{
for (double CurrentSaturationSway = 0.00; CurrentSaturationSway < MaxSaturationSway; CurrentSaturationSway += 0.01)
{
if (CurrentSaturationPasses == 0 && CurrentSaturationSway > 0) continue;
for (int CurrentKNN = 1; CurrentKNN < MaxKNN; CurrentKNN++)
{
Stopwatch CurrentRegressionTime = new Stopwatch();
CurrentRegressionTime.Start();
List<TrainingCase> TrainingCases = TrainingCase.ReadFromDisc(CurrentSamples);
TrainingCases = Saturate(TrainingCases, CurrentSaturationPasses, CurrentSaturationSway);
int Errors = Fitness(TrainingCases, ValidationCases, CurrentKNN);
DrawKNNWaveform(TrainingCases, "./img/KNNWaveform" + WaveformID + ".Samples" + CurrentSamples + ".SatPasses" + CurrentSaturationPasses + ".SatSway" + CurrentSaturationSway + ".KNN" + CurrentKNN + ".Errors" + Errors + "of" + (ValidationCases.Count) + ".png", "KNNWaveform " + WaveformID + " on " + CurrentSamples + " Samples, with " + CurrentSaturationPasses + " Saturation Passes at " + CurrentSaturationSway + "% Sway, Scored " + Errors + " Errors out of " + ValidationCases.Count + " Validation Cases (" + (((float)Errors / (float)ValidationCases.Count) * 100) + "%) Using K=" + CurrentKNN);
CurrentRegressionTime.Stop();
Results.Add(new Result(WaveformID, CurrentSamples, CurrentSaturationPasses, CurrentSaturationSway, CurrentKNN, Errors, ValidationCases.Count, CurrentRegressionTime.ElapsedMilliseconds));
Console.WriteLine(Results[Results.Count - 1].ToHumanString());
WaveformID++;
}
}
}
}
TotalRegressionTime.Stop();
Result.Write("Results.txt", Results, TotalRegressionTime.ElapsedMilliseconds);
}
static List<TrainingCase> Saturate (List<TrainingCase> TrainingCases, int Passes, double Sway)
{
for (int t = TrainingCases.Count - 1; t >= 0; t--)
{
for (int p = 0; p < Passes; p++)
{
TrainingCases.Add(TrainingCases[t].SaturationClone(Sway));
}
}
return TrainingCases;
}
static void DrawKNNWaveform(List<TrainingCase> TrainingCases, String Filename, String Annotation)
{
double HighestInput = TrainingCases[0].Inputs[0];
double LowestInput = TrainingCases[0].Inputs[0];
List<String> Classes = new List<String>();
for (int t = 0; t < TrainingCases.Count; t++)
{
for (int i = 0; i < TrainingCases[t].Inputs.Length; i++)
{
if (TrainingCases[t].Inputs[i] > HighestInput) HighestInput = TrainingCases[t].Inputs[i];
if (TrainingCases[t].Inputs[i] < LowestInput) LowestInput = TrainingCases[t].Inputs[i];
if (Classes.IndexOf(TrainingCases[t].Class) == -1) Classes.Add(TrainingCases[t].Class);
}
}
Classes.Sort();
double InputRange = HighestInput - LowestInput;
for (int t = 0; t < TrainingCases.Count; t++)
{
TrainingCases[t].Normalize(LowestInput, InputRange);
}
if (InputRange == 0)
{
HighestInput = 1;
LowestInput = -1;
InputRange = 2;
}
int alpha = 122;
List<SolidBrush> BrushSet = new List<SolidBrush>()
{
new SolidBrush(Color.FromArgb(alpha, Color.Red)),
new SolidBrush(Color.FromArgb(alpha, Color.OrangeRed)),
new SolidBrush(Color.FromArgb(alpha, Color.Orange)),
new SolidBrush(Color.FromArgb(alpha, Color.Yellow)),
new SolidBrush(Color.FromArgb(alpha, Color.GreenYellow)),
new SolidBrush(Color.FromArgb(alpha, Color.Green)),
new SolidBrush(Color.FromArgb(alpha, Color.LightBlue)),
new SolidBrush(Color.FromArgb(alpha, Color.Blue)),
new SolidBrush(Color.FromArgb(alpha, Color.Purple)),
new SolidBrush(Color.FromArgb(alpha, Color.DarkViolet))
};
float width = 3000;
float height = 1000;
float border = 60;
float mark_width = 10;
float mark_height = 10;
Bitmap b = new Bitmap((int)width, (int)height);
using (Graphics g = Graphics.FromImage(b))
{
g.Clear(Color.Black);
for (int t = 0; t < TrainingCases.Count; t++)
{
for (int n = 0; n < TrainingCases[t].NormalizedInputs.Length; n++)
{
g.FillRectangle(BrushSet[Classes.IndexOf(TrainingCases[t].Class)],
(((float)(n) / (float)TrainingCases[t].NormalizedInputs.Length) * (width - (border * 2))) + border,
((float)TrainingCases[t].NormalizedInputs[n] * (height - (border * 2))) + border,
mark_width,
mark_height
);
}
}
g.DrawString(Annotation, new Font(FontFamily.GenericMonospace, 24, FontStyle.Bold), Brushes.White, 10, 10);
}
b.Save(Filename);
}
static int Fitness(List<TrainingCase> TrainingCases, List<TrainingCase> ValidationCases, int KNN)
{
int Errors = 0;
for (int v = 0; v < ValidationCases.Count; v++)
{
List<KeyValuePair<String, double>> ClassDistances = new List<KeyValuePair<string, double>>();
for (int t = 0; t < TrainingCases.Count; t++)
{
double Distance = 0;
for (int i = 0; i < TrainingCases[t].Inputs.Length; i++)
{
Distance += Math.Pow(TrainingCases[t].Inputs[i] - ValidationCases[v].Inputs[i], 2);
}
Distance = Math.Sqrt(Distance);
ClassDistances.Add(new KeyValuePair<String, double>(TrainingCases[t].Class, Distance));
}
ClassDistances.Sort((a, b) => { return a.Value.CompareTo(b.Value); });
Dictionary<String, int> ClassVotes = new Dictionary<string, int>();
for (int k = 0; k < KNN && k < ClassDistances.Count; k++)
{
if (!ClassVotes.ContainsKey(ClassDistances[k].Key))
{
ClassVotes.Add(ClassDistances[k].Key, 0);
}
ClassVotes[ClassDistances[k].Key]++;
}
List<KeyValuePair<String, int>> MostClassVotes = ClassVotes.ToList();
MostClassVotes.Sort((a, b) => { return -a.Value.CompareTo(b.Value); });
if (MostClassVotes[0].Key != ValidationCases[v].Class)
{
Errors++;
}
}
return Errors;
}
}
class TrainingCase
{
public static Random R = new Random();
public double[] Inputs;
public double[] NormalizedInputs;
public String Class;
public static List<TrainingCase> BufferedFromDisc;
public static List<TrainingCase> ReadFromDisc (int Partial = -1)
{
if (BufferedFromDisc == null)
{
BufferedFromDisc = new List<TrainingCase>();
using (FileStream FS = File.OpenRead("ocr_train.data"))
{
using (StreamReader SR = new StreamReader(FS))
{
SR.ReadLine(); // fucking headers
while (!SR.EndOfStream)
{
String Line = SR.ReadLine();
String[] Parts = Line.Split(',');
List<double> Inputs = new List<double>();
for (int p = 1; p < Parts.Length; p++)
{
Inputs.Add(Double.Parse(Parts[p]));
}
BufferedFromDisc.Add(new TrainingCase(Inputs.ToArray(), Parts[0]));
}
}
}
}
List<TrainingCase> ret = new List<TrainingCase>();
ret.AddRange(BufferedFromDisc.GetRange(0, Partial == -1 ? BufferedFromDisc.Count: Partial));
return ret;
}
public TrainingCase(double[] Inputs, String Class)
{
this.Inputs = Inputs;
this.NormalizedInputs = new double[Inputs.Length];
this.Class = Class;
}
public TrainingCase SaturationClone(double Sway)
{
double[] CloneInputs = new double[Inputs.Length];
Inputs.CopyTo(CloneInputs, 0);
for (int i = 0; i < CloneInputs.Length; i++)
{
if (Inputs[i] != 0)
{
CloneInputs[i] += (CloneInputs[i] * (Sway * 2) * R.NextDouble()) - Sway;
}
}
return new TrainingCase(CloneInputs, Class);
}
public double[] Normalize(double Min, double Range)
{
for (int i = 0; i < Inputs.Length; i++)
{
NormalizedInputs[i] = (Inputs[i] - Min) / Range;
}
return NormalizedInputs;
}
}
class Result
{
public static void Write(String Filename, List<Result> Results, long TotalRegressionTime)
{
StringBuilder Output = new StringBuilder();
Output.AppendLine("Regression ran in " + TotalRegressionTime + "ms, and generated " + Results.Count + " results");
Output.AppendLine("WaveformID,Samples,SaturationPasses,SaturationSway,KNN,Errors,TotalValidationCases,Error%,RegressionTime(ms)");
for (int r = 0; r < Results.Count; r++)
{
Output.AppendLine(Results[r].ToString());
}
using (FileStream fs = File.Create(Filename))
{
using (StreamWriter sw = new StreamWriter(fs))
{
sw.Write(Output.ToString());
}
}
}
public int WaveformID;
public int Samples;
public int SaturationPasses;
public double SaturationSway;
public int KNN;
public int Errors;
public int ValidationCases;
public long RegressionTime;
public Result(int WaveformID, int Samples, int SaturationPasses, double SaturationSway, int KNN, int Errors, int ValidationCases, long RegressionTime)
{
this.WaveformID = WaveformID;
this.Samples = Samples;
this.SaturationPasses = SaturationPasses;
this.SaturationSway = SaturationSway;
this.KNN = KNN;
this.Errors = Errors;
this.ValidationCases = ValidationCases;
this.RegressionTime = RegressionTime;
}
public override string ToString()
{
return WaveformID + "," + Samples + "," + SaturationPasses + "," + SaturationSway + "," + KNN + "," + Errors + "," + ValidationCases + "," + ((float)Errors / (float)ValidationCases) + "," + RegressionTime;
}
public string ToHumanString()
{
return "WaveformID: " + WaveformID + " Samples: " + Samples + " Passes: " + SaturationPasses + " Sway: " + SaturationSway + " KNN: " + KNN + " Errors: " + Errors + " Of: " + ValidationCases + " %: " + ((float)Errors / (float)ValidationCases) + " MS: " + RegressionTime;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment