Created
August 20, 2016 15:44
-
-
Save billyzs/23f7cf4c5b0a9f0884a7edc6ca65ac37 to your computer and use it in GitHub Desktop.
OpenCV Kmeans algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import numpy as np | |
import cv2 | |
CRITERIA = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 5, 1.0) | |
YCBCR_BLACK = np.array([[0, 128, 128]], dtype='uint8') | |
def k_means(img, k, criteria=CRITERIA, show=False, attempts=5): | |
""" | |
:param img: RGB array (uint8) | |
:param k: # of clusters | |
:param criteria: OpenCV criteria | |
:param show: True displays image | |
:return: centroids, masks | |
""" | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2YCR_CB) | |
original_shape = img.shape | |
img = img.reshape(-1, 3) | |
img = np.float32(img) | |
opencv_version = int(cv2.__version__.split(".")[0]) | |
if opencv_version == 2: | |
ret, mask, center = cv2.kmeans(img, k, criteria, attempts, cv2.KMEANS_RANDOM_CENTERS) | |
elif opencv_version == 3 : | |
ret, mask, center = cv2.kmeans(img, k, None, criteria, attempts, cv2.KMEANS_RANDOM_CENTERS) | |
center = np.uint8(center) | |
# visualization code | |
if show: | |
base = YCBCR_BLACK # YcBcR code for gray (rgb 128,128,128) | |
for i in range(0, k): | |
bases = np.repeat(base, k, axis=0) | |
bases[i] = center[i] | |
res = bases[mask.flatten()] | |
res = res.reshape(original_shape) | |
res = cv2.cvtColor(res, cv2.COLOR_YCR_CB2RGB) | |
cv2.imshow(str(bases[i]) + ' w/ k = ' + str(k) + ' A = ' + str(attempts), res) | |
cv2.waitKey(0) | |
# cv2.destroyAllWindows() | |
# print center | |
return center, mask | |
if __name__ == "__main__": | |
criteria = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 5, 1.0) | |
# k_means(cv2.imread("./img/scrambled-1.png"), 9, criteria, True, 10) | |
# k_means(cv2.imread("./img/scrambled-1.png"), 9, criteria, True) | |
img = cv2.imread("img/protourney_chess_set_black_camel_profile_900.jpg") | |
#img_shape = img.shape | |
center, mask = k_means(img, 5, criteria, 5, True) | |
# print mask | |
print type(mask) | |
print len(mask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment