Skip to content

Instantly share code, notes, and snippets.

@SteffenDE
Created July 25, 2023 16:05
Show Gist options
  • Save SteffenDE/893a9f00b4b95a2d0df3331a665b67ba to your computer and use it in GitHub Desktop.
Save SteffenDE/893a9f00b4b95a2d0df3331a665b67ba to your computer and use it in GitHub Desktop.
Running the all-mpnet-base-v2 sentence transformer in Elixir using Ortex

Ortex MPNet Sentence Transformer

Mix.install([
  {:ortex, github: "elixir-nx/ortex", ref: "9e384971d1904ba91e5bfa49594d742a1d06cb4c"},
  {:tokenizers,
   github: "elixir-nx/tokenizers", override: true, ref: "20295cfdf9b6342d723b405481791ec87afa203c"},
  {:exla,
   github: "elixir-nx/nx",
   sparse: "exla",
   override: true,
   ref: "9a68cf06fef98a42f9a9c5a8d4745685a5b9fe64"},
  {:nx,
   github: "elixir-nx/nx",
   sparse: "nx",
   override: true,
   ref: "9a68cf06fef98a42f9a9c5a8d4745685a5b9fe64"},
  {:bumblebee, github: "elixir-nx/bumblebee", ref: "8ec547243a4a1a61e45b25780a994014dc986099"},
  {:kino, "~> 0.10"}
])

Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA, client: :host)

alias VegaLite, as: Vl

Exporting the Model

This section requires Python with venv support to be installed.

See https://huggingface.co/docs/transformers/serialization?highlight=onnx.

tmp_dir = System.tmp_dir!() <> "livebook_ortex_mpnet"
File.mkdir(tmp_dir)

System.shell("python3 -m venv .venv", cd: tmp_dir, into: IO.binstream())

System.shell(
  "pip3 install optimum[exporters]",
  cd: tmp_dir,
  env: [
    {"VIRTUAL_ENV", Path.join([tmp_dir, ".venv"])},
    {"PATH", "#{Path.join([tmp_dir, ".venv", "bin"])}:#{System.get_env("PATH")}"}
  ],
  into: IO.binstream()
)
System.shell(
  "optimum-cli export onnx --model sentence-transformers/all-mpnet-base-v2 mpnet/",
  cd: tmp_dir,
  env: [
    {"VIRTUAL_ENV", Path.join([tmp_dir, ".venv"])},
    {"PATH", "#{Path.join([tmp_dir, ".venv", "bin"])}:#{System.get_env("PATH")}"}
  ],
  into: IO.binstream()
)

Creating the Serving

model = Ortex.load(Path.join([tmp_dir, "mpnet", "model.onnx"]))
{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
defmodule SentenceTransformerOnnx do
  import Nx.Defn

  defn mean_pooling(model_output, attention_mask) do
    input_mask_expanded = Nx.new_axis(attention_mask, -1)

    model_output
    |> Nx.multiply(input_mask_expanded)
    |> Nx.sum(axes: [1])
    |> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))
  end

  def serving(model, tokenizer) do
    Nx.Serving.new(Ortex.Serving, model)
    |> Nx.Serving.client_preprocessing(fn inputs ->
      {:ok, encodings} = Tokenizers.Tokenizer.encode_batch(tokenizer, inputs)

      input_ids = for i <- encodings, do: Tokenizers.Encoding.get_ids(i)
      input_mask = for i <- encodings, do: Tokenizers.Encoding.get_attention_mask(i)

      inputs =
        Enum.zip_with(input_ids, input_mask, fn a, b ->
          {Nx.tensor(a), Nx.tensor(b)}
        end)
        |> Nx.Batch.stack()

      {inputs, %{attention_mask: Nx.tensor(input_mask)}}
    end)
    |> Nx.Serving.client_postprocessing(fn {{output}, _meta}, client_info ->
      mean_pooling(output, client_info.attention_mask)
    end)
  end
end

Playground

A playground, similar to what's on the Huggingface site: https://huggingface.co/sentence-transformers/all-mpnet-base-v2

serving = SentenceTransformerOnnx.serving(model, tokenizer)

input = Kino.Input.text("Source sentence:", default: "That is a happy person") |> Kino.render()

comparison =
  Kino.Input.textarea("Sentences to compare to:",
    default: """
    That is a happy dog
    That is a very happy person
    Today is a sunny day
    """
  )
  |> Kino.render()

input_text = Kino.Input.read(input)

if byte_size(input_text) == 0 do
  Kino.interrupt!(:normal, "Please enter source sentence")
end

comparison_texts =
  Kino.Input.read(comparison)
  |> String.split("\n", trim: true)

if comparison_texts == [] do
  Kino.interrupt!(:normal, "Please enter comparison sentences (one per line)")
end

input = Nx.Serving.run(serving, [input_text])

comparison =
  Nx.Serving.run(serving, comparison_texts)

sim = Bumblebee.Utils.Nx.cosine_similarity(input, comparison)

for {v, i} <- Enum.with_index(comparison_texts) do
  {v, Nx.to_number(sim[0][i])}
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment