Skip to content

Instantly share code, notes, and snippets.

@NicoKiaru
Last active September 5, 2023 13:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NicoKiaru/1efa49eb23b2438a603ea1f531be77bd to your computer and use it in GitHub Desktop.
Save NicoKiaru/1efa49eb23b2438a603ea1f531be77bd to your computer and use it in GitHub Desktop.
[Spine analysis] Analyse and output cropped images around each spine. Needs a Simple Neurite Tracer input - #BIOP #Fiji #SNT
/**
* Dendrite Info Display and Analysis
* PTBIOP update site should be enable
* NeuroMorpho update site should be enabled
*
* This script takes a dendrite image as an input and a SNT traces files where:
* - Spine branch should be tagged 'Spine' (at least)
* - The branches on which tagged spines are located should be tagged 'root'
*
* The script outputs:
* - 2 images for reviewing the analysis :
* -- one stack image where each spine cropped region is reoriented in 3D
* -- this stack but as a montage image
* - one csv file which contains the measurement result for the image and for each spine
*
* SNT API : https://morphonets.github.io/SNT/
* SNT Github : https://github.com/morphonets/SNT
*
* for Alessandro Chioino project, Sandi Lab
*
* @author : nicolas.chiaruttini@epfl.ch, BIOP, EPFL, 2020
*/
//---- Inputs
#@ImagePlus imp
#@File(label = "File containing traces") swc_filepath
#@int(label = "Number of pixels surrounding each spine (even number)") pixPerSpineXY
pixPerSpineZ = pixPerSpineXY
#@int(label = "Number of slices taken with z averaging") nSlicesAvg
#@double(label = "Voxel Size", style="format:#.00") voxSize
#@boolean(label = "Interpolate Source") interpolate_source
#@SourceAndConverterService sac_service
#@File(label = "Output Folder", style = "directory") outputFolder
// Number formatting:
df = new DecimalFormat("0.000")
// ------------------------- Get Dendrites ROIs
tree = new Tree(swc_filepath.getAbsolutePath())
cal = imp.getCalibration()
rm = RoiManager.getInstance()
if (rm==null) {
rm = new RoiManager()
}
rm.reset();
analyzer = new TreeAnalyzer(tree)
int i=0;
double totalRootLength = 0;
int idx_spine = 0
// Initializes the list of roiinfo, one roiinfo object per roi contained in the roi manager
// - listed in the natural order = the order given by the RoiManager
final ArrayList<SpineInfo> spines = new ArrayList<>()
cal = imp.getCalibration()
for (branch : tree.list()){
i++;
if (branch.getName().contains("Spine")) {
println(branch.getName())
idx_spine++
pt = branch.getNode(0)
zPos = pt.getZ();
xPos = pt.getX();
yPos = pt.getY();
pt_roi = new PointRoi((xPos)/cal.pixelWidth,(yPos)/cal.pixelHeight);
pt_roi.setName("Dendrite_start_"+i);
pt_roi.setPosition(1,(zPos/cal.pixelDepth) as int,1);
rm.addRoi(pt_roi);
roi_fork = pt_roi
pt = branch.lastPoint()
zPos = pt.getZ();
xPos = pt.getX();
yPos = pt.getY();
pt_roi = new PointRoi((xPos)/cal.pixelWidth,(yPos)/cal.pixelHeight);
pt_roi.setName("Dendrite_end_"+i);
pt_roi.setPosition(1,(zPos/cal.pixelDepth) as int,1);
rm.addRoi(pt_roi);
roi_tip = pt_roi
sInfo = new SpineInfo()
sInfo.index = idx_spine
sInfo.branchName = branch.getName()
sInfo.SNTLength = branch.getLength()
sInfo.roi_fork = roi_fork
sInfo.roi_end = roi_tip
sInfo.idx_fork = -1 //idx
sInfo.idx_end = -1 //idx+1
sInfo.px_fork = sInfo.roi_fork.getXBase()*cal.pixelWidth
sInfo.py_fork = sInfo.roi_fork.getYBase()*cal.pixelHeight
sInfo.pz_fork = sInfo.roi_fork.getZPosition()*cal.pixelDepth
sInfo.px_end = sInfo.roi_end.getXBase()*cal.pixelWidth
sInfo.py_end = sInfo.roi_end.getYBase()*cal.pixelHeight
sInfo.pz_end = sInfo.roi_end.getZPosition()*cal.pixelDepth
spines.add(sInfo)
}
if (branch.getName().contains("Root")) {
totalRootLength+=branch.getLength()
}
}
// ----------------------------- ANALYSIS
idxM = -1
SpineInfo.measureFunctions['Index'] = { spine -> spine.index }
SpineInfo.measurementOrder[++idxM] = 'Index'
SpineInfo.measureFunctions['Name'] = { spine -> spine.branchName }
SpineInfo.measurementOrder[++idxM] = 'Name'
SpineInfo.measureFunctions['SNT Length'] = { spine -> df.format(spine.SNTLength) }
SpineInfo.measurementOrder[++idxM] = 'SNT Length'
SpineInfo.measureFunctions['Length'] = { spine -> df.format(spine.getLength()) }
SpineInfo.measurementOrder[++idxM] = 'Length'
SpineInfo.measureFunctions['Radius'] = { spine -> df.format(spine.getRadius()) }
SpineInfo.measurementOrder[++idxM] = 'Radius'
SpineInfo.measureFunctions['Intensity Spine Fork'] = { spine -> df.format(spine.fluoValueFork) }
SpineInfo.measurementOrder[++idxM] = 'Intensity Spine Fork'
SpineInfo.measureFunctions['Intensity Spine Tip'] = { spine -> df.format(spine.getMaxFluoValue()) }
SpineInfo.measurementOrder[++idxM] = 'Intensity Spine Tip'
SpineInfo.measureFunctions['Z Position'] = { spine -> df.format(spine.getZPosition()) }
SpineInfo.measurementOrder[++idxM] = 'Z Position'
//------------------------------------ MAIN
// ----- State : list of roi in the order set by the user (changes on table click)
List<SpineInfo> currentList;
//---- Variable initialisation
if (pixPerSpineXY%2 == 1) pixPerSpineXY+=1;
if (pixPerSpineZ%2 == 1) pixPerSpineZ+=1;
//
nSpines = idx_spine //roiManager.getCount() / 2
// Split channels of ImagePlus (more convenient) splitted channels are never shown
impChannels = ChannelSplitter.split(imp)
// Dictionary of ImagePlus (one per channel)
fluoImagesPlus = [:]
source = getSource(imp);
spinesCrops = new ArrayList<>()
cal = new Calibration()
cal.pixelWidth = voxSize
cal.pixelHeight = voxSize
cal.pixelDepth = voxSize
cal.setUnit(imp.getCalibration().getUnit())
for (spine in spines) {
cropAroundDendrite = getSpineImage(spine, source)
spine.croppedImageReoriented = ZProjector.run(cropAroundDendrite, "average", ((pixPerSpineXY-nSlicesAvg)/2) as int, ((pixPerSpineXY+nSlicesAvg)/2)as int)
spine.croppedImageReoriented.setCalibration(cal)
spinesCrops.add(spine.croppedImageReoriented)
spine.measure()
spine.addOverlay()
}
allSpines = Concatenator.run(spinesCrops.toArray(new ImagePlus[0]))
for (int idx = 0; idx<spines.size(); idx++) {
SpineInfo spine = spines.get(idx);
spine.addToImagePlusOverlay(allSpines, idx+1)
}
// Output :
// 1 - Spine Stack
titleWithoutExtension = FilenameUtils.removeExtension(imp.getTitle())
allSpines.setTitle(titleWithoutExtension+"-spinestack")
allSpines.show()
IJ.run(allSpines, "mpl-plasma", "");
IJ.saveAs(allSpines, "Tiff", outputFolder.getAbsolutePath() + File.separator + allSpines.getTitle());
// 2 - Spine Montage
IJ.run("Scale...", "x=3 y=3 z=1.0 interpolation=None average process create");
impToRemove = IJ.getImage();
IJ.run(impToRemove, "mpl-plasma", "");
IJ.run("Enhance Contrast", "saturated=0.35");
IJ.run("Flatten", "stack");
nColumns = 6
nRows = (spines.size() / nColumns) + 1
IJ.run("Make Montage...", "columns="+nColumns+" rows="+nRows+" scale=1");
montageImage = IJ.getImage();
montageImage.setTitle(titleWithoutExtension+"-spinemontage")
IJ.saveAs(montageImage, "Tiff", outputFolder.getAbsolutePath() + File.separator + montageImage.getTitle());
// 3 - CSV File
String dataString = ""
// Print header
dataString+="Script parameters\n"
dataString+="Pixels per Spine\t" + pixPerSpineXY + "\n";
dataString+="Number of slices for averaging\t" + nSlicesAvg + "\n";
dataString+="Output voxel Size\t" + voxSize + "\n";
dataString+="Source Interpolated\t"+ interpolate_source + "\n";
dataString+="Length Unit\t"+ imp.getCalibration().getUnit() + "\n";
dataString+="Total Root Branch Length\t"+ df.format(totalRootLength) + "\n";
dataString+="Number of spines\t"+ nSpines + "\n";
dataString+="Image Name\t"
for (i = 0;i<SpineInfo.measurementOrder.keySet().size(); i++) {
String name = SpineInfo.measurementOrder.get(i)
//Function<SpineInfo, Object> f = SpineInfo.measureFunctions.get(name);
//values.put(name, f.apply(this))
dataString+="\t"+name+ "\t"
}
dataString+="\n"
for (spine in spines) {
dataString+= imp.getTitle()+"\t"+spine+"\n"
}
//IJ.log(dataString)
FileUtils.writeStringToFile(new File(outputFolder.getAbsolutePath() + File.separator + titleWithoutExtension+"-results.csv"), dataString);
// Clean Up
impToRemove.changes = false;
impToRemove.close();
//------------------------------------ END OF MAIN
// --------------- METHODS
ImagePlus getSpineImage(SpineInfo spine, Source source_in) {
double dx = spine.px_end-spine.px_fork
double dy = spine.py_end-spine.py_fork
double dz = spine.pz_end-spine.pz_fork
pFork = new RealPoint(3);
pFork.setPosition(new double[]{spine.px_fork,spine.py_fork,spine.pz_fork});
pEnd = new RealPoint(3);
pEnd.setPosition(new double[]{spine.px_end+dx,spine.py_end+dy,spine.pz_end+dz});
//Source
alignedSpine = SourceHelper.AlignAxisResample(source_in, pEnd, pFork, voxSize, pixPerSpineXY, pixPerSpineXY, pixPerSpineZ, true, interpolate_source);
raiOut = alignedSpine.getSource(0,0)
raiOut = Views.rotate(raiOut,2,1)
return ImageJFunctions.wrap(raiOut, source_in.getName()+"_"+spine.getName());
}
/**
* Gets an extended Img (ImgLib2 structure) from a ImagePlus
*/
Source getSource(ImagePlus imp) {
def img = ImageJFunctions.wrap(imp)
//return Views.expandZero(img,[pixPerSpineXY,pixPerSpineXY,pixPerSpineZ] as long[]);
m = new AffineTransform3D();
cal = imp.getCalibration()
m.scale(cal.pixelWidth, cal.pixelHeight, cal.pixelDepth)
rais = new RandomAccessibleIntervalSource<UnsignedByteType>(
Views.expandZero(img,[pixPerSpineXY,pixPerSpineXY,pixPerSpineZ] as long[]),
new UnsignedByteType(),
m,
imp.getTitle()
);
return rais;
}
void setNewList(List<SpineInfo> newList) {
currentList = newList;
}
// ------------------------- Spine Info Class
class SpineInfo {
int index
double SNTLength
String branchName
int idx_fork
int idx_end
Roi roi_fork
Roi roi_end
double px_fork // value in microns
double py_fork
double pz_fork
double px_end
double py_end
double pz_end
String originalFile
double maxFluoValue = -1 // computed - max of the profile plot
double fluoValueFork = -1 // computed - max of the profile plot
double radiusValue = -1
double xCenterSpineShift = -100
ImagePlus croppedImageReoriented
Map<String, Object> values = new HashMap<>()
static Map<String, Function<SpineInfo, Object>> measureFunctions = new HashMap<>()
static Map<Integer, String> measurementOrder = [:]
public void measure() {
for (int i = 0;i<measurementOrder.keySet().size(); i++) {
String name = measurementOrder.get(i)
Function<SpineInfo, Object> f = SpineInfo.measureFunctions.get(name);
values.put(name, f.apply(this))
}
}
public double getLength() {
double dx = px_fork-px_end;
double dy = py_fork-py_end;
double dz = pz_fork-pz_end;
return Math.sqrt(dx*dx+dy*dy+dz*dz);
}
public double getMaxFluoValue() {
return maxFluoValue;
}
public void addOverlay() {
addToImagePlusOverlay(croppedImageReoriented)
}
public void addToImagePlusOverlay(ImagePlus imp) {
addToImagePlusOverlay(imp,0)
}
public void addToImagePlusOverlay(ImagePlus imp, int iSlice) {
if (imp.getOverlay() == null) {
imp.setOverlay(new Overlay())
}
for (Roi roi : getRois()) {
roi.setPosition(iSlice)
imp.getOverlay().add(roi)
}
}
public List<Roi> getRois() {
List<Roi> rois = new ArrayList<>()
Overlay ov = new Overlay()
double xCenter = croppedImageReoriented.getWidth() /2.0
double yCenter = croppedImageReoriented.getHeight() /2.0
double pixelSize = croppedImageReoriented.getCalibration().pixelWidth
double splineLengthInPixel = getLength() / pixelSize
Line spineRoiLength = new Line(xCenter, yCenter, xCenter, yCenter + splineLengthInPixel )
croppedImageReoriented.setRoi(spineRoiLength)
rois.add(spineRoiLength)
Line spineRoiDiameter = new Line(
xCenter - (radiusValue / pixelSize as double) + (xCenterSpineShift / pixelSize as double),
yCenter,
xCenter + (radiusValue / pixelSize as double) + (xCenterSpineShift / pixelSize as double),
yCenter )
rois.add(spineRoiDiameter)
return rois
}
public double getZPosition() {
return (pz_fork+pz_end)/2.0
}
public double getRadius() {
double xCenter = croppedImageReoriented.getWidth() /2.0
double yCenter = croppedImageReoriented.getHeight() /2.0
double pixelSize = croppedImageReoriented.getCalibration().pixelWidth
double splineLengthInPixel = getLength() / pixelSize
Line spineRoiLength = new Line(xCenter, yCenter + splineLengthInPixel, xCenter, yCenter )
croppedImageReoriented.setRoi(spineRoiLength)
//imgSpine.show()
ProfilePlot ppL = new ProfilePlot(croppedImageReoriented)
fluoValueFork = ppL.getProfile()[0]
Line spineRoiRadiusMeasure = new Line(xCenter - splineLengthInPixel*0.75, yCenter, xCenter + splineLengthInPixel*0.75, yCenter );
croppedImageReoriented.setRoi(spineRoiRadiusMeasure)
//imgSpine.show()
ProfilePlot pp = new ProfilePlot(croppedImageReoriented)
double[] data = pp.getProfile()
// Complicated .. but hopefully quite robust:
// look at a maxima starting from the middle of the data
int indexMaxValue = data.length / 2 // starts at the middle
boolean maxFound = false
double maxValue = data[indexMaxValue]
while (maxFound == false) {
if (indexMaxValue == 0) {
//pp.getPlot().show()
maxFluoValue = -1
return -1 // issue
}
if (indexMaxValue == data.length -1) {
//pp.getPlot().show()
maxFluoValue = -1
return -1 // issue
}
if (data[indexMaxValue + 1] > maxValue) {
indexMaxValue = indexMaxValue + 1;
maxValue = data[indexMaxValue]
} else if (data[indexMaxValue - 1] > maxValue) {
indexMaxValue = indexMaxValue - 1;
maxValue = data[indexMaxValue]
} else {
maxFound = true
}
}
maxFluoValue = maxValue
double crossingValue = maxValue / 2.0 // width at half maximum
boolean crossedRight = false
int indexCrossedRight = indexMaxValue
double valueBeforeCrossing = maxValue
double locationCrossedRight
while (crossedRight == false) {
if (indexCrossedRight == data.length -1) {
//pp.getPlot().show()
maxFluoValue = -1
return -1 // issue
}
if (data[indexCrossedRight + 1] < crossingValue) {
crossedRight = true
double valueAfterCrossing = data[indexCrossedRight + 1]
locationCrossedRight = indexCrossedRight + (valueBeforeCrossing-crossingValue)/(valueBeforeCrossing-valueAfterCrossing)
} else {
indexCrossedRight = indexCrossedRight + 1
valueBeforeCrossing = data[indexCrossedRight]
}
}
boolean crossedLeft = false
int indexCrossedLeft = indexMaxValue
valueBeforeCrossing = maxValue
double locationCrossedLeft
while (crossedLeft == false) {
if (indexCrossedLeft == 0) {
//pp.getPlot().show()
return -1 // issue
}
if (data[indexCrossedLeft - 1] < crossingValue) {
crossedLeft = true
double valueAfterCrossing = data[indexCrossedLeft - 1]
locationCrossedLeft = indexCrossedLeft - (valueBeforeCrossing-crossingValue)/(valueBeforeCrossing-valueAfterCrossing)
} else {
indexCrossedLeft = indexCrossedLeft - 1
valueBeforeCrossing = data[indexCrossedLeft]
}
}
/*Plot pTest = new Plot("truc", "x", "y")
pTest.add("line", pp.getPlot().getXValues() as double[], pp.getPlot().getYValues() as double[])
pTest.add("line", [locationCrossedLeft * pixelSize, locationCrossedRight * pixelSize] as double[], [crossingValue, crossingValue] as double[])
pTest.show()*/
//pp.getPlot()
//pp.getPlot().show()
xCenterSpineShift = (((locationCrossedRight + locationCrossedLeft) / 2.0) - data.length / 2.0) * pixelSize
radiusValue = (locationCrossedRight - locationCrossedLeft) * pixelSize / 2.0
return radiusValue
}
public String toString() {
String str = ""
for (int i = 0;i<measurementOrder.keySet().size(); i++) {
String measurement = measurementOrder.get(i)
str+=measurement+"\t"+this.values.get(measurement)+"\t"
}
return str
}
public String getName() {
return "Spine_"+idx_fork;
}
}
//----------------------------------------------------
import ij.IJ
import ij.ImagePlus
import ij.gui.Roi
import ij.plugin.frame.RoiManager
import ij.plugin.ChannelSplitter
import java.io.File
import java.nio.file.Files
import java.nio.file.Path
import java.util.stream.Stream
import java.nio.file.Paths
import java.util.stream.Collectors
import java.util.function.Function
import java.util.ArrayList
import java.text.DecimalFormat
import java.awt.BorderLayout
import static net.imglib2.cache.img.DiskCachedCellImgOptions.options
import net.imglib2.cache.img.DiskCachedCellImgFactory
import net.imglib2.cache.img.DiskCachedCellImgOptions
import net.imglib2.img.Img
import net.imglib2.type.numeric.integer.UnsignedByteType
import net.imglib2.type.numeric.integer.UnsignedShortType
import net.imglib2.type.numeric.ARGBType
import net.imglib2.realtransform.AffineTransform3D
import net.imglib2.view.Views
import net.imglib2.Cursor
import net.imglib2.FinalInterval
import net.imglib2.img.display.imagej.ImageJFunctions
import net.imglib2.*
import net.imglib2.img.Img
import net.imglib2.util.Intervals
import net.imglib2.util.Util
import net.imglib2.view.Views
import net.imglib2.realtransform.AffineTransform3D
import net.imglib2.util.Util
import bdv.util.volatiles.VolatileViews
import bdv.util.BdvFunctions
import bdv.util.BdvOptions
import bdv.util.BdvHandle
import bdv.util.BdvSource
import bdv.util.volatiles.VolatileViews
import bdv.util.Affine3DHelpers
import bdv.util.RandomAccessibleIntervalSource
import javax.swing.JTable
import javax.swing.JScrollPane
import javax.swing.*
import javax.swing.ListSelectionModel
import javax.swing.JFrame;
import javax.swing.JScrollPane;
import javax.swing.JTable;
import javax.swing.RowSorter;
import javax.swing.table.DefaultTableModel;
import javax.swing.table.TableModel;
import javax.swing.table.TableRowSorter;
import javax.swing.SwingUtilities
import javax.swing.event.RowSorterEvent
import ch.epfl.biop.sourceandconverter.SourceHelper
import bdv.viewer.Source
import net.imglib2.view.Views
import ij.plugin.ZProjector
import ij.plugin.Concatenator
import ij.measure.Calibration
import ij.gui.Line
import ij.gui.ProfilePlot
import ij.gui.Plot
import ij.gui.Overlay
import org.apache.commons.io.FilenameUtils
import org.apache.commons.io.FileUtils
import sc.fiji.snt.*
import sc.fiji.snt.io.*
import sc.fiji.snt.analysis.*
import sc.fiji.snt.annotation.*
import sc.fiji.snt.viewer.*
import sc.fiji.snt.util.*
import org.scijava.util.*
import ij.IJ
import ij.WindowManager
import ij.plugin.frame.RoiManager
import ij.gui.PointRoi
import ij.gui.OvalRoi
import java.text.DecimalFormat
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment