Skip to content

Instantly share code, notes, and snippets.

@montanaflynn
Created January 16, 2019 09:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save montanaflynn/6f76f4b67f6ed6a98cad45c441b3c67f to your computer and use it in GitHub Desktop.
Save montanaflynn/6f76f4b67f6ed6a98cad45c441b3c67f to your computer and use it in GitHub Desktop.
Simple neural net with one hidden layer consisting of one neuron
// Found at https://play.golang.org/p/sR0vNRAQD1
// Inspired by https://medium.com/technology-invention-and-more/how-to-build-a-simple-neural-network-in-9-lines-of-python-code-cc8f23647ca1
package main
import (
"fmt"
"math/rand"
"math"
)
type activation func(float64) float64
func sigmoid(x float64) float64 {
return 1 / (1+math.Exp(x * (-1)))
}
func sigmoid_d(x float64) float64 {
return math.Exp(x) / math.Pow((math.Exp(x) + 1), 2.0)
}
func activate(x, w []float64, fn activation) float64 {
raw := 0.0
for j := 0; j < len(x); j += 1 {
raw += x[j] * w[j]
}
return fn(raw)
}
func totalMseLoss(X [][]float64, y, w []float64) float64 {
loss := 0.0
for i := 0; i < len(X); i += 1 {
loss += math.Pow(y[i] - activate(X[i], w, sigmoid), 2.0)
}
return loss / float64(len(X))
}
func info(X [][]float64, y, w0 []float64, epoch int) {
fmt.Println("Iteration", epoch + 1)
fmt.Println("Loss: ", totalMseLoss(X, y, w0))
for i := 0; i < len(X); i += 1 {
fmt.Printf("%.4f ", activate(X[i], w0, sigmoid))
if i == len(X) - 1 {
fmt.Println()
}
}
for i := 0; i < len(w0); i += 1 {
fmt.Printf("w%d %.4f ", i, w0[i])
if i == len(w0) - 1 {
fmt.Println()
}
}
fmt.Println("------")
}
func main() {
rate := 0.1
X := [][]float64{
[]float64{0.0, 0.0, 1.0},
[]float64{0.0, 1.0, 1.0},
[]float64{1.0, 0.0, 1.0},
[]float64{1.0, 1.0, 1.0},
}
y := []float64{0.0, 0.0, 1.0, 1.0}
w0 := make([]float64, len(X[0]))
for i := 0; i < len(w0); i += 1 {
w0[i] = 2*rand.Float64() - 1
}
for epoch := 0; epoch < 1000; epoch += 1 {
for i := 0; i < len(X); i += 1 {
out := activate(X[i], w0, sigmoid)
error := y[i] - out
for j := 0; j < len(X[0]); j += 1 {
w0[j] += rate * X[i][j] * error * out
}
}
if epoch % 100 == 0 {
info(X, y, w0, epoch)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment