Skip to content

Instantly share code, notes, and snippets.

@GenevieveBuckley
Last active September 10, 2023 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GenevieveBuckley/874b6b282b388524b6d62f25b3b9bb1c to your computer and use it in GitHub Desktop.
Save GenevieveBuckley/874b6b282b388524b6d62f25b3b9bb1c to your computer and use it in GitHub Desktop.
micro-sam speed tests

micro-sam speed tests

8th September 2023. Results from Constantin's development/benchmarks.py script.

Summary: something about AMG is really killing the performance for MPS backends.

  • We'll need to do some line profiling on that part of the code to get more information.
  • I am suspicious that the fallback to CPU might somehow involve a lot of transferring large tensors back and forth between the cpu and mps, and perhaps that is a large part of why it is so much slower than the cpu only computation.
    Running benchmark_amg ...
    [W MPSFallback.mm:11] Warning: The operator 'torchvision::nms' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())
    

Benchmark results

model device benchmark runtime
vit_h cpu embeddings 16.9334
vit_h cpu prompt-p1n0 0.040143
vit_h cpu prompt-p2n4 0.0383821
vit_h cpu prompt-box 0.041321
vit_h cpu prompt-box-and-points 0.0410571
vit_h cpu amg 52.4115
-------- --------- ---------------------- -----------
vit_h mps embeddings 11.4617
vit_h mps prompt-p1n0 0.0342882
vit_h mps prompt-p2n4 0.0313947
vit_h mps prompt-box 0.0284584
vit_h mps prompt-box-and-points 0.032028
vit_h mps amg 220.851
-------- --------- ---------------------- -----------
vit_l cpu embeddings 10.1647
vit_l cpu prompt-p1n0 0.0386679
vit_l cpu prompt-p2n4 0.039386
vit_l cpu prompt-box 0.037823
vit_l cpu prompt-box-and-points 0.0404921
vit_l cpu amg 41.5105
-------- --------- ---------------------- -----------
vit_l mps embeddings 4.63046
vit_l mps prompt-p1n0 0.0345128
vit_l mps prompt-p2n4 0.0316033
vit_l mps prompt-box 0.0278041
vit_l mps prompt-box-and-points 0.0320668
vit_l mps amg 145.669
-------- --------- ---------------------- -----------
vit_b cpu embeddings 4.01707
vit_b cpu prompt-p1n0 0.0395901
vit_b cpu prompt-p2n4 0.0408981
vit_b cpu prompt-box 0.0390179
vit_b cpu prompt-box-and-points 0.0438592
vit_b cpu amg 40.5832
-------- --------- ---------------------- -----------
vit_b mps embeddings 1.97907
vit_b mps prompt-p1n0 0.0335419
vit_b mps prompt-p2n4 0.0308049
vit_b mps prompt-box 0.0274649
vit_b mps prompt-box-and-points 0.0313318
vit_b mps amg 148.047
-------- --------- ---------------------- -----------

7th September 2023

Summary: somehow the mps brranch runs three and a half time SLOWER than the dev branch (with cpu pytorch backend), when predicting masks for point grid prompts.

Baseline (dev branch)

2.38 seconds per iteration to predict masks for point grid prompts.

(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % git branch
* dev
  master
  mps
(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % time python examples/annotator_2d.py
Example data directory is: /Users/genevieb/Documents/GitHub/temp/test-micro-sam/micro-sam/data
Precomputing the state for instance segmentation.
Predict masks for point grid prompts: 100%|█████| 16/16 [00:38<00:00,  2.38s/it]
python examples/annotator_2d.py  156.56s user 55.95s system 308% cpu 1:08.78 total
(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % time python examples/annotator_2d.py
Example data directory is: /Users/genevieb/Documents/GitHub/temp/test-micro-sam/micro-sam/data
Load the AMG state from ./embeddings/embeddings-hela2d.zarr/amg_state.pickle
python examples/annotator_2d.py  7.28s user 5.49s system 86% cpu 14.705 total

mps branch

8.37 seconds per iteration to predict masks for point grid prompts.

(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % time python examples/ann
otator_2d.py
Example data directory is: /Users/genevieb/Documents/GitHub/temp/test-micro-sam/micro-sam/data
Using apple MPS device.
Precomputing the state for instance segmentation.
/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/modeling/mask_decoder.py:126: UserWarning: MPS: no support for int64 repeats mask, casting it to int32 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343686130/work/aten/src/ATen/native/mps/operations/Repeat.mm:236.)
  src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
Predict masks for point grid prompts: 100%|█████| 16/16 [02:13<00:00,  8.37s/it]
python examples/annotator_2d.py  63.98s user 44.61s system 64% cpu 2:47.11 total
(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % time python examples/annotator_2d.py
Example data directory is: /Users/genevieb/Documents/GitHub/temp/test-micro-sam/micro-sam/data
Using apple MPS device.
Load the AMG state from ./embeddings/embeddings-hela2d.zarr/amg_state.pickle
python examples/annotator_2d.py  7.26s user 5.62s system 77% cpu 16.597 total

Line profiling

amg.initialize -> _process_crop -> process_batch -> _to_mask_data & predict_torch

CPU

(test-micro-sam-mps) genevieb@192-168-1-102 development % kernprof -lv benchmark.py -m vit_h -d cpu -e -p
Running benchmarks for vit_h
with device: cpu
Running benchmark_amg ...
| model   | device   | benchmark   |   runtime |
|:--------|:---------|:------------|----------:|
| vit_h   | cpu      | amg         |   51.9006 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s

Total time: 4.04613 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _to_mask_data at line 255

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   255                                               @profile
   256                                               def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
   257        16          5.0      0.3      0.0          orig_h, orig_w = original_size
   258
   259                                                   # serialize predictions and store in MaskData
   260        16       1214.0     75.9      0.0          data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
   261        16          5.0      0.3      0.0          if points is not None:
   262        16        964.0     60.2      0.0              data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
   263
   264        16          0.0      0.0      0.0          del masks
   265
   266                                                   # calculate the stability scores
   267        32     897476.0  28046.1     22.2          data["stability_score"] = amg_utils.calculate_stability_score(
   268        16        145.0      9.1      0.0              data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
   269                                                   )
   270
   271                                                   # threshold masks and calculate boxes
   272        16     153194.0   9574.6      3.8          data["masks"] = data["masks"] > self._predictor.model.mask_threshold
   273        16        268.0     16.8      0.0          data["masks"] = data["masks"].type(torch.bool)
   274        16     370802.0  23175.1      9.2          data["boxes"] = batched_mask_to_box(data["masks"])
   275
   276                                                   # compress to RLE
   277        16         90.0      5.6      0.0          data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
   278        16    2611240.0 163202.5     64.5          data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
   279        16      10724.0    670.2      0.3          del data["masks"]
   280
   281        16          1.0      0.1      0.0          return data

Total time: 44.9853 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_batch at line 366

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   366                                               @profile
   367                                               def _process_batch(self, points, im_size, crop_box, original_size):
   368                                                   # run model on this batch
   369        16       1258.0     78.6      0.0          transformed_points = self._predictor.transform.apply_coords(points, im_size)
   370        16        552.0     34.5      0.0          in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
   371        16        428.0     26.8      0.0          in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
   372        32   40890910.0    1e+06     90.9          masks, iou_preds, _ = self._predictor.predict_torch(
   373        16        303.0     18.9      0.0              in_points[:, None, :],
   374        16         33.0      2.1      0.0              in_labels[:, None],
   375        16          1.0      0.1      0.0              multimask_output=True,
   376        16          1.0      0.1      0.0              return_logits=True,
   377                                                   )
   378        16    4046698.0 252918.6      9.0          data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
   379        16      45095.0   2818.4      0.1          del masks
   380        16          5.0      0.3      0.0          return data

Total time: 50.0068 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_crop at line 382

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   382                                               @profile
   383                                               def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
   384                                                   # crop the image and calculate embeddings
   385         1          1.0      1.0      0.0          x0, y0, x1, y1 = crop_box
   386         1          4.0      4.0      0.0          cropped_im = image[y0:y1, x0:x1, :]
   387         1          1.0      1.0      0.0          cropped_im_size = cropped_im.shape[:2]
   388
   389         1          0.0      0.0      0.0          if not precomputed_embeddings:
   390                                                       self._predictor.set_image(cropped_im)
   391
   392                                                   # get the points for this crop
   393         1          8.0      8.0      0.0          points_scale = np.array(cropped_im_size)[None, ::-1]
   394         1        187.0    187.0      0.0          points_for_image = self.point_grids[crop_layer_idx] * points_scale
   395
   396                                                   # generate masks for this crop in batches
   397         1         27.0     27.0      0.0          data = amg_utils.MaskData()
   398         2          1.0      0.5      0.0          n_batches = len(points_for_image) // self._points_per_batch +\
   399         1          1.0      1.0      0.0              int(len(points_for_image) % self._points_per_batch != 0)
   400        18       3620.0    201.1      0.0          for (points,) in tqdm(
   401         1          3.0      3.0      0.0              amg_utils.batch_iterator(self._points_per_batch, points_for_image),
   402         1          0.0      0.0      0.0              disable=not verbose, total=n_batches,
   403         1          1.0      1.0      0.0              desc="Predict masks for point grid prompts",
   404                                                   ):
   405        16   44996183.0    3e+06     90.0              batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
   406        16    4998693.0 312418.3     10.0              data.cat(batch_data)
   407        16       8108.0    506.8      0.0              del batch_data
   408
   409         1          0.0      0.0      0.0          if not precomputed_embeddings:
   410                                                       self._predictor.reset_image()
   411
   412         1          0.0      0.0      0.0          return data

Total time: 40.8883 s
File: /Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/predictor.py
Function: predict_torch at line 168

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   168                                               @torch.no_grad()
   169                                               @profile
   170                                               def predict_torch(
   171                                                   self,
   172                                                   point_coords: Optional[torch.Tensor],
   173                                                   point_labels: Optional[torch.Tensor],
   174                                                   boxes: Optional[torch.Tensor] = None,
   175                                                   mask_input: Optional[torch.Tensor] = None,
   176                                                   multimask_output: bool = True,
   177                                                   return_logits: bool = False,
   178                                               ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
   179                                                   """
   180                                                   Predict masks for the given input prompts, using the currently set image.
   181                                                   Input prompts are batched torch tensors and are expected to already be
   182                                                   transformed to the input frame using ResizeLongestSide.
   183
   184                                                   Arguments:
   185                                                     point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
   186                                                       model. Each point is in (X,Y) in pixels.
   187                                                     point_labels (torch.Tensor or None): A BxN array of labels for the
   188                                                       point prompts. 1 indicates a foreground point and 0 indicates a
   189                                                       background point.
   190                                                     boxes (np.ndarray or None): A Bx4 array given a box prompt to the
   191                                                       model, in XYXY format.
   192                                                     mask_input (np.ndarray): A low resolution mask input to the model, typically
   193                                                       coming from a previous prediction iteration. Has form Bx1xHxW, where
   194                                                       for SAM, H=W=256. Masks returned by a previous iteration of the
   195                                                       predict method do not need further transformation.
   196                                                     multimask_output (bool): If true, the model will return three masks.
   197                                                       For ambiguous input prompts (such as a single click), this will often
   198                                                       produce better masks than a single prediction. If only a single
   199                                                       mask is needed, the model's predicted quality score can be used
   200                                                       to select the best mask. For non-ambiguous prompts, such as multiple
   201                                                       input prompts, multimask_output=False can give better results.
   202                                                     return_logits (bool): If true, returns un-thresholded masks logits
   203                                                       instead of a binary mask.
   204
   205                                                   Returns:
   206                                                     (torch.Tensor): The output masks in BxCxHxW format, where C is the
   207                                                       number of masks, and (H, W) is the original image size.
   208                                                     (torch.Tensor): An array of shape BxC containing the model's
   209                                                       predictions for the quality of each mask.
   210                                                     (torch.Tensor): An array of shape BxCxHxW, where C is the number
   211                                                       of masks and H=W=256. These low res logits can be passed to
   212                                                       a subsequent iteration as mask input.
   213                                                   """
   214        16        369.0     23.1      0.0          if not self.is_image_set:
   215                                                       raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
   216
   217        16          3.0      0.2      0.0          if point_coords is not None:
   218        16          3.0      0.2      0.0              points = (point_coords, point_labels)
   219                                                   else:
   220                                                       points = None
   221
   222                                                   # Embed prompts
   223        32      18334.0    572.9      0.0          sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
   224        16          0.0      0.0      0.0              points=points,
   225        16          0.0      0.0      0.0              boxes=boxes,
   226        16          0.0      0.0      0.0              masks=mask_input,
   227                                                   )
   228
   229                                                   # Predict masks
   230        32   38480121.0    1e+06     94.1          low_res_masks, iou_predictions = self.model.mask_decoder(
   231        16          6.0      0.4      0.0              image_embeddings=self.features,
   232        16      41064.0   2566.5      0.1              image_pe=self.model.prompt_encoder.get_dense_pe(),
   233        16          0.0      0.0      0.0              sparse_prompt_embeddings=sparse_embeddings,
   234        16          2.0      0.1      0.0              dense_prompt_embeddings=dense_embeddings,
   235        16          3.0      0.2      0.0              multimask_output=multimask_output,
   236                                                   )
   237
   238                                                   # Upscale the masks to the original image resolution
   239        16    2348373.0 146773.3      5.7          masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
   240
   241        16          7.0      0.4      0.0          if not return_logits:
   242                                                       masks = masks > self.model.mask_threshold
   243
   244        16          8.0      0.5      0.0          return masks, iou_predictions, low_res_masks

MPS

(test-micro-sam-mps) genevieb@192-168-1-102 development % kernprof -lv benchmark.py -m vit_h -d mps -e -p
Running benchmarks for vit_h
with device: mps
Running benchmark_amg ...
/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/modeling/mask_decoder.py:126: UserWarning: MPS: no support for int64 repeats mask, casting it to int32 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343686130/work/aten/src/ATen/native/mps/operations/Repeat.mm:236.)
  src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
[W MPSFallback.mm:11] Warning: The operator 'torchvision::nms' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())
| model   | device   | benchmark   |   runtime |
|:--------|:---------|:------------|----------:|
| vit_h   | mps      | amg         |   346.799 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s

Total time: 123.62 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _to_mask_data at line 255

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   255                                               @profile
   256                                               def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
   257        16        169.0     10.6      0.0          orig_h, orig_w = original_size
   258
   259                                                   # serialize predictions and store in MaskData
   260        16      88827.0   5551.7      0.1          data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
   261        16         20.0      1.2      0.0          if points is not None:
   262        16      71025.0   4439.1      0.1              data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
   263
   264        16          6.0      0.4      0.0          del masks
   265
   266                                                   # calculate the stability scores
   267        32     858545.0  26829.5      0.7          data["stability_score"] = amg_utils.calculate_stability_score(
   268        16      33751.0   2109.4      0.0              data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
   269                                                   )
   270
   271                                                   # threshold masks and calculate boxes
   272        16      15019.0    938.7      0.0          data["masks"] = data["masks"] > self._predictor.model.mask_threshold
   273        16        269.0     16.8      0.0          data["masks"] = data["masks"].type(torch.bool)
   274        16    2740052.0 171253.2      2.2          data["boxes"] = batched_mask_to_box(data["masks"])
   275
   276                                                   # compress to RLE
   277        16       5700.0    356.2      0.0          data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
   278        16  119806776.0    7e+06     96.9          data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
   279        16        246.0     15.4      0.0          del data["masks"]
   280
   281        16          2.0      0.1      0.0          return data

Total time: 336.425 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_batch at line 366

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   366                                               @profile
   367                                               def _process_batch(self, points, im_size, crop_box, original_size):
   368                                                   # run model on this batch
   369        16      25639.0   1602.4      0.0          transformed_points = self._predictor.transform.apply_coords(points, im_size)
   370        16      71020.0   4438.8      0.0          in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
   371        16      42514.0   2657.1      0.0          in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
   372        32  212613510.0    7e+06     63.2          masks, iou_preds, _ = self._predictor.predict_torch(
   373        16        970.0     60.6      0.0              in_points[:, None, :],
   374        16         51.0      3.2      0.0              in_labels[:, None],
   375        16          2.0      0.1      0.0              multimask_output=True,
   376        16          0.0      0.0      0.0              return_logits=True,
   377                                                   )
   378        16  123671154.0    8e+06     36.8          data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
   379        16         37.0      2.3      0.0          del masks
   380        16          2.0      0.1      0.0          return data

Total time: 343.198 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_crop at line 382

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   382                                               @profile
   383                                               def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
   384                                                   # crop the image and calculate embeddings
   385         1          1.0      1.0      0.0          x0, y0, x1, y1 = crop_box
   386         1          2.0      2.0      0.0          cropped_im = image[y0:y1, x0:x1, :]
   387         1          1.0      1.0      0.0          cropped_im_size = cropped_im.shape[:2]
   388
   389         1          0.0      0.0      0.0          if not precomputed_embeddings:
   390                                                       self._predictor.set_image(cropped_im)
   391
   392                                                   # get the points for this crop
   393         1         10.0     10.0      0.0          points_scale = np.array(cropped_im_size)[None, ::-1]
   394         1        185.0    185.0      0.0          points_for_image = self.point_grids[crop_layer_idx] * points_scale
   395
   396                                                   # generate masks for this crop in batches
   397         1         38.0     38.0      0.0          data = amg_utils.MaskData()
   398         2          2.0      1.0      0.0          n_batches = len(points_for_image) // self._points_per_batch +\
   399         1          1.0      1.0      0.0              int(len(points_for_image) % self._points_per_batch != 0)
   400        18       9459.0    525.5      0.0          for (points,) in tqdm(
   401         1          3.0      3.0      0.0              amg_utils.batch_iterator(self._points_per_batch, points_for_image),
   402         1          0.0      0.0      0.0              disable=not verbose, total=n_batches,
   403         1          1.0      1.0      0.0              desc="Predict masks for point grid prompts",
   404                                                   ):
   405        16  336429279.0    2e+07     98.0              batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
   406        16    6530072.0 408129.5      1.9              data.cat(batch_data)
   407        16     228971.0  14310.7      0.1              del batch_data
   408
   409         1          0.0      0.0      0.0          if not precomputed_embeddings:
   410                                                       self._predictor.reset_image()
   411
   412         1          0.0      0.0      0.0          return data

Total time: 212.395 s
File: /Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/predictor.py
Function: predict_torch at line 168

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   168                                               @torch.no_grad()
   169                                               @profile
   170                                               def predict_torch(
   171                                                   self,
   172                                                   point_coords: Optional[torch.Tensor],
   173                                                   point_labels: Optional[torch.Tensor],
   174                                                   boxes: Optional[torch.Tensor] = None,
   175                                                   mask_input: Optional[torch.Tensor] = None,
   176                                                   multimask_output: bool = True,
   177                                                   return_logits: bool = False,
   178                                               ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
   179                                                   """
   180                                                   Predict masks for the given input prompts, using the currently set image.
   181                                                   Input prompts are batched torch tensors and are expected to already be
   182                                                   transformed to the input frame using ResizeLongestSide.
   183
   184                                                   Arguments:
   185                                                     point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
   186                                                       model. Each point is in (X,Y) in pixels.
   187                                                     point_labels (torch.Tensor or None): A BxN array of labels for the
   188                                                       point prompts. 1 indicates a foreground point and 0 indicates a
   189                                                       background point.
   190                                                     boxes (np.ndarray or None): A Bx4 array given a box prompt to the
   191                                                       model, in XYXY format.
   192                                                     mask_input (np.ndarray): A low resolution mask input to the model, typically
   193                                                       coming from a previous prediction iteration. Has form Bx1xHxW, where
   194                                                       for SAM, H=W=256. Masks returned by a previous iteration of the
   195                                                       predict method do not need further transformation.
   196                                                     multimask_output (bool): If true, the model will return three masks.
   197                                                       For ambiguous input prompts (such as a single click), this will often
   198                                                       produce better masks than a single prediction. If only a single
   199                                                       mask is needed, the model's predicted quality score can be used
   200                                                       to select the best mask. For non-ambiguous prompts, such as multiple
   201                                                       input prompts, multimask_output=False can give better results.
   202                                                     return_logits (bool): If true, returns un-thresholded masks logits
   203                                                       instead of a binary mask.
   204
   205                                                   Returns:
   206                                                     (torch.Tensor): The output masks in BxCxHxW format, where C is the
   207                                                       number of masks, and (H, W) is the original image size.
   208                                                     (torch.Tensor): An array of shape BxC containing the model's
   209                                                       predictions for the quality of each mask.
   210                                                     (torch.Tensor): An array of shape BxCxHxW, where C is the number
   211                                                       of masks and H=W=256. These low res logits can be passed to
   212                                                       a subsequent iteration as mask input.
   213                                                   """
   214        16         12.0      0.8      0.0          if not self.is_image_set:
   215                                                       raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
   216
   217        16          7.0      0.4      0.0          if point_coords is not None:
   218        16          3.0      0.2      0.0              points = (point_coords, point_labels)
   219                                                   else:
   220                                                       points = None
   221
   222                                                   # Embed prompts
   223        32     471902.0  14746.9      0.2          sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
   224        16          3.0      0.2      0.0              points=points,
   225        16          2.0      0.1      0.0              boxes=boxes,
   226        16          1.0      0.1      0.0              masks=mask_input,
   227                                                   )
   228
   229                                                   # Predict masks
   230        32  142891626.0    4e+06     67.3          low_res_masks, iou_predictions = self.model.mask_decoder(
   231        16          1.0      0.1      0.0              image_embeddings=self.features,
   232        16     209664.0  13104.0      0.1              image_pe=self.model.prompt_encoder.get_dense_pe(),
   233        16          3.0      0.2      0.0              sparse_prompt_embeddings=sparse_embeddings,
   234        16          2.0      0.1      0.0              dense_prompt_embeddings=dense_embeddings,
   235        16          1.0      0.1      0.0              multimask_output=multimask_output,
   236                                                   )
   237
   238                                                   # Upscale the masks to the original image resolution
   239        16   68821903.0    4e+06     32.4          masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
   240
   241        16         92.0      5.8      0.0          if not return_logits:
   242                                                       masks = masks > self.model.mask_threshold
   243
   244        16         54.0      3.4      0.0          return masks, iou_predictions, low_res_masks

amg.initialize -> _process_crop -> _process_batch

CPU

(test-micro-sam-mps) genevieb@192-168-1-102 development % kernprof -lv benchmark.py -m vit_h -d cpu -e -p
Running benchmarks for vit_h
with device: cpu
Running benchmark_amg ...
| model   | device   | benchmark   |   runtime |
|:--------|:---------|:------------|----------:|
| vit_h   | cpu      | amg         |   45.2618 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s

Total time: 38.4909 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_batch at line 365

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   365                                               @profile
   366                                               def _process_batch(self, points, im_size, crop_box, original_size):
   367                                                   # run model on this batch
   368        16        911.0     56.9      0.0          transformed_points = self._predictor.transform.apply_coords(points, im_size)
   369        16        875.0     54.7      0.0          in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
   370        16        393.0     24.6      0.0          in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
   371        32   34451581.0    1e+06     89.5          masks, iou_preds, _ = self._predictor.predict_torch(
   372        16        297.0     18.6      0.0              in_points[:, None, :],
   373        16         31.0      1.9      0.0              in_labels[:, None],
   374        16          2.0      0.1      0.0              multimask_output=True,
   375        16          0.0      0.0      0.0              return_logits=True,
   376                                                   )
   377        16    4001238.0 250077.4     10.4          data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
   378        16      35558.0   2222.4      0.1          del masks
   379        16          4.0      0.2      0.0          return data

Total time: 43.4143 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_crop at line 381

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   381                                               @profile
   382                                               def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
   383                                                   # crop the image and calculate embeddings
   384         1          0.0      0.0      0.0          x0, y0, x1, y1 = crop_box
   385         1          1.0      1.0      0.0          cropped_im = image[y0:y1, x0:x1, :]
   386         1          0.0      0.0      0.0          cropped_im_size = cropped_im.shape[:2]
   387
   388         1          0.0      0.0      0.0          if not precomputed_embeddings:
   389                                                       self._predictor.set_image(cropped_im)
   390
   391                                                   # get the points for this crop
   392         1          7.0      7.0      0.0          points_scale = np.array(cropped_im_size)[None, ::-1]
   393         1         16.0     16.0      0.0          points_for_image = self.point_grids[crop_layer_idx] * points_scale
   394
   395                                                   # generate masks for this crop in batches
   396         1         37.0     37.0      0.0          data = amg_utils.MaskData()
   397         2          1.0      0.5      0.0          n_batches = len(points_for_image) // self._points_per_batch +\
   398         1          0.0      0.0      0.0              int(len(points_for_image) % self._points_per_batch != 0)
   399        18       1977.0    109.8      0.0          for (points,) in tqdm(
   400         1          4.0      4.0      0.0              amg_utils.batch_iterator(self._points_per_batch, points_for_image),
   401         1          0.0      0.0      0.0              disable=not verbose, total=n_batches,
   402         1          1.0      1.0      0.0              desc="Predict masks for point grid prompts",
   403                                                   ):
   404        16   38501125.0    2e+06     88.7              batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
   405        16    4903074.0 306442.1     11.3              data.cat(batch_data)
   406        16       8065.0    504.1      0.0              del batch_data
   407
   408         1          0.0      0.0      0.0          if not precomputed_embeddings:
   409                                                       self._predictor.reset_image()
   410
   411         1          0.0      0.0      0.0          return data

mps

(test-micro-sam-mps) genevieb@192-168-1-102 development % kernprof -lv benchmark.py -m vit_h -d mps -e -p
Running benchmarks for vit_h
with device: mps
Running benchmark_amg ...
/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/modeling/mask_decoder.py:126: UserWarning: MPS: no support for int64 repeats mask, casting it to int32 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343686130/work/aten/src/ATen/native/mps/operations/Repeat.mm:236.)
  src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
[W MPSFallback.mm:11] Warning: The operator 'torchvision::nms' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())
| model   | device   | benchmark   |   runtime |
|:--------|:---------|:------------|----------:|
| vit_h   | mps      | amg         |   249.157 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s

Total time: 240.035 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_batch at line 365

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   365                                               @profile
   366                                               def _process_batch(self, points, im_size, crop_box, original_size):
   367                                                   # run model on this batch
   368        16      20809.0   1300.6      0.0          transformed_points = self._predictor.transform.apply_coords(points, im_size)
   369        16      62775.0   3923.4      0.0          in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
   370        16      26964.0   1685.2      0.0          in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
   371        32  126742884.0    4e+06     52.8          masks, iou_preds, _ = self._predictor.predict_torch(
   372        16       1792.0    112.0      0.0              in_points[:, None, :],
   373        16         47.0      2.9      0.0              in_labels[:, None],
   374        16          4.0      0.2      0.0              multimask_output=True,
   375        16          1.0      0.1      0.0              return_logits=True,
   376                                                   )
   377        16  113179665.0    7e+06     47.2          data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
   378        16         42.0      2.6      0.0          del masks
   379        16          2.0      0.1      0.0          return data

Total time: 245.755 s
File: /Users/genevieb/Documents/GitHub/micro-sam/micro_sam/instance_segmentation.py
Function: _process_crop at line 381

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   381                                               @profile
   382                                               def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
   383                                                   # crop the image and calculate embeddings
   384         1          1.0      1.0      0.0          x0, y0, x1, y1 = crop_box
   385         1          1.0      1.0      0.0          cropped_im = image[y0:y1, x0:x1, :]
   386         1          0.0      0.0      0.0          cropped_im_size = cropped_im.shape[:2]
   387
   388         1          1.0      1.0      0.0          if not precomputed_embeddings:
   389                                                       self._predictor.set_image(cropped_im)
   390
   391                                                   # get the points for this crop
   392         1          7.0      7.0      0.0          points_scale = np.array(cropped_im_size)[None, ::-1]
   393         1        176.0    176.0      0.0          points_for_image = self.point_grids[crop_layer_idx] * points_scale
   394
   395                                                   # generate masks for this crop in batches
   396         1        710.0    710.0      0.0          data = amg_utils.MaskData()
   397         2          0.0      0.0      0.0          n_batches = len(points_for_image) // self._points_per_batch +\
   398         1          0.0      0.0      0.0              int(len(points_for_image) % self._points_per_batch != 0)
   399        18      21123.0   1173.5      0.0          for (points,) in tqdm(
   400         1        134.0    134.0      0.0              amg_utils.batch_iterator(self._points_per_batch, points_for_image),
   401         1          0.0      0.0      0.0              disable=not verbose, total=n_batches,
   402         1          1.0      1.0      0.0              desc="Predict masks for point grid prompts",
   403                                                   ):
   404        16  240049242.0    2e+07     97.7              batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
   405        16    5672560.0 354535.0      2.3              data.cat(batch_data)
   406        16      11417.0    713.6      0.0              del batch_data
   407
   408         1          0.0      0.0      0.0          if not precomputed_embeddings:
   409                                                       self._predictor.reset_image()
   410
   411         1          0.0      0.0      0.0          return data
model device benchmark runtime
vit_h cpu embeddings 16.9334
vit_h cpu prompt-p1n0 0.040143
vit_h cpu prompt-p2n4 0.0383821
vit_h cpu prompt-box 0.041321
vit_h cpu prompt-box-and-points 0.0410571
vit_h cpu amg 52.4115
vit_h mps embeddings 11.4617
vit_h mps prompt-p1n0 0.0342882
vit_h mps prompt-p2n4 0.0313947
vit_h mps prompt-box 0.0284584
vit_h mps prompt-box-and-points 0.032028
vit_h mps amg 220.851
vit_l cpu embeddings 10.1647
vit_l cpu prompt-p1n0 0.0386679
vit_l cpu prompt-p2n4 0.039386
vit_l cpu prompt-box 0.037823
vit_l cpu prompt-box-and-points 0.0404921
vit_l cpu amg 41.5105
vit_l mps embeddings 4.63046
vit_l mps prompt-p1n0 0.0345128
vit_l mps prompt-p2n4 0.0316033
vit_l mps prompt-box 0.0278041
vit_l mps prompt-box-and-points 0.0320668
vit_l mps amg 145.669
vit_b cpu embeddings 4.01707
vit_b cpu prompt-p1n0 0.0395901
vit_b cpu prompt-p2n4 0.0408981
vit_b cpu prompt-box 0.0390179
vit_b cpu prompt-box-and-points 0.0438592
vit_b cpu amg 40.5832
vit_b mps embeddings 1.97907
vit_b mps prompt-p1n0 0.0335419
vit_b mps prompt-p2n4 0.0308049
vit_b mps prompt-box 0.0274649
vit_b mps prompt-box-and-points 0.0313318
vit_b mps amg 148.047
@GenevieveBuckley
Copy link
Author

Benchmarks from Constantin's development/benchmarks.py file:

(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % python development/benchmark.py -m vit_h -d cpu
Running benchmarks for vit_h
with device: cpu
Running benchmark_embeddings ...
Running benchmark_prompts ...
Running benchmark_amg ...
Traceback (most recent call last):
  File "/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/pandas/compat/_optional.py", line 142, in import_optional_dependency
    module = importlib.import_module(name)
  File "/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1004, in _find_and_load_unlocked
ModuleNotFoundError: No module named 'tabulate'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/genevieb/Documents/GitHub/micro-sam/development/benchmark.py", line 176, in <module>
    main()
  File "/Users/genevieb/Documents/GitHub/micro-sam/development/benchmark.py", line 172, in main
    print(benchmark_results.to_markdown(index=False))
  File "/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/pandas/core/frame.py", line 2756, in to_markdown
    tabulate = import_optional_dependency("tabulate")
  File "/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/pandas/compat/_optional.py", line 145, in import_optional_dependency
    raise ImportError(msg)
ImportError: Missing optional dependency 'tabulate'.  Use pip or conda to install tabulate.
(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % pip install tabulate
Collecting tabulate
  Using cached tabulate-0.9.0-py3-none-any.whl (35 kB)
Installing collected packages: tabulate
Successfully installed tabulate-0.9.0
(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % python development/benchmark.py -m vit_h -d cpu
Running benchmarks for vit_h
with device: cpu
Running benchmark_embeddings ...
Running benchmark_prompts ...
Running benchmark_amg ...
| model   | device   | benchmark             |    runtime |
|:--------|:---------|:----------------------|-----------:|
| vit_h   | cpu      | embeddings            | 16.9334    |
| vit_h   | cpu      | prompt-p1n0           |  0.040143  |
| vit_h   | cpu      | prompt-p2n4           |  0.0383821 |
| vit_h   | cpu      | prompt-box            |  0.041321  |
| vit_h   | cpu      | prompt-box-and-points |  0.0410571 |
| vit_h   | cpu      | amg                   | 52.4115    |
(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % python development/benchmark.py -m vit_h -d mps
Running benchmarks for vit_h
with device: mps
Running benchmark_embeddings ...
Running benchmark_prompts ...
/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/modeling/mask_decoder.py:126: UserWarning: MPS: no support for int64 repeats mask, casting it to int32 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343686130/work/aten/src/ATen/native/mps/operations/Repeat.mm:236.)
  src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
Running benchmark_amg ...
[W MPSFallback.mm:11] Warning: The operator 'torchvision::nms' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())
| model   | device   | benchmark             |     runtime |
|:--------|:---------|:----------------------|------------:|
| vit_h   | mps      | embeddings            |  11.4617    |
| vit_h   | mps      | prompt-p1n0           |   0.0342882 |
| vit_h   | mps      | prompt-p2n4           |   0.0313947 |
| vit_h   | mps      | prompt-box            |   0.0284584 |
| vit_h   | mps      | prompt-box-and-points |   0.032028  |
| vit_h   | mps      | amg                   | 220.851     |
(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % python development/benchmark.py -m vit_l -d cpu
Running benchmarks for vit_l
with device: cpu
Download https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth to
Download successful and checksums agree.
Running benchmark_embeddings ...
Running benchmark_prompts ...
Running benchmark_amg ...
| model   | device   | benchmark             |    runtime |
|:--------|:---------|:----------------------|-----------:|
| vit_l   | cpu      | embeddings            | 10.1647    |
| vit_l   | cpu      | prompt-p1n0           |  0.0386679 |
| vit_l   | cpu      | prompt-p2n4           |  0.039386  |
| vit_l   | cpu      | prompt-box            |  0.037823  |
| vit_l   | cpu      | prompt-box-and-points |  0.0404921 |
| vit_l   | cpu      | amg                   | 41.5105    |
(test-micro-sam-mps) genevieb@192-168-1-102 micro-sam % python development/benchmark.py -m vit_l -d mps
Running benchmarks for vit_l
with device: mps
Running benchmark_embeddings ...
Running benchmark_prompts ...
/Users/genevieb/mambaforge/envs/test-micro-sam-mps/lib/python3.10/site-packages/segment_anything/modeling/mask_decoder.py:126: UserWarning: MPS: no support for int64 repeats mask, casting it to int32 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343686130/work/aten/src/ATen/native/mps/operations/Repeat.mm:236.)
  src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
Running benchmark_amg ...
[W MPSFallback.mm:11] Warning: The operator 'torchvision::nms' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (function operator())
| model   | device   | benchmark             |     runtime |
|:--------|:---------|:----------------------|------------:|
| vit_l   | mps      | embeddings            |   4.63046   |
| vit_l   | mps      | prompt-p1n0           |   0.0345128 |
| vit_l   | mps      | prompt-p2n4           |   0.0316033 |
| vit_l   | mps      | prompt-box            |   0.0278041 |
| vit_l   | mps      | prompt-box-and-points |   0.0320668 |
| vit_l   | mps      | amg                   | 145.669     |

@GenevieveBuckley
Copy link
Author

GenevieveBuckley commented Sep 8, 2023

More line profiling.

(test-micro-sam-mps-nightly) genevieb@dyn-130-194-109-212 development % kernprof -lv benchmark.py --model_type vit_h --device cpu -e -p
Running benchmarks for vit_h
with device: cpu
Running benchmark_amg ...
| model   | device   | benchmark   |   runtime |
|:--------|:---------|:------------|----------:|
| vit_h   | cpu      | amg         |    50.254 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s

Total time: 2.84928 s
File: /Users/genevieb/mambaforge/envs/test-micro-sam-mps-nightly/lib/python3.10/site-packages/segment_anything/utils/amg.py
Function: mask_to_rle_pytorch at line 107

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   107                                           @profile
   108                                           def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
   109                                               """
   110                                               Encodes masks to an uncompressed RLE, in the format expected by
   111                                               pycoco tools.
   112                                               """
   113                                               # Put in fortran order and flatten h,w
   114        16         54.0      3.4      0.0      b, h, w = tensor.shape
   115        16     164070.0  10254.4      5.8      tensor = tensor.permute(0, 2, 1).flatten(1)
   116
   117                                               # Compute change indices
   118        16     103694.0   6480.9      3.6      diff = tensor[:, 1:] ^ tensor[:, :-1]
   119        16     596115.0  37257.2     20.9      change_indices = diff.nonzero()
   120
   121                                               # Encode run length
   122        16         12.0      0.8      0.0      out = []
   123      3088        648.0      0.2      0.0      for i in range(b):
   124      3072    1778556.0    579.0     62.4          cur_idxs = change_indices[change_indices[:, 0] == i, 1]
   125      6144      35509.0      5.8      1.2          cur_idxs = torch.cat(
   126      3072        309.0      0.1      0.0              [
   127      3072      22064.0      7.2      0.8                  torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
   128      3072      15048.0      4.9      0.5                  cur_idxs + 1,
   129      3072       7463.0      2.4      0.3                  torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
   130                                                       ]
   131                                                   )
   132      3072      23499.0      7.6      0.8          btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
   133      3072      27327.0      8.9      1.0          counts = [] if tensor[i, 0] == 0 else [0]
   134      3072      73294.0     23.9      2.6          counts.extend(btw_idxs.detach().cpu().tolist())
   135      3072       1613.0      0.5      0.1          out.append({"size": [h, w], "counts": counts})
   136        16          9.0      0.6      0.0      return out

(test-micro-sam-mps-nightly) genevieb@dyn-130-194-109-212 development % kernprof -lv benchmark.py --model_type vit_h --device mps -e -p
Running benchmarks for vit_h
with device: mps
Running benchmark_amg ...
| model   | device   | benchmark   |   runtime |
|:--------|:---------|:------------|----------:|
| vit_h   | mps      | amg         |    474.32 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s

Total time: 166.182 s
File: /Users/genevieb/mambaforge/envs/test-micro-sam-mps-nightly/lib/python3.10/site-packages/segment_anything/utils/amg.py
Function: mask_to_rle_pytorch at line 107

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   107                                           @profile
   108                                           def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
   109                                               """
   110                                               Encodes masks to an uncompressed RLE, in the format expected by
   111                                               pycoco tools.
   112                                               """
   113                                               # Put in fortran order and flatten h,w
   114        16        110.0      6.9      0.0      b, h, w = tensor.shape
   115        16      15219.0    951.2      0.0      tensor = tensor.permute(0, 2, 1).flatten(1)
   116
   117                                               # Compute change indices
   118        16      38611.0   2413.2      0.0      diff = tensor[:, 1:] ^ tensor[:, :-1]
   119        16   25247303.0    2e+06     15.2      change_indices = diff.nonzero()
   120
   121                                               # Encode run length
   122        16        196.0     12.2      0.0      out = []
   123      3088       3986.0      1.3      0.0      for i in range(b):
   124      3072   62310013.0  20283.2     37.5          cur_idxs = change_indices[change_indices[:, 0] == i, 1]
   125      6144   24125294.0   3926.6     14.5          cur_idxs = torch.cat(
   126      3072        887.0      0.3      0.0              [
   127      3072    3651022.0   1188.5      2.2                  torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
   128      3072   24708532.0   8043.1     14.9                  cur_idxs + 1,
   129      3072    1452161.0    472.7      0.9                  torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
   130                                                       ]
   131                                                   )
   132      3072   21384123.0   6961.0     12.9          btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
   133      3072    2082247.0    677.8      1.3          counts = [] if tensor[i, 0] == 0 else [0]
   134      3072    1148525.0    373.9      0.7          counts.extend(btw_idxs.detach().cpu().tolist())
   135      3072      13608.0      4.4      0.0          out.append({"size": [h, w], "counts": counts})
   136        16          8.0      0.5      0.0      return out

(test-micro-sam-mps-nightly) genevieb@dyn-130-194-109-212 development % kernprof -lv benchmark.py --model_type vit_h --device cpu -e -p
Running benchmarks for vit_h
with device: cpu
Example data directory is: /Users/genevieb/Documents/GitHub/micro-sam/examples/data
Downloading data from 'https://owncloud.gwdg.de/index.php/s/2sr1DHQ34tV7WEb/download' to file '/Users/genevieb/Documents/GitHub/micro-sam/examples/data/hela-2d-image.png'.
100%|███████████████████████████████████████| 212k/212k [00:00<00:00, 1.02GB/s]
Running benchmark_amg ...
| model   | device   | benchmark   |   runtime |
|:--------|:---------|:------------|----------:|
| vit_h   | cpu      | amg         |   40.2803 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s

Total time: 2.06135 s
File: /Users/genevieb/mambaforge/envs/test-micro-sam-mps-nightly/lib/python3.10/site-packages/segment_anything/utils/amg.py
Function: mask_to_rle_pytorch at line 107

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   107                                           @profile
   108                                           def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
   109                                               """
   110                                               Encodes masks to an uncompressed RLE, in the format expected by
   111                                               pycoco tools.
   112                                               """
   113                                               # Put in fortran order and flatten h,w
   114        16         23.0      1.4      0.0      b, h, w = tensor.shape
   115        16     161868.0  10116.8      7.9      tensor = tensor.permute(0, 2, 1).flatten(1)
   116
   117                                               # Compute change indices
   118        16      54507.0   3406.7      2.6      diff = tensor[:, 1:] ^ tensor[:, :-1]
   119        16     413235.0  25827.2     20.0      change_indices = diff.nonzero()
   120
   121                                               # Encode run length
   122        16         14.0      0.9      0.0      out = []
   123      3088        496.0      0.2      0.0      for i in range(b):
   124      3072    1274800.0    415.0     61.8          cur_idxs = change_indices[change_indices[:, 0] == i, 1]
   125      6144      20794.0      3.4      1.0          cur_idxs = torch.cat(
   126      3072        302.0      0.1      0.0              [
   127      3072      15663.0      5.1      0.8                  torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
   128      3072      11924.0      3.9      0.6                  cur_idxs + 1,
   129      3072       8106.0      2.6      0.4                  torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
   130                                                       ]
   131                                                   )
   132      3072      17430.0      5.7      0.8          btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
   133      3072      24777.0      8.1      1.2          counts = [] if tensor[i, 0] == 0 else [0]
   134      3072      56226.0     18.3      2.7          counts.extend(btw_idxs.detach().cpu().tolist())
   135      3072       1176.0      0.4      0.1          out.append({"size": [h, w], "counts": counts})
   136        16          7.0      0.4      0.0      return out

(test-micro-sam-mps-nightly) genevieb@dyn-130-194-109-212 development % kernprof -lv benchmark.py --model_type vit_h --device mps -e -p
Running benchmarks for vit_h
with device: mps
Example data directory is: /Users/genevieb/Documents/GitHub/micro-sam/examples/data
Running benchmark_amg ...
| model   | device   | benchmark   |   runtime |
|:--------|:---------|:------------|----------:|
| vit_h   | mps      | amg         |   255.855 |
Wrote profile results to benchmark.py.lprof
Timer unit: 1e-06 s

Total time: 144.144 s
File: /Users/genevieb/mambaforge/envs/test-micro-sam-mps-nightly/lib/python3.10/site-packages/segment_anything/utils/amg.py
Function: mask_to_rle_pytorch at line 107

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   107                                           @profile
   108                                           def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
   109                                               """
   110                                               Encodes masks to an uncompressed RLE, in the format expected by
   111                                               pycoco tools.
   112                                               """
   113                                               # Put in fortran order and flatten h,w
   114        16        128.0      8.0      0.0      b, h, w = tensor.shape
   115        16       9596.0    599.8      0.0      tensor = tensor.permute(0, 2, 1).flatten(1)
   116
   117                                               # Compute change indices
   118        16      11275.0    704.7      0.0      diff = tensor[:, 1:] ^ tensor[:, :-1]
   119        16    2875329.0 179708.1      2.0      change_indices = diff.nonzero()
   120
   121                                               # Encode run length
   122        16         28.0      1.8      0.0      out = []
   123      3088       2200.0      0.7      0.0      for i in range(b):
   124      3072   57902234.0  18848.4     40.2          cur_idxs = change_indices[change_indices[:, 0] == i, 1]
   125      6144   25840480.0   4205.8     17.9          cur_idxs = torch.cat(
   126      3072       1227.0      0.4      0.0              [
   127      3072    3523664.0   1147.0      2.4                  torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
   128      3072   26443298.0   8607.8     18.3                  cur_idxs + 1,
   129      3072    1483706.0    483.0      1.0                  torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
   130                                                       ]
   131                                                   )
   132      3072   23292589.0   7582.2     16.2          btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
   133      3072    1905329.0    620.2      1.3          counts = [] if tensor[i, 0] == 0 else [0]
   134      3072     841710.0    274.0      0.6          counts.extend(btw_idxs.detach().cpu().tolist())
   135      3072      11502.0      3.7      0.0          out.append({"size": [h, w], "counts": counts})
   136        16          5.0      0.3      0.0      return out

@GenevieveBuckley
Copy link
Author

GenevieveBuckley commented Sep 10, 2023

Run benchmarks

  1. Install pandas tabulate dependency python -m pip install tabulate
  2. Run benchmark script, eg: python benchmark.py --model_type vit_h --device cpu

Line profiling

  1. Install line profiler: python -m pip install line_profiler
  2. Add @profile decorator to any function in the call stack
  3. Run kernprof -lv benchmark.py --model_type vit_h --device cpu

Snakeviz visualization

https://jiffyclub.github.io/snakeviz/

  1. Install snakeviz: python -m pip install snakeviz
  2. Generate profile file: python -m cProfile -o program.prof benchmark.py --model_type vit_h --device cpu
  3. Visualize profile file: snakeviz program.prof

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment