Skip to content

Instantly share code, notes, and snippets.

@ivan3bx
Last active March 20, 2024 19:57
Show Gist options
  • Save ivan3bx/3a5e0c2c09b1516d4ef2e7e8a67229b5 to your computer and use it in GitHub Desktop.
Save ivan3bx/3a5e0c2c09b1516d4ef2e7e8a67229b5 to your computer and use it in GitHub Desktop.
K-Means clustering & calculating some stats for nutrition data
package main
import (
"fmt"
"math"
"github.com/montanaflynn/stats"
"github.com/muesli/clusters"
"github.com/muesli/kmeans"
"gonum.org/v1/gonum/stat"
)
func main() {
// data in the following format
// [0] cosine similarity distance of an item represented by these values
// [1] number of calories for this sample
// [2..] nutrient amounts in grams (carbs, protein etc..)
bbqChicken := [][]float64{
{0.30936217308044434, 339, 41.73, 11.28, 1.2, 15.795},
{0.31012654304504395, 396, 50.4, 14.4, 2.4, 16.8},
{0.31477395804083474, 357, 49.665, 8.94, 1.95, 16.89},
{0.32190563213262624, 357, 49.665, 8.94, 1.95, 16.89},
{0.3601418137550354, 396, 50.4, 14.4, 2.4, 16.8},
{0.3631199598312378, 340.5, 40.95, 13.05, -1, 14.1},
{0.3631199598312378, 375, 48, 11.1, -1, 19.5},
{0.36419129371643066, 369, 53.55, 10.35, -1, 15.45},
{0.47245991230010986, 620, 19, 33, 0, 57},
{0.47404394966494867, 365, 8, 28, 1, 19},
{0.3238506317138672, 313.5, 35.60999999999999, 11.865, 1.05, 17.265},
{0.32428465259923467, 301.5, 36.69, 9.705, 1.05, 17.265},
{0.365716814994812, 366, 49.65, 10.95, 3, 16.05},
{0.3693286588587954, 333, 46.5, 8.4, 3.75, 16.5},
{0.2780816389273607, 310.5, 42, 9, -1, 13.8},
{0.296245813369751, 294, 30.6, 10.95, 2.7, 18.3},
{0.37020546197891235, 349.5, 44.7, 10.5, -1, 17.4},
{0.41507232189178467, 538, 33, 28, 2, 37},
{0.4200516939163208, 998, 37, 64, 0, 65},
{0.43077021837234497, 561, 63, 21, 3, 29},
{0.432941112939052, 422, 62, 11, 3, 18},
{0.4425056576728821, 877, 94, 34, 11, 50},
{0.45972079038619995, 263, 11, 14, 1, 23},
{0.4625535011291504, 457, 54, 11, 6, 35},
{0.4642601449567121, 292, -1, 6, 1, 35},
{0.3245284356036313, 313.5, 35.60999999999999, 11.865, 1.05, 17.265},
{0.3247811794281006, 301.5, 36.69, 9.705, 1.05, 17.265},
{0.3260763883590698, 301.5, 36.69, 9.705, 1.05, 17.265},
{0.32911187410354614, 313.5, 35.60999999999999, 11.865, 1.05, 17.265},
{0.3515464663505554, 366, 41.7, 12.45, 4.5, 18.9},
}
// similar sample, sampling nutrients of a chick-fil-a chicken sandwich
// chickFila := [][]float64{
// {0.17526186557184698, 373.5, 31.335, 16.785, 2.1, 24.42},
// {0.31898687240442225, 421.5, 41.22, 20.19, 2.55, 18.255},
// {0.32166337966918945, 342, 15.585, 16.8, 1.35, 32.1},
// {0.33655810356140137, 436.5, 36.96, 23.1, 2.25, 19.86},
// {0.33884588844502517, 339, 29.1, 9.27, 1.05, 32.25},
// {0.3511131808493486, 364.5, 26.28, 13.665, 0.9, 32.1},
// {0.3525458574295044, 450, 36.585, 25.109999999999996, 3, 19.08},
// {0.3615299272993182, 421.5, 41.22, 20.19, 2.55, 18.255},
// {0.3642024614931608, 379.5, 25.92, 15.72, 1.8, 31.305},
// {0.38329021284849063, 375, 31.335, 16.785, 2.1, 24.42},
// {0.40265870094299316, 253.5, 35.25, 6.9, -1, 15.45},
// {0.41297308129232535, 330, 48, 7.05, -1, 15},
// {0.4183231245471899, 357, 34.2, 12.6, -1, 26.85},
// {0.4253319501876831, 312, 27.9, 15.75, -1, 13.65},
// {0.4301462769508362, 333, 37.35, 10.545, 3.24, 19.5},
// {0.4358830451965332, 286.5, 26.85, 11.4, -1, 16.5},
// {0.43745875358581543, 285, 29.25, 10.95, 3.45, 15.6},
// {0.44068199396133423, 318, 21, 12.6, -1, 36},
// {0.44141894578933716, 253.5, 36, 4.65, 5.7, 13.95},
// {0.44735704205504156, 214.29000000000005, 0, 10.71, -1, 29.46},
// {0.47878752585612294, 3028, 37, 324, 4, 4},
// {0.47987449169158936, 688, 76, 27, 6, 35},
// {0.4989275782980791, 1478, 165, 61, 9, 64},
// {0.5153046250343323, 442, 42, 14, 2, 35},
// }
rawData := bbqChicken
var d clusters.Observations
for _, v := range rawData {
coordinates := clusters.Coordinates(v)
d = append(d, coordinates)
}
k := int(math.Sqrt(float64(len(rawData))))
fmt.Println("Using k:", k)
km := kmeans.New()
clusters, err := km.Partition(d, k)
if err != nil {
fmt.Println(err)
return
}
for _, c := range clusters {
// fmt.Printf("Centered at x: %.2f y: %.2f\n", c.Center[0], c.Center[1])
fmt.Printf("Matching data points: %+v\n\n", c.Observations)
var dist []float64
var cals []float64
for _, ob := range c.Observations {
dist = append(dist, ob.Coordinates().Coordinates()[0])
cals = append(cals, ob.Coordinates().Coordinates()[1])
}
md, mdstd := stat.MeanStdDev(dist, nil)
weights := make([]float64, len(dist))
for i := range dist {
weights[i] = 1 - dist[i]
}
medianDist, err := stats.Median(dist)
if err != nil {
panic(err)
}
iqr, err := stats.InterQuartileRange(dist)
if err != nil {
panic(err)
}
lowest, err := stats.Min(dist)
if err != nil {
panic(err)
}
calMean, calStd := stat.MeanStdDev(cals, weights)
fmt.Printf("Distance Min: %.2f Mean: %.2f Std: %.2f Median: %.2f IQR: %.2f\n", lowest, md, mdstd, medianDist, iqr)
fmt.Printf("Mean cals: %.2f Stdev: %.3f\n\n", calMean, calStd)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment