Skip to content

Instantly share code, notes, and snippets.

@oeway
Last active October 11, 2023 16:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oeway/06af06feafe48717753d1e15ed0f9245 to your computer and use it in GitHub Desktop.
Save oeway/06af06feafe48717753d1e15ed0f9245 to your computer and use it in GitHub Desktop.

SAM for the BioEngine

Inference

Here are some files for converting SAM models to make it available in the BioEngine (local deployment: https://github.com/oeway/bioengine/).

The tree strcture for the model repository is like this:

├── sam-backbone
│   ├── 1
│   │   └── model.pt
│   └── config.pbtxt
└── sam-decoder
    ├── 1
    │   └── model.onnx
    └── config.pbtxt

(the folder named 1 is the version number, it's a folder containing the model file, we can have multiple versions, e.g. a 2folder)

For longer term, it would be nice if we can integrate the SAM model here: https://github.com/bioimage-io/bioengine-model-runner and upload to our model repository for the BionEinge.

To connect to the BioEngine and run the model, take a look at the file test_bioengine_sam.py.

You can also run it via: https://jupyter.imjoy.io/ (make sure you do %pip install imjoy-rpc numpy)

Synchronous API

In the code test_bioengine_sam.py, we used async api to the BioEngine (since it will work with JupyterLite in the browser and the native Python). If you run it in native Python, you can also do synchronous calls by replacing from imjoy_rpc.hypha import connect_to_server to from imjoy_rpc.hypha.sync import connect_to_server (and you can then remove all the async/await). About the imjoy-rpc api with synchronous wrapper see here.

The full code for native Python using the synchronous API is in test_bioengine_sam_sync.py.

You can find an example for napari here: https://github.com/bioimage-io/napari-bioimageio/blob/main/examples/bioengine-app-demo/bioengine_app_demo/_bioengine_app.py#L7C1-L7C58

Training

To support training, we would like to get a python class following this interactive training interface.

Here is an example class for trainging cellpose.

name: "sam-decoder"
backend: "onnxruntime"
platform: "onnxruntime_onnx"
parameters: {
key: "INFERENCE_MODE"
value: {
string_value: "true"
}
}
instance_group {
count: 1
kind: KIND_CPU
}
name: "sam-backbone"
backend: "pytorch"
platform: "pytorch_libtorch"
max_batch_size : 1
input [
{
name: "input0__0"
data_type: TYPE_FP32
dims: [3, -1, -1]
}
]
output [
{
name: "output0__0"
data_type: TYPE_FP32
dims: [256, 64, 64]
}
]
parameters: {
key: "INFERENCE_MODE"
value: {
string_value: "true"
}
}
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
from segment_anything import build_sam, SamPredictor
image = cv2.imread('733_D4_2_blue_red_green.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor = SamPredictor(build_sam(checkpoint="sam_vit_h_4b8939.pth"))
model = predictor.model.image_encoder
# predictor.set_image(image)
# Switch the model to eval model
model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 1024, 1024)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# Save the TorchScript model
traced_script_module.save("sam_vit_h_4b8939.pt")
# predictor.set_image(<your_image>)
# masks, _, _ = predictor.predict(<input_prompts>)
# model = torch.jit.load("sam_vit_h_4b8939.pt")
# example = torch.rand(1, 3, 16, 16)
# model.eval()
# result = model(example)
# # shape[1, 256, 64, 64]
# print(result.shape)
from imjoy_rpc.hypha import connect_to_server
import numpy as np
import time
SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io"
async def test_backbone(triton):
config = await triton.get_config(model_name="micro-sam-vit-b-backbone")
print(config)
image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype(
"float32"
)
start_time = time.time()
result = await triton.execute(
inputs=[image],
model_name="micro-sam-vit-b-backbone",
)
print("Backbone",result)
embedding = result['output0__0']
print("Time taken: ", time.time() - start_time)
print("Test passed", embedding.shape)
async def test_decoder(triton):
start_time = time.time()
config = await triton.get_config(model_name="micro-sam-vit-b-decoder")
print("Decoder", config)
# {'name': 'orig_im_size', 'dims': [2]}
# {'name': 'has_mask_input', 'dims': [1]}
# {'name': 'mask_input', 'dims': [1, 1, 256, 256]}
# {'name': 'point_labels', 'dims': [1, -1]}
# {'name': 'point_coords', 'dims': [1, -1, 2]}
# {'name': 'image_embeddings', 'dims': [1, 256, 64, 64]}
orig_im_size = np.array([1024, 1024], dtype=np.float32)
has_mask_input = np.array([0], dtype=np.float32)
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
point_labels = np.array([[0, 1, 2]], dtype=np.float32)
point_coords = np.array([[[100, 200], [300, 400], [500, 600]]], dtype=np.float32)
image_embeddings = np.random.rand(1, 256, 64, 64).astype(np.float32)
result = await triton.execute(
inputs=[orig_im_size, has_mask_input, mask_input, point_labels, point_coords, image_embeddings],
model_name="micro-sam-vit-b-decoder",
)
# the output keys are ['iou_predictions', 'low_res_masks', 'masks', '__info__']
print(result)
print("Time taken: ", time.time() - start_time)
print("Test passed", result['masks'].shape)
async def run():
server = await connect_to_server(
{"name": "test client", "server_url": SERVER_URL, "method_timeout": 100}
)
triton = await server.get_service("triton-client")
await test_backbone(triton)
await test_decoder(triton)
if __name__ == "__main__":
import asyncio
asyncio.run(run())
from imjoy_rpc.hypha.sync import connect_to_server
import numpy as np
import time
SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io"
def test_backbone(triton):
config = triton.get_config(model_name="micro-sam-vit-b-backbone")
print(config)
image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype(
"float32"
)
start_time = time.time()
result = triton.execute(
inputs=[image],
model_name="micro-sam-vit-b-backbone",
)
print("Backbone",result)
embedding = result['output0__0']
print("Time taken: ", time.time() - start_time)
print("Test passed", embedding.shape)
def test_decoder(triton):
start_time = time.time()
config = triton.get_config(model_name="micro-sam-vit-b-decoder")
print("Decoder", config)
# {'name': 'orig_im_size', 'dims': [2]}
# {'name': 'has_mask_input', 'dims': [1]}
# {'name': 'mask_input', 'dims': [1, 1, 256, 256]}
# {'name': 'point_labels', 'dims': [1, -1]}
# {'name': 'point_coords', 'dims': [1, -1, 2]}
# {'name': 'image_embeddings', 'dims': [1, 256, 64, 64]}
orig_im_size = np.array([1024, 1024], dtype=np.float32)
has_mask_input = np.array([0], dtype=np.float32)
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
point_labels = np.array([[0, 1, 2]], dtype=np.float32)
point_coords = np.array([[[100, 200], [300, 400], [500, 600]]], dtype=np.float32)
image_embeddings = np.random.rand(1, 256, 64, 64).astype(np.float32)
result = triton.execute(
inputs=[orig_im_size, has_mask_input, mask_input, point_labels, point_coords, image_embeddings],
model_name="micro-sam-vit-b-decoder",
)
# the output keys are ['iou_predictions', 'low_res_masks', 'masks', '__info__']
print(result)
print("Time taken: ", time.time() - start_time)
print("Test passed", result['masks'].shape)
def run():
server = connect_to_server(
{"name": "test client", "server_url": SERVER_URL, "method_timeout": 100}
)
triton = server.get_service("triton-client")
test_backbone(triton)
test_decoder(triton)
if __name__ == "__main__":
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment