Skip to content

Instantly share code, notes, and snippets.

Last active August 11, 2017 04:40
Show Gist options
  • Save itdxer/c75ea4df8a085ae6035d5818a1d7396e to your computer and use it in GitHub Desktop.
Save itdxer/c75ea4df8a085ae6035d5818a1d7396e to your computer and use it in GitHub Desktop.
Fuzzy C-means in Theano
import theano
import theano.tensor as T
import numpy as np
def asfloat(value):
""" Convert variable to float type configured by theano
floatX variable.
value : matrix, ndarray or scalar
Value that could be converted to float type.
matrix, ndarray or scalar
Output would be input value converted to float type
configured by theano floatX variable.
if isinstance(value, (np.matrix, np.ndarray)):
return value.astype(theano.config.floatX)
float_x_type = np.cast[theano.config.floatX]
return float_x_type(value)
def clone(instance):
class_ = instance.__class__
parameters = instance.get_params()
return class_(**parameters)
class FuzzyCMeans(object):
Fuzzy c-means.
n_clusters : int
Number of clusters.
m : float
def __init__(self, n_clusters, m=2):
if n_clusters < 2:
raise ValueError("Number of clusters should be greater than 2")
if m < 1:
raise ValueError("Parameter `m` should be greater than 1")
self.n_clusters = n_clusters
self.m = m
self.centers_ = None
self.is_initialized = False
def init_methods(self):
if self.is_initialized:
raise AttributeError("Methods have already been initialized.")
x = T.matrix('x')
centers = self.centers_
d = distance_to_centers = (
x.reshape((x.shape[0], 1, x.shape[1])) -
centers.reshape((1, centers.shape[0], centers.shape[1]))
).norm(L=2, axis=2)
weights = 1 / (
distance_to_centers.reshape((d.shape[0], d.shape[1], 1)) /
distance_to_centers.reshape((d.shape[0], 1, d.shape[1]))
) ** asfloat(2. / (self.m - 1))
proba = weights / T.sum(weights, axis=1).reshape((-1, 1))
proba_power_m = weights ** self.m
new_centers = ( /
T.sum(proba_power_m, axis=0).reshape((-1, 1))
self.predict_proba = theano.function([x], proba)
self.train_iteration = theano.function([x], proba, updates=[
(centers, new_centers),
self.is_initialized = True
def get_params(self, deep=False):
return dict(n_clusters=self.n_clusters, m=self.m)
def centers(self):
return self.centers_.get_value()
def fit(self, data, maxiter=100, epsilon=1e-5, verbose=False):
n_features = data.shape[1]
if self.centers_ is None:
data_min = data.min(axis=0)
data_max = data.max(axis=0)
random_centers = np.random.random((self.n_clusters, n_features))
scaled_centers = (data_max - data_min) * random_centers + data_min
self.centers_ = theano.shared(
n_expected_features = self.centers.shape[1]
if n_expected_features != n_features:
raise ValueError("Input data must contain {} features, "
"found {}".format(n_expected_features,
if not self.is_initialized:
i = 1
proba_update = np.inf
prev_proba = None
while (proba_update > epsilon) and (i <= maxiter):
proba = self.train_iteration(data)
if prev_proba is not None:
proba_update = np.linalg.norm(prev_proba - proba)
prev_proba = proba
i += 1
def predict(self, data):
proba = self.predict_proba(data)
return proba.argmax(axis=1)
def select_best_clustering(algorithm, n_trials, data, **fit_kwargs):
""" Select best clusters using SSE.
algorithm : object
n_trials : int
data : matrix
Exception will raise in case input parameter values
are invalid.
Pretrained clustering algorithm that give smallest
SSE (sum of squared error) score.
if n_trials < 1:
raise ValueError("Number of trials should be greater than 1")
if not isinstance(n_trials, int):
raise ValueError("Number of tirals should be an integer number")
algorithms = []
for trial in range(n_trials):
algorithm = clone(algorithm), **fit_kwargs)
clusters = algorithm.predict(data)
centers = algorithm.centers[clusters, :]
sse_score = np.sum((data - centers) ** 2)
# We should use second variable as unique value to prevent
# object instance comparison when we have exactly the same
# score values.
algorithms.append((sse_score, trial, algorithm))
_, _, best_algorithm = min(algorithms)
return best_algorithm
if __name__ == '__main__':
fcm = FuzzyCMeans(n_clusters=2, m=2), maxiter=100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment