Created
October 19, 2019 23:06
-
-
Save eggsbenjamin/a3609d30ebf111b3d1f64aabbc029711 to your computer and use it in GitHub Desktop.
Golang Machine Learning - K Nearest Neighbours Algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sepal_length | sepal_width | petal_length | petal_width | species | |
---|---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | setosa | |
4.9 | 3.0 | 1.4 | 0.2 | setosa | |
4.7 | 3.2 | 1.3 | 0.2 | setosa | |
4.6 | 3.1 | 1.5 | 0.2 | setosa | |
5.0 | 3.6 | 1.4 | 0.2 | setosa | |
5.4 | 3.9 | 1.7 | 0.4 | setosa | |
4.6 | 3.4 | 1.4 | 0.3 | setosa | |
5.0 | 3.4 | 1.5 | 0.2 | setosa | |
4.4 | 2.9 | 1.4 | 0.2 | setosa | |
4.9 | 3.1 | 1.5 | 0.1 | setosa | |
5.4 | 3.7 | 1.5 | 0.2 | setosa | |
4.8 | 3.4 | 1.6 | 0.2 | setosa | |
4.8 | 3.0 | 1.4 | 0.1 | setosa | |
4.3 | 3.0 | 1.1 | 0.1 | setosa | |
5.8 | 4.0 | 1.2 | 0.2 | setosa | |
5.7 | 4.4 | 1.5 | 0.4 | setosa | |
5.4 | 3.9 | 1.3 | 0.4 | setosa | |
5.1 | 3.5 | 1.4 | 0.3 | setosa | |
5.7 | 3.8 | 1.7 | 0.3 | setosa | |
5.1 | 3.8 | 1.5 | 0.3 | setosa | |
5.4 | 3.4 | 1.7 | 0.2 | setosa | |
5.1 | 3.7 | 1.5 | 0.4 | setosa | |
4.6 | 3.6 | 1.0 | 0.2 | setosa | |
5.1 | 3.3 | 1.7 | 0.5 | setosa | |
4.8 | 3.4 | 1.9 | 0.2 | setosa | |
5.0 | 3.0 | 1.6 | 0.2 | setosa | |
5.0 | 3.4 | 1.6 | 0.4 | setosa | |
5.2 | 3.5 | 1.5 | 0.2 | setosa | |
5.2 | 3.4 | 1.4 | 0.2 | setosa | |
4.7 | 3.2 | 1.6 | 0.2 | setosa | |
4.8 | 3.1 | 1.6 | 0.2 | setosa | |
5.4 | 3.4 | 1.5 | 0.4 | setosa | |
5.2 | 4.1 | 1.5 | 0.1 | setosa | |
5.5 | 4.2 | 1.4 | 0.2 | setosa | |
4.9 | 3.1 | 1.5 | 0.1 | setosa | |
5.0 | 3.2 | 1.2 | 0.2 | setosa | |
5.5 | 3.5 | 1.3 | 0.2 | setosa | |
4.9 | 3.1 | 1.5 | 0.1 | setosa | |
4.4 | 3.0 | 1.3 | 0.2 | setosa | |
5.1 | 3.4 | 1.5 | 0.2 | setosa | |
5.0 | 3.5 | 1.3 | 0.3 | setosa | |
4.5 | 2.3 | 1.3 | 0.3 | setosa | |
4.4 | 3.2 | 1.3 | 0.2 | setosa | |
5.0 | 3.5 | 1.6 | 0.6 | setosa | |
5.1 | 3.8 | 1.9 | 0.4 | setosa | |
4.8 | 3.0 | 1.4 | 0.3 | setosa | |
5.1 | 3.8 | 1.6 | 0.2 | setosa | |
4.6 | 3.2 | 1.4 | 0.2 | setosa | |
5.3 | 3.7 | 1.5 | 0.2 | setosa | |
5.0 | 3.3 | 1.4 | 0.2 | setosa | |
7.0 | 3.2 | 4.7 | 1.4 | versicolor | |
6.4 | 3.2 | 4.5 | 1.5 | versicolor | |
6.9 | 3.1 | 4.9 | 1.5 | versicolor | |
5.5 | 2.3 | 4.0 | 1.3 | versicolor | |
6.5 | 2.8 | 4.6 | 1.5 | versicolor | |
5.7 | 2.8 | 4.5 | 1.3 | versicolor | |
6.3 | 3.3 | 4.7 | 1.6 | versicolor | |
4.9 | 2.4 | 3.3 | 1.0 | versicolor | |
6.6 | 2.9 | 4.6 | 1.3 | versicolor | |
5.2 | 2.7 | 3.9 | 1.4 | versicolor | |
5.0 | 2.0 | 3.5 | 1.0 | versicolor | |
5.9 | 3.0 | 4.2 | 1.5 | versicolor | |
6.0 | 2.2 | 4.0 | 1.0 | versicolor | |
6.1 | 2.9 | 4.7 | 1.4 | versicolor | |
5.6 | 2.9 | 3.6 | 1.3 | versicolor | |
6.7 | 3.1 | 4.4 | 1.4 | versicolor | |
5.6 | 3.0 | 4.5 | 1.5 | versicolor | |
5.8 | 2.7 | 4.1 | 1.0 | versicolor | |
6.2 | 2.2 | 4.5 | 1.5 | versicolor | |
5.6 | 2.5 | 3.9 | 1.1 | versicolor | |
5.9 | 3.2 | 4.8 | 1.8 | versicolor | |
6.1 | 2.8 | 4.0 | 1.3 | versicolor | |
6.3 | 2.5 | 4.9 | 1.5 | versicolor | |
6.1 | 2.8 | 4.7 | 1.2 | versicolor | |
6.4 | 2.9 | 4.3 | 1.3 | versicolor | |
6.6 | 3.0 | 4.4 | 1.4 | versicolor | |
6.8 | 2.8 | 4.8 | 1.4 | versicolor | |
6.7 | 3.0 | 5.0 | 1.7 | versicolor | |
6.0 | 2.9 | 4.5 | 1.5 | versicolor | |
5.7 | 2.6 | 3.5 | 1.0 | versicolor | |
5.5 | 2.4 | 3.8 | 1.1 | versicolor | |
5.5 | 2.4 | 3.7 | 1.0 | versicolor | |
5.8 | 2.7 | 3.9 | 1.2 | versicolor | |
6.0 | 2.7 | 5.1 | 1.6 | versicolor | |
5.4 | 3.0 | 4.5 | 1.5 | versicolor | |
6.0 | 3.4 | 4.5 | 1.6 | versicolor | |
6.7 | 3.1 | 4.7 | 1.5 | versicolor | |
6.3 | 2.3 | 4.4 | 1.3 | versicolor | |
5.6 | 3.0 | 4.1 | 1.3 | versicolor | |
5.5 | 2.5 | 4.0 | 1.3 | versicolor | |
5.5 | 2.6 | 4.4 | 1.2 | versicolor | |
6.1 | 3.0 | 4.6 | 1.4 | versicolor | |
5.8 | 2.6 | 4.0 | 1.2 | versicolor | |
5.0 | 2.3 | 3.3 | 1.0 | versicolor | |
5.6 | 2.7 | 4.2 | 1.3 | versicolor | |
5.7 | 3.0 | 4.2 | 1.2 | versicolor | |
5.7 | 2.9 | 4.2 | 1.3 | versicolor | |
6.2 | 2.9 | 4.3 | 1.3 | versicolor | |
5.1 | 2.5 | 3.0 | 1.1 | versicolor | |
5.7 | 2.8 | 4.1 | 1.3 | versicolor | |
6.3 | 3.3 | 6.0 | 2.5 | virginica | |
5.8 | 2.7 | 5.1 | 1.9 | virginica | |
7.1 | 3.0 | 5.9 | 2.1 | virginica | |
6.3 | 2.9 | 5.6 | 1.8 | virginica | |
6.5 | 3.0 | 5.8 | 2.2 | virginica | |
7.6 | 3.0 | 6.6 | 2.1 | virginica | |
4.9 | 2.5 | 4.5 | 1.7 | virginica | |
7.3 | 2.9 | 6.3 | 1.8 | virginica | |
6.7 | 2.5 | 5.8 | 1.8 | virginica | |
7.2 | 3.6 | 6.1 | 2.5 | virginica | |
6.5 | 3.2 | 5.1 | 2.0 | virginica | |
6.4 | 2.7 | 5.3 | 1.9 | virginica | |
6.8 | 3.0 | 5.5 | 2.1 | virginica | |
5.7 | 2.5 | 5.0 | 2.0 | virginica | |
5.8 | 2.8 | 5.1 | 2.4 | virginica | |
6.4 | 3.2 | 5.3 | 2.3 | virginica | |
6.5 | 3.0 | 5.5 | 1.8 | virginica | |
7.7 | 3.8 | 6.7 | 2.2 | virginica | |
7.7 | 2.6 | 6.9 | 2.3 | virginica | |
6.0 | 2.2 | 5.0 | 1.5 | virginica | |
6.9 | 3.2 | 5.7 | 2.3 | virginica | |
5.6 | 2.8 | 4.9 | 2.0 | virginica | |
7.7 | 2.8 | 6.7 | 2.0 | virginica | |
6.3 | 2.7 | 4.9 | 1.8 | virginica | |
6.7 | 3.3 | 5.7 | 2.1 | virginica | |
7.2 | 3.2 | 6.0 | 1.8 | virginica | |
6.2 | 2.8 | 4.8 | 1.8 | virginica | |
6.1 | 3.0 | 4.9 | 1.8 | virginica | |
6.4 | 2.8 | 5.6 | 2.1 | virginica | |
7.2 | 3.0 | 5.8 | 1.6 | virginica | |
7.4 | 2.8 | 6.1 | 1.9 | virginica | |
7.9 | 3.8 | 6.4 | 2.0 | virginica | |
6.4 | 2.8 | 5.6 | 2.2 | virginica | |
6.3 | 2.8 | 5.1 | 1.5 | virginica | |
6.1 | 2.6 | 5.6 | 1.4 | virginica | |
7.7 | 3.0 | 6.1 | 2.3 | virginica | |
6.3 | 3.4 | 5.6 | 2.4 | virginica | |
6.4 | 3.1 | 5.5 | 1.8 | virginica | |
6.0 | 3.0 | 4.8 | 1.8 | virginica | |
6.9 | 3.1 | 5.4 | 2.1 | virginica | |
6.7 | 3.1 | 5.6 | 2.4 | virginica | |
6.9 | 3.1 | 5.1 | 2.3 | virginica | |
5.8 | 2.7 | 5.1 | 1.9 | virginica | |
6.8 | 3.2 | 5.9 | 2.3 | virginica | |
6.7 | 3.3 | 5.7 | 2.5 | virginica | |
6.7 | 3.0 | 5.2 | 2.3 | virginica | |
6.3 | 2.5 | 5.0 | 1.9 | virginica | |
6.5 | 3.0 | 5.2 | 2.0 | virginica | |
6.2 | 3.4 | 5.4 | 2.3 | virginica | |
5.9 | 3.0 | 5.1 | 1.8 | virginica |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"math" | |
"sort" | |
) | |
// KNNDistanceFn defines a function that returns the distance between two n dimension vectors | |
type KNNDistanceFn func(v1, v2 []float64) float64 | |
// KNNClassifier defines a labelled example for training a KNN algorithm | |
type KNNClassifierLabelledExample struct { | |
FeatureVector []float64 | |
Label int | |
} | |
// KNNClassifier represents a KNN algorithm | |
type KNNClassifier struct { | |
k int | |
labelledExamples []KNNClassifierLabelledExample | |
distanceFn KNNDistanceFn | |
} | |
func NewKNNClassifier(k int, labelledExamples []KNNClassifierLabelledExample, distanceFn KNNDistanceFn) *KNNClassifier { | |
return &KNNClassifier{ | |
k: k, | |
labelledExamples: labelledExamples, | |
distanceFn: distanceFn, | |
} | |
} | |
type knnClassifierNeighbour struct { | |
distance float64 | |
label int | |
} | |
// Classify executes the KNN algorithm | |
func (c *KNNClassifier) Classify(v []float64) int { | |
neighbours := []knnClassifierNeighbour{} | |
for _, example := range c.labelledExamples { | |
neighbours = append(neighbours, knnClassifierNeighbour{ | |
distance: c.distanceFn(v, example.FeatureVector), | |
label: example.Label, | |
}) | |
} | |
sort.SliceStable(neighbours, func(i, j int) bool { | |
return neighbours[i].distance < neighbours[j].distance | |
}) | |
var result int | |
nearestNeightbourLabelCounts := map[int]int{} | |
for i := 0; i < c.k; i++ { | |
nearestNeightbourLabelCounts[neighbours[i].label]++ | |
count := nearestNeightbourLabelCounts[neighbours[i].label] | |
if count > nearestNeightbourLabelCounts[result] { | |
result = neighbours[i].label | |
} | |
} | |
return result | |
} | |
// EuclideanDistance returns the Euclidean distance between two n dimesional vectors. Implements KNNDistanceFn. | |
func EuclideanDistance(v1 []float64, v2 []float64) float64 { | |
lenv1 := len(v1) | |
lenv2 := len(v2) | |
if lenv1 > lenv2 { | |
v2 = append(v2, make([]float64, 0, lenv2-lenv1)...) | |
} | |
if lenv2 > lenv1 { | |
v1 = append(v1, make([]float64, 0, lenv1-lenv2)...) | |
} | |
var result float64 | |
for i := 0; i < len(v1); i++ { | |
result += math.Pow(v2[i]-v1[i], 2) | |
} | |
return math.Sqrt(result) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/csv" | |
"errors" | |
"io" | |
"log" | |
"math" | |
"math/rand" | |
"os" | |
"strconv" | |
"strings" | |
"testing" | |
"time" | |
) | |
func TestKNNClassifier(t *testing.T) { | |
f, err := os.Open("./iris.csv") | |
if err != nil { | |
t.Fatal(err) | |
} | |
defer f.Close() | |
labelledExamples, err := prepareIrisDataset(f) | |
if err != nil { | |
t.Fatal(err) | |
} | |
// shuffle slice | |
rand.Seed(time.Now().UnixNano()) | |
rand.Shuffle(len(labelledExamples), func(i, j int) { | |
labelledExamples[i], labelledExamples[j] = labelledExamples[j], labelledExamples[i] | |
}) | |
acceptableMarginOfError := 0.05 | |
// split data between training and test data 80/20 | |
splitIndex := int(0.8 * float64(len(labelledExamples))) | |
trainingData := labelledExamples[:splitIndex] | |
testData := labelledExamples[splitIndex:] | |
type incorrectlyClassifiedResult struct { | |
labelledExample KNNClassifierLabelledExample | |
actual int | |
} | |
errs := []incorrectlyClassifiedResult{} | |
k := int(math.Sqrt(float64(len(trainingData)))) | |
knnClassifier := NewKNNClassifier(k, trainingData, EuclideanDistance) | |
for _, test := range testData { | |
if actual := knnClassifier.Classify(test.FeatureVector); actual != test.Label { | |
errs = append(errs, incorrectlyClassifiedResult{ | |
labelledExample: test, | |
actual: actual, | |
}) | |
} | |
} | |
actualMarginOfError := float64(len(errs)) / float64(len(labelledExamples)) | |
if actualMarginOfError > acceptableMarginOfError { | |
log.Fatalf("margin of error exceeds acceptable level. actual %f acceptable %f\n%v", actualMarginOfError, acceptableMarginOfError, errs) | |
} | |
log.Printf("actual margin of error %f", actualMarginOfError) | |
} | |
func TestEuclideandistance(t *testing.T) { | |
tests := []struct { | |
title string | |
v1, v2 []float64 | |
expected float64 | |
}{ | |
{ | |
"3d", | |
[]float64{3, 6, 5}, | |
[]float64{7, -5, 1}, | |
12.369, | |
}, | |
} | |
for _, test := range tests { | |
t.Run(test.title, func(t *testing.T) { | |
if actual := EuclideanDistance(test.v1, test.v2); approx(actual, test.expected) { | |
t.Fatalf("expected %f got %f", test.expected, actual) | |
} | |
}) | |
} | |
} | |
func prepareIrisDataset(r io.Reader) ([]KNNClassifierLabelledExample, error) { | |
records, err := csv.NewReader(r).ReadAll() | |
if err != nil { | |
return nil, err | |
} | |
labelledExamples := []KNNClassifierLabelledExample{} | |
for i, record := range records { | |
if i == 0 { // skip header row | |
continue | |
} | |
labelledExample := KNNClassifierLabelledExample{} | |
for i := 0; i < 4; i++ { | |
v, err := strconv.ParseFloat(record[i], 64) | |
if err != nil { | |
return nil, err | |
} | |
labelledExample.FeatureVector = append(labelledExample.FeatureVector, v) | |
} | |
switch strings.ToLower(record[4]) { | |
case "setosa": | |
labelledExample.Label = 0 | |
case "versicolor": | |
labelledExample.Label = 1 | |
case "virginica": | |
labelledExample.Label = 2 | |
default: | |
return nil, errors.New("unexpected species: " + record[4]) | |
} | |
labelledExamples = append(labelledExamples, labelledExample) | |
} | |
return labelledExamples, nil | |
} | |
func approx(v1, v2 float64) bool { | |
return math.Abs(v1-v2) < 0.00000001 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment