Skip to content

Instantly share code, notes, and snippets.

@prerakmody
Created April 19, 2024 09:42
Show Gist options
  • Save prerakmody/2db12ff4914d7c322dbae837d584c8be to your computer and use it in GitHub Desktop.
Save prerakmody/2db12ff4914d7c322dbae837d584c8be to your computer and use it in GitHub Desktop.
Stochasticaly Varying Spatial Smoothing (SVLS)
def get_svls_filter_3d(kernel_size=3, sigma=1, verbose=False):
"""
Ref: https://github.com/mobarakol/SVLS/blob/main/svls.py (pytorch)
- Alternative (for gauss kernel): https://gist.github.com/blzq/c87d42f45a8c5a53f5b393e27b1f5319
Note: group parameter in Conv3D is giving an issue in tf==2.10.0 on Unix
- "tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:453] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: Failed to launch ptxas'"
"""
try:
# Step 1 - Create a x, y, z coordinate grid of shape (kernel_size, kernel_size, kernel_size, 3)
x_coord = tf.range(kernel_size) # [3]
x_grid_2d = tf.tile(x_coord, [kernel_size]) # [3*3=9]
x_grid_2d = tf.reshape(x_grid_2d, (kernel_size, kernel_size)) # [3,3]
x_grid = tf.tile(x_grid_2d, [kernel_size, 1]) # [3,3] -> [3*3,3]
x_grid = tf.reshape(x_grid, (kernel_size, kernel_size, kernel_size)) # [3,3,3]
y_grid_2d = tf.transpose(x_grid_2d) # [3,3]
y_grid = tf.tile(y_grid_2d, [kernel_size, 1]) # [3,3] -> [3*3,3]
y_grid = tf.reshape(y_grid, (kernel_size, kernel_size, kernel_size)) # [3,3,3]
z_grid = tf.tile(y_grid_2d, [1, kernel_size])
z_grid = tf.reshape(z_grid, (kernel_size, kernel_size, kernel_size))
xyz_grid = tf.stack([x_grid, y_grid, z_grid], axis=-1)
xyz_grid = tf.cast(xyz_grid, tf.float32)
# Step 2 - Calculate the 3-dimensional gaussian kernel
mean = (kernel_size - 1) / 2.
variance = sigma**2.
gaussian_kernel = (1. / (2. * math.pi * variance + 1e-16)) * tf.exp(-tf.reduce_sum((xyz_grid - mean)**2., axis=-1) / (2 * variance + 1e-16))
# Step 3 - Make sure sum of values in gaussian kernel equals 2
gaussian_kernel = gaussian_kernel / tf.reduce_sum(gaussian_kernel)
neighbors_sum = 1 - gaussian_kernel[1, 1, 1]
# Step 3.1 - Need to do this in tensorflow as you cant do tensor assignment
indices = tf.constant([[1, 1, 1]])
updates = tf.constant([1], dtype=gaussian_kernel.dtype)
shape = tf.constant([kernel_size, kernel_size, kernel_size])
tensor_neighbours_sum = tf.scatter_nd(indices, updates, shape) # given tensor of shape [3,3,3] and will set tensor_neighbours_sum[1,1,1] = 1
gaussian_kernel = tensor_neighbours_sum - gaussian_kernel
svls_kernel_3d = tf.abs(gaussian_kernel / neighbors_sum)
if verbose: print (' - [losses2.py][get_svls_filter_3d()] svls_kernel_3d: ', svls_kernel_3d)
# Step 4 - Make convolutional layer and set weights (applied on each label separately, since group parameter is giving an runtime error)
svls_kernel_3d = tf.reshape(svls_kernel_3d, (kernel_size, kernel_size, kernel_size, 1, 1))
svls_filter_3d = tf.keras.layers.Conv3D(filters=1, kernel_size=kernel_size, use_bias=False, padding='same')
svls_filter_3d.build((None, None, None, None, 1))
svls_filter_3d.set_weights([svls_kernel_3d])
svls_filter_3d.trainable = False
return svls_filter_3d, svls_kernel_3d[:,:,:,:,0] # [3,3,3,1]
except:
traceback.print_exc()
# pdb.set_trace()
return None, None
class CELossWithSVLS(tf.keras.Model):
def __init__(self, classes, sigma):
super(CELossWithSVLS, self).__init__()
print (' - [losses2.py][CELossWithSVLS] Using sigma=', sigma, ' for classes=', classes)
self.cls = tf.constant(classes)
self.sigma = sigma
self.svls_layer, self.svls_kernel = get_svls_filter_3d(sigma=sigma) # self.svls_kernel = [3,3,3,1]
# @tf.function(jit_compile=config.JIT_COMPILE)
def call(self, y_true, y_pred, label_mask, weights):
try:
# Step 0 - Init
loss_labels = []
label_mask = tf.cast(label_mask, dtype=tf.float32)
# Step 0.1 - Label (y_true) Smoothing
if 1:
labelCount = tf.cast(tf.math.reduce_sum(tf.ones_like(weights)), tf.int32)
# Step 0.1 - Trial 1
def process_classID(classID):
return self.svls_layer(tf.expand_dims(y_true[:,:,:,:,classID], axis=-1))
yTrueSmoothed = tf.map_fn(process_classID, tf.range(labelCount), fn_output_signature=tf.float32) # [B,H,W,D,L] --> [L,B,H,W,D,1]
yTrueSmoothed = tf.transpose(yTrueSmoothed, perm=[1, 2, 3, 4, 0, 5])[:,:,:,:,:,0] # [L,B,H,W,D,1] --> [B,H,W,D,L]
# yTrueSmoothed = []; for classID in range(labelCount): yTrueSmoothed.append(self.svls_layer(tf.expand_dims(y_true[:,:,:,:,classID], axis=-1))); yTrueSmoothed = tf.concat(yTrueSmoothed, axis=-1) # in pdb
yTrueSmoothed = yTrueSmoothed / tf.math.reduce_sum(self.svls_kernel)
if 1:
print (' - [losses.py][CELossWithSVLS][batch=0] labels with GT: ',np.sum(y_true[0,:,:,:,:], axis=(0,1,2)))
print (' - [losses.py][CELossWithSVLS][batch=1] labels with GT: ',np.sum(y_true[1,:,:,:,:], axis=(0,1,2)))
batchId = np.random.choice([0,1])
labelId = np.random.choice(np.argwhere(np.sum(y_true[batchId,:,:,:,:], axis=(0,1,2))).flatten())
sliceId = np.random.choice(np.argwhere(np.sum(y_true[batchId,:,:,:,labelId], axis=(0,1))).flatten())
cmap = 'Oranges'
f,axarr = plt.subplots(3,2)
plt.suptitle(' - [losses.py][CELossWithSVLS] batchId: '+str(batchId)+' || labelId: '+str(labelId)+' || sliceId: '+str(sliceId) + ' || sigma: '+str(self.sigma))
axarr[0,0].imshow(y_true[batchId,:,:,sliceId-1,labelId], vmin=0, vmax=1, cmap=cmap); axarr[0,0].set_title('sliceId-1: ' + str(sliceId-1))
axarr[1,0].imshow(y_true[batchId,:,:,sliceId ,labelId], vmin=0, vmax=1, cmap=cmap); axarr[1,0].set_title('y_true | unique: '+str(np.unique(y_true[batchId,:,:,sliceId,labelId])))
axarr[2,0].imshow(y_true[batchId,:,:,sliceId+1,labelId], vmin=0, vmax=1, cmap=cmap); axarr[2,0].set_title('sliceId+1: ' + str(sliceId+1))
yTrueSmoothedUniqueVals = ['{:.3f}'.format(each) for each in np.unique(yTrueSmoothed[batchId,:,:,sliceId,labelId])][:4]
axarr[0,1].imshow(yTrueSmoothed[batchId,:,:,sliceId-1,labelId], vmin=0, vmax=1, cmap=cmap); axarr[0,1].set_title('sliceId-1: ' + str(sliceId-1))
axarr[1,1].imshow(yTrueSmoothed[batchId,:,:,sliceId ,labelId], vmin=0, vmax=1, cmap=cmap); axarr[1,1].set_title('y_true_smoothed | unique: '+str(yTrueSmoothedUniqueVals) + ' ... ')
axarr[2,1].imshow(yTrueSmoothed[batchId,:,:,sliceId+1,labelId], vmin=0, vmax=1, cmap=cmap); axarr[2,1].set_title('sliceId+1: ' + str(sliceId+1))
plt.show(block=False)
pdb.set_trace()
# Step 1.1 - Foreground loss
loss_labels_pos = -1.0 * yTrueSmoothed * tf.math.log(y_pred + config._EPSILON) # [B,H,W,D,L]
loss_labels_pos = label_mask * tf.math.reduce_sum(loss_labels_pos, axis=[1,2,3]) # [B,H,W,D,L] --> [B,L]
# Step 1.2 - Background loss
loss_labels_neg = -1.0 * (1 - yTrueSmoothed) * tf.math.log(1 - y_pred + config._EPSILON) # [B,H,W,D,L]
loss_labels_neg = label_mask * tf.math.reduce_sum(loss_labels_neg, axis=[1,2,3]) # [B,H,W,D,L] --> [B,L]
loss_labels = loss_labels_pos + loss_labels_neg # [B,L]
# Step 2 - Mask results on the basis of ground truth availability
label_mask = tf.where(tf.math.greater(label_mask,0), label_mask, config._EPSILON) # for reasons of division
loss_for_train = None
loss_labels_for_train = None
loss_labels_for_report = tf.math.reduce_sum(loss_labels,axis=0) / tf.math.reduce_sum(label_mask, axis=0) # [B,L] -> [L], [B,L] -> [L], [L]/[L] = [L] (average of labels across batches)
loss_for_report = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) # [B,L] -> [B], [B,L] -> [B], mean([B]) -> [1] (Average across batches of sum of labels)
# Step 3 - Weighted DICE
if len(weights):
label_weights = weights / tf.math.reduce_sum(weights) # normalized
loss_labels_w = loss_labels * label_weights # [B,L]
loss_labels_for_train = tf.math.reduce_sum(loss_labels_w,axis=0) / tf.math.reduce_sum(label_mask, axis=0) # [L]
loss_for_train = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels_w,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) # [1]
else:
loss_labels_for_train = loss_labels_for_report
loss_for_train = loss_for_report
# Step 4 - Return results
return loss_for_train, loss_labels_for_train, loss_for_report, loss_labels_for_report
except:
traceback.print_exc()
# pdb.set_trace()
return None, None, None, None
import seaborn as sns
import matplotlib.pyplot as plt
_, tmp1 = get_svls_filter_3d(sigma=1)
_, tmp2 = get_svls_filter_3d(sigma=2)
_, tmp3 = get_svls_filter_3d(sigma=3)
_, tmp4 = get_svls_filter_3d(sigma=3)
f,axarr = plt.subplots(2,6)
vmin, vmax, cmap = 0, 0.1, 'Oranges'
plt.suptitle('3D SVLS (Stochastically varying label smoothing) filter')
sns.heatmap(tmp1[0,:,:,0], ax=axarr[0,0], vmin=vmin, vmax=vmax, cmap=cmap);
sns.heatmap(tmp1[1,:,:,0], ax=axarr[0,1], vmin=vmin, vmax=vmax, cmap=cmap); axarr[0,1].set_title('sigma=1 (smooth gradient from center to edge)')
sns.heatmap(tmp1[2,:,:,0], ax=axarr[0,2], vmin=vmin, vmax=vmax, cmap=cmap);
sns.heatmap(tmp2[0,:,:,0], ax=axarr[0,3], vmin=vmin, vmax=vmax, cmap=cmap);
sns.heatmap(tmp2[1,:,:,0], ax=axarr[0,4], vmin=vmin, vmax=vmax, cmap=cmap); axarr[0,4].set_title('sigma=2')
sns.heatmap(tmp2[2,:,:,0], ax=axarr[0,5], vmin=vmin, vmax=vmax, cmap=cmap);
sns.heatmap(tmp3[0,:,:,0], ax=axarr[1,0], vmin=vmin, vmax=vmax, cmap=cmap);
sns.heatmap(tmp3[1,:,:,0], ax=axarr[1,1], vmin=vmin, vmax=vmax, cmap=cmap); axarr[1,1].set_title('sigma=3')
sns.heatmap(tmp3[2,:,:,0], ax=axarr[1,2], vmin=vmin, vmax=vmax, cmap=cmap);
sns.heatmap(tmp4[0,:,:,0], ax=axarr[1,3], vmin=vmin, vmax=vmax, cmap=cmap);
sns.heatmap(tmp4[1,:,:,0], ax=axarr[1,4], vmin=vmin, vmax=vmax, cmap=cmap); axarr[1,4].set_title('sigma=4 (sharp gradient from center to edge)')
sns.heatmap(tmp4[2,:,:,0], ax=axarr[1,5], vmin=vmin, vmax=vmax, cmap=cmap);
for ax in axarr.flatten(): _ = ax.set_xticks([]); _ = ax.set_yticks([])
plt.show(block=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment