Skip to content

Instantly share code, notes, and snippets.

@muety
Last active August 24, 2017 12:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save muety/cc09a854db6c94089ec40983c0a58e6a to your computer and use it in GitHub Desktop.
Save muety/cc09a854db6c94089ec40983c0a58e6a to your computer and use it in GitHub Desktop.
Simple Neural Network in Go
/* Simple neural net with one hidden layer consisting of one neuron */
/* Inspired by https://medium.com/technology-invention-and-more/how-to-build-a-simple-neural-network-in-9-lines-of-python-code-cc8f23647ca1 */
package main
import (
"fmt"
"math/rand"
"math"
)
type activation func(float64) float64
func sigmoid(x float64) float64 {
return 1 / (1+math.Exp(x * (-1)))
}
func sigmoid_d(x float64) float64 {
return math.Exp(x) / math.Pow((math.Exp(x) + 1), 2.0)
}
func activate(x, w []float64, fn activation) float64 {
raw := 0.0
for j := 0; j < len(x); j += 1 {
raw += x[j] * w[j]
}
return fn(raw)
}
func totalMseLoss(X [][]float64, y, w []float64) float64 {
loss := 0.0
for i := 0; i < len(X); i += 1 {
loss += math.Pow(y[i] - activate(X[i], w, sigmoid), 2.0)
}
return loss / float64(len(X))
}
func info(X [][]float64, y, w0 []float64, epoch int) {
fmt.Println("Iteration", epoch + 1)
fmt.Println("Loss: ", totalMseLoss(X, y, w0))
for i := 0; i < len(X); i += 1 {
fmt.Printf("%.4f ", activate(X[i], w0, sigmoid))
if i == len(X) - 1 {
fmt.Println()
}
}
for i := 0; i < len(w0); i += 1 {
fmt.Printf("w%d %.4f ", i, w0[i])
if i == len(w0) - 1 {
fmt.Println()
}
}
fmt.Println("------")
}
func main() {
rate := 0.1
X := [][]float64{
[]float64{0.0, 0.0, 1.0},
[]float64{0.0, 1.0, 1.0},
[]float64{1.0, 0.0, 1.0},
[]float64{1.0, 1.0, 1.0},
}
y := []float64{0.0, 0.0, 1.0, 1.0}
w0 := make([]float64, len(X[0]))
for i := 0; i < len(w0); i += 1 {
w0[i] = 2*rand.Float64() - 1
}
for epoch := 0; epoch < 1000; epoch += 1 {
for i := 0; i < len(X); i += 1 {
out := activate(X[i], w0, sigmoid)
error := y[i] - out
for j := 0; j < len(X[0]); j += 1 {
// See (22) in http://www.idi.ntnu.no/~keithd/classes/advai/lectures/backprop.pdf
w0[j] += rate * X[i][j] * error * sigmoid_d(out)
}
}
if epoch % 100 == 0 {
info(X, y, w0, epoch)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment