Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
require "t_learn"
require 'pycall/import'
include PyCall::Import
data_list = JSON.load(File.open("./sample_2dim.json"))
k_means = TLearn::K_Means.new()
history = k_means.fit(data_list, c=3)
cluster_list = history[:result]
pyimport 'numpy', as: 'np'
pyimport 'matplotlib.mlab', as: 'mlab'
pyimport 'matplotlib.pyplot', as: 'plt'
fig = plt.figure.()
ax = fig.add_subplot.(1, 1, 1)
colors = ["red", "blue", "green"]
cluster_list.each_with_index {|cluster, i|
x_list, y_list = [], []
cluster[:v_list].each {|v|
x_list.push(v[0])
y_list.push(v[1])
}
c_x,c_y = cluster[:vec][0], cluster[:vec][1]
ax.scatter.(x_list, y_list, c:colors[i])
ax.scatter.(c_x, c_y, c:colors[i], s:200)
}
plt.savefig.("test.png")
open("./test.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment