Last active
January 7, 2020 20:48
-
-
Save hanslovsky/7ca4809064e5e6cac369c071a2dccd53 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.concurrent.Executors | |
import javafx.application.Platform | |
import javafx.util.Duration | |
import net.imglib2.algorithm.util.Grids | |
import net.imglib2.util.Intervals | |
import net.imglib2.view.Views | |
import org.janelia.saalfeldlab.paintera.state.LabelSourceState | |
import org.janelia.saalfeldlab.paintera.state.label.ConnectomicsLabelState | |
def sources = paintera.baseView.sourceInfo() | |
def labelSource = sources.currentSourceProperty().get() | |
def labels = sources.getState(labelSource) | |
if (!(labels instanceof LabelSourceState || labels instanceof ConnectomicsLabelState)) { | |
throw new RuntimeException("Need to select label data first!") | |
} | |
labels = labels instanceof LabelSourceState ? labels as LabelSourceState : labels as ConnectomicsLabelState | |
def lockedSegments = labels instanceof LabelSourceState ? labels.lockedSegments() : labels.getLockedSegments() | |
def assignment = labels instanceof LabelSourceState ? labels.assignment : labels.getFragmentSegmentAssignment() | |
def source = labelSource.getDataSource(0, 0) | |
def min = Intervals.minAsLongArray(source) | |
def max = Intervals.maxAsLongArray(source) | |
def blockSize = [64, 64, 64] as int[] | |
def blocks = Grids.collectAllContainedIntervals(min, max, blockSize) | |
def numThreads = Runtime.getRuntime().availableProcessors() | |
def tasks = [] | |
def es = Executors.newFixedThreadPool(numThreads) | |
for (block in blocks) { | |
def myBlock = block | |
tasks.add({ | |
def count = 0L | |
def invalid = 0L | |
def cursor = Views.interval(source, myBlock).cursor() | |
while (cursor.hasNext()) { | |
def segment = assignment.getSegment(cursor.next().getIntegerLong()) | |
if (segment <= 0) | |
++invalid | |
else if (lockedSegments.isLocked(segment)) | |
++count | |
// break | |
} | |
[count, invalid] as long[] | |
}) | |
} | |
def start = System.currentTimeMillis() | |
def futures = es.invokeAll(tasks) | |
def count = 0 | |
def invalid = 0 | |
for (f in futures) { | |
def c = f.get() | |
count += c[0] | |
invalid += c[1] | |
} | |
es.shutdown() | |
def totalVoxels = Intervals.dimensionsAsLongArray(source).inject(1) { prod, val -> prod * val } | |
def time = System.currentTimeMillis() - start | |
def message = "${count}/${totalVoxels} voxels (${invalid} with special label) are locked (run time: ${time} milliseconds)." |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment