Skip to content

Instantly share code, notes, and snippets.

@SteffenDE
Created July 25, 2023 15:52
Show Gist options
  • Save SteffenDE/1f0e146a405bc4885f281bd1fd50ae14 to your computer and use it in GitHub Desktop.
Save SteffenDE/1f0e146a405bc4885f281bd1fd50ae14 to your computer and use it in GitHub Desktop.

Ortex Sentence Transformer

Mix.install([
  {:ortex, github: "elixir-nx/ortex", ref: "9e384971d1904ba91e5bfa49594d742a1d06cb4c"},
  {:tokenizers,
   github: "elixir-nx/tokenizers", override: true, ref: "20295cfdf9b6342d723b405481791ec87afa203c"},
  {:exla,
   github: "elixir-nx/nx",
   sparse: "exla",
   override: true,
   ref: "9a68cf06fef98a42f9a9c5a8d4745685a5b9fe64"},
  {:nx,
   github: "elixir-nx/nx",
   sparse: "nx",
   override: true,
   ref: "9a68cf06fef98a42f9a9c5a8d4745685a5b9fe64"},
  {:bumblebee, github: "elixir-nx/bumblebee", ref: "8ec547243a4a1a61e45b25780a994014dc986099"},
  {:kino, "~> 0.10"},
  {:vega_lite, "~> 0.1.7"},
  {:kino_vega_lite, "~> 0.1.9"}
])

Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA, client: :host)

alias VegaLite, as: Vl

Exporting the Model

This section requires Python with venv support to be installed.

See https://huggingface.co/docs/transformers/serialization?highlight=onnx.

tmp_dir = System.tmp_dir!() <> "livebook_ortex_sentence_embeddings"
File.mkdir(tmp_dir)

System.shell("python3 -m venv .venv", cd: tmp_dir, into: IO.binstream())

System.shell(
  "pip3 install optimum[exporters]",
  cd: tmp_dir,
  env: [
    {"VIRTUAL_ENV", Path.join([tmp_dir, ".venv"])},
    {"PATH", "#{Path.join([tmp_dir, ".venv", "bin"])}:#{System.get_env("PATH")}"}
  ],
  into: IO.binstream()
)
System.shell(
  "optimum-cli export onnx --model sentence-transformers/all-MiniLM-L6-v2 minilm/",
  cd: tmp_dir,
  env: [
    {"VIRTUAL_ENV", Path.join([tmp_dir, ".venv"])},
    {"PATH", "#{Path.join([tmp_dir, ".venv", "bin"])}:#{System.get_env("PATH")}"}
  ],
  into: IO.binstream()
)

Creating the Serving

model = Ortex.load(Path.join([tmp_dir, "minilm", "model.onnx"]))
{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
defmodule SentenceTransformerOnnx do
  import Nx.Defn

  defn mean_pooling(model_output, attention_mask) do
    input_mask_expanded = Nx.new_axis(attention_mask, -1)

    model_output
    |> Nx.multiply(input_mask_expanded)
    |> Nx.sum(axes: [1])
    |> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
  end

  def serving(model, tokenizer) do
    Nx.Serving.new(Ortex.Serving, model)
    |> Nx.Serving.client_preprocessing(fn inputs ->
      {:ok, encodings} = Tokenizers.Tokenizer.encode_batch(tokenizer, inputs)

      # get the maximum sequence length from the input by looking at the attention mask
      max_length =
        encodings
        |> Enum.map(&Tokenizers.Encoding.get_attention_mask/1)
        |> Enum.map(fn tensor -> Enum.sum(tensor) end)
        |> Enum.max(fn -> nil end)

      encodings =
        if max_length do
          for e <- encodings, do: Tokenizers.Encoding.truncate(e, max_length)
        else
          encodings
        end

      input_ids = for i <- encodings, do: Tokenizers.Encoding.get_ids(i)
      input_mask = for i <- encodings, do: Tokenizers.Encoding.get_attention_mask(i)
      token_type_ids = for i <- encodings, do: Tokenizers.Encoding.get_type_ids(i)

      inputs =
        Enum.zip_with([input_ids, input_mask, token_type_ids], fn [a, b, c] ->
          {Nx.tensor(a), Nx.tensor(b), Nx.tensor(c)}
        end)
        |> Nx.Batch.stack()

      {inputs, %{attention_mask: Nx.tensor(input_mask)}}
    end)
    |> Nx.Serving.client_postprocessing(fn {{output}, _meta}, client_info ->
      mean_pooling(output, client_info.attention_mask)
    end)
  end
end
serving = SentenceTransformerOnnx.serving(model, tokenizer)

input = Kino.Input.text("Source sentence:") |> Kino.render()
comparison = Kino.Input.textarea("Sentences to compare to:") |> Kino.render()

input_text = Kino.Input.read(input)

if byte_size(input_text) == 0 do
  Kino.interrupt!(:normal, "Please enter source sentence")
end

comparison_texts =
  Kino.Input.read(comparison)
  |> String.split("\n", trim: true)

if comparison_texts == [] do
  Kino.interrupt!(:normal, "Please enter comparison sentences (one per line)")
end

input = Nx.Serving.run(serving, [input_text])

comparison =
  Nx.Serving.run(serving, comparison_texts)

sim = Bumblebee.Utils.Nx.cosine_similarity(input, comparison)

for {v, i} <- Enum.with_index(comparison_texts) do
  {v, Nx.to_number(sim[0][i])}
end
|> Enum.sort_by(&elem(&1, 1), :desc)

Benchmarking

defmodule ConcurrentBench do
  def run(fun, concurrency \\ System.schedulers_online(), timeout \\ 10_000) do
    # use an erlang counter to count the number of function invocations
    counter = :counters.new(1, [:write_concurrency])

    # returns time in microseconds
    {taken, _} =
      :timer.tc(fn ->
        tasks =
          for _i <- 1..concurrency do
            Task.async(fn ->
              Stream.repeatedly(fn ->
                fun.()
                # only count after the function ran successfully
                :counters.add(counter, 1, 1)
              end)
              |> Stream.run()
            end)
          end

        results = Task.yield_many(tasks, timeout)

        # kill all processes
        Enum.map(results, fn {task, res} ->
          res || Task.shutdown(task, :brutal_kill)
        end)
      end)

    runs = :counters.get(counter, 1)
    ips = runs / (taken / 1_000_000)

    %{runs: runs, ips: ips}
  end
end
text =
  "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr"
splitted = String.split(text, " ", trim: true)

texts =
  for i <- 1..length(splitted) do
    Enum.take(splitted, i)
    |> Enum.join(" ")
  end
sequence_lengths =
  Enum.map(texts, fn text ->
    {:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text)
    Tokenizers.Encoding.get_attention_mask(encoding) |> Enum.sum()
  end)
defmodule BenchTest do
  def run(serving, text, batch_size, batch_timeout, concurrency, timeout \\ 10_000) do
    {:ok, pid} =
      Kino.start_child(
        {Nx.Serving,
         serving: serving, name: MyServing, batch_size: batch_size, batch_timeout: batch_timeout}
      )

    mod = Module.concat(Bench, "Test#{System.unique_integer()}")

    defmodule mod do
      def run(text, concurrency, timeout) do
        ConcurrentBench.run(
          fn ->
            Nx.Serving.batched_run(MyServing, [text])
          end,
          concurrency,
          timeout
        )
      end
    end

    result =
      mod.run(text, concurrency, timeout)
      |> tap(fn _ ->
        :code.purge(mod)
        :code.delete(mod)
      end)
      |> IO.inspect(
        label: "batch: #{batch_size}; concurrency: #{concurrency}, timeout: #{batch_timeout}"
      )

    Kino.terminate_child(pid)

    result
  end
end
chart =
  Vl.new(width: 1280, height: 720)
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "sequence_length", type: :quantitative)
  |> Vl.encode_field(:y, "ips", type: :quantitative)
  |> Kino.VegaLite.new()
for {text, sequence_length} <- Enum.zip(texts, sequence_lengths) do
  batch_size = 64
  batch_timeout = 50
  concurrency = 64

  %{ips: ips} = BenchTest.run(serving, text, batch_size, batch_timeout, concurrency, 5_000)

  IO.inspect("sequence_length: #{sequence_length}, ips: #{ips}")

  Kino.VegaLite.push(chart, %{ips: ips, sequence_length: sequence_length})
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment