Last active
July 20, 2023 09:40
-
-
Save Nanguage/27a003bec90ee8fdd29859cb446c608d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<config lang="json"> | |
{ | |
"name": "mi-core", | |
"type": "web-python", | |
"tags": [], | |
"flags": [], | |
"ui": "", | |
"version": "0.1.0", | |
"cover": "", | |
"description": "Connect to the bioengine server, and execute operations.", | |
"icon": "extension", | |
"inputs": null, | |
"outputs": null, | |
"api_version": "0.1.8", | |
"env": "", | |
"permissions": [], | |
"requirements": ["imageio", "xarray"], | |
"dependencies": [] | |
} | |
</config> | |
<script lang="python"> | |
import io | |
import imageio | |
import numpy as np | |
from xarray import DataArray | |
from imjoy import api | |
from imjoy_rpc.hypha import connect_to_server | |
try: | |
import pyodide | |
is_pyodide = True | |
except ImportError: | |
is_pyodide = False | |
async def fetch_file_content(url) -> bytes: | |
"""Fetch file content from url, return bytes. | |
Compatible with both pyodide and native-python. | |
Reference: | |
https://github.com/imjoy-team/kaibu-utils/blob/ecc25337adb0c94e6345f09bba80aa0637ce9af0/kaibu_utils/__init__.py#L403-L425 | |
""" | |
await api.log("Fetch URL: " + url) | |
if is_pyodide: | |
from js import fetch | |
response = await fetch(url) | |
bytes_ = await response.arrayBuffer() | |
bytes_ = bytes_.to_py() | |
else: | |
import requests | |
bytes_ = requests.get(url) | |
await api.log("Fetched bytes: " + str(len(bytes_))) | |
return bytes_ | |
def is_channel_first(shape): | |
if len(shape) == 5: # with batch dimension | |
shape = shape[1:] | |
min_dim = np.argmin(list(shape)) | |
if min_dim == 0: # easy case: channel first | |
return True | |
elif min_dim == len(shape) - 1: # easy case: channel last | |
return False | |
else: # hard case: can't figure it out, just guess channel first | |
return True | |
def get_default_image_axes(shape, input_tensor_axes): | |
ndim = len(shape) | |
has_z_axis = "z" in input_tensor_axes | |
if ndim == 2: | |
axes = "yx" | |
elif ndim == 3 and has_z_axis: | |
axes = "zyx" | |
elif ndim == 3: | |
channel_first = is_channel_first(shape) | |
axes = "cyx" if channel_first else "yxc" | |
elif ndim == 4 and has_z_axis: | |
channel_first = is_channel_first(shape) | |
axes = "czyx" if channel_first else "zyxc" | |
elif ndim == 4: | |
channel_first = is_channel_first(shape) | |
axes = "bcyx" if channel_first else "byxc" | |
elif ndim == 5: | |
channel_first = is_channel_first(shape) | |
axes = "bczyx" if channel_first else "bzyxc" | |
else: | |
raise ValueError(f"Invalid number of image dimensions: {ndim}") | |
return axes | |
def map_axes( | |
input_array, | |
input_axes, | |
output_axes, | |
# spatial axes: drop at middle coordnate, other axes (channel or batch): drop at 0 coordinate | |
drop_function=lambda ax_name, ax_len: ax_len // 2 if ax_name in "zyx" else 0 | |
): | |
assert len(input_axes) == input_array.ndim, f"Number of axes {len(input_axes)} and dimension of input {input_array.ndim} don't match" | |
shape = {ax_name: sh for ax_name, sh in zip(input_axes, input_array.shape)} | |
output = DataArray(input_array, dims=tuple(input_axes)) | |
# drop axes not part of the output | |
drop_axis_names = tuple(set(input_axes) - set(output_axes)) | |
drop_axes = {ax_name: drop_function(ax_name, shape[ax_name]) for ax_name in drop_axis_names} | |
output = output[drop_axes] | |
# expand axes missing from the input | |
missing_axes = tuple(set(output_axes) - set(input_axes)) | |
output = output.expand_dims(dim=missing_axes) | |
# transpose to the desired axis order | |
output = output.transpose(*tuple(output_axes)) | |
# return numpy array | |
return output.values | |
def transform_input(image: np.ndarray, image_axes: str, output_axes: str): | |
"""Transfor the input image into an output tensor with output_axes | |
Args: | |
image: the input image | |
image_axes: the axes of the input image as simple string | |
output_axes: the axes of the output tensor that will be returned | |
""" | |
return map_axes(image, image_axes, output_axes) | |
class Plugin(): | |
async def setup(self): | |
api.log("Connector plugin is ready.") | |
server = await connect_to_server( | |
{"name": "client", "server_url": "https://ai.imjoy.io", "method_timeout": 30} | |
) | |
self.triton = await server.get_service("triton-client") | |
self.image = None | |
def set_current_image(self, image): | |
assert isinstance(image, np.ndarray) | |
self.image = image | |
async def get_current_image_shape(self): | |
assert self.image is not None | |
return self.image.shape | |
async def show_image_vtk(self): | |
if self.image is None: | |
await api.alert("Please load an image first.") | |
return | |
self.vtk_viewer = await api.createWindow( | |
src="https://oeway.github.io/itk-vtk-viewer/", | |
fullscreen=False | |
) | |
image = self.image | |
self.vtk_viewer.setImage(image) | |
async def bioengine_execute(self, model_id, inputs=None, return_rdf=False, weight_format=None): | |
kwargs = {"model_id": model_id, "inputs": inputs, "return_rdf": return_rdf, "weight_format": weight_format} | |
ret = await self.triton.execute( | |
inputs=[kwargs], | |
model_name="bioengine-model-runner", | |
serialization="imjoy" | |
) | |
return ret["result"] | |
async def get_model_rdf(self, model_id): | |
ret = await self.bioengine_execute(model_id, return_rdf=True) | |
return ret["rdf"] | |
async def load_image_from_bytes(self, file_name, img_bytes): | |
_file = io.BytesIO(img_bytes) | |
_file.name = file_name | |
if file_name.endswith(".tif") or file_name.endswith(".tiff"): | |
image = imageio.volread(_file) | |
else: | |
image = imageio.imread(_file) | |
await api.log( | |
"Image loaded with shape: " + str(image.shape) + | |
" and dtype: " + str(image.dtype) | |
) | |
self.set_current_image(image) | |
async def load_image_from_url(self, url): | |
file_name = url.split("?")[0].rstrip('/').split("/")[-1] | |
await api.log(file_name) | |
bytes_ = await fetch_file_content(url) | |
await self.load_image_from_bytes(file_name, bytes_) | |
async def run_model( | |
self, model_id, rdf, | |
image_axes=None, weight_format=None): | |
if self.image is None: | |
await api.alert("Please load an image first.") | |
return False | |
img = self.image | |
input_spec = rdf['inputs'][0] | |
input_tensor_axes = input_spec['axes'] | |
await api.log("input_tensor_axes", input_tensor_axes) | |
if image_axes is None: | |
shape = img.shape | |
image_axes = get_default_image_axes(shape, input_tensor_axes) | |
await api.log(f"Image axes were not provided. They were automatically determined to be {image_axes}") | |
else: | |
await api.log(f"Image axes were provided as {image_axes}") | |
assert len(image_axes) == img.ndim | |
await api.log("Transforming input image...") | |
img = transform_input(img, image_axes, input_tensor_axes) | |
await api.log(f"Input image was transformed into shape {img.shape} to fit the model") | |
await api.log("Data loaded, running model...") | |
try: | |
result = await self.bioengine_execute( | |
model_id, inputs=[img], weight_format=weight_format) | |
except Exception as exp: | |
await api.alert(f"Failed to run the model ({model_id}) in the BioEngine, error: {exp}") | |
return False | |
if not result['success']: | |
await api.alert(f"Failed to run the model ({model_id}) in the BioEngine, error: {result['error']}") | |
return False | |
output = result['outputs'][0] | |
await api.showMessage(f"🎉Model execution completed! Got output tensor of shape {output.shape}") | |
output_tensor_axes = rdf['outputs'][0]['axes'] | |
transformed_output = map_axes(output, output_tensor_axes, image_axes) | |
self.set_current_image(transformed_output) | |
return True | |
async def run(self, ctx): | |
pass | |
api.export(Plugin()) | |
</script> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<docs lang="markdown"> | |
The panel for upload the image file. | |
</docs> | |
<config lang="json"> | |
{ | |
"name": "mi-panel", | |
"type": "window", | |
"tags": [], | |
"ui": "", | |
"version": "0.1.0", | |
"cover": "", | |
"description": "[TODO: describe this plugin with one sentence.]", | |
"icon": "extension", | |
"inputs": null, | |
"outputs": null, | |
"api_version": "0.1.8", | |
"env": "", | |
"permissions": [], | |
"requirements": [ | |
"https://cdn.tailwindcss.com", | |
"https://cdn.jsdelivr.net/npm/imjoy-rpc@0.5.6/dist/hypha-rpc-websocket.min.js" | |
], | |
"dependencies": [ | |
"https://gist.githubusercontent.com/Nanguage/27a003bec90ee8fdd29859cb446c608d/raw/d8c8c44405c9b1e3b4c90a44fa369a89e5591ba3/mi_core.imjoy.html" | |
] | |
} | |
</config> | |
<script lang="javascript"> | |
class ImJoyPlugin { | |
async setup() { | |
console.log("setup") | |
this.core_plugin = await api.getPlugin('mi-core') | |
this.setupEvents() | |
} | |
setupEvents() { | |
const setInfoPanel = this.setInfoPanel.bind(this) | |
this.listenFileSelection(setInfoPanel) | |
this.listenURLInput(setInfoPanel) | |
this.listenSampleInputButton(setInfoPanel) | |
this.listenSampleOutputButton(setInfoPanel) | |
this.listenSubmitButton(setInfoPanel) | |
} | |
listenFileSelection(setInfoPanel) { | |
const fileInput = document.getElementById("file-input") | |
const fileChosen = document.getElementById('file-chosen'); | |
const viewImageByBytes = this.viewImageByBytes.bind(this) | |
fileInput.addEventListener("input", async () => { | |
setInfoPanel("Loading file...") | |
const file = fileInput.files[0] | |
fileChosen.textContent = file.name | |
if (!file) { | |
await api.alert("No file selected") | |
setInfoPanel("") | |
return | |
} | |
const reader = new FileReader() | |
reader.onload = async function() { | |
const content = this.result | |
await viewImageByBytes(file.name, content) | |
// set to empty so that the same file can be loaded again | |
fileInput.value = '' | |
setInfoPanel("") | |
} | |
reader.readAsArrayBuffer(file) | |
}) | |
} | |
listenURLInput(setInfoPanel) { | |
const viewImageByURL = this.viewImageByURL.bind(this) | |
const urlInput = document.getElementById("url-input") | |
const urlLoadButton = document.getElementById("url-load-button") | |
urlLoadButton.addEventListener("click", async () => { | |
const url = urlInput.value | |
if (!url) { | |
await api.alert("No URL provided") | |
return | |
} | |
setInfoPanel("Loading image...") | |
await viewImageByURL(url) | |
setInfoPanel("") | |
}) | |
} | |
listenSampleInputButton(setInfoPanel) { | |
const sampleInputButton = document.getElementById("sample-input-button") | |
sampleInputButton.addEventListener("click", async () => { | |
const sampleInput = this.model_rdf.sample_inputs[0] | |
setInfoPanel("Loading sample input...") | |
await this.viewImageByURL(sampleInput) | |
setInfoPanel("") | |
}) | |
} | |
listenSampleOutputButton(setInfoPanel) { | |
const sampleOutputButton = document.getElementById("sample-output-button") | |
sampleOutputButton.addEventListener("click", async () => { | |
const sampleOutput = this.model_rdf.sample_outputs[0] | |
setInfoPanel("Loading sample output...") | |
await this.viewImageByURL(sampleOutput) | |
setInfoPanel("") | |
}) | |
} | |
listenSubmitButton(setInfoPanel) { | |
const submitButton = document.getElementById("submit-button") | |
const core = this.core_plugin | |
const self = this | |
const inputAxes = document.getElementById("input-axes-input") | |
submitButton.addEventListener("click", async () => { | |
const input_axes = inputAxes.value || null | |
console.log("Input axes:", input_axes) | |
setInfoPanel("Running model...") | |
const ret = await core.run_model( | |
self.model_id, self.model_rdf, input_axes | |
) | |
if (ret === false) { | |
setInfoPanel("Failed to run the model.") | |
} else { | |
await self.updateImageShapePanel() | |
await core.show_image_vtk() | |
setInfoPanel("") | |
} | |
}) | |
} | |
async viewImageByBytes(fileName, bytes) { | |
console.log(bytes) | |
this.bytes_data = bytes | |
await this.core_plugin.load_image_from_bytes(fileName, bytes) | |
await this.updateImageShapePanel() | |
await this.core_plugin.show_image_vtk() | |
} | |
async viewImageByURL(url) { | |
await this.core_plugin.load_image_from_url(url) | |
await this.updateImageShapePanel() | |
await this.core_plugin.show_image_vtk() | |
} | |
set_default_url(sampleInput) { | |
const urlInput = document.getElementById("url-input") | |
urlInput.value = sampleInput | |
} | |
setInfoPanel(content) { | |
const info = document.getElementById("info-panel") | |
info.textContent = content | |
} | |
async updateImageShapePanel() { | |
const shape = await this.core_plugin.get_current_image_shape() | |
const imgShapePanel = document.getElementById("img-shape-panel") | |
imgShapePanel.textContent = `Current image shape: ${shape}` | |
} | |
async run(ctx) { | |
console.log("run", ctx) | |
const defaultModel = "10.5281/zenodo.5874741" | |
const model_id = ctx.data.id || defaultModel | |
this.model_id = model_id | |
// Load the model RDF | |
const rdf = await this.core_plugin.get_model_rdf(model_id) | |
console.log(rdf) | |
this.model_rdf = rdf | |
const sampleInput = rdf.sample_inputs[0] | |
this.set_default_url(sampleInput) | |
} | |
} | |
api.export(new ImJoyPlugin()) | |
</script> | |
<window lang="html"> | |
<div class="m-1"> | |
<p class="mt-2">Load an PNG, JPG or TIFF file:</p> | |
<div class="flex"> | |
<span id="file-chosen" class="w-1/2 px-4 py-2 border border-gray-300 rounded-l-md text-gray-500">No file selected.</span> | |
<label for="file-input" class="py-2 px-3 bg-blue-500 text-white rounded-r-md cursor-pointer w-20 text-center hover:bg-blue-600"> | |
Browse | |
</label> | |
<input type="file" id="file-input" class="hidden" /> | |
</div> | |
<p class="mt-2">Or load from an URL. The URL to OME-Zarr are supported:</p> | |
<div class="flex"> | |
<input type="text" class="w-1/2 px-4 py-2 border border-gray-300 rounded-l-md text-gray-500" | |
placeholder="Target URL" | |
id="url-input" /> | |
<button | |
class="px-4 py-2 bg-blue-500 text-white rounded-r-md w-20 hover:bg-blue-600" | |
id="url-load-button">Load</button> | |
</div> | |
<p class="mt-2">Input axes (optional, for example yxzc):</p> | |
<input type="text" class="w-2/3 px-4 py-2 border border-gray-300 rounded-md text-gray-500" | |
placeholder="Input axes" | |
id="input-axes-input" /> | |
<p class="mt-2 text-gray-500 text-sm" id="info-panel"></p> | |
<p class="mt-2 text-gray-500 text-sm" id="img-shape-panel"></p> | |
<div class="flex flex-row gap-4 mt-2"> | |
<button | |
class="px-4 py-2 bg-blue-500 text-white rounded-md w-25 hover:bg-blue-600" | |
id="submit-button">Submit</button> | |
<button | |
class="px-4 py-2 bg-cyan-300 text-white rounded-md w-30 hover:bg-cyan-400" | |
id="sample-input-button">Sample input</button> | |
<button | |
class="px-4 py-2 bg-cyan-300 text-white rounded-md w-30 hover:bg-cyan-400" | |
id="sample-output-button">Sample output</button> | |
</div> | |
</div> | |
</window> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment