Skip to content

Instantly share code, notes, and snippets.

@jpcbertoldo
Created February 8, 2024 17:56
Show Gist options
  • Save jpcbertoldo/b766da23f45a4117b428940764c50de3 to your computer and use it in GitHub Desktop.
Save jpcbertoldo/b766da23f45a4117b428940764c50de3 to your computer and use it in GitHub Desktop.
tmp-spro-scratch.py
# %%
# Scratch of spro on a batch of images
import torch
# ---- Dummy data ----
# a batch of predictions for two images
# shape: (2, 1, 256, 256)
predictions = torch.rand(2, 1, 256, 256)
# masks as list[list[Tensor]]
# external list is for the batch
# internal lists are for the multiple masks of the image
masks = [
# each internal list corresponds to one image (index in `preds`)
[
# each tensor corresponds to one file (ie. one spro curve / mask)
torch.rand(256, 256) > .99,
torch.rand(256, 256) > .90,
],
[
torch.rand(256, 256) > .80,
],
]
# ---- Concat masks & repeat predictions ----
# Concatenate all masks into a single tensor
# shape: (3, 256, 256)
masks_concatenated = torch.cat([torch.cat(image_masks) for image_masks in masks])
# make a predictions tensor where the prediction of an image with N masks is repeated N times
# (ie. the same prediction is repeated for each mask of the image)
# shape: (3, 1, 256, 256)
predictions_repeated = torch.cat([
pred.repeat(len(image_masks), 1, 1, 1)
for pred, image_masks in zip(predictions, masks)
])
# so the indexes in `predictions_repeated` match the indexes in `masks_concatenated`
# but does NOT the indexes in `predictions`
# mapping would be
# `predictions_repeated[0]` -> `predictions[0]`
# `predictions_repeated[1]` -> `predictions[0]`
# `predictions_repeated[2]` -> `predictions[1]`
# ---- SPRO ----
# ...
# a fake result
# 10_000 is the number of points (thresholds) in the spro curve
spro_curves = torch.rand(3, 10_000)
#...
# this should happen with all curves from all batches (NOT per batch), Eq. 1 in the paper
spro = spro_curves.mean(dim=0)
# something weird happens here: batch size of the input is 2, but there are 3 curves
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment