Create a gist now

Instantly share code, notes, and snippets.

%matplotlib inline
import csv
import numpy as np
import scipy.cluster.vq as vq
import matplotlib.pyplot as plt
bank_csv = csv.reader(open('seeds_dataset.txt','rU'), delimiter="\t")
data = []
# Read data
for row in bank_csv:
missing = False
float_arr = []
for cell in row:
if not cell:
missing = True
break
else:
# Convert each cell to float
float_arr.append(float(cell))
# Take row if row is not missing data
if not missing:
data.append(float_arr)
data = np.array(data)
# Normalize vectors
whitened = vq.whiten(data)
# Perform k means on all features to classify into 3 groups
centroids, _ = vq.kmeans(whitened, 3)
# Classify data by distance to centroids
cls, _ = vq.vq(whitened, centroids)
# Plot first two features (area vs perimter in this case)
plt.plot(data[cls==0,0], data[cls==0,1],'ob',
data[cls==1,0], data[cls==1,1],'or',
data[cls==2,0], data[cls==2,1],'og')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment