Skip to content

Instantly share code, notes, and snippets.

@mattn
Created October 15, 2019 16:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattn/a0a877e583e469f70b98e71312eb37f2 to your computer and use it in GitHub Desktop.
Save mattn/a0a877e583e469f70b98e71312eb37f2 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"import (\n",
" \"encoding/csv\"\n",
" \"os\"\n",
" \"strconv\"\n",
" \"sort\"\n",
" \"math\"\n",
" \"fmt\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"func loadData() ([][]float64, []string) {\n",
" f, err := os.Open(\"datasets/iris.csv\")\n",
" if err != nil {\n",
" panic(err)\n",
" }\n",
" defer f.Close()\n",
" \n",
" r := csv.NewReader(f)\n",
" r.Comma = ','\n",
" r.LazyQuotes = true\n",
" _, err = r.Read()\n",
" if err != nil {\n",
" panic(err)\n",
" }\n",
" rows, err := r.ReadAll()\n",
" if err != nil {\n",
" panic(err)\n",
" }\n",
"\n",
" X := [][]float64{}\n",
" Y := []string{}\n",
" for i, cols := range rows {\n",
" x := make([]float64, 4)\n",
" y := cols[4]\n",
" for j, s := range cols[:4] {\n",
" v, err := strconv.ParseFloat(s, 64)\n",
" if err != nil {\n",
" panic(err)\n",
" }\n",
" x[j] = v\n",
" }\n",
" X = append(X, x)\n",
" Y = append(Y, y)\n",
" }\n",
" return X, Y\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"X, Y := loadData()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"var trainX, testX [][]float64\n",
"var trainY, testY []string\n",
"for i, _ := range X {\n",
" if i%2 == 0 {\n",
" trainX = append(trainX, X[i])\n",
" trainY = append(trainY, Y[i])\n",
" } else {\n",
" testX = append(testX, X[i])\n",
" testY = append(testY, Y[i])\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"type KNN struct {\n",
" k int\n",
" XX [][]float64\n",
" Y []string\n",
"}\n",
"\n",
"func distance(lhs, rhs []float64) float64 {\n",
" val := 0.0\n",
" for i, _ := range lhs {\n",
" val += math.Pow(lhs[i] - rhs[i], 2)\n",
" }\n",
" return math.Sqrt(val)\n",
"}\n",
"\n",
"func (knn *KNN) predict(X [][]float64) []string {\n",
" results := []string{}\n",
" for _, x := range X {\n",
" var (\n",
" nearLabels []string\n",
" )\n",
" type item struct {\n",
" i int\n",
" f float64\n",
" }\n",
" var items []item\n",
" for i, xx := range knn.XX {\n",
" items = append(items, item {\n",
" i: i,\n",
" f: distance(x, xx),\n",
" })\n",
" }\n",
" sort.Slice(items, func(i, j int) bool {\n",
" return items[i].f < items[j].f\n",
" })\n",
"\n",
" var indexes []int\n",
" var labels []string\n",
" for i := 0; i < knn.k; i++ {\n",
" labels = append(labels, knn.Y[items[i].i])\n",
" }\n",
"\n",
" founds := map[string]int{}\n",
" for _, label := range labels {\n",
" founds[label] += 1\n",
" }\n",
"\n",
" type rank struct {\n",
" i int\n",
" s string\n",
" }\n",
" var ranks []rank\n",
" for k, v := range founds {\n",
" ranks = append(ranks, rank {\n",
" i: v,\n",
" s: k,\n",
" })\n",
" }\n",
" sort.Slice(ranks, func(i, j int) bool {\n",
" return ranks[i].i > ranks[j].i\n",
" })\n",
" results = append(results, ranks[0].s)\n",
" }\n",
" return results\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"knn := KNN {\n",
" k: 8,\n",
" XX: trainX,\n",
" Y: trainY,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"predicted := knn.predict(testX)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"correct := 0\n",
"for i, _ := range predicted {\n",
" if predicted[i] == testY[i] {\n",
" correct += 1\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9866666666666667\n"
]
},
{
"data": {
"text/plain": [
"19 <nil>"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fmt.Println(float64(correct) / float64(len(predicted)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Go",
"language": "go",
"name": "gophernotes"
},
"language_info": {
"codemirror_mode": "",
"file_extension": ".go",
"mimetype": "",
"name": "go",
"nbconvert_exporter": "",
"pygments_lexer": "",
"version": "go1.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment