Skip to content

Instantly share code, notes, and snippets.

@herissondev
Last active June 4, 2024 14:09
Show Gist options
  • Save herissondev/2e56459e3076c9404233b5530d0bbc44 to your computer and use it in GitHub Desktop.
Save herissondev/2e56459e3076c9404233b5530d0bbc44 to your computer and use it in GitHub Desktop.
Segment Anything Model using Ortex library

Segment Anything Model (SAM) using Ortex

Mix.install([
  {:ortex, "~> 0.1.9"},
  {:image, "~> 0.37"},
  {:nx_image, "~> 0.1.2"},
  {:exla, "~> 0.7.2"},
  {:kino, "~> 0.12.3"},
  {:stb_image, "~> 0.6.8"}
])

Recap

This is my attempt into running Segment Anything Model (SAM) from facebook research using the Ortex library to run ONNX models.

Please note that this is my first time using Nx/Ortex/Onnx.

I'm more or less trying to port this jupyter notebook example to livebook: https://colab.research.google.com/drive/1wmjHHcrZ_s8iFuVFh9iHo6GbUS_xH5xq

I'm using onnx "mobile" models found here as they are running fast enough on my mac.

Issues

Final Mask is distorted. I do not know where this come from, maybe the way i use tensor reshapes ? SAM requires a 1024x1024 image as input, to facilitate this for now i'm only giving it pre-resized images.

Note that I have tried running different ONNX models exports and it did not change the problem

Nx.global_default_backend(EXLA.Backend)
Nx.default_backend()

Loading the models

model =
  Ortex.load("/Users/erisson/Documents/DEV/LEARNING/IA/SamOrtex/files/mobile_sam_encoder.onnx")
decoder =
  Ortex.load("/Users/erisson/Documents/DEV/LEARNING/IA/SamOrtex/files/mobile_decoder.onnx")

Getting the image

image_input = Kino.Input.image("Uploaded Image")
# https://ibb.co/Hr8588Y the image I'm using
# Code mostly copied from bumblebee examples

%{file_ref: file_ref, format: :rgb, height: height, width: width} = Kino.Input.read(image_input)

content = file_ref |> Kino.Input.file_path() |> File.read!()

image_tensor =
  Nx.from_binary(content, :u8)
  |> Nx.reshape({height, width, 3})

# does not need resizing, image is already 1024x1024
resized_tensor = image_tensor
# resized_tensor = NxImage.resize(image_tensor, {1024, 1024} )

original_image = Kino.Image.new(image_tensor)
original_label = Kino.Markdown.new("**Original image**")

resized_image = Kino.Image.new(resized_tensor)
resized_label = Kino.Markdown.new("**Resized image**")

Kino.Layout.grid(
  [
    Kino.Layout.grid([original_image, original_label], boxed: true),
    Kino.Layout.grid([resized_image, resized_label], boxed: true)
  ],
  columns: 3
)
Capture d’écran 2024-05-09 à 10 36 12
tensor =
  resized_tensor
  |> Nx.as_type(:f32)

# Mean and std values copied from transformer.js 
mean = Nx.tensor([123.675, 116.28, 103.53])
std = Nx.tensor([58.395, 57.12, 57.375])

normalized_tensor =
  tensor
  |> NxImage.normalize(mean, std)

# Running image encoder
{image_embeddings} = Ortex.run(model, Nx.broadcast(normalized_tensor, {1024, 1024, 3}))

Prompt encoding & mask generation

# prepare inputs 

# xy box coordinates in our image of the object we want to detour
input_point = Nx.tensor([[345, 272], [640, 760]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2})
# input_point = Nx.tensor([[514, 514], [0, 0]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2})

# 2, 3 is for box startig / end points
input_label = Nx.tensor([2, 3]) |> Nx.reshape({1, 2}) |> Nx.as_type(:f32)

# Filled with 0, not used here
mask_input = Nx.broadcast(0, {1, 1, 256, 256}) |> Nx.as_type(:f32)

# not using mask_input
has_mask = Nx.broadcast(0, 1) |> Nx.as_type(:f32)

original_image_dim = Nx.tensor([height, width]) |> Nx.as_type(:f32)

{mask, _io, low_res} =
  Ortex.run(decoder, {
    Nx.broadcast(image_embeddings, {1, 256, 64, 64}),
    Nx.broadcast(input_point, {1, 2, 2}),
    Nx.broadcast(input_label, {1, 2}),
    Nx.broadcast(mask_input, {1, 1, 256, 256}),
    Nx.broadcast(has_mask, {1}),
    Nx.broadcast(original_image_dim, {2})
  })

Transform output to a black an white image

# tarnsform mask to black and white,
# there must be a better way of doing this
mask =
  mask
  |> Nx.backend_transfer()
  |> Nx.map(fn x ->
    if Nx.to_number(x) >= 0 do
      255
    else
      0
    end
  end)

low_res_masks =
  low_res
  |> Nx.backend_transfer()
  |> Nx.map(fn x ->
    if Nx.to_number(x) >= 0 do
      255
    else
      0
    end
  end)

# transform mask into an image shape
mask =
  mask[0][0] |> Nx.as_type(:u8) |> Nx.reshape({1024, 1024, 1}) |> NxImage.resize({1024, 1024})

mask_small = low_res_masks[0][0] |> Nx.as_type(:u8) |> Nx.reshape({256, 256, 1})

Showing the masks

As you can see the high res mask is distorted / badly placed. However, the low res mask fits corectly ??

# mask = 
mask_image = Kino.Image.new(mask)
mask_label = Kino.Markdown.new("**Image mask**")

low_res = NxImage.resize(mask_small, {1024, 1024})
low_res_image = Kino.Image.new(low_res)
low_res_label = Kino.Markdown.new("**Low res mask**")

Kino.Layout.grid(
  [
    Kino.Layout.grid([original_image, original_label], boxed: true),
    Kino.Layout.grid([mask_image, mask_label], boxed: true),
    Kino.Layout.grid([low_res_image, low_res_label], boxed: true)
  ],
  columns: 3
)
Capture d’écran 2024-05-09 à 10 34 01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment