Skip to content

Instantly share code, notes, and snippets.

@nfisher
Created March 27, 2023 13:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nfisher/f31bdcb3fbc99686f2383315ced8bd4c to your computer and use it in GitHub Desktop.
Save nfisher/f31bdcb3fbc99686f2383315ced8bd4c to your computer and use it in GitHub Desktop.
Perceptrons
sepal_length sepal_width petal_length petal_width species
5.1 3.5 1.4 0.2 setosa
4.9 3.0 1.4 0.2 setosa
4.7 3.2 1.3 0.2 setosa
4.6 3.1 1.5 0.2 setosa
5.0 3.6 1.4 0.2 setosa
5.4 3.9 1.7 0.4 setosa
4.6 3.4 1.4 0.3 setosa
5.0 3.4 1.5 0.2 setosa
4.4 2.9 1.4 0.2 setosa
4.9 3.1 1.5 0.1 setosa
5.4 3.7 1.5 0.2 setosa
4.8 3.4 1.6 0.2 setosa
4.8 3.0 1.4 0.1 setosa
4.3 3.0 1.1 0.1 setosa
5.8 4.0 1.2 0.2 setosa
5.7 4.4 1.5 0.4 setosa
5.4 3.9 1.3 0.4 setosa
5.1 3.5 1.4 0.3 setosa
5.7 3.8 1.7 0.3 setosa
5.1 3.8 1.5 0.3 setosa
5.4 3.4 1.7 0.2 setosa
5.1 3.7 1.5 0.4 setosa
4.6 3.6 1.0 0.2 setosa
5.1 3.3 1.7 0.5 setosa
4.8 3.4 1.9 0.2 setosa
5.0 3.0 1.6 0.2 setosa
5.0 3.4 1.6 0.4 setosa
5.2 3.5 1.5 0.2 setosa
5.2 3.4 1.4 0.2 setosa
4.7 3.2 1.6 0.2 setosa
4.8 3.1 1.6 0.2 setosa
5.4 3.4 1.5 0.4 setosa
5.2 4.1 1.5 0.1 setosa
5.5 4.2 1.4 0.2 setosa
4.9 3.1 1.5 0.1 setosa
5.0 3.2 1.2 0.2 setosa
5.5 3.5 1.3 0.2 setosa
4.9 3.1 1.5 0.1 setosa
4.4 3.0 1.3 0.2 setosa
5.1 3.4 1.5 0.2 setosa
5.0 3.5 1.3 0.3 setosa
4.5 2.3 1.3 0.3 setosa
4.4 3.2 1.3 0.2 setosa
5.0 3.5 1.6 0.6 setosa
5.1 3.8 1.9 0.4 setosa
4.8 3.0 1.4 0.3 setosa
5.1 3.8 1.6 0.2 setosa
4.6 3.2 1.4 0.2 setosa
5.3 3.7 1.5 0.2 setosa
5.0 3.3 1.4 0.2 setosa
7.0 3.2 4.7 1.4 versicolor
6.4 3.2 4.5 1.5 versicolor
6.9 3.1 4.9 1.5 versicolor
5.5 2.3 4.0 1.3 versicolor
6.5 2.8 4.6 1.5 versicolor
5.7 2.8 4.5 1.3 versicolor
6.3 3.3 4.7 1.6 versicolor
4.9 2.4 3.3 1.0 versicolor
6.6 2.9 4.6 1.3 versicolor
5.2 2.7 3.9 1.4 versicolor
5.0 2.0 3.5 1.0 versicolor
5.9 3.0 4.2 1.5 versicolor
6.0 2.2 4.0 1.0 versicolor
6.1 2.9 4.7 1.4 versicolor
5.6 2.9 3.6 1.3 versicolor
6.7 3.1 4.4 1.4 versicolor
5.6 3.0 4.5 1.5 versicolor
5.8 2.7 4.1 1.0 versicolor
6.2 2.2 4.5 1.5 versicolor
5.6 2.5 3.9 1.1 versicolor
5.9 3.2 4.8 1.8 versicolor
6.1 2.8 4.0 1.3 versicolor
6.3 2.5 4.9 1.5 versicolor
6.1 2.8 4.7 1.2 versicolor
6.4 2.9 4.3 1.3 versicolor
6.6 3.0 4.4 1.4 versicolor
6.8 2.8 4.8 1.4 versicolor
6.7 3.0 5.0 1.7 versicolor
6.0 2.9 4.5 1.5 versicolor
5.7 2.6 3.5 1.0 versicolor
5.5 2.4 3.8 1.1 versicolor
5.5 2.4 3.7 1.0 versicolor
5.8 2.7 3.9 1.2 versicolor
6.0 2.7 5.1 1.6 versicolor
5.4 3.0 4.5 1.5 versicolor
6.0 3.4 4.5 1.6 versicolor
6.7 3.1 4.7 1.5 versicolor
6.3 2.3 4.4 1.3 versicolor
5.6 3.0 4.1 1.3 versicolor
5.5 2.5 4.0 1.3 versicolor
5.5 2.6 4.4 1.2 versicolor
6.1 3.0 4.6 1.4 versicolor
5.8 2.6 4.0 1.2 versicolor
5.0 2.3 3.3 1.0 versicolor
5.6 2.7 4.2 1.3 versicolor
5.7 3.0 4.2 1.2 versicolor
5.7 2.9 4.2 1.3 versicolor
6.2 2.9 4.3 1.3 versicolor
5.1 2.5 3.0 1.1 versicolor
5.7 2.8 4.1 1.3 versicolor
6.3 3.3 6.0 2.5 virginica
5.8 2.7 5.1 1.9 virginica
7.1 3.0 5.9 2.1 virginica
6.3 2.9 5.6 1.8 virginica
6.5 3.0 5.8 2.2 virginica
7.6 3.0 6.6 2.1 virginica
4.9 2.5 4.5 1.7 virginica
7.3 2.9 6.3 1.8 virginica
6.7 2.5 5.8 1.8 virginica
7.2 3.6 6.1 2.5 virginica
6.5 3.2 5.1 2.0 virginica
6.4 2.7 5.3 1.9 virginica
6.8 3.0 5.5 2.1 virginica
5.7 2.5 5.0 2.0 virginica
5.8 2.8 5.1 2.4 virginica
6.4 3.2 5.3 2.3 virginica
6.5 3.0 5.5 1.8 virginica
7.7 3.8 6.7 2.2 virginica
7.7 2.6 6.9 2.3 virginica
6.0 2.2 5.0 1.5 virginica
6.9 3.2 5.7 2.3 virginica
5.6 2.8 4.9 2.0 virginica
7.7 2.8 6.7 2.0 virginica
6.3 2.7 4.9 1.8 virginica
6.7 3.3 5.7 2.1 virginica
7.2 3.2 6.0 1.8 virginica
6.2 2.8 4.8 1.8 virginica
6.1 3.0 4.9 1.8 virginica
6.4 2.8 5.6 2.1 virginica
7.2 3.0 5.8 1.6 virginica
7.4 2.8 6.1 1.9 virginica
7.9 3.8 6.4 2.0 virginica
6.4 2.8 5.6 2.2 virginica
6.3 2.8 5.1 1.5 virginica
6.1 2.6 5.6 1.4 virginica
7.7 3.0 6.1 2.3 virginica
6.3 3.4 5.6 2.4 virginica
6.4 3.1 5.5 1.8 virginica
6.0 3.0 4.8 1.8 virginica
6.9 3.1 5.4 2.1 virginica
6.7 3.1 5.6 2.4 virginica
6.9 3.1 5.1 2.3 virginica
5.8 2.7 5.1 1.9 virginica
6.8 3.2 5.9 2.3 virginica
6.7 3.3 5.7 2.5 virginica
6.7 3.0 5.2 2.3 virginica
6.3 2.5 5.0 1.9 virginica
6.5 3.0 5.2 2.0 virginica
6.2 3.4 5.4 2.3 virginica
5.9 3.0 5.1 1.8 virginica
package main
import (
"encoding/csv"
"errors"
"fmt"
"log"
"math"
"math/rand"
"os"
"strconv"
)
func main() {
log.SetFlags(0)
input := &Data[float64]{}
err := LoadIris(input, true)
if err != nil {
log.Fatalln(err)
}
input.Split(0.70)
fmt.Println(input.TargetName)
_, features := input.Shape()
p := New(features, step[float64])
err = Fit(p, 1000, 0.001, input)
if err != nil {
log.Fatalln("fit", err)
}
result, err := input.Test(func(X []float64) (float64, error) {
return Predict(p, X, false)
})
if err != nil {
log.Println(err)
}
log.Printf("precision=%0.2f%%\n", result)
}
type stepFn[T Numeric] func(T, error) (T, error)
// Fit trains a perceptron (p) on the data set (d) over a number of iterations (iters). It uses the step function (step)
// to align the data to the labels. The learning rate (r) influences how much change there is to the weights for each
// correction.
func Fit[T Numeric](p *Perceptron[T], iters int, r T, d *Data[T]) error {
for n := 0; n < iters; n++ {
err := d.Train(func(X []T, target T) error {
Yjt, err := p.Step(Dot(p.Weights, X))
if err != nil {
return err
}
if Yjt == target {
return nil
}
diff := Yjt - target
for j, w := range p.Weights {
p.Weights[j] = w - r*diff*X[j]
}
return nil
})
if err != nil {
return err
}
}
return nil
}
func step[T Numeric](x T, err error) (T, error) {
if err != nil {
return 0, err
}
s := math.Round(float64(x))
if s > 0 {
return T(s), nil
}
return 0, nil
}
func Predict[T Numeric](p *Perceptron[T], x []T, addBiasCol bool) (T, error) {
if addBiasCol {
x = append(x, 1)
}
return p.Step(Dot(p.Weights, x))
}
func New[T Numeric](features int, fn stepFn[T]) *Perceptron[T] {
weights := make([]T, features)
return &Perceptron[T]{
Weights: weights,
Step: fn,
}
}
type Perceptron[T Numeric] struct {
Weights []T
Step stepFn[T]
}
type Data[T Numeric] struct {
Values [][]T
Target []T
TargetName map[string]T
Headers []string
features int
train []int
test []int
}
func (d *Data[T]) Shape() (int, int) {
return len(d.Values), d.features
}
func (d *Data[T]) Split(pct float64) {
order := make([]int, len(d.Values))
for i := 0; i < len(order); i++ {
order[i] = i
}
rand.Shuffle(len(order), func(i, j int) {
order[i], order[j] = order[j], order[i]
})
trainSz := int(pct * float64(len(d.Values)))
d.train = order[:trainSz]
d.test = order[trainSz:]
}
func (d *Data[T]) Train(fn func(X []T, target T) error) error {
for _, i := range d.train {
err := fn(d.Values[i], d.Target[i])
if err != nil {
return err
}
}
return nil
}
func (d *Data[T]) Test(fn func(X []T) (T, error)) (float64, error) {
var correct int
for _, i := range d.test {
p, err := fn(d.Values[i])
if err != nil {
return 0, err
}
if p == d.Target[i] {
correct++
}
}
return float64(correct) / float64(len(d.test)) * 100.0, nil
}
func LoadIris[T Numeric](d *Data[T], addBiasColumn bool) error {
r, err := os.Open("iris.csv")
if err != nil {
return err
}
f := csv.NewReader(r)
records, err := f.ReadAll()
if err != nil {
return err
}
d.Headers = records[0][:len(records[0])-1]
d.TargetName = map[string]T{}
records = records[1:] // trim header row
var a []T
for _, row := range records {
name := row[len(row)-1]
v, ok := d.TargetName[name]
if !ok {
v = T(len(d.TargetName)) + 1
d.TargetName[name] = v
}
d.Target = append(d.Target, v)
a = []T{}
for _, s := range row[:len(row)-1] {
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return err
}
a = append(a, T(f))
}
if addBiasColumn {
a = append(a, 1)
}
d.Values = append(d.Values, a)
}
d.features = len(a)
return nil
}
func Dot[T Numeric](a, b []T) (T, error) {
if len(a) != len(b) {
return 0, fmt.Errorf("unaligned vectors a=%d, b=%d", len(a), len(b))
}
if len(a) == 0 {
return 0, errors.New("empty vectors")
}
var sum T = 0
for i := range a {
sum += a[i] * b[i]
}
return sum, nil
}
type Numeric interface {
~int | ~float32 | ~float64
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment