Skip to content

Instantly share code, notes, and snippets.

@maweigert
Created July 26, 2021 11:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maweigert/98af7d72078ce3701dc186cc47b0c6a8 to your computer and use it in GitHub Desktop.
Save maweigert/98af7d72078ce3701dc186cc47b0c6a8 to your computer and use it in GitHub Desktop.
stardist_memory
import tensorflow as tf
import numpy as np
import psutil
import os
from stardist.models import Config3D, StarDist3D
os.environ['CUDA_VISIBLE_DEVICES']= "0"# '1,0'
from tensorflow.keras.utils import Sequence
import stardist
from csbdeep.utils.tf import limit_gpu_memory
# adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
limit_gpu_memory(0.7, total_memory=12000,)
class MemoryCallback(tf.keras.callbacks.Callback):
@staticmethod
def get_memory():
return psutil.virtual_memory().used/1e9
def on_train_begin(self, logs=None):
self.memory_base = self.get_memory()
self.epochs_memory = []
def on_epoch_end(self, epoch, logs=None):
mem = self.get_memory()-self.memory_base
self.epochs_memory.append((epoch, mem))
np.savetxt('memory.log', np.array(self.epochs_memory))
# some custom functions to show the memory error
def calculate_extents(lbl, func=np.median):
"""
Intermediate function : https://github.com/stardist/stardist/issues/57
Aggregate bounding box sizes of objects in label images. """
import numpy as np
from collections.abc import Iterable
from csbdeep.utils import _raise
from skimage.measure import regionprops
if (isinstance(lbl,np.ndarray) and lbl.ndim==4) or (not isinstance(lbl,np.ndarray) and isinstance(lbl,Iterable)):
return func(np.stack([calculate_extents(_lbl,func) for _lbl in lbl], axis=0), axis=0)
n = lbl.ndim
n in (2,3) or _raise(ValueError("label image should be 2- or 3-dimensional (or pass a list of these)"))
regs = regionprops(lbl)
if len(regs) == 0:
return np.zeros(n)
else:
extents = np.array([np.array(r.bbox[n:])-np.array(r.bbox[:n]) for r in regs])
return func(extents, axis=0)
def random_image(shape):
"""
just a random imaga
"""
return np.random.randint(0, 255, shape)
def diag_mask(shape, n_cells=3, cell_size=2, cell_dist=5):
"""
Labels with cells on the diagonal
"""
mask = np.zeros(shape, dtype=np.int16)
for cell_id in range(1, n_cells+1):
center_pos = cell_id*cell_dist
mask[center_pos-cell_size:center_pos+cell_size,
center_pos-cell_size:center_pos+cell_size,
center_pos-cell_size:center_pos+cell_size,
] = cell_id
return mask
class DataLoader(Sequence):
"""
Dataloader to simulate reading from disk
"""
def __init__(self, shapes_list, mode):
self.x, self.mode = shapes_list, mode
#retro compatibility with numpy
self.ndim=3
if mode == "Y":
self.dtype = np.int
else:
self.dtype = np.float
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
shape = self.x[idx]
if self.mode == "X":
return random_image(shape)
else:
return diag_mask(shape)
# Just a random bunch of 3D syntetic data
SHAPES_DIFF = [(100, 100, 100), (110, 190, 90), (100, 100, 110), (110, 100, 100), (100, 110, 100)]*20
X_DIFF = DataLoader(SHAPES_DIFF, mode="X")
Y_DIFF = DataLoader(SHAPES_DIFF, mode="Y")
extents = calculate_extents(Y_DIFF)
anisotropy = tuple(np.max(extents) / extents)
print('empirical anisotropy of labeled objects = %s' % str(anisotropy))
from stardist import Rays_GoldenSpiral
n_rays = 96
rays = Rays_GoldenSpiral(n_rays,
anisotropy=anisotropy
)
# Predict on subsampled grid for increased efficiency and larger field of view
grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)
n_channel = 1
conf = Config3D (
rays = rays,
grid = grid,
anisotropy = anisotropy,
use_gpu = True,
n_channel_in = n_channel,
train_patch_size = (64,64,64),
train_batch_size = 1,
train_epochs = 100,
train_reduce_lr={'factor': 0.5, 'patience': 10, 'min_delta': 0},
train_tensorboard=False
)
print(conf)
vars(conf)
model_name = f'any_name'
model = StarDist3D(conf, name=model_name, basedir="stardist_models")
model.prepare_for_training()
model.callbacks.append(MemoryCallback())
model.train(X_DIFF, Y_DIFF,
validation_data=(X_DIFF, Y_DIFF),
epochs=50,
seed=42)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment