Skip to content

Instantly share code, notes, and snippets.

@billyzs
Created August 20, 2016 15:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save billyzs/23f7cf4c5b0a9f0884a7edc6ca65ac37 to your computer and use it in GitHub Desktop.
Save billyzs/23f7cf4c5b0a9f0884a7edc6ca65ac37 to your computer and use it in GitHub Desktop.
OpenCV Kmeans algorithm
#!/usr/bin/env python
import numpy as np
import cv2
CRITERIA = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 5, 1.0)
YCBCR_BLACK = np.array([[0, 128, 128]], dtype='uint8')
def k_means(img, k, criteria=CRITERIA, show=False, attempts=5):
"""
:param img: RGB array (uint8)
:param k: # of clusters
:param criteria: OpenCV criteria
:param show: True displays image
:return: centroids, masks
"""
img = cv2.cvtColor(img, cv2.COLOR_RGB2YCR_CB)
original_shape = img.shape
img = img.reshape(-1, 3)
img = np.float32(img)
opencv_version = int(cv2.__version__.split(".")[0])
if opencv_version == 2:
ret, mask, center = cv2.kmeans(img, k, criteria, attempts, cv2.KMEANS_RANDOM_CENTERS)
elif opencv_version == 3 :
ret, mask, center = cv2.kmeans(img, k, None, criteria, attempts, cv2.KMEANS_RANDOM_CENTERS)
center = np.uint8(center)
# visualization code
if show:
base = YCBCR_BLACK # YcBcR code for gray (rgb 128,128,128)
for i in range(0, k):
bases = np.repeat(base, k, axis=0)
bases[i] = center[i]
res = bases[mask.flatten()]
res = res.reshape(original_shape)
res = cv2.cvtColor(res, cv2.COLOR_YCR_CB2RGB)
cv2.imshow(str(bases[i]) + ' w/ k = ' + str(k) + ' A = ' + str(attempts), res)
cv2.waitKey(0)
# cv2.destroyAllWindows()
# print center
return center, mask
if __name__ == "__main__":
criteria = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 5, 1.0)
# k_means(cv2.imread("./img/scrambled-1.png"), 9, criteria, True, 10)
# k_means(cv2.imread("./img/scrambled-1.png"), 9, criteria, True)
img = cv2.imread("img/protourney_chess_set_black_camel_profile_900.jpg")
#img_shape = img.shape
center, mask = k_means(img, 5, criteria, 5, True)
# print mask
print type(mask)
print len(mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment