Skip to content

Instantly share code, notes, and snippets.

@noczero
Created February 6, 2024 05:43
Show Gist options
  • Save noczero/9708f62baedbaeb2360a823e8eb819b8 to your computer and use it in GitHub Desktop.
Save noczero/9708f62baedbaeb2360a823e8eb819b8 to your computer and use it in GitHub Desktop.
Clustering using Gaussian Mixture Model with Multi Label Output
require 'rumale'
require 'rumale/dataset'
p "Clustering using Gaussian Mixture Model with Multi Label Output"
class MultiLabelGMM < Rumale::Clustering::GaussianMixture
def initialize(n_clusters: 8, init: 'k-means++', covariance_type: 'diag',
max_iter: 50, tol: 1.0e-4, reg_covar: 1.0e-6, random_seed: nil)
super()
@params = {
n_clusters: n_clusters,
init: (init == 'random' ? 'random' : 'k-means++'),
covariance_type: (covariance_type == 'full' ? 'full' : 'diag'),
max_iter: max_iter,
tol: tol,
reg_covar: reg_covar,
random_seed: random_seed || srand
}
end
def fit_predict_probability(x)
check_enable_linalg('fit_predict_probability')
x = ::Rumale::Validation.check_convert_sample_array(x)
fit(x)
predict_probability(x)
end
def predict_probability(x)
check_enable_linalg('fit_predict_probability')
x = ::Rumale::Validation.check_convert_sample_array(x)
@memberships = calc_memberships(x, @weights, @means, @covariances, @params[:covariance_type])
@memberships
end
def multi_labels(threshold: 0.5)
filtered = @memberships.gt(threshold)
rows, _ = filtered.shape
labels = []
rows.times do |i|
labels << filtered[i, true].where.to_a
end
labels
end
end
def three_clusters_dataset
centers = Numo::DFloat[[1, -1], [1, 1], [0, 1]]
Rumale::Dataset.make_blobs(20, centers: centers, cluster_std: 1.0, random_seed: 1)
end
# make random data points
X = three_clusters_dataset
# create GMM
analyzer = MultiLabelGMM.new(n_clusters: 3, max_iter: 50)
analyzer.fit(X[0])
membership_probability = analyzer.predict_probability(X[0])
# get labels, set threshold
labels = analyzer.multi_labels(threshold: 0.01)
p "Predict with Multiple Labels: #{labels}"
p "Predict probability: #{membership_probability.to_a}"
# OUTPUT
# "Clustering using Gaussian Mixture Model with Multi Label Output"
# "Predict with Multiple Labels: [[0], [0], [0, 2], [0, 1], [2], [0], [0], [1], [1], [1], [0], [0, 2], [0, 2], [2], [0], [0, 2], [0], [0, 2], [0], [0, 2]]"
# "Predict probability: [[0.9999300578542513, 6.994209164666061e-05, 5.410205012670883e-11], [0.999999902557336, 9.74426640572274e-08, 3.550966462139485e-24], [0.9648774261105081, 2.5175367380437922e-05, 0.03509739852211158], [0.0870845586267286, 0.9129154413732714, 2.764361458641714e-60], [7.399786459326608e-05, 0.00028824722729036106, 0.9996377549081163], [0.9999999983440929, 1.6559071223563116e-09, 4.499261874956609e-40], [0.9999999999987949, 1.2052216872414945e-12, 1.8678536160401503e-79], [0.0009811012716997099, 0.9990188987283003, 5.978741885682614e-22], [0.0090431166759968, 0.9909568833240032, 1.9338083568360927e-19], [0.0009959081369024156, 0.9990040918630976, 1.0584495703556302e-52], [0.9996007892020022, 7.054129216690488e-16, 0.000399210797997221], [0.8789222240661732, 1.2729189520984938e-14, 0.12107777593381416], [0.054773691389011274, 3.683030907264873e-07, 0.945225940307898], [0.0021937520207057633, 0.00717041854029634, 0.990635829438998], [1.0, 1.9790293999282198e-20, 2.5566206326303784e-61], [0.07888761826423746, 6.779761648761619e-31, 0.9211123817357626], [0.9999577847322032, 4.2215267796756885e-05, 1.2249438474291322e-66], [0.19317347420287334, 7.610608522876177e-13, 0.8068265257963656], [0.9999999998728679, 2.3536888671005802e-18, 1.2713204820920964e-10], [0.03217656138888003, 5.476684853388209e-06, 0.9678179619262666]]"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment