Skip to content

Instantly share code, notes, and snippets.

@eggsbenjamin
Created October 19, 2019 23:06
Show Gist options
  • Save eggsbenjamin/a3609d30ebf111b3d1f64aabbc029711 to your computer and use it in GitHub Desktop.
Save eggsbenjamin/a3609d30ebf111b3d1f64aabbc029711 to your computer and use it in GitHub Desktop.
Golang Machine Learning - K Nearest Neighbours Algorithm
sepal_length sepal_width petal_length petal_width species
5.1 3.5 1.4 0.2 setosa
4.9 3.0 1.4 0.2 setosa
4.7 3.2 1.3 0.2 setosa
4.6 3.1 1.5 0.2 setosa
5.0 3.6 1.4 0.2 setosa
5.4 3.9 1.7 0.4 setosa
4.6 3.4 1.4 0.3 setosa
5.0 3.4 1.5 0.2 setosa
4.4 2.9 1.4 0.2 setosa
4.9 3.1 1.5 0.1 setosa
5.4 3.7 1.5 0.2 setosa
4.8 3.4 1.6 0.2 setosa
4.8 3.0 1.4 0.1 setosa
4.3 3.0 1.1 0.1 setosa
5.8 4.0 1.2 0.2 setosa
5.7 4.4 1.5 0.4 setosa
5.4 3.9 1.3 0.4 setosa
5.1 3.5 1.4 0.3 setosa
5.7 3.8 1.7 0.3 setosa
5.1 3.8 1.5 0.3 setosa
5.4 3.4 1.7 0.2 setosa
5.1 3.7 1.5 0.4 setosa
4.6 3.6 1.0 0.2 setosa
5.1 3.3 1.7 0.5 setosa
4.8 3.4 1.9 0.2 setosa
5.0 3.0 1.6 0.2 setosa
5.0 3.4 1.6 0.4 setosa
5.2 3.5 1.5 0.2 setosa
5.2 3.4 1.4 0.2 setosa
4.7 3.2 1.6 0.2 setosa
4.8 3.1 1.6 0.2 setosa
5.4 3.4 1.5 0.4 setosa
5.2 4.1 1.5 0.1 setosa
5.5 4.2 1.4 0.2 setosa
4.9 3.1 1.5 0.1 setosa
5.0 3.2 1.2 0.2 setosa
5.5 3.5 1.3 0.2 setosa
4.9 3.1 1.5 0.1 setosa
4.4 3.0 1.3 0.2 setosa
5.1 3.4 1.5 0.2 setosa
5.0 3.5 1.3 0.3 setosa
4.5 2.3 1.3 0.3 setosa
4.4 3.2 1.3 0.2 setosa
5.0 3.5 1.6 0.6 setosa
5.1 3.8 1.9 0.4 setosa
4.8 3.0 1.4 0.3 setosa
5.1 3.8 1.6 0.2 setosa
4.6 3.2 1.4 0.2 setosa
5.3 3.7 1.5 0.2 setosa
5.0 3.3 1.4 0.2 setosa
7.0 3.2 4.7 1.4 versicolor
6.4 3.2 4.5 1.5 versicolor
6.9 3.1 4.9 1.5 versicolor
5.5 2.3 4.0 1.3 versicolor
6.5 2.8 4.6 1.5 versicolor
5.7 2.8 4.5 1.3 versicolor
6.3 3.3 4.7 1.6 versicolor
4.9 2.4 3.3 1.0 versicolor
6.6 2.9 4.6 1.3 versicolor
5.2 2.7 3.9 1.4 versicolor
5.0 2.0 3.5 1.0 versicolor
5.9 3.0 4.2 1.5 versicolor
6.0 2.2 4.0 1.0 versicolor
6.1 2.9 4.7 1.4 versicolor
5.6 2.9 3.6 1.3 versicolor
6.7 3.1 4.4 1.4 versicolor
5.6 3.0 4.5 1.5 versicolor
5.8 2.7 4.1 1.0 versicolor
6.2 2.2 4.5 1.5 versicolor
5.6 2.5 3.9 1.1 versicolor
5.9 3.2 4.8 1.8 versicolor
6.1 2.8 4.0 1.3 versicolor
6.3 2.5 4.9 1.5 versicolor
6.1 2.8 4.7 1.2 versicolor
6.4 2.9 4.3 1.3 versicolor
6.6 3.0 4.4 1.4 versicolor
6.8 2.8 4.8 1.4 versicolor
6.7 3.0 5.0 1.7 versicolor
6.0 2.9 4.5 1.5 versicolor
5.7 2.6 3.5 1.0 versicolor
5.5 2.4 3.8 1.1 versicolor
5.5 2.4 3.7 1.0 versicolor
5.8 2.7 3.9 1.2 versicolor
6.0 2.7 5.1 1.6 versicolor
5.4 3.0 4.5 1.5 versicolor
6.0 3.4 4.5 1.6 versicolor
6.7 3.1 4.7 1.5 versicolor
6.3 2.3 4.4 1.3 versicolor
5.6 3.0 4.1 1.3 versicolor
5.5 2.5 4.0 1.3 versicolor
5.5 2.6 4.4 1.2 versicolor
6.1 3.0 4.6 1.4 versicolor
5.8 2.6 4.0 1.2 versicolor
5.0 2.3 3.3 1.0 versicolor
5.6 2.7 4.2 1.3 versicolor
5.7 3.0 4.2 1.2 versicolor
5.7 2.9 4.2 1.3 versicolor
6.2 2.9 4.3 1.3 versicolor
5.1 2.5 3.0 1.1 versicolor
5.7 2.8 4.1 1.3 versicolor
6.3 3.3 6.0 2.5 virginica
5.8 2.7 5.1 1.9 virginica
7.1 3.0 5.9 2.1 virginica
6.3 2.9 5.6 1.8 virginica
6.5 3.0 5.8 2.2 virginica
7.6 3.0 6.6 2.1 virginica
4.9 2.5 4.5 1.7 virginica
7.3 2.9 6.3 1.8 virginica
6.7 2.5 5.8 1.8 virginica
7.2 3.6 6.1 2.5 virginica
6.5 3.2 5.1 2.0 virginica
6.4 2.7 5.3 1.9 virginica
6.8 3.0 5.5 2.1 virginica
5.7 2.5 5.0 2.0 virginica
5.8 2.8 5.1 2.4 virginica
6.4 3.2 5.3 2.3 virginica
6.5 3.0 5.5 1.8 virginica
7.7 3.8 6.7 2.2 virginica
7.7 2.6 6.9 2.3 virginica
6.0 2.2 5.0 1.5 virginica
6.9 3.2 5.7 2.3 virginica
5.6 2.8 4.9 2.0 virginica
7.7 2.8 6.7 2.0 virginica
6.3 2.7 4.9 1.8 virginica
6.7 3.3 5.7 2.1 virginica
7.2 3.2 6.0 1.8 virginica
6.2 2.8 4.8 1.8 virginica
6.1 3.0 4.9 1.8 virginica
6.4 2.8 5.6 2.1 virginica
7.2 3.0 5.8 1.6 virginica
7.4 2.8 6.1 1.9 virginica
7.9 3.8 6.4 2.0 virginica
6.4 2.8 5.6 2.2 virginica
6.3 2.8 5.1 1.5 virginica
6.1 2.6 5.6 1.4 virginica
7.7 3.0 6.1 2.3 virginica
6.3 3.4 5.6 2.4 virginica
6.4 3.1 5.5 1.8 virginica
6.0 3.0 4.8 1.8 virginica
6.9 3.1 5.4 2.1 virginica
6.7 3.1 5.6 2.4 virginica
6.9 3.1 5.1 2.3 virginica
5.8 2.7 5.1 1.9 virginica
6.8 3.2 5.9 2.3 virginica
6.7 3.3 5.7 2.5 virginica
6.7 3.0 5.2 2.3 virginica
6.3 2.5 5.0 1.9 virginica
6.5 3.0 5.2 2.0 virginica
6.2 3.4 5.4 2.3 virginica
5.9 3.0 5.1 1.8 virginica
package main
import (
"math"
"sort"
)
// KNNDistanceFn defines a function that returns the distance between two n dimension vectors
type KNNDistanceFn func(v1, v2 []float64) float64
// KNNClassifier defines a labelled example for training a KNN algorithm
type KNNClassifierLabelledExample struct {
FeatureVector []float64
Label int
}
// KNNClassifier represents a KNN algorithm
type KNNClassifier struct {
k int
labelledExamples []KNNClassifierLabelledExample
distanceFn KNNDistanceFn
}
func NewKNNClassifier(k int, labelledExamples []KNNClassifierLabelledExample, distanceFn KNNDistanceFn) *KNNClassifier {
return &KNNClassifier{
k: k,
labelledExamples: labelledExamples,
distanceFn: distanceFn,
}
}
type knnClassifierNeighbour struct {
distance float64
label int
}
// Classify executes the KNN algorithm
func (c *KNNClassifier) Classify(v []float64) int {
neighbours := []knnClassifierNeighbour{}
for _, example := range c.labelledExamples {
neighbours = append(neighbours, knnClassifierNeighbour{
distance: c.distanceFn(v, example.FeatureVector),
label: example.Label,
})
}
sort.SliceStable(neighbours, func(i, j int) bool {
return neighbours[i].distance < neighbours[j].distance
})
var result int
nearestNeightbourLabelCounts := map[int]int{}
for i := 0; i < c.k; i++ {
nearestNeightbourLabelCounts[neighbours[i].label]++
count := nearestNeightbourLabelCounts[neighbours[i].label]
if count > nearestNeightbourLabelCounts[result] {
result = neighbours[i].label
}
}
return result
}
// EuclideanDistance returns the Euclidean distance between two n dimesional vectors. Implements KNNDistanceFn.
func EuclideanDistance(v1 []float64, v2 []float64) float64 {
lenv1 := len(v1)
lenv2 := len(v2)
if lenv1 > lenv2 {
v2 = append(v2, make([]float64, 0, lenv2-lenv1)...)
}
if lenv2 > lenv1 {
v1 = append(v1, make([]float64, 0, lenv1-lenv2)...)
}
var result float64
for i := 0; i < len(v1); i++ {
result += math.Pow(v2[i]-v1[i], 2)
}
return math.Sqrt(result)
}
package main
import (
"encoding/csv"
"errors"
"io"
"log"
"math"
"math/rand"
"os"
"strconv"
"strings"
"testing"
"time"
)
func TestKNNClassifier(t *testing.T) {
f, err := os.Open("./iris.csv")
if err != nil {
t.Fatal(err)
}
defer f.Close()
labelledExamples, err := prepareIrisDataset(f)
if err != nil {
t.Fatal(err)
}
// shuffle slice
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(labelledExamples), func(i, j int) {
labelledExamples[i], labelledExamples[j] = labelledExamples[j], labelledExamples[i]
})
acceptableMarginOfError := 0.05
// split data between training and test data 80/20
splitIndex := int(0.8 * float64(len(labelledExamples)))
trainingData := labelledExamples[:splitIndex]
testData := labelledExamples[splitIndex:]
type incorrectlyClassifiedResult struct {
labelledExample KNNClassifierLabelledExample
actual int
}
errs := []incorrectlyClassifiedResult{}
k := int(math.Sqrt(float64(len(trainingData))))
knnClassifier := NewKNNClassifier(k, trainingData, EuclideanDistance)
for _, test := range testData {
if actual := knnClassifier.Classify(test.FeatureVector); actual != test.Label {
errs = append(errs, incorrectlyClassifiedResult{
labelledExample: test,
actual: actual,
})
}
}
actualMarginOfError := float64(len(errs)) / float64(len(labelledExamples))
if actualMarginOfError > acceptableMarginOfError {
log.Fatalf("margin of error exceeds acceptable level. actual %f acceptable %f\n%v", actualMarginOfError, acceptableMarginOfError, errs)
}
log.Printf("actual margin of error %f", actualMarginOfError)
}
func TestEuclideandistance(t *testing.T) {
tests := []struct {
title string
v1, v2 []float64
expected float64
}{
{
"3d",
[]float64{3, 6, 5},
[]float64{7, -5, 1},
12.369,
},
}
for _, test := range tests {
t.Run(test.title, func(t *testing.T) {
if actual := EuclideanDistance(test.v1, test.v2); approx(actual, test.expected) {
t.Fatalf("expected %f got %f", test.expected, actual)
}
})
}
}
func prepareIrisDataset(r io.Reader) ([]KNNClassifierLabelledExample, error) {
records, err := csv.NewReader(r).ReadAll()
if err != nil {
return nil, err
}
labelledExamples := []KNNClassifierLabelledExample{}
for i, record := range records {
if i == 0 { // skip header row
continue
}
labelledExample := KNNClassifierLabelledExample{}
for i := 0; i < 4; i++ {
v, err := strconv.ParseFloat(record[i], 64)
if err != nil {
return nil, err
}
labelledExample.FeatureVector = append(labelledExample.FeatureVector, v)
}
switch strings.ToLower(record[4]) {
case "setosa":
labelledExample.Label = 0
case "versicolor":
labelledExample.Label = 1
case "virginica":
labelledExample.Label = 2
default:
return nil, errors.New("unexpected species: " + record[4])
}
labelledExamples = append(labelledExamples, labelledExample)
}
return labelledExamples, nil
}
func approx(v1, v2 float64) bool {
return math.Abs(v1-v2) < 0.00000001
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment