Last active
August 29, 2015 13:56
-
-
Save adrianseeley/9221018 to your computer and use it in GitHub Desktop.
GATO.MEDIUM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
using System.Threading.Tasks; | |
using System.Threading; | |
namespace GATO.MEDIUM | |
{ | |
class Program | |
{ | |
static void Main(string[] args) | |
{ | |
Medium m = new Medium(2, 1, 10); | |
m.Train(new List<TrainingCase>() | |
{ | |
new TrainingCase(new List<double>() {0, 0}, new List<double>() {0}), | |
new TrainingCase(new List<double>() {0, 1}, new List<double>() {1}), | |
new TrainingCase(new List<double>() {1, 0}, new List<double>() {1}), | |
new TrainingCase(new List<double>() {1, 1}, new List<double>() {0}), | |
}); | |
} | |
} | |
class Medium | |
{ | |
public List<Cell> Cells = new List<Cell>(); | |
public List<int> InputCellIndexes = new List<int>(); | |
public List<int> OutputCellIndexes = new List<int>(); | |
public Medium(int NumberOfInputCells, int NumberOfOutputCells, int NumberOfHiddenCells) | |
{ | |
for (int c = 0; c < NumberOfInputCells; c++) | |
{ | |
InputCellIndexes.Add(Cells.Count); | |
Cells.Add(new Cell(Cells.Count)); | |
} | |
for (int c = 0; c < NumberOfOutputCells; c++) | |
{ | |
OutputCellIndexes.Add(Cells.Count); | |
Cells.Add(new Cell(Cells.Count)); | |
} | |
for (int c = 0; c < NumberOfHiddenCells; c++) | |
{ | |
Cells.Add(new Cell(Cells.Count)); | |
} | |
} | |
public void Train(List<TrainingCase> TrainingCases) | |
{ | |
// Create a list for hisorical error rates | |
List<List<double>> HistoricalErrorRates = new List<List<double>>(); | |
// Iterate training runs | |
for (int iter = 0; iter < 10000; iter++) | |
{ | |
// Create a list of error rates | |
List<double> ErrorRates = new List<double>(); | |
// Iterate provided training cases | |
for (int t = 0; t < TrainingCases.Count; t++) | |
{ | |
// Iterate cells | |
for (int c = 0; c < Cells.Count; c++) | |
{ | |
// Reset Cell | |
Cells[c].Reset(); | |
} | |
// Ensure input size match | |
if (TrainingCases[t].Inputs.Count != InputCellIndexes.Count) throw new Exception("Number of inputs for training case must equal the number of input cell indexes"); | |
// Iterate training case inputs | |
for (int c = 0; c < TrainingCases[t].Inputs.Count; c++) | |
{ | |
// Stimulate inputs according to training case | |
Cells[InputCellIndexes[c]].Stimulate(TrainingCases[t].Inputs[c]); | |
} | |
// Resolve stimulation in medium | |
ResolveStimulation(); | |
// Create list to hold measured outputs | |
List<double> MeasuredOutputs = new List<double>(); | |
// Iterate outputs | |
for (int c = 0; c < OutputCellIndexes.Count; c++) | |
{ | |
// Measure output by recording accumulation of cell | |
MeasuredOutputs.Add(Cells[OutputCellIndexes[c]].Accumulation); | |
} | |
// Create a total for output error | |
double OutputError = 0; | |
// Iterate outputs | |
for (int c = 0; c < MeasuredOutputs.Count; c++) | |
{ | |
// Measure output error at cell, add to total | |
OutputError += TrainingCases[t].Outputs[c] - MeasuredOutputs[c]; | |
} | |
// Store output error | |
ErrorRates.Add(OutputError); | |
// Iterate cells | |
for (int c = 0; c < Cells.Count; c++) | |
{ | |
// Reset Cell | |
Cells[c].Reset(); | |
} | |
// Iterate training case inputs | |
for (int c = 0; c < TrainingCases[t].Inputs.Count; c++) | |
{ | |
// Stimulate inputs according to training case | |
Cells[InputCellIndexes[c]].Stimulate(TrainingCases[t].Inputs[c]); | |
// Provide error stimulation | |
Cells[InputCellIndexes[c]].StimulateError(OutputError); | |
} | |
// Resolve error stimulation in medium | |
ResolveErrorStimulation(); | |
} | |
for (int e = 0; e < ErrorRates.Count; e++) Console.Write(ErrorRates[e] + ","); | |
Console.WriteLine(this); | |
HistoricalErrorRates.Add(ErrorRates); | |
Thread.Sleep(100); | |
} | |
} | |
private void ResolveStimulation() | |
{ | |
bool resolved = false; | |
// While stimulation transfer has not resolved | |
while (!resolved) | |
{ | |
// Mark flag | |
resolved = true; | |
// Iterate cells | |
for (int c = 0; c < Cells.Count; c++) | |
{ | |
// Flag will be marked false if any cell requires a transfer of stimulation | |
resolved = resolved && Cells[c].Resolve(); | |
} | |
// If any cell required a transfer (noted by marking !resolved) | |
if (!resolved) | |
{ | |
// Iterate cells | |
for (int c = 0; c < Cells.Count; c++) | |
{ | |
// Transfer stimulation | |
Cells[c].Transfer(Cells); | |
} | |
} | |
} | |
} | |
private void ResolveErrorStimulation() | |
{ | |
bool resolved = false; | |
// While stimulation transfer has not resolved | |
while (!resolved) | |
{ | |
// Mark flag | |
resolved = true; | |
// Iterate cells | |
for (int c = 0; c < Cells.Count; c++) | |
{ | |
// Flag will be marked false if any cell requires a transfer of stimulation | |
resolved = resolved && Cells[c].ResolveError(); | |
} | |
// If any cell required a transfer (noted by marking !resolved) | |
if (!resolved) | |
{ | |
// Iterate cells | |
for (int c = 0; c < Cells.Count; c++) | |
{ | |
// Transfer stimulation | |
Cells[c].TransferError(Cells); | |
} | |
} | |
} | |
// Iterate cells | |
for (int c = 0; c < Cells.Count; c++) | |
{ | |
// Mutate cell | |
Cells[c].Mutate(Cells); | |
} | |
} | |
public override string ToString() | |
{ | |
StringBuilder ret = new StringBuilder(); | |
for (int c = 0; c < Cells.Count; c++) | |
{ | |
ret.Append("(" + c + ": " + Cells[c].Impedance + "->"); | |
for (int t = 0; t < Cells[c].TransferCellIndexes.Count; t++) | |
{ | |
ret.Append(Cells[c].TransferCellIndexes[t] + (t != Cells[c].TransferCellIndexes.Count - 1 ? "," : "")); | |
} | |
ret.Append(")"); | |
} | |
return ret.ToString(); | |
} | |
} | |
class Cell | |
{ | |
public static Random R = new Random(); | |
public double Impedance = 1.0; | |
public double Accumulation = 0.0; | |
public double Stimulation = 0.0; | |
public double TransferStimulation = 0.0; | |
public double ErrorAccumulation = 0.0; | |
public double ErrorStimulation = 0.0; | |
public double ErrorTransferStimulation = 0.0; | |
public List<int> TransferCellIndexes = new List<int>(); | |
public Cell(int CellIndex) | |
{ | |
// Each cell starts with itself as a recursive transfer index | |
TransferCellIndexes.Add(CellIndex); | |
} | |
public void Reset() | |
{ | |
ErrorStimulation = 0.0; | |
ErrorAccumulation = 0.0; | |
Stimulation = 0.0; | |
Accumulation = 0.0; | |
TransferStimulation = 0.0; | |
} | |
public void Stimulate(double Stimulation) | |
{ | |
this.Stimulation += Stimulation; | |
} | |
public void StimulateError(double ErrorStimulation) | |
{ | |
this.ErrorStimulation += ErrorStimulation; | |
this.ErrorAccumulation += ErrorStimulation; | |
} | |
public bool Resolve() | |
{ | |
// If there is more stimulation than this cell can impede | |
if (Stimulation > Impedance) | |
{ | |
// Calculate the transfer stimulation from input overflow | |
TransferStimulation = Stimulation - Impedance; | |
// Burn out the accumulation | |
Accumulation = 0.0; | |
// Cleanup stimulation | |
Stimulation = 0.0; | |
// Return that this cell is not resolved | |
return false; | |
} | |
// Otherwise this cell can impede all the stimulation | |
else | |
{ | |
// Add the stimulation to the accumulation | |
Accumulation += Stimulation; | |
// If the accumulation is more than can be impeded | |
if (Accumulation > Impedance) | |
{ | |
// Calculate the transfer stimulation from accumulation overflow | |
TransferStimulation = Accumulation - Impedance; | |
// Burn out the accumulation | |
Accumulation = 0.0; | |
// Cleanup stimulation | |
Stimulation = 0.0; | |
// Return that this cell is not resolved | |
return false; | |
} | |
// Otherise this cell accumulated all the stimulation | |
else | |
{ | |
// Cleanup stimulation | |
Stimulation = 0.0; | |
// Return that this cell is resolved | |
return true; | |
} | |
} | |
} | |
public bool ResolveError() | |
{ | |
// If there is more stimulation than this cell can impede | |
if (Stimulation > Impedance) | |
{ | |
// Calculate the transfer stimulation from input overflow | |
TransferStimulation = Stimulation - Impedance; | |
// Burn out the accumulation | |
Accumulation = 0.0; | |
// Cleanup stimulation | |
Stimulation = 0.0; | |
// Return that this cell is not resolved | |
return false; | |
} | |
// Otherwise this cell can impede all the stimulation | |
else | |
{ | |
// Add the stimulation to the accumulation | |
Accumulation += Stimulation; | |
// If the accumulation is more than can be impeded | |
if (Accumulation > Impedance) | |
{ | |
// Calculate the transfer stimulation from accumulation overflow | |
TransferStimulation = Accumulation - Impedance; | |
// Burn out the accumulation | |
Accumulation = 0.0; | |
// Cleanup stimulation | |
Stimulation = 0.0; | |
// Return that this cell is not resolved | |
return false; | |
} | |
// Otherise this cell accumulated all the stimulation | |
else | |
{ | |
// Cleanup stimulation | |
Stimulation = 0.0; | |
// Return that this cell is resolved | |
return true; | |
} | |
} | |
} | |
public void Transfer(List<Cell> Cells) | |
{ | |
if (TransferCellIndexes.Count == 0) throw new Exception("Cells must have at least one transfer cell index"); | |
// Iterate cells to transfer too | |
for (int c = 0; c < TransferCellIndexes.Count; c++) | |
{ | |
// Transfer an equal portion of stimulation to each transfer cell index | |
Cells[TransferCellIndexes[c]].Stimulate(TransferStimulation / TransferCellIndexes.Count); | |
} | |
// Cleanup transfer stimulation | |
TransferStimulation = 0.0; | |
} | |
public void TransferError(List<Cell> Cells) | |
{ | |
if (TransferCellIndexes.Count == 0) throw new Exception("Cells must have at least one transfer cell index"); | |
// Iterate cells to transfer too | |
for (int c = 0; c < TransferCellIndexes.Count; c++) | |
{ | |
// Transfer an equal portion of stimulation to each transfer cell index | |
Cells[TransferCellIndexes[c]].Stimulate(TransferStimulation / TransferCellIndexes.Count); | |
// Transfer an equal portion of error stimulation to each transfer cell index | |
Cells[TransferCellIndexes[c]].StimulateError(ErrorTransferStimulation / TransferCellIndexes.Count); | |
} | |
// Cleanup transfer stimulation | |
TransferStimulation = 0.0; | |
// Cleanup error transfer stimulation | |
ErrorTransferStimulation = 0.0; | |
} | |
public void Mutate(List<Cell> Cells) | |
{ | |
double LearningRate = 0.10; | |
double SmallMutationRate = 0.05; | |
double LargeMutationRate = 0.03; | |
// Add or remove impedance | |
if (R.NextDouble() > 0.5) | |
{ | |
// Update impedance up based on the error accumulation, the past impedance, and the learning rate | |
Impedance += ErrorAccumulation * LearningRate; | |
} | |
else | |
{ | |
// Update impedance down based on the error accumulation, the past impedance, and the learning rate | |
Impedance -= ErrorAccumulation * LearningRate; | |
} | |
// Cap impedance | |
Impedance = Math.Min(1, Math.Max(0, Impedance)); | |
// Randomly decide to mutate | |
double ShouldMutate = 1 - (R.NextDouble() * ErrorAccumulation); | |
// If decision was for a large mutation | |
if (ShouldMutate < LargeMutationRate) | |
{ | |
// Randomly decide to add or remove | |
double ShouldRemove = R.NextDouble(); | |
// If there is only one transfer currently | |
if (TransferCellIndexes.Count == 1) | |
{ | |
// Force add, as we can't remove the only transfer | |
ShouldRemove = 0.0; | |
} | |
// If we are removing a transfer | |
if (ShouldRemove > 0.5) | |
{ | |
// Randomly remove a transfer | |
TransferCellIndexes.RemoveAt(R.Next(0, TransferCellIndexes.Count)); | |
} | |
// Otherwise we are adding a transfer | |
else | |
{ | |
// Randomly add a transfer | |
TransferCellIndexes.Add(R.Next(0, Cells.Count)); | |
} | |
} | |
// Else if decision was for a small mutation | |
else if (ShouldMutate < SmallMutationRate) | |
{ | |
// Randomly change a random transfer | |
TransferCellIndexes[R.Next(0, TransferCellIndexes.Count)] = R.Next(0, Cells.Count); | |
} | |
} | |
public override string ToString() | |
{ | |
String ret = "i: " + Impedance + ", a: " + Accumulation + ", s: " + Stimulation + ", ts: " + TransferStimulation + ", ea: " + ErrorAccumulation + ", es: " + ErrorStimulation + " ets: " + ErrorTransferStimulation + " (->"; | |
for (int c = 0; c < TransferCellIndexes.Count; c++) | |
{ | |
ret += TransferCellIndexes[c] + (c != TransferCellIndexes.Count - 1 ? "," : ""); | |
} | |
ret += ")"; | |
return ret; | |
} | |
} | |
class TrainingCase | |
{ | |
public List<double> Inputs = new List<double>(); | |
public List<double> Outputs = new List<double>(); | |
public TrainingCase(List<double> Inputs, List<double> Outputs) | |
{ | |
this.Inputs = Inputs; | |
this.Outputs = Outputs; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment