Skip to content

Instantly share code, notes, and snippets.

@ahnsws
Created April 27, 2023 19:24
Show Gist options
  • Save ahnsws/b82ed163c773c5d841585e182825f472 to your computer and use it in GitHub Desktop.
Save ahnsws/b82ed163c773c5d841585e182825f472 to your computer and use it in GitHub Desktop.
a parallelized version of ashlar
import copy
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
from typing import Callable
import numpy as np
import numpy.typing as npt
from ashlar import _version, utils
from ashlar.reg import (
EdgeAligner,
Mosaic,
LayerAligner,
warn_data,
TiffListWriter,
Reader,
BioformatsReader,
)
from ashlar.scripts.ashlar import process_axis_flip
from skimage.transform import rescale
from tifffile import tifffile
from tqdm import tqdm
def execute_indexed_parallel(
func: Callable, *, args: list, tqdm_args: dict = None
) -> list:
if tqdm_args is None:
tqdm_args = {}
results = [None for _ in range(len(args))]
with ThreadPoolExecutor() as executor:
with tqdm(total=len(args), **tqdm_args) as pbar:
futures = {executor.submit(func, *arg): i for i, arg in enumerate(args)}
for future in as_completed(futures):
index = futures[future]
results[index] = future.result()
pbar.update(1)
return results
def execute_parallel(func: Callable, *, args: list, tqdm_args: dict = None):
if tqdm_args is None:
tqdm_args = {}
with ThreadPoolExecutor() as executor:
with tqdm(total=len(args), **tqdm_args) as pbar:
futures = {executor.submit(func, *arg): i for i, arg in enumerate(args)}
for _ in as_completed(futures):
pbar.update(1)
class ThreadSafeBioformatsReader(BioformatsReader):
lock = Lock()
def read(self, *args, **kwargs):
with self.lock:
return super().read(*args, **kwargs)
class ParallelMosaic(Mosaic):
def assemble_channel_parallel(
self,
channel: int,
positions: list[npt.NDArray],
reader: Reader,
out: npt.NDArray = None,
tqdm_args: dict = None,
):
if tqdm_args is None:
tqdm_args = {}
if out is None:
out = np.zeros(self.shape, self.dtype)
else:
if out.shape != self.shape:
raise ValueError(
f"out array shape {out.shape} does not match Mosaic"
f" shape {self.shape}"
)
for si, position in tqdm(enumerate(positions), **tqdm_args):
img = reader.read(c=channel, series=si)
img = self.correct_illumination(img, channel)
utils.paste(out, img, position, func=utils.pastefunc_blend)
# Memory-conserving axis flips.
if self.flip_mosaic_x:
for i in range(len(out)):
out[i] = out[i, ::-1]
if self.flip_mosaic_y:
for i in range(len(out) // 2):
out[[i, -i - 1]] = out[[-i - 1, i]]
return out
def make_thumbnail(reader, channel=0, scale=0.05, verbose=False):
metadata = reader.metadata
positions = metadata.positions - metadata.origin
coordinate_max = (positions + metadata.size).max(axis=0)
mshape = ((coordinate_max + 1) * scale).astype(int)
mosaic = np.zeros(mshape, dtype=np.uint16)
total = reader.metadata.num_images
for i in tqdm(
range(total),
desc=" assembling thumbnail",
total=total,
file=sys.stdout,
disable=not verbose,
):
img = reader.read(c=channel, series=i)
# We don't need anti-aliasing as long as the coarse features in the
# images are bigger than the scale factor. This speeds up the rescaling
# dramatically.
img_s = rescale(img, scale, anti_aliasing=False)
utils.paste(mosaic, img_s, positions[i] * scale, np.maximum)
return mosaic
class ParallelEdgeAligner(EdgeAligner):
def make_thumbnail(self):
if not self.do_make_thumbnail:
return
self.reader.thumbnail = make_thumbnail(
self.reader, channel=self.channel, verbose=self.verbose
)
def compute_threshold(self):
# Compute error threshold for rejecting aligments. We generate a
# distribution of error scores for many known non-overlapping image
# regions and take a certain percentile as the maximum allowable error.
# The percentile becomes our accepted false-positive ratio.
edges = self.neighbors_graph.edges
num_tiles = self.metadata.num_images
# If not enough tiles overlap to matter, skip this whole thing.
if len(edges) <= 1:
self.errors_negative_sampled = np.empty(0)
self.max_error = np.inf
return
widths = np.array([self.intersection(t1, t2).shape.min() for t1, t2 in edges])
w = widths.max()
max_offset = self.metadata.size[0] - w
# Number of possible pairs minus number of actual neighbor pairs.
num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges)
# Reduce permutation count for small datasets -- there are fewer
# possible truly distinct strips with fewer tiles. The calculation here
# is just a heuristic, not rigorously derived.
n = 1000 if num_distant_pairs > 8 else (num_distant_pairs + 1) * 10
pairs = np.empty((n, 2), dtype=int)
offsets = np.empty((n, 2), dtype=int)
# Generate n random non-overlapping image strips. Strips are always
# horizontal, across the entire image width.
max_tries = 100
if self.randomize is False:
random_state = np.random.RandomState(0)
else:
random_state = np.random.RandomState()
for i in range(n):
# Limit tries to avoid infinite loop in pathological cases.
for current_try in range(max_tries):
t1, t2 = random_state.randint(self.metadata.num_images, size=2)
o1, o2 = random_state.randint(max_offset, size=2)
# Check for non-overlapping strips and abort the retry loop.
if t1 != t2 and (t1, t2) not in edges:
# Different, non-neighboring tiles -- always OK.
break
elif t1 == t2 and abs(o1 - o2) > w:
# Same tile OK if strips don't overlap within the image.
break
elif (t1, t2) in edges:
# Neighbors OK if either strip is entirely outside the
# expected overlap region (based on nominal positions).
its = self.intersection(t1, t2, np.repeat(w, 2))
ioff1, ioff2 = its.offsets[:, 0]
if (
its.shape[0] > its.shape[1]
or o1 < ioff1 - w
or o1 > ioff1 + w
or o2 < ioff2 - w
or o2 > ioff2 + w
):
break
else:
# Retries exhausted. This should be very rare.
warn_data("Could not find non-overlapping strips in {max_tries} tries")
pairs[i] = t1, t2
offsets[i] = o1, o2
def register(t1, t2, offset1, offset2):
img1 = self.reader.read(t1, self.channel)[offset1 : offset1 + w, :]
img2 = self.reader.read(t2, self.channel)[offset2 : offset2 + w, :]
_, error = utils.register(img1, img2, self.filter_sigma, upsample=1)
return error
# prepare arguments for executor
args = []
for (t1, t2), (offset1, offset2) in zip(pairs, offsets):
arg = (t1, t2, offset1, offset2)
args.append(copy.deepcopy(arg))
errors = execute_indexed_parallel(
register,
args=args,
tqdm_args=dict(
file=sys.stdout,
disable=not self.verbose,
desc=" quantifying alignment error",
),
)
errors = np.array(errors)
self.errors_negative_sampled = errors
self.max_error = np.percentile(errors, self.false_positive_ratio * 100)
def register_all(self):
args = []
for t1, t2 in self.neighbors_graph.edges:
arg = (t1, t2)
args.append(copy.deepcopy(arg))
execute_parallel(
self.register_pair,
args=args,
tqdm_args=dict(
file=sys.stdout,
disable=not self.verbose,
desc=" aligning edge",
),
)
if self.verbose:
print()
self.all_errors = np.array([x[1] for x in self._cache.values()])
# Set error values above the threshold to infinity.
for k, v in self._cache.items():
if v[1] > self.max_error or np.any(np.abs(v[0]) > self.max_shift_pixels):
self._cache[k] = (v[0], np.inf)
class ParallelLayerAligner(LayerAligner):
def make_thumbnail(self):
self.reader.thumbnail = make_thumbnail(
self.reader, channel=self.channel, verbose=self.verbose
)
def register_all(self):
n = self.metadata.num_images
args = [copy.deepcopy((i,)) for i in range(n)]
results = execute_indexed_parallel(
self.register,
args=args,
tqdm_args=dict(
file=sys.stdout,
disable=not self.verbose,
desc=" aligning tile",
),
)
shift, error = list(zip(*results))
self.shifts = np.array(shift)
self.errors = np.array(error)
assert self.shifts.shape == (n, 2)
assert self.errors.shape == (n,)
if self.verbose:
print()
class ParallelTiffListWriter(TiffListWriter):
def run(self):
pixel_size = self.mosaics[0].aligner.metadata.pixel_size
resolution_cm = 10000 / pixel_size
software = f"Ashlar v{_version}"
def write(
cycle: int,
mosaic: ParallelMosaic,
channel: int,
path: str,
positions: list[npt.NDArray],
reader: Reader,
tqdm_ind: int,
):
tqdm_args = dict(
desc=f" cycle {cycle}, channel {channel}",
total=len(positions),
disable=not self.verbose,
file=sys.stdout,
position=tqdm_ind,
)
with tifffile.TiffWriter(path, bigtiff=True) as tiff:
tiff.write(
data=mosaic.assemble_channel_parallel(
channel, positions, reader, tqdm_args=tqdm_args
),
software=software.encode("utf-8"),
resolution=(resolution_cm, resolution_cm, "centimeter"),
# FIXME Propagate this from input files (especially RGB).
photometric="minisblack",
)
tqdm_ind = 0
args = []
for cycle, mosaic in enumerate(self.mosaics):
for channel in mosaic.channels:
path = self.path_format.format(cycle=cycle, channel=channel)
positions = mosaic.aligner.positions
reader = mosaic.aligner.reader
arg = (cycle, mosaic, channel, path, positions, reader, tqdm_ind)
args.append(arg)
# arg_copy = []
# for a in arg:
# if isinstance(a, Mosaic):
# arg_copy.append(a)
# else:
# arg_copy.append(copy.deepcopy(a))
# args.append(tuple(arg_copy))
tqdm_ind += 1
execute_parallel(write, args=args, tqdm_args=dict(disable=True))
def process_single(
filepaths: list[str],
output_path_format: str,
flip_x: bool,
flip_y: bool,
aligner_args: dict | None,
mosaic_args: dict | None,
quiet: bool,
):
mosaic_args = mosaic_args.copy()
mosaics = []
if not quiet:
print("Stitching and registering input images")
print("Cycle 0:")
print(" reading %s" % filepaths[0])
reader = ThreadSafeBioformatsReader(filepaths[0])
process_axis_flip(reader, flip_x, flip_y)
ea_args = aligner_args.copy()
if len(filepaths) == 1:
ea_args["do_make_thumbnail"] = False
edge_aligner = ParallelEdgeAligner(reader, **ea_args)
edge_aligner.run()
mshape = edge_aligner.mosaic_shape
mosaic_args_final = mosaic_args.copy()
mosaics.append(ParallelMosaic(edge_aligner, mshape, **mosaic_args_final))
for cycle, filepath in enumerate(filepaths[1:], 1):
if not quiet:
print("Cycle %d:" % cycle)
print(" reading %s" % filepath)
reader = ThreadSafeBioformatsReader(filepath)
process_axis_flip(reader, flip_x, flip_y)
layer_aligner = ParallelLayerAligner(reader, edge_aligner, **aligner_args)
layer_aligner.run()
mosaic_args_final = mosaic_args.copy()
mosaics.append(ParallelMosaic(layer_aligner, mshape, **mosaic_args_final))
# Disable reader caching to save memory during mosaicing and writing.
edge_aligner.reader = edge_aligner.reader.reader
if not quiet:
print()
print(f"Merging tiles and writing to {output_path_format}")
writer = ParallelTiffListWriter(mosaics, output_path_format, verbose=not quiet)
writer.run()
def run():
stack_paths = [Path("R1.tiff"), Path("R2.tiff"), Path("R3.tiff")]
aligner_args = dict(filter_sigma=0.0, max_shift=15, channel=0, verbose=True)
mosaic_args = dict(verbose=True)
dst = Path("cycle_{cycle}_channel_{channel}.tiff")
process_single(
output_path_format=str(dst),
filepaths=[str(p) for p in stack_paths],
aligner_args=aligner_args,
mosaic_args=mosaic_args,
quiet=False,
flip_x=False,
flip_y=False,
)
if __name__ == "__main__":
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment