Skip to content

Instantly share code, notes, and snippets.

@dantswain
Created November 28, 2021 03:23
Show Gist options
  • Save dantswain/38f56db677db21d8335d8a29fc73c81b to your computer and use it in GitHub Desktop.
Save dantswain/38f56db677db21d8335d8a29fc73c81b to your computer and use it in GitHub Desktop.
K-Means Clustering with Elixir NX

K-Means Clustering

The data

Mix.install([
  {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
  {:vega_lite, "~> 0.1"},
  {:kino, "~> 0.3"}
])
n_points = 64
x_min = -4
x_max = 4
y_min = -4
y_max = 4
n_dims = 2

n_per_init_cluster = floor(n_points / 2)

r1 =
  Nx.add(
    Nx.random_normal({n_per_init_cluster, n_dims}, 0.0, 1.0, names: [:x, :y]),
    Nx.tensor([1.0, 1.0])
  )

label1 = Nx.broadcast(0, {n_per_init_cluster, 1})
c1 = Nx.concatenate([r1, label1], axis: 1)
r2 = Nx.add(Nx.random_normal({n_per_init_cluster, n_dims}, 0.0, 1.0), Nx.tensor([-1.0, -1.0]))
label2 = Nx.broadcast(1, {n_per_init_cluster, 1})
c2 = Nx.concatenate([r2, label2], axis: 1)
labeled = Nx.concatenate([c1, c2])
alias VegaLite, as: Vl

# a helper to plot labeled data
mk_data_layer = fn labeled_data ->
  Vl.new()
  |> Vl.data_from_series(
    x: Nx.to_flat_list(labeled_data[y: 0]),
    y: Nx.to_flat_list(labeled_data[y: 1]),
    label: Nx.to_flat_list(labeled_data[y: 2])
  )
  |> Vl.mark(:point)
  |> Vl.encode_field(:x, "x", type: :quantitative, title: "X")
  |> Vl.encode_field(:y, "y", type: :quantitative, title: "Y")
  |> Vl.encode_field(:color, "label", type: :nominal)
end

Vl.new(title: "Raw Data w/ True Labels", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(labeled)
])

Clustering - Initialization

k = 2

# the unlabeled data
data = labeled[y: 0..1]

# calculate initial centroids randomly uniformly in the space that the data spans
initial_centroids =
  0..(n_dims - 1)
  |> Enum.reduce(nil, fn ix, acc ->
    pos = [x_min + (x_max - x_min) * :rand.uniform(), y_min + (y_max - y_min) * :rand.uniform()]

    case acc do
      nil -> Nx.tensor([pos ++ [ix]], names: [:x, :y])
      _ -> Nx.concatenate([acc, Nx.tensor([pos ++ [ix]])])
    end
  end)
# helper to plot centroids
mk_centroid_layer = fn labeled_centroids ->
  Vl.new()
  |> Vl.data_from_series(
    x: Nx.to_flat_list(labeled_centroids[y: 0]),
    y: Nx.to_flat_list(labeled_centroids[y: 1]),
    label: Nx.to_flat_list(labeled_centroids[y: 2])
  )
  |> Vl.mark(:square, size: 400)
  |> Vl.encode_field(:x, "x", type: :quantitative, title: "X")
  |> Vl.encode_field(:y, "y", type: :quantitative, title: "Y")
  |> Vl.encode_field(:color, "label", type: :nominal)
end

Vl.new(title: "Location of Initial Centroids w/ True Labels", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(labeled),
  mk_centroid_layer.(initial_centroids)
])
# helper function to calculate the distance from data to centroids (unlabeled)
dist_fn = fn d, centroids ->
  c = Nx.new_axis(centroids, 1)

  Nx.subtract(d, c)
  |> Nx.power(2)
  |> Nx.sum(axes: [2])
  |> Nx.sqrt()
end

# hepler function to find labels
find_labels = fn d, centroids ->
  dist_fn.(d, centroids)
  |> Nx.argmin(axis: 0)
end

new_labels = find_labels.(data, initial_centroids[y: 0..(n_dims - 1)])

alg_labeled = Nx.concatenate([data, Nx.new_axis(new_labels, 1)], axis: 1)
Vl.new(title: "Initial Labeling", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(alg_labeled)
])

Clustering - First Iteration

calc_centroids_map = fn data, labels, old_centroids ->
  Enum.reduce(0..(k - 1), %{}, fn el, acc ->
    selector =
      labels
      |> Nx.equal(el)
      |> Nx.reshape({n_points, 1})
      |> Nx.tile([1, n_dims])

    summed =
      selector
      |> Nx.select(data, Nx.tensor([0]))
      |> Nx.sum(axes: [0])
      |> Map.put(:names, [:x, :y])

    num_in_cluster = Nx.to_scalar(Nx.sum(selector))

    if num_in_cluster == 0 do
      Map.put(acc, el, Nx.take(old_centroids, el))
    else
      Map.put(acc, el, Nx.divide(summed, num_in_cluster))
    end
  end)
end

new_centroids = calc_centroids_map.(data, new_labels, initial_centroids)

label_centroids = fn centroids ->
  Nx.concatenate(
    [
      Nx.stack(Map.values(centroids)),
      Nx.iota({k, 1})
    ],
    axis: 1
  )
end

new_centroids = label_centroids.(new_centroids)
new_labels = find_labels.(data, new_centroids[y: 0..(n_dims - 1)])
alg_labeled = Nx.concatenate([data, Nx.new_axis(new_labels, 1)], axis: 1)
Vl.new(title: "Result of First Iteration", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(alg_labeled),
  mk_centroid_layer.(new_centroids)
])

Clustering - N Iterations

n_iters = 10

# rename some variables
centroids = new_centroids
labels = new_labels

{final_centroids, final_labels} =
  Enum.reduce(1..n_iters, {centroids, labels}, fn _ix, {pvs_centroids, pvs_labels} ->
    new_centroids = calc_centroids_map.(data, pvs_labels, pvs_centroids)
    new_centroids = label_centroids.(new_centroids)
    new_labels = find_labels.(data, new_centroids[y: 0..(n_dims - 1)])
    {new_centroids, new_labels}
  end)
alg_labeled = Nx.concatenate([data, Nx.new_axis(final_labels, 1)], axis: 1)

true_labels_layer =
  Vl.new()
  |> Vl.data_from_series(
    x: Nx.to_flat_list(labeled[y: 0]),
    y: Nx.to_flat_list(labeled[y: 1]),
    label: Nx.to_flat_list(labeled[y: 2])
  )
  |> Vl.mark(:point, size: 200)
  |> Vl.encode_field(:x, "x", type: :quantitative, title: "X")
  |> Vl.encode_field(:y, "y", type: :quantitative, title: "Y")
  |> Vl.encode_field(:color, "label", type: :nominal)

Vl.new(title: "Result of N Iterations", width: 700, height: 700)
|> Vl.layers([
  mk_data_layer.(alg_labeled),
  true_labels_layer,
  mk_centroid_layer.(final_centroids)
])
@dantswain
Copy link
Author

Thanks @polvalente ! This all makes sense. Re: (3), the outer circle is actually a "true" label since I generated the data at the outset from two distributions and the color corresponds to which distribution. It's a little contrived, but it was a helpful comparison for me to see if the algorithm was doing what I thought it should.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment