Skip to content

Instantly share code, notes, and snippets.

@rrcook
Last active September 16, 2023 03:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rrcook/449c4e5b07a8cc03da3e041e6271741b to your computer and use it in GitHub Desktop.
Save rrcook/449c4e5b07a8cc03da3e041e6271741b to your computer and use it in GitHub Desktop.

Summarizing news headlines with CPU and GPU

Section

Mix.install(
  [
    {:req, "~> 0.3.0"},
    {:elixir_feed_parser, "~> 2.1.0"},
    {:readability, "~> 0.10.0"},
    {:bumblebee, "~> 0.1.0"},
    {:axon, "~> 0.3"},
    {:exla, "~> 0.4"},
    {:nx, "~> 0.4"}
  ],
  system_env: %{
    "XLA_TARGET" => "cuda111",
    "XLA_FLAGS" => "--xla_gpu_cuda_data_dir=/usr/lib/nvidia-cuda-toolkit/libdevice"
  }
)
# myreq_body = Req.get!("https://news.google.com/rss").body
myreq_body = Req.get!("http://memeorandum.com/feed.xml").body
IO.inspect(myreq_body)
{:ok, feed} = ElixirFeedParser.parse(myreq_body)

first_link_re = ~r/HREF=[\'"]?([^\'" >]+)/

article_links =
  Enum.map(feed.entries, fn e -> e.description end)
  |> Enum.take(8)
  |> Enum.map(fn d -> Regex.run(first_link_re, d) end)
  |> Enum.map(fn l -> Enum.at(l, 1) end)
art_link = Enum.at(article_links, 2)
summary = Readability.summarize(art_link)
Nx.default_backend({EXLA.Backend, client: :host})
# Nx.global_default_backend(EXLA.Backend)
Nx.default_backend()
model_name = "facebook/bart-large-cnn"

{:ok, model} = Bumblebee.load_model({:hf, model_name})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
title_serving =
  Bumblebee.Text.Generation.generation(model, tokenizer, min_length: 10, max_length: 20)

article_serving =
  Bumblebee.Text.Generation.generation(model, tokenizer, min_length: 200, max_length: 225)
small_title = Nx.Serving.run(title_serving, summary.title)
small_text = Nx.Serving.run(article_serving, summary.article_text)
# Now let's try the GPU
Nx.default_backend({EXLA.Backend, client: :cuda})
Nx.default_backend()
{:ok, g_model} = Bumblebee.load_model({:hf, model_name})
{:ok, g_tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
g_title_serving =
  Bumblebee.Text.Generation.generation(g_model, g_tokenizer, min_length: 10, max_length: 20)

g_article_serving =
  Bumblebee.Text.Generation.generation(g_model, g_tokenizer, min_length: 200, max_length: 225)
g_small_title = Nx.Serving.run(g_title_serving, summary.title)
g_small_text = Nx.Serving.run(g_article_serving, summary.article_text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment