Skip to content

Instantly share code, notes, and snippets.

@hanslovsky
Last active June 9, 2021 03:01
Show Gist options
  • Save hanslovsky/f52d006118c468d3277d6f240f784949 to your computer and use it in GitHub Desktop.
Save hanslovsky/f52d006118c468d3277d6f240f784949 to your computer and use it in GitHub Desktop.
Example use case for using Jep to generate numpy arrays in CacheLoader
#!/usr/bin/env kscript
// requires kscript: https://github.com/holgerbrandl/kscript
// install jep native libraries with
// python -m pip install jep
// When using Python interpreter in a a non-standard location, set PYTHONHOME appropriately.
@file:MavenRepository("scijava", "https://maven.scijava.org/content/groups/public")
@file:DependsOn("net.imglib2:imglib2-cache:1.0.0-beta-16")
@file:DependsOn("net.imglib2:imglib2:5.12.0")
@file:DependsOn("black.ninia:jep:3.9.1")
@file:DependsOn("sc.fiji:bigdataviewer-vistools:1.0.0-beta-28")
@file:DependsOn("sc.fiji:bigdataviewer-core:10.2.0")
import bdv.util.BdvFunctions
import bdv.util.BdvOptions
import bdv.util.volatiles.SharedQueue
import bdv.util.volatiles.VolatileViews
import java.lang.System
import jep.DirectNDArray
import jep.SharedInterpreter
import net.imglib2.cache.CacheLoader
import net.imglib2.cache.img.CachedCellImg
import net.imglib2.cache.ref.GuardedStrongRefLoaderCache
import net.imglib2.cache.ref.SoftRefLoaderCache
import net.imglib2.img.basictypeaccess.volatiles.VolatileDoubleAccess
import net.imglib2.img.cell.Cell
import net.imglib2.img.cell.CellGrid
import net.imglib2.type.numeric.real.DoubleType
import net.imglib2.util.Intervals
import java.nio.ByteBuffer
import java.nio.DoubleBuffer
import java.util.concurrent.BlockingQueue
import java.util.concurrent.CountDownLatch
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.TimeUnit
class DoubleBufferAccess(private val buf: DoubleBuffer) : VolatileDoubleAccess {
override fun getValue(index: Int) = buf[index]
override fun setValue(index: Int, value: Double) {
buf.put(index, value)
}
override fun isValid() = true
companion object {
val empty get() = DoubleBufferAccess(ByteBuffer.allocate(0).asDoubleBuffer())
}
}
class Task(val buf: DoubleBuffer, val index: Long, val min: LongArray, val max: LongArray, val dim: IntArray, val code: String, val blockName: String? = null) {
private val latch = CountDownLatch(1)
fun complete() = latch.countDown()
fun awaitCompletion() = latch.await()
}
class Worker(
queue: BlockingQueue<Task>,
init: String? = null,
name: String? = null) {
private var closed = false
private val pythonReady = CountDownLatch(1)
private val workerThread = Thread {
val python = try {
SharedInterpreter().also { it.initialize() }
} finally {
pythonReady.countDown()
}
init?.let { python.exec(it) }
while (!closed) {
queue.poll(10, TimeUnit.MILLISECONDS)?.let { task ->
try {
require(task.buf.isDirect)
python.set("_buf", DirectNDArray(task.buf, *task.dim.reversedArray()))
python.set("_index", task.index)
python.set("_min", task.min.reversedArray())
python.set("_max", task.max.reversedArray())
python.set("_dim", task.dim.reversedArray())
python.exec("${task.blockName ?: "block"} = Block(_buf, _index, _min, _max, _dim)")
python.exec(task.code)
} catch (e: Exception) {
e.printStackTrace()
} finally {
task.complete()
}
}
}
}
init {
workerThread.isDaemon = true
name?.let { workerThread.setName(it) }
workerThread.start()
pythonReady.await()
}
fun close() {
closed = true
}
companion object {
fun SharedInterpreter.initialize() {
exec(
"""
from dataclasses import dataclass
import numpy as np
@dataclass
class Block:
data: np.ndarray
index: int
min: tuple
max: tuple
dim: tuple
""".trimIndent())
}
}
}
class WorkerQueue(numWorkers: Int, init: String? = null) {
private val queue = LinkedBlockingDeque<Task>()
private val workers = Array(numWorkers) { Worker(queue, init, "Python-$it") }
fun submitAndAwaitCompletion(buf: DoubleBuffer, index: Long, min: LongArray, max: LongArray, dim: IntArray, code: String, blockName: String? = null) {
val task = Task(buf, index, min, max, dim, code, blockName)
queue.add(task)
task.awaitCompletion()
}
fun close() {
workers.forEach { it.close() }
}
}
class JepyterCacheLoader(
private val grid: CellGrid,
numWorkers: Int,
private val code: String,
init: String? = null,
private val blockName: String? = null) : CacheLoader<Long, Cell<DoubleBufferAccess>> {
private val workerQueue = WorkerQueue(numWorkers, init)
override fun get(key: Long): Cell<DoubleBufferAccess> {
grid.getCellDimension(1, 2L)
val min = LongArray(grid.nDim) { grid.getCellMin(it, key) }
val dim = IntArray(grid.nDim)
grid.getCellDimensions(key, min, dim)
val max = LongArray(grid.nDim) { min[it] + dim[it] - 1 }
val buf = ByteBuffer.allocateDirect(8 * Intervals.numElements(*dim).toInt()).asDoubleBuffer()
workerQueue.submitAndAwaitCompletion(buf, key, min, max, dim, code, blockName)
return Cell(dim, min, DoubleBufferAccess(buf))
}
companion object {
private val CellGrid.nDim get() = numDimensions()
}
}
val dims = longArrayOf(300, 400, 500)
val bs = intArrayOf(30, 40, 50)
val grid = CellGrid(dims, bs)
val loader = JepyterCacheLoader(
grid,
3,
code = """
block.data[...] = np.mod(np.arange(block.data.size), 255).reshape(block.data.shape)
# add 50 milliseconds delay to visualize how blocks are generated on demand
time.sleep(0.05)
""".trimIndent(),
init = "import numpy as np; import time"
)
// Soft ref cache will not work because native memory will not be added to heap.
// Use cache with hard limit on size instead to make sure that unused memory gets freed.
val cache = GuardedStrongRefLoaderCache<Long, Cell<DoubleBufferAccess>>(30).withLoader(loader)
val img = CachedCellImg(grid, DoubleType(), cache, DoubleBufferAccess.empty)
val bdv = BdvFunctions.show(
VolatileViews.wrapAsVolatile(img, SharedQueue(10, 1)),
"numpy",
BdvOptions.options().numRenderingThreads(10))
bdv.setDisplayRange(0.0, 255.0)
#!/usr/bin/env kscript
// requires kscript: https://github.com/holgerbrandl/kscript
// install jep native libraries and tensorflow, stardist dependencies with
// python -m pip install jep stardist tensorflow
// When using Python interpreter in a a non-standard location, set PYTHONHOME appropriately.
@file:MavenRepository("scijava", "https://maven.scijava.org/content/groups/public")
@file:DependsOn("net.imglib2:imglib2-cache:1.0.0-beta-16")
@file:DependsOn("net.imglib2:imglib2:5.12.0")
@file:DependsOn("black.ninia:jep:3.9.1")
@file:DependsOn("sc.fiji:bigdataviewer-vistools:1.0.0-beta-28")
@file:DependsOn("sc.fiji:bigdataviewer-core:10.2.0")
import bdv.util.BdvFunctions
import bdv.util.BdvOptions
import bdv.util.volatiles.SharedQueue
import bdv.util.volatiles.VolatileViews
import java.lang.System
import jep.DirectNDArray
import jep.SharedInterpreter
import net.imglib2.cache.CacheLoader
import net.imglib2.cache.img.CachedCellImg
import net.imglib2.cache.ref.GuardedStrongRefLoaderCache
import net.imglib2.cache.ref.SoftRefLoaderCache
import net.imglib2.img.basictypeaccess.volatiles.VolatileDoubleAccess
import net.imglib2.img.cell.Cell
import net.imglib2.img.cell.CellGrid
import net.imglib2.type.numeric.real.DoubleType
import net.imglib2.util.Intervals
import java.nio.ByteBuffer
import java.nio.DoubleBuffer
import java.util.concurrent.BlockingQueue
import java.util.concurrent.CountDownLatch
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.TimeUnit
class DoubleBufferAccess(private val buf: DoubleBuffer) : VolatileDoubleAccess {
override fun getValue(index: Int) = buf[index]
override fun setValue(index: Int, value: Double) {
buf.put(index, value)
}
override fun isValid() = true
companion object {
val empty get() = DoubleBufferAccess(ByteBuffer.allocate(0).asDoubleBuffer())
}
}
class Task(val buf: DoubleBuffer, val index: Long, val min: LongArray, val max: LongArray, val dim: IntArray, val code: String, val blockName: String? = null) {
private val latch = CountDownLatch(1)
fun complete() = latch.countDown()
fun awaitCompletion() = latch.await()
}
class Worker(
queue: BlockingQueue<Task>,
init: String? = null,
name: String? = null) {
private var closed = false
private val pythonReady = CountDownLatch(1)
private val workerThread = Thread {
val python = try {
SharedInterpreter().also { it.initialize() }
} finally {
pythonReady.countDown()
}
init?.let { python.exec(it) }
while (!closed) {
queue.poll(10, TimeUnit.MILLISECONDS)?.let { task ->
try {
require(task.buf.isDirect)
python.set("_buf", DirectNDArray(task.buf, *task.dim.reversedArray()))
python.set("_index", task.index)
python.set("_min", task.min.reversedArray())
python.set("_max", task.max.reversedArray())
python.set("_dim", task.dim.reversedArray())
python.exec("${task.blockName ?: "block"} = Block(_buf, _index, _min, _max, _dim)")
python.exec(task.code)
} catch (e: Exception) {
e.printStackTrace()
} finally {
task.complete()
}
}
}
}
init {
workerThread.isDaemon = true
name?.let { workerThread.setName(it) }
workerThread.start()
pythonReady.await()
}
fun close() {
closed = true
}
companion object {
fun SharedInterpreter.initialize() {
exec(
"""
from dataclasses import dataclass
import numpy as np
@dataclass
class Block:
data: np.ndarray
index: int
min: tuple
max: tuple
dim: tuple
""".trimIndent())
}
}
}
class WorkerQueue(numWorkers: Int, init: String? = null) {
private val queue = LinkedBlockingDeque<Task>()
private val workers = Array(numWorkers) { Worker(queue, init, "Python-$it") }
fun submitAndAwaitCompletion(buf: DoubleBuffer, index: Long, min: LongArray, max: LongArray, dim: IntArray, code: String, blockName: String? = null) {
val task = Task(buf, index, min, max, dim, code, blockName)
queue.add(task)
task.awaitCompletion()
}
fun close() {
workers.forEach { it.close() }
}
}
class JepyterCacheLoader(
private val grid: CellGrid,
numWorkers: Int,
private val code: String,
init: String? = null,
private val blockName: String? = null) : CacheLoader<Long, Cell<DoubleBufferAccess>> {
private val workerQueue = WorkerQueue(numWorkers, init)
override fun get(key: Long): Cell<DoubleBufferAccess> {
grid.getCellDimension(1, 2L)
val min = LongArray(grid.nDim) { grid.getCellMin(it, key) }
val dim = IntArray(grid.nDim)
grid.getCellDimensions(key, min, dim)
val max = LongArray(grid.nDim) { min[it] + dim[it] - 1 }
val buf = ByteBuffer.allocateDirect(8 * Intervals.numElements(*dim).toInt()).asDoubleBuffer()
workerQueue.submitAndAwaitCompletion(buf, key, min, max, dim, code, blockName)
return Cell(dim, min, DoubleBufferAccess(buf))
}
companion object {
private val CellGrid.nDim get() = numDimensions()
}
}
val initBlock = """
from stardist.data import test_image_nuclei_2d
from stardist.models import StarDist2D
from stardist.plot import render_label
from csbdeep.utils import normalize
img = test_image_nuclei_2d()
model = StarDist2D.from_pretrained('2D_versatile_fluo')
""".trimIndent()
val code = """
halo = 10
offsets = tuple(min(m, halo) for m in block.min)
print(f'{offsets=}')
slicing = tuple(slice(m - o, M+1 + halo) for o, m, M in zip(offsets, block.min, block.max))
# slicing = tuple(slice(m, M+1) for m, M in zip(block.min, block.max))
labels, _ = model.predict_instances(normalize(img[slicing]))
block.data[...] = labels[tuple(slice(o, o + s) for o, s in zip(offsets, block.data.shape))] # labels
""".trimIndent()
val dims = longArrayOf(512, 512)
val bs = intArrayOf(80, 90)
val grid = CellGrid(dims, bs)
val loader = JepyterCacheLoader(
grid,
3,
code = code,
init = initBlock
)
// Soft ref cache will not work because native memory will not be added to heap.
// Use cache with hard limit on size instead to make sure that unused memory gets freed.
val cache = GuardedStrongRefLoaderCache<Long, Cell<DoubleBufferAccess>>(30).withLoader(loader)
val img = CachedCellImg(grid, DoubleType(), cache, DoubleBufferAccess.empty)
val bdv = BdvFunctions.show(
VolatileViews.wrapAsVolatile(img, SharedQueue(10, 1)),
"numpy",
BdvOptions.options().numRenderingThreads(10).is2D())
bdv.setDisplayRange(0.0, 10.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment