Skip to content

Instantly share code, notes, and snippets.

@a-h
Created April 2, 2018 14:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save a-h/e8600dadea283c8643b5b027506f8e33 to your computer and use it in GitHub Desktop.
Save a-h/e8600dadea283c8643b5b027506f8e33 to your computer and use it in GitHub Desktop.
KMeans on random 2D data
package main
import (
"fmt"
"math/rand"
"os"
"strconv"
"time"
"github.com/a-h/ml/clustering"
"github.com/a-h/ml/distance"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/plotutil"
"gonum.org/v1/plot/vg"
)
func init() {
rand.Seed(time.Now().Unix())
}
func main() {
p, err := plot.New()
if err != nil {
fmt.Println("Error creating Plot: ", err)
os.Exit(-1)
}
p.Title.Text = "KMeans"
p.X.Min = 0
p.X.Padding = 0
p.X.Label.Text = "X"
p.Y.Min = 0
p.Y.Padding = 0
p.Y.Label.Text = "Y"
// Create some random data and assign to n clusters.
data := random2DVectors(50)
n := 3
assignment, err := clustering.KMeans(data, n, distance.Euclidean)
if err != nil {
fmt.Println("Error clustering data: ", err)
os.Exit(-1)
}
// Get the clusters.
clusters, err := clustering.Assign(data, assignment)
if err != nil {
fmt.Println("Error assigning data to clusters: ", err)
os.Exit(-1)
}
// Convert them to scatter inputs (something that implements the XYer interface).
for i, cluster := range clusters {
scatter := convert2DVectorToPlotterXY(cluster)
// Add them to the chart.
err = addScatters(p, i, strconv.Itoa(i+1), scatter)
if err != nil {
panic(err)
}
}
// Save the plot to a PNG file.
if err := p.Save(15*vg.Centimeter, 15*vg.Centimeter, "points.png"); err != nil {
panic(err)
}
}
func convert2DVectorToPlotterXY(v []clustering.Vector) plotter.XYs {
pts := make(plotter.XYs, len(v))
for i := 0; i < len(v); i++ {
pts[i] = xy{
X: v[i][0],
Y: v[i][1],
}
}
return pts
}
type xy struct {
X, Y float64
}
func random2DVectors(n int) []clustering.Vector {
op := make([]clustering.Vector, n)
for i := 0; i < n; i++ {
v := make(clustering.Vector, 2)
randomise(v, -10, 10)
op[i] = v
}
return op
}
func randomise(v []float64, min, max int) {
for i := 0; i < len(v); i++ {
v[i] = float64(rand.Intn(max-min) + min)
}
}
func addScatters(plt *plot.Plot, index int, name string, xyers plotter.XYs) error {
s, err := plotter.NewScatter(xyers)
if err != nil {
return err
}
s.Color = plotutil.Color(index)
s.Shape = plotutil.Shape(index)
plt.Add(s)
plt.Legend.Add(name)
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment