Skip to content
{{ message }}

Instantly share code, notes, and snippets.

# blorsch/single_neuron_ann.swift

Last active May 17, 2020
A single neuron neural network (perceptron) written entirely in Swift with only Foundation
 import Foundation //Swift version of https://github.com/gitshanks/simpleneuralnet/blob/master/neuralnet.py //In Swift you can't easily add, multiply arrays/matrices etc, so I had write it manually and some of them are kinda ugly //Used for my 2020 Marin Academy Math Night Presentation: "Neural Networks and Their Applications in Text Generation" //Pattern to the data: output = column 2 of input var inputs = [[0,0,1], [1,1,1], [1,0,1], [0,1,1]] as [[Double]] var outputs = [, , , ] as [[Double]] //3x1 matrix of weights randomly initialized to be between -1 and 1 var weights = ((0..<3).map( {_ in Double.random(in: -1...1)})).map({[\$0]}) func trainNeuron(trainingInputs: [[Double]], trainingOutputs: [[Double]], i:Int) { for _ in 0.. [[Double]] { return sigmoidArray(input: dot(m1: inputs, m2: weights)) } //swift doesnt really have a dot product function so lets write our own. m1 = AxN, m2 = Nx1 func dot(m1: [[Double]], m2: [[Double]]) -> [[Double]] { var dotted = [[Double]]() for i in 0.. Double { return 1 / (1 + pow(Darwin.M_E, -1*input)) } //derivative of sigmoid (x*(1-x) func sigGrad(input: [[Double]]) -> [[Double]] { return input.map({[(\$0) * (1-(\$0))]}) } //applys sigmoid to every element in Nx1 array func sigmoidArray(input: [[Double]]) -> [[Double]] { return input.map({[sigmoid(input: \$0)]}) } //transpose array - https://stackoverflow.com/a/32922962/4777497 func transpose(input: [[T]]) -> [[T]] { if input.isEmpty{ return [[T]]() } let count = input.count var out = [[T]](repeating: [T](), count: count) for outer in input { for (index, inner) in outer.enumerated() { out[index].append(inner) } } return out } //starting weights print("Starting weights: \(weights)") //train trainNeuron(trainingInputs: inputs, trainingOutputs: outputs, i: 10000) //trained weights print("Trained weights: \(weights)") //validate let validatingOutput = getOutput(inputs: [[1,1,0]], weights: weights) print(validatingOutput) //should be 1 (column 2 is 1 and that was the pattern
to join this conversation on GitHub. Already have an account? Sign in to comment