Skip to content

Instantly share code, notes, and snippets.

@Nanguage
Last active July 20, 2023 09:40
Show Gist options
  • Save Nanguage/27a003bec90ee8fdd29859cb446c608d to your computer and use it in GitHub Desktop.
Save Nanguage/27a003bec90ee8fdd29859cb446c608d to your computer and use it in GitHub Desktop.
<config lang="json">
{
"name": "mi-core",
"type": "web-python",
"tags": [],
"flags": [],
"ui": "",
"version": "0.1.0",
"cover": "",
"description": "Connect to the bioengine server, and execute operations.",
"icon": "extension",
"inputs": null,
"outputs": null,
"api_version": "0.1.8",
"env": "",
"permissions": [],
"requirements": ["imageio", "xarray"],
"dependencies": []
}
</config>
<script lang="python">
import io
import imageio
import numpy as np
from xarray import DataArray
from imjoy import api
from imjoy_rpc.hypha import connect_to_server
try:
import pyodide
is_pyodide = True
except ImportError:
is_pyodide = False
async def fetch_file_content(url) -> bytes:
"""Fetch file content from url, return bytes.
Compatible with both pyodide and native-python.
Reference:
https://github.com/imjoy-team/kaibu-utils/blob/ecc25337adb0c94e6345f09bba80aa0637ce9af0/kaibu_utils/__init__.py#L403-L425
"""
await api.log("Fetch URL: " + url)
if is_pyodide:
from js import fetch
response = await fetch(url)
bytes_ = await response.arrayBuffer()
bytes_ = bytes_.to_py()
else:
import requests
bytes_ = requests.get(url)
await api.log("Fetched bytes: " + str(len(bytes_)))
return bytes_
def is_channel_first(shape):
if len(shape) == 5: # with batch dimension
shape = shape[1:]
min_dim = np.argmin(list(shape))
if min_dim == 0: # easy case: channel first
return True
elif min_dim == len(shape) - 1: # easy case: channel last
return False
else: # hard case: can't figure it out, just guess channel first
return True
def get_default_image_axes(shape, input_tensor_axes):
ndim = len(shape)
has_z_axis = "z" in input_tensor_axes
if ndim == 2:
axes = "yx"
elif ndim == 3 and has_z_axis:
axes = "zyx"
elif ndim == 3:
channel_first = is_channel_first(shape)
axes = "cyx" if channel_first else "yxc"
elif ndim == 4 and has_z_axis:
channel_first = is_channel_first(shape)
axes = "czyx" if channel_first else "zyxc"
elif ndim == 4:
channel_first = is_channel_first(shape)
axes = "bcyx" if channel_first else "byxc"
elif ndim == 5:
channel_first = is_channel_first(shape)
axes = "bczyx" if channel_first else "bzyxc"
else:
raise ValueError(f"Invalid number of image dimensions: {ndim}")
return axes
def map_axes(
input_array,
input_axes,
output_axes,
# spatial axes: drop at middle coordnate, other axes (channel or batch): drop at 0 coordinate
drop_function=lambda ax_name, ax_len: ax_len // 2 if ax_name in "zyx" else 0
):
assert len(input_axes) == input_array.ndim, f"Number of axes {len(input_axes)} and dimension of input {input_array.ndim} don't match"
shape = {ax_name: sh for ax_name, sh in zip(input_axes, input_array.shape)}
output = DataArray(input_array, dims=tuple(input_axes))
# drop axes not part of the output
drop_axis_names = tuple(set(input_axes) - set(output_axes))
drop_axes = {ax_name: drop_function(ax_name, shape[ax_name]) for ax_name in drop_axis_names}
output = output[drop_axes]
# expand axes missing from the input
missing_axes = tuple(set(output_axes) - set(input_axes))
output = output.expand_dims(dim=missing_axes)
# transpose to the desired axis order
output = output.transpose(*tuple(output_axes))
# return numpy array
return output.values
def transform_input(image: np.ndarray, image_axes: str, output_axes: str):
"""Transfor the input image into an output tensor with output_axes
Args:
image: the input image
image_axes: the axes of the input image as simple string
output_axes: the axes of the output tensor that will be returned
"""
return map_axes(image, image_axes, output_axes)
class Plugin():
async def setup(self):
api.log("Connector plugin is ready.")
server = await connect_to_server(
{"name": "client", "server_url": "https://ai.imjoy.io", "method_timeout": 30}
)
self.triton = await server.get_service("triton-client")
self.image = None
def set_current_image(self, image):
assert isinstance(image, np.ndarray)
self.image = image
async def get_current_image_shape(self):
assert self.image is not None
return self.image.shape
async def show_image_vtk(self):
if self.image is None:
await api.alert("Please load an image first.")
return
self.vtk_viewer = await api.createWindow(
src="https://oeway.github.io/itk-vtk-viewer/",
fullscreen=False
)
image = self.image
self.vtk_viewer.setImage(image)
async def bioengine_execute(self, model_id, inputs=None, return_rdf=False, weight_format=None):
kwargs = {"model_id": model_id, "inputs": inputs, "return_rdf": return_rdf, "weight_format": weight_format}
ret = await self.triton.execute(
inputs=[kwargs],
model_name="bioengine-model-runner",
serialization="imjoy"
)
return ret["result"]
async def get_model_rdf(self, model_id):
ret = await self.bioengine_execute(model_id, return_rdf=True)
return ret["rdf"]
async def load_image_from_bytes(self, file_name, img_bytes):
_file = io.BytesIO(img_bytes)
_file.name = file_name
if file_name.endswith(".tif") or file_name.endswith(".tiff"):
image = imageio.volread(_file)
else:
image = imageio.imread(_file)
await api.log(
"Image loaded with shape: " + str(image.shape) +
" and dtype: " + str(image.dtype)
)
self.set_current_image(image)
async def load_image_from_url(self, url):
file_name = url.split("?")[0].rstrip('/').split("/")[-1]
await api.log(file_name)
bytes_ = await fetch_file_content(url)
await self.load_image_from_bytes(file_name, bytes_)
async def run_model(
self, model_id, rdf,
image_axes=None, weight_format=None):
if self.image is None:
await api.alert("Please load an image first.")
return False
img = self.image
input_spec = rdf['inputs'][0]
input_tensor_axes = input_spec['axes']
await api.log("input_tensor_axes", input_tensor_axes)
if image_axes is None:
shape = img.shape
image_axes = get_default_image_axes(shape, input_tensor_axes)
await api.log(f"Image axes were not provided. They were automatically determined to be {image_axes}")
else:
await api.log(f"Image axes were provided as {image_axes}")
assert len(image_axes) == img.ndim
await api.log("Transforming input image...")
img = transform_input(img, image_axes, input_tensor_axes)
await api.log(f"Input image was transformed into shape {img.shape} to fit the model")
await api.log("Data loaded, running model...")
try:
result = await self.bioengine_execute(
model_id, inputs=[img], weight_format=weight_format)
except Exception as exp:
await api.alert(f"Failed to run the model ({model_id}) in the BioEngine, error: {exp}")
return False
if not result['success']:
await api.alert(f"Failed to run the model ({model_id}) in the BioEngine, error: {result['error']}")
return False
output = result['outputs'][0]
await api.showMessage(f"🎉Model execution completed! Got output tensor of shape {output.shape}")
output_tensor_axes = rdf['outputs'][0]['axes']
transformed_output = map_axes(output, output_tensor_axes, image_axes)
self.set_current_image(transformed_output)
return True
async def run(self, ctx):
pass
api.export(Plugin())
</script>
<docs lang="markdown">
The panel for upload the image file.
</docs>
<config lang="json">
{
"name": "mi-panel",
"type": "window",
"tags": [],
"ui": "",
"version": "0.1.0",
"cover": "",
"description": "[TODO: describe this plugin with one sentence.]",
"icon": "extension",
"inputs": null,
"outputs": null,
"api_version": "0.1.8",
"env": "",
"permissions": [],
"requirements": [
"https://cdn.tailwindcss.com",
"https://cdn.jsdelivr.net/npm/imjoy-rpc@0.5.6/dist/hypha-rpc-websocket.min.js"
],
"dependencies": [
"https://gist.githubusercontent.com/Nanguage/27a003bec90ee8fdd29859cb446c608d/raw/d8c8c44405c9b1e3b4c90a44fa369a89e5591ba3/mi_core.imjoy.html"
]
}
</config>
<script lang="javascript">
class ImJoyPlugin {
async setup() {
console.log("setup")
this.core_plugin = await api.getPlugin('mi-core')
this.setupEvents()
}
setupEvents() {
const setInfoPanel = this.setInfoPanel.bind(this)
this.listenFileSelection(setInfoPanel)
this.listenURLInput(setInfoPanel)
this.listenSampleInputButton(setInfoPanel)
this.listenSampleOutputButton(setInfoPanel)
this.listenSubmitButton(setInfoPanel)
}
listenFileSelection(setInfoPanel) {
const fileInput = document.getElementById("file-input")
const fileChosen = document.getElementById('file-chosen');
const viewImageByBytes = this.viewImageByBytes.bind(this)
fileInput.addEventListener("input", async () => {
setInfoPanel("Loading file...")
const file = fileInput.files[0]
fileChosen.textContent = file.name
if (!file) {
await api.alert("No file selected")
setInfoPanel("")
return
}
const reader = new FileReader()
reader.onload = async function() {
const content = this.result
await viewImageByBytes(file.name, content)
// set to empty so that the same file can be loaded again
fileInput.value = ''
setInfoPanel("")
}
reader.readAsArrayBuffer(file)
})
}
listenURLInput(setInfoPanel) {
const viewImageByURL = this.viewImageByURL.bind(this)
const urlInput = document.getElementById("url-input")
const urlLoadButton = document.getElementById("url-load-button")
urlLoadButton.addEventListener("click", async () => {
const url = urlInput.value
if (!url) {
await api.alert("No URL provided")
return
}
setInfoPanel("Loading image...")
await viewImageByURL(url)
setInfoPanel("")
})
}
listenSampleInputButton(setInfoPanel) {
const sampleInputButton = document.getElementById("sample-input-button")
sampleInputButton.addEventListener("click", async () => {
const sampleInput = this.model_rdf.sample_inputs[0]
setInfoPanel("Loading sample input...")
await this.viewImageByURL(sampleInput)
setInfoPanel("")
})
}
listenSampleOutputButton(setInfoPanel) {
const sampleOutputButton = document.getElementById("sample-output-button")
sampleOutputButton.addEventListener("click", async () => {
const sampleOutput = this.model_rdf.sample_outputs[0]
setInfoPanel("Loading sample output...")
await this.viewImageByURL(sampleOutput)
setInfoPanel("")
})
}
listenSubmitButton(setInfoPanel) {
const submitButton = document.getElementById("submit-button")
const core = this.core_plugin
const self = this
const inputAxes = document.getElementById("input-axes-input")
submitButton.addEventListener("click", async () => {
const input_axes = inputAxes.value || null
console.log("Input axes:", input_axes)
setInfoPanel("Running model...")
const ret = await core.run_model(
self.model_id, self.model_rdf, input_axes
)
if (ret === false) {
setInfoPanel("Failed to run the model.")
} else {
await self.updateImageShapePanel()
await core.show_image_vtk()
setInfoPanel("")
}
})
}
async viewImageByBytes(fileName, bytes) {
console.log(bytes)
this.bytes_data = bytes
await this.core_plugin.load_image_from_bytes(fileName, bytes)
await this.updateImageShapePanel()
await this.core_plugin.show_image_vtk()
}
async viewImageByURL(url) {
await this.core_plugin.load_image_from_url(url)
await this.updateImageShapePanel()
await this.core_plugin.show_image_vtk()
}
set_default_url(sampleInput) {
const urlInput = document.getElementById("url-input")
urlInput.value = sampleInput
}
setInfoPanel(content) {
const info = document.getElementById("info-panel")
info.textContent = content
}
async updateImageShapePanel() {
const shape = await this.core_plugin.get_current_image_shape()
const imgShapePanel = document.getElementById("img-shape-panel")
imgShapePanel.textContent = `Current image shape: ${shape}`
}
async run(ctx) {
console.log("run", ctx)
const defaultModel = "10.5281/zenodo.5874741"
const model_id = ctx.data.id || defaultModel
this.model_id = model_id
// Load the model RDF
const rdf = await this.core_plugin.get_model_rdf(model_id)
console.log(rdf)
this.model_rdf = rdf
const sampleInput = rdf.sample_inputs[0]
this.set_default_url(sampleInput)
}
}
api.export(new ImJoyPlugin())
</script>
<window lang="html">
<div class="m-1">
<p class="mt-2">Load an PNG, JPG or TIFF file:</p>
<div class="flex">
<span id="file-chosen" class="w-1/2 px-4 py-2 border border-gray-300 rounded-l-md text-gray-500">No file selected.</span>
<label for="file-input" class="py-2 px-3 bg-blue-500 text-white rounded-r-md cursor-pointer w-20 text-center hover:bg-blue-600">
Browse
</label>
<input type="file" id="file-input" class="hidden" />
</div>
<p class="mt-2">Or load from an URL. The URL to OME-Zarr are supported:</p>
<div class="flex">
<input type="text" class="w-1/2 px-4 py-2 border border-gray-300 rounded-l-md text-gray-500"
placeholder="Target URL"
id="url-input" />
<button
class="px-4 py-2 bg-blue-500 text-white rounded-r-md w-20 hover:bg-blue-600"
id="url-load-button">Load</button>
</div>
<p class="mt-2">Input axes (optional, for example yxzc):</p>
<input type="text" class="w-2/3 px-4 py-2 border border-gray-300 rounded-md text-gray-500"
placeholder="Input axes"
id="input-axes-input" />
<p class="mt-2 text-gray-500 text-sm" id="info-panel"></p>
<p class="mt-2 text-gray-500 text-sm" id="img-shape-panel"></p>
<div class="flex flex-row gap-4 mt-2">
<button
class="px-4 py-2 bg-blue-500 text-white rounded-md w-25 hover:bg-blue-600"
id="submit-button">Submit</button>
<button
class="px-4 py-2 bg-cyan-300 text-white rounded-md w-30 hover:bg-cyan-400"
id="sample-input-button">Sample input</button>
<button
class="px-4 py-2 bg-cyan-300 text-white rounded-md w-30 hover:bg-cyan-400"
id="sample-output-button">Sample output</button>
</div>
</div>
</window>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment