Skip to content

Instantly share code, notes, and snippets.

@adrianseeley
Last active August 29, 2015 13:56
Show Gist options
  • Save adrianseeley/9221018 to your computer and use it in GitHub Desktop.
Save adrianseeley/9221018 to your computer and use it in GitHub Desktop.
GATO.MEDIUM
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading;
namespace GATO.MEDIUM
{
class Program
{
static void Main(string[] args)
{
Medium m = new Medium(2, 1, 10);
m.Train(new List<TrainingCase>()
{
new TrainingCase(new List<double>() {0, 0}, new List<double>() {0}),
new TrainingCase(new List<double>() {0, 1}, new List<double>() {1}),
new TrainingCase(new List<double>() {1, 0}, new List<double>() {1}),
new TrainingCase(new List<double>() {1, 1}, new List<double>() {0}),
});
}
}
class Medium
{
public List<Cell> Cells = new List<Cell>();
public List<int> InputCellIndexes = new List<int>();
public List<int> OutputCellIndexes = new List<int>();
public Medium(int NumberOfInputCells, int NumberOfOutputCells, int NumberOfHiddenCells)
{
for (int c = 0; c < NumberOfInputCells; c++)
{
InputCellIndexes.Add(Cells.Count);
Cells.Add(new Cell(Cells.Count));
}
for (int c = 0; c < NumberOfOutputCells; c++)
{
OutputCellIndexes.Add(Cells.Count);
Cells.Add(new Cell(Cells.Count));
}
for (int c = 0; c < NumberOfHiddenCells; c++)
{
Cells.Add(new Cell(Cells.Count));
}
}
public void Train(List<TrainingCase> TrainingCases)
{
// Create a list for hisorical error rates
List<List<double>> HistoricalErrorRates = new List<List<double>>();
// Iterate training runs
for (int iter = 0; iter < 10000; iter++)
{
// Create a list of error rates
List<double> ErrorRates = new List<double>();
// Iterate provided training cases
for (int t = 0; t < TrainingCases.Count; t++)
{
// Iterate cells
for (int c = 0; c < Cells.Count; c++)
{
// Reset Cell
Cells[c].Reset();
}
// Ensure input size match
if (TrainingCases[t].Inputs.Count != InputCellIndexes.Count) throw new Exception("Number of inputs for training case must equal the number of input cell indexes");
// Iterate training case inputs
for (int c = 0; c < TrainingCases[t].Inputs.Count; c++)
{
// Stimulate inputs according to training case
Cells[InputCellIndexes[c]].Stimulate(TrainingCases[t].Inputs[c]);
}
// Resolve stimulation in medium
ResolveStimulation();
// Create list to hold measured outputs
List<double> MeasuredOutputs = new List<double>();
// Iterate outputs
for (int c = 0; c < OutputCellIndexes.Count; c++)
{
// Measure output by recording accumulation of cell
MeasuredOutputs.Add(Cells[OutputCellIndexes[c]].Accumulation);
}
// Create a total for output error
double OutputError = 0;
// Iterate outputs
for (int c = 0; c < MeasuredOutputs.Count; c++)
{
// Measure output error at cell, add to total
OutputError += TrainingCases[t].Outputs[c] - MeasuredOutputs[c];
}
// Store output error
ErrorRates.Add(OutputError);
// Iterate cells
for (int c = 0; c < Cells.Count; c++)
{
// Reset Cell
Cells[c].Reset();
}
// Iterate training case inputs
for (int c = 0; c < TrainingCases[t].Inputs.Count; c++)
{
// Stimulate inputs according to training case
Cells[InputCellIndexes[c]].Stimulate(TrainingCases[t].Inputs[c]);
// Provide error stimulation
Cells[InputCellIndexes[c]].StimulateError(OutputError);
}
// Resolve error stimulation in medium
ResolveErrorStimulation();
}
for (int e = 0; e < ErrorRates.Count; e++) Console.Write(ErrorRates[e] + ",");
Console.WriteLine(this);
HistoricalErrorRates.Add(ErrorRates);
Thread.Sleep(100);
}
}
private void ResolveStimulation()
{
bool resolved = false;
// While stimulation transfer has not resolved
while (!resolved)
{
// Mark flag
resolved = true;
// Iterate cells
for (int c = 0; c < Cells.Count; c++)
{
// Flag will be marked false if any cell requires a transfer of stimulation
resolved = resolved && Cells[c].Resolve();
}
// If any cell required a transfer (noted by marking !resolved)
if (!resolved)
{
// Iterate cells
for (int c = 0; c < Cells.Count; c++)
{
// Transfer stimulation
Cells[c].Transfer(Cells);
}
}
}
}
private void ResolveErrorStimulation()
{
bool resolved = false;
// While stimulation transfer has not resolved
while (!resolved)
{
// Mark flag
resolved = true;
// Iterate cells
for (int c = 0; c < Cells.Count; c++)
{
// Flag will be marked false if any cell requires a transfer of stimulation
resolved = resolved && Cells[c].ResolveError();
}
// If any cell required a transfer (noted by marking !resolved)
if (!resolved)
{
// Iterate cells
for (int c = 0; c < Cells.Count; c++)
{
// Transfer stimulation
Cells[c].TransferError(Cells);
}
}
}
// Iterate cells
for (int c = 0; c < Cells.Count; c++)
{
// Mutate cell
Cells[c].Mutate(Cells);
}
}
public override string ToString()
{
StringBuilder ret = new StringBuilder();
for (int c = 0; c < Cells.Count; c++)
{
ret.Append("(" + c + ": " + Cells[c].Impedance + "->");
for (int t = 0; t < Cells[c].TransferCellIndexes.Count; t++)
{
ret.Append(Cells[c].TransferCellIndexes[t] + (t != Cells[c].TransferCellIndexes.Count - 1 ? "," : ""));
}
ret.Append(")");
}
return ret.ToString();
}
}
class Cell
{
public static Random R = new Random();
public double Impedance = 1.0;
public double Accumulation = 0.0;
public double Stimulation = 0.0;
public double TransferStimulation = 0.0;
public double ErrorAccumulation = 0.0;
public double ErrorStimulation = 0.0;
public double ErrorTransferStimulation = 0.0;
public List<int> TransferCellIndexes = new List<int>();
public Cell(int CellIndex)
{
// Each cell starts with itself as a recursive transfer index
TransferCellIndexes.Add(CellIndex);
}
public void Reset()
{
ErrorStimulation = 0.0;
ErrorAccumulation = 0.0;
Stimulation = 0.0;
Accumulation = 0.0;
TransferStimulation = 0.0;
}
public void Stimulate(double Stimulation)
{
this.Stimulation += Stimulation;
}
public void StimulateError(double ErrorStimulation)
{
this.ErrorStimulation += ErrorStimulation;
this.ErrorAccumulation += ErrorStimulation;
}
public bool Resolve()
{
// If there is more stimulation than this cell can impede
if (Stimulation > Impedance)
{
// Calculate the transfer stimulation from input overflow
TransferStimulation = Stimulation - Impedance;
// Burn out the accumulation
Accumulation = 0.0;
// Cleanup stimulation
Stimulation = 0.0;
// Return that this cell is not resolved
return false;
}
// Otherwise this cell can impede all the stimulation
else
{
// Add the stimulation to the accumulation
Accumulation += Stimulation;
// If the accumulation is more than can be impeded
if (Accumulation > Impedance)
{
// Calculate the transfer stimulation from accumulation overflow
TransferStimulation = Accumulation - Impedance;
// Burn out the accumulation
Accumulation = 0.0;
// Cleanup stimulation
Stimulation = 0.0;
// Return that this cell is not resolved
return false;
}
// Otherise this cell accumulated all the stimulation
else
{
// Cleanup stimulation
Stimulation = 0.0;
// Return that this cell is resolved
return true;
}
}
}
public bool ResolveError()
{
// If there is more stimulation than this cell can impede
if (Stimulation > Impedance)
{
// Calculate the transfer stimulation from input overflow
TransferStimulation = Stimulation - Impedance;
// Burn out the accumulation
Accumulation = 0.0;
// Cleanup stimulation
Stimulation = 0.0;
// Return that this cell is not resolved
return false;
}
// Otherwise this cell can impede all the stimulation
else
{
// Add the stimulation to the accumulation
Accumulation += Stimulation;
// If the accumulation is more than can be impeded
if (Accumulation > Impedance)
{
// Calculate the transfer stimulation from accumulation overflow
TransferStimulation = Accumulation - Impedance;
// Burn out the accumulation
Accumulation = 0.0;
// Cleanup stimulation
Stimulation = 0.0;
// Return that this cell is not resolved
return false;
}
// Otherise this cell accumulated all the stimulation
else
{
// Cleanup stimulation
Stimulation = 0.0;
// Return that this cell is resolved
return true;
}
}
}
public void Transfer(List<Cell> Cells)
{
if (TransferCellIndexes.Count == 0) throw new Exception("Cells must have at least one transfer cell index");
// Iterate cells to transfer too
for (int c = 0; c < TransferCellIndexes.Count; c++)
{
// Transfer an equal portion of stimulation to each transfer cell index
Cells[TransferCellIndexes[c]].Stimulate(TransferStimulation / TransferCellIndexes.Count);
}
// Cleanup transfer stimulation
TransferStimulation = 0.0;
}
public void TransferError(List<Cell> Cells)
{
if (TransferCellIndexes.Count == 0) throw new Exception("Cells must have at least one transfer cell index");
// Iterate cells to transfer too
for (int c = 0; c < TransferCellIndexes.Count; c++)
{
// Transfer an equal portion of stimulation to each transfer cell index
Cells[TransferCellIndexes[c]].Stimulate(TransferStimulation / TransferCellIndexes.Count);
// Transfer an equal portion of error stimulation to each transfer cell index
Cells[TransferCellIndexes[c]].StimulateError(ErrorTransferStimulation / TransferCellIndexes.Count);
}
// Cleanup transfer stimulation
TransferStimulation = 0.0;
// Cleanup error transfer stimulation
ErrorTransferStimulation = 0.0;
}
public void Mutate(List<Cell> Cells)
{
double LearningRate = 0.10;
double SmallMutationRate = 0.05;
double LargeMutationRate = 0.03;
// Add or remove impedance
if (R.NextDouble() > 0.5)
{
// Update impedance up based on the error accumulation, the past impedance, and the learning rate
Impedance += ErrorAccumulation * LearningRate;
}
else
{
// Update impedance down based on the error accumulation, the past impedance, and the learning rate
Impedance -= ErrorAccumulation * LearningRate;
}
// Cap impedance
Impedance = Math.Min(1, Math.Max(0, Impedance));
// Randomly decide to mutate
double ShouldMutate = 1 - (R.NextDouble() * ErrorAccumulation);
// If decision was for a large mutation
if (ShouldMutate < LargeMutationRate)
{
// Randomly decide to add or remove
double ShouldRemove = R.NextDouble();
// If there is only one transfer currently
if (TransferCellIndexes.Count == 1)
{
// Force add, as we can't remove the only transfer
ShouldRemove = 0.0;
}
// If we are removing a transfer
if (ShouldRemove > 0.5)
{
// Randomly remove a transfer
TransferCellIndexes.RemoveAt(R.Next(0, TransferCellIndexes.Count));
}
// Otherwise we are adding a transfer
else
{
// Randomly add a transfer
TransferCellIndexes.Add(R.Next(0, Cells.Count));
}
}
// Else if decision was for a small mutation
else if (ShouldMutate < SmallMutationRate)
{
// Randomly change a random transfer
TransferCellIndexes[R.Next(0, TransferCellIndexes.Count)] = R.Next(0, Cells.Count);
}
}
public override string ToString()
{
String ret = "i: " + Impedance + ", a: " + Accumulation + ", s: " + Stimulation + ", ts: " + TransferStimulation + ", ea: " + ErrorAccumulation + ", es: " + ErrorStimulation + " ets: " + ErrorTransferStimulation + " (->";
for (int c = 0; c < TransferCellIndexes.Count; c++)
{
ret += TransferCellIndexes[c] + (c != TransferCellIndexes.Count - 1 ? "," : "");
}
ret += ")";
return ret;
}
}
class TrainingCase
{
public List<double> Inputs = new List<double>();
public List<double> Outputs = new List<double>();
public TrainingCase(List<double> Inputs, List<double> Outputs)
{
this.Inputs = Inputs;
this.Outputs = Outputs;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment