%matplotlib inline | |
import csv | |
import numpy as np | |
import scipy.cluster.vq as vq | |
import matplotlib.pyplot as plt | |
bank_csv = csv.reader(open('seeds_dataset.txt','rU'), delimiter="\t") | |
data = [] | |
# Read data | |
for row in bank_csv: | |
missing = False | |
float_arr = [] | |
for cell in row: | |
if not cell: | |
missing = True | |
break | |
else: | |
# Convert each cell to float | |
float_arr.append(float(cell)) | |
# Take row if row is not missing data | |
if not missing: | |
data.append(float_arr) | |
data = np.array(data) | |
# Normalize vectors | |
whitened = vq.whiten(data) | |
# Perform k means on all features to classify into 3 groups | |
centroids, _ = vq.kmeans(whitened, 3) | |
# Classify data by distance to centroids | |
cls, _ = vq.vq(whitened, centroids) | |
# Plot first two features (area vs perimter in this case) | |
plt.plot(data[cls==0,0], data[cls==0,1],'ob', | |
data[cls==1,0], data[cls==1,1],'or', | |
data[cls==2,0], data[cls==2,1],'og') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment