This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// A high-level encapsulation of a neural net | |
and NeuralNet(numInputs: int, numOutputs: int) = | |
do | |
if numInputs <= 0 then invalidArg "numInputs" "There must be at least one neuron in the input layer"; | |
if numOutputs <= 0 then invalidArg "numOutputs" "There must be at least one neuron in the output layer"; | |
let inputLayer: NeuralNetLayer = new NeuralNetLayer(numInputs); | |
let outputLayer: NeuralNetLayer = new NeuralNetLayer(numOutputs); | |
let mutable hiddenLayers: NeuralNetLayer seq = Seq.empty; | |
let mutable isConnected: bool = false; | |
let connectLayers (n1:NeuralNetLayer) (n2:NeuralNetLayer) = n1.Connect(n2); | |
let layersMinusInput: NeuralNetLayer seq = | |
seq { | |
for layer in hiddenLayers do yield layer; | |
yield outputLayer; | |
} | |
let layersMinusOutput: NeuralNetLayer seq = | |
seq { | |
yield inputLayer; | |
for layer in hiddenLayers do yield layer; | |
} | |
/// Yields all connections to nodes inside of the network | |
let connections = Seq.collect (fun (l:NeuralNetLayer) -> l.Neurons) layersMinusInput | |
|> Seq.collect (fun (n:Neuron) -> n.Inputs); | |
/// Gets the layers of the neural network, in sequential order | |
member this.Layers: NeuralNetLayer seq = | |
seq { | |
yield inputLayer; | |
for layer in hiddenLayers do | |
yield layer; | |
yield outputLayer; | |
} | |
/// Represents the input layer for the network which take in values from another system | |
member this.InputLayer = inputLayer; | |
/// Represents the last layer in the network which has the values that will be taken out of the network | |
member this.OutputLayer = outputLayer; | |
/// Connects the various layers of the neural network | |
member this.Connect() = | |
if isConnected then invalidOp "The Neural Network has already been connected"; | |
Seq.iter2 (fun l lNext -> connectLayers l lNext) layersMinusOutput layersMinusInput | |
isConnected <- true; | |
/// Determines whether or not the network has been connected. After the network is connected, it can no longer be added to | |
member this.IsConnected = isConnected; | |
/// Adds a hidden layer to the middle of the neural net | |
member this.AddHiddenLayer(layer: NeuralNetLayer) = | |
if isConnected then invalidOp "Hidden layers cannot be added after the network has been connected."; | |
hiddenLayers <- Seq.append hiddenLayers [layer]; | |
/// Sets the weights on all connections in the neural network | |
member this.SetWeights(weights: decimal seq) = | |
if isConnected = false then do this.Connect(); | |
Seq.iter2 (fun (w:decimal) (c:NeuronConnection) -> c.Weight <- w) weights connections; | |
/// Evaluates the entire neural network and yields the result of the output layer | |
member this.Evaluate(): decimal seq = | |
if not isConnected then do this.Connect(); | |
// Iterate through the layers and run calculations | |
let mutable result: decimal seq = Seq.empty; | |
for layer in this.Layers do | |
result <- layer.Evaluate(); | |
result; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment