Mix.install([
{:ortex, "~> 0.1.9"},
{:image, "~> 0.37"},
{:nx_image, "~> 0.1.2"},
{:exla, "~> 0.7.2"},
{:kino, "~> 0.12.3"},
{:stb_image, "~> 0.6.8"}
])
This is my attempt into running Segment Anything Model (SAM) from facebook research using the Ortex library to run ONNX models.
Please note that this is my first time using Nx/Ortex/Onnx.
I'm more or less trying to port this jupyter notebook example to livebook: https://colab.research.google.com/drive/1wmjHHcrZ_s8iFuVFh9iHo6GbUS_xH5xq
I'm using onnx "mobile" models found here as they are running fast enough on my mac.
Final Mask is distorted. I do not know where this come from, maybe the way i use tensor reshapes ? SAM requires a 1024x1024 image as input, to facilitate this for now i'm only giving it pre-resized images.
Note that I have tried running different ONNX models exports and it did not change the problem
Nx.global_default_backend(EXLA.Backend)
Nx.default_backend()
model =
Ortex.load("/Users/erisson/Documents/DEV/LEARNING/IA/SamOrtex/files/mobile_sam_encoder.onnx")
decoder =
Ortex.load("/Users/erisson/Documents/DEV/LEARNING/IA/SamOrtex/files/mobile_decoder.onnx")
image_input = Kino.Input.image("Uploaded Image")
# https://ibb.co/Hr8588Y the image I'm using
# Code mostly copied from bumblebee examples
%{file_ref: file_ref, format: :rgb, height: height, width: width} = Kino.Input.read(image_input)
content = file_ref |> Kino.Input.file_path() |> File.read!()
image_tensor =
Nx.from_binary(content, :u8)
|> Nx.reshape({height, width, 3})
# does not need resizing, image is already 1024x1024
resized_tensor = image_tensor
# resized_tensor = NxImage.resize(image_tensor, {1024, 1024} )
original_image = Kino.Image.new(image_tensor)
original_label = Kino.Markdown.new("**Original image**")
resized_image = Kino.Image.new(resized_tensor)
resized_label = Kino.Markdown.new("**Resized image**")
Kino.Layout.grid(
[
Kino.Layout.grid([original_image, original_label], boxed: true),
Kino.Layout.grid([resized_image, resized_label], boxed: true)
],
columns: 3
)
tensor =
resized_tensor
|> Nx.as_type(:f32)
# Mean and std values copied from transformer.js
mean = Nx.tensor([123.675, 116.28, 103.53])
std = Nx.tensor([58.395, 57.12, 57.375])
normalized_tensor =
tensor
|> NxImage.normalize(mean, std)
# Running image encoder
{image_embeddings} = Ortex.run(model, Nx.broadcast(normalized_tensor, {1024, 1024, 3}))
# prepare inputs
# xy box coordinates in our image of the object we want to detour
input_point = Nx.tensor([[345, 272], [640, 760]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2})
# input_point = Nx.tensor([[514, 514], [0, 0]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2})
# 2, 3 is for box startig / end points
input_label = Nx.tensor([2, 3]) |> Nx.reshape({1, 2}) |> Nx.as_type(:f32)
# Filled with 0, not used here
mask_input = Nx.broadcast(0, {1, 1, 256, 256}) |> Nx.as_type(:f32)
# not using mask_input
has_mask = Nx.broadcast(0, 1) |> Nx.as_type(:f32)
original_image_dim = Nx.tensor([height, width]) |> Nx.as_type(:f32)
{mask, _io, low_res} =
Ortex.run(decoder, {
Nx.broadcast(image_embeddings, {1, 256, 64, 64}),
Nx.broadcast(input_point, {1, 2, 2}),
Nx.broadcast(input_label, {1, 2}),
Nx.broadcast(mask_input, {1, 1, 256, 256}),
Nx.broadcast(has_mask, {1}),
Nx.broadcast(original_image_dim, {2})
})
# tarnsform mask to black and white,
# there must be a better way of doing this
mask =
mask
|> Nx.backend_transfer()
|> Nx.map(fn x ->
if Nx.to_number(x) >= 0 do
255
else
0
end
end)
low_res_masks =
low_res
|> Nx.backend_transfer()
|> Nx.map(fn x ->
if Nx.to_number(x) >= 0 do
255
else
0
end
end)
# transform mask into an image shape
mask =
mask[0][0] |> Nx.as_type(:u8) |> Nx.reshape({1024, 1024, 1}) |> NxImage.resize({1024, 1024})
mask_small = low_res_masks[0][0] |> Nx.as_type(:u8) |> Nx.reshape({256, 256, 1})
As you can see the high res mask is distorted / badly placed. However, the low res mask fits corectly ??
# mask =
mask_image = Kino.Image.new(mask)
mask_label = Kino.Markdown.new("**Image mask**")
low_res = NxImage.resize(mask_small, {1024, 1024})
low_res_image = Kino.Image.new(low_res)
low_res_label = Kino.Markdown.new("**Low res mask**")
Kino.Layout.grid(
[
Kino.Layout.grid([original_image, original_label], boxed: true),
Kino.Layout.grid([mask_image, mask_label], boxed: true),
Kino.Layout.grid([low_res_image, low_res_label], boxed: true)
],
columns: 3
)