Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Playing with BNNS (Swift version). The "hello world" of neural networks.
/*
The "hello world" of neural networks: a simple 3-layer feed-forward
network that implements an XOR logic gate.
The first layer is the input layer. It has two neurons a and b, which
are the two inputs to the XOR gate.
The middle layer is the hidden layer. This has two neurons h1, h2 that
will learn what it means to be an XOR gate.
Neuron a is connected to h1 and h2. Neuron b is also connected to h1
and h2. Each of these four connections has its own weight. You learn
these weights by training the network (not done in this demo program).
The final layer is the output layer. This has a single neuron. Its
value is either "high" or "low", just like the output of an XOR gate.
Both h1 and h2 are connected to the o neuron.
+---+ +----+
| a | | h1 |
+---+ +----+ +---+
| o |
+---+ +----+ +---+
| b | | h2 |
+---+ +----+
The expected output is:
predict(0, 0) should give 0
predict(0, 1) should give 1
predict(1, 0) should give 1
predict(1, 1) should give 0
*/
import Foundation
import Accelerate
private var hiddenLayer: BNNSFilter?
private var outputLayer: BNNSFilter?
func createNetwork() -> Bool {
let activation = BNNSActivation(function: BNNSActivationFunctionSigmoid, alpha: 0, beta: 0)
// These weights and bias values were found by training the network
// (using a different program). These numbers represent what the net
// has learned, in this case the proper response of an XOR gate.
let inputToHiddenWeights: [Float] = [ 54, 14, 17, 14 ]
let inputToHiddenBias: [Float] = [ -8, -20 ]
let hiddenToOutputWeights: [Float] = [ 92, -98 ]
let hiddenToOutputBias: [Float] = [ -48 ]
let inputToHiddenWeightsData = BNNSLayerData(
data: inputToHiddenWeights, data_type: BNNSDataTypeFloat32,
data_scale: 0, data_bias: 0, data_table: nil)
let inputToHiddenBiasData = BNNSLayerData(
data: inputToHiddenBias, data_type: BNNSDataTypeFloat32,
data_scale: 0, data_bias: 0, data_table: nil)
let hiddenToOutputWeightsData = BNNSLayerData(
data:hiddenToOutputWeights, data_type: BNNSDataTypeFloat32,
data_scale: 0, data_bias: 0, data_table: nil)
let hiddenToOutputBiasData = BNNSLayerData(
data: hiddenToOutputBias, data_type: BNNSDataTypeFloat32,
data_scale: 0, data_bias: 0, data_table: nil)
var inputToHiddenParams = BNNSFullyConnectedLayerParameters(
in_size: 2, out_size: 2, weights: inputToHiddenWeightsData,
bias: inputToHiddenBiasData, activation: activation)
var hiddenToOutputParams = BNNSFullyConnectedLayerParameters(
in_size: 2, out_size: 1, weights: hiddenToOutputWeightsData,
bias: hiddenToOutputBiasData, activation: activation)
var inputDesc = BNNSVectorDescriptor(
size: 2, data_type: BNNSDataTypeFloat32, data_scale: 0, data_bias: 0)
var hiddenDesc = BNNSVectorDescriptor(
size: 2, data_type: BNNSDataTypeFloat32, data_scale: 0, data_bias: 0)
hiddenLayer = BNNSFilterCreateFullyConnectedLayer(&inputDesc, &hiddenDesc, &inputToHiddenParams, nil)
if hiddenLayer == nil {
print("BNNSFilterCreateFullyConnectedLayer failed for hidden layer")
return false
}
var outputDesc = BNNSVectorDescriptor(
size: 1, data_type: BNNSDataTypeFloat32, data_scale: 0, data_bias: 0)
outputLayer = BNNSFilterCreateFullyConnectedLayer(&hiddenDesc, &outputDesc, &hiddenToOutputParams, nil)
if outputLayer == nil {
print("BNNSFilterCreateFullyConnectedLayer failed for output layer")
return false
}
return true
}
func predict(_ a: Float, _ b: Float) {
precondition(hiddenLayer != nil)
precondition(outputLayer != nil)
// These arrays hold the inputs and outputs to and from the layers.
let input = [a, b]
var hidden: [Float] = [0, 0]
var output: [Float] = [0]
var status = BNNSFilterApply(hiddenLayer, input, &hidden)
if status != 0 {
print("BNNSFilterApply failed on hidden layer")
}
status = BNNSFilterApply(outputLayer, hidden, &output)
if status != 0 {
print("BNNSFilterApply failed on output layer")
}
print("Predict \(a), \(b) = \(output[0])")
}
func destroyNetwork() {
BNNSFilterDestroy(hiddenLayer)
BNNSFilterDestroy(outputLayer)
}
func run() {
if createNetwork() {
print("Making predictions for XOR gate:")
predict(0, 0)
predict(0, 1)
predict(1, 0)
predict(1, 1)
destroyNetwork()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment