Skip to content

Instantly share code, notes, and snippets.

@polvalente
Created November 27, 2021 01:41
Show Gist options
  • Save polvalente/32c642c6a47fc75d65a381ea7a09f19b to your computer and use it in GitHub Desktop.
Save polvalente/32c642c6a47fc75d65a381ea7a09f19b to your computer and use it in GitHub Desktop.
K-means livebook example

K-means

Initialization

Mix.install([
  {:torchx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "torchx"},
  {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}
])
k = IO.gets("K (clusters)") |> String.trim() |> String.to_integer()
num_points = IO.gets("num_points") |> String.trim() |> String.to_integer()
dimensions = IO.gets("dimensions") |> String.trim() |> String.to_integer()

# points = Nx.random_uniform({num_points, dimensions}) |> IO.inspect(label: "points")

points =
  Nx.tensor([
    [1, 1],
    [1.1, 1.1],
    [0.9, 0.9],
    [10, 10],
    [11, 10],
    [10, 11],
    [100, 100],
    [101, 100],
    [100, 101]
  ])

slice_start = Enum.random(1..(num_points - k)//1)

centroids =
  points
  |> Nx.slice([slice_start, 0], [k, dimensions])

centroids = Nx.tensor([[1, 1], [10, 10], [100, 100]])
# Run algorithm
max_iter = IO.gets("max_iter") |> String.trim() |> String.to_integer()

Enum.reduce_while(1..max_iter, {Nx.broadcast(0, {num_points}), centroids}, fn _, {_, centroids} ->
  tiled_centroids =
    centroids
    |> IO.inspect(label: "centroids")
    |> Nx.reshape({1, k * dimensions})
    |> Nx.tile([num_points, 1])

  assignments =
    points
    |> Nx.tile([1, k])
    |> Nx.subtract(tiled_centroids)
    |> Nx.reshape({k * num_points, dimensions})
    |> Nx.LinAlg.norm(axes: [1])
    |> Nx.reshape({num_points, k})
    |> Nx.argsort(axis: 1, direction: :asc)
    |> Nx.slice([0, 0], [num_points, 1])
    |> Nx.reshape({num_points})

  new_centroids =
    Nx.stack(
      for i <- 0..(k - 1)//1 do
        selector =
          Nx.equal(assignments, i)
          |> Nx.reshape({num_points, 1})
          |> Nx.tile([1, dimensions])
          |> IO.inspect(label: "selector")

        den = Nx.sum(selector)

        centroid = Nx.select(selector, points, Nx.broadcast(0, points)) |> Nx.sum(axes: [0])

        if Nx.equal(den, 0) |> Nx.to_scalar() == 1 do
          Nx.take(centroids, i)
        else
          Nx.divide(centroid, den)
        end
      end
    )

  {:cont, {assignments, new_centroids}}
end)
@dantswain
Copy link

@polvalente Thanks for the example! Here is my version https://gist.github.com/dantswain/38f56db677db21d8335d8a29fc73c81b would love feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment